42 #include "debug/Branch.hh"
47 namespace branch_prediction
52 {1,3,4,5,7,8,9,11,12,14,15,17,19,21,23,25,27,29,32,34,37,41,45,49,53,58,63,
56 {0,4,5,7,9,11,12,14,16,17,19,22,28,33,39,45,};
59 int n_local_histories,
int local_history_length,
int assoc,
61 int ghist_length,
int block_size,
68 : filterTable(num_filters), acyclic_histories(acyclic_bits.size()),
69 acyclic2_histories(acyclic_bits.size()),
70 blurrypath_histories(blurrypath_bits.size()),
71 ghist_words(ghist_length/block_size+1, 0),
72 path_history(path_length, 0), imli_counter(4,0),
73 localHistories(n_local_histories, local_history_length),
74 recency_stack(assoc), last_ghist_bit(false), occupancy(0)
85 int max_modhist_idx = -1;
87 max_modhist_idx = (max_modhist_idx < elem) ? elem : max_modhist_idx;
89 if (max_modhist_idx >= 0) {
96 int max_modpath_idx = -1;
98 max_modpath_idx = (max_modpath_idx < elem) ? elem : max_modpath_idx;
100 if (max_modpath_idx >= 0) {
120 const MultiperspectivePerceptronParams &
p) :
BPredUnit(
p),
150 for (
auto &spec :
specs) {
156 for (
auto &spec :
specs) {
157 spec->setBitRequirements();
159 const MultiperspectivePerceptronParams &
p =
160 static_cast<const MultiperspectivePerceptronParams &
>(
params());
163 p.local_history_length,
p.ignore_path_size);
167 p.num_local_histories,
168 p.local_history_length,
assoc,
179 int nlocal_histories,
int local_history_length,
bool ignore_path_size)
183 totalbits += imli_bits;
186 if (!ignore_path_size) {
193 if (!ignore_path_size) {
195 totalbits += 16 *
len;
198 totalbits +=
doing_local ? (nlocal_histories * local_history_length) : 0;
202 for (
auto &bve : bv) {
206 totalbits += num_filter_entries * 2;
209 for (
auto &abj : abi) {
210 for (
auto abk : abj) {
220 for (
int i = 0;
i <
specs.size();
i +=1) {
230 int table_size_bits = (remaining / (
specs.size()-num_sized));
231 for (
int i = 0;
i <
specs.size();
i += 1) {
234 int my_table_size = table_size_bits /
241 DPRINTF(
Branch,
"%d bits of metadata so far, %d left out of "
242 "%d total budget\n", totalbits, remaining,
budgetbits);
243 DPRINTF(
Branch,
"table size is %d bits, %d entries for 5 bit, %d entries "
244 "for 6 bit\n", table_size_bits,
264 return mpreds < bp.mpreds;
268 for (
int i = 0;
i < best_preds.size();
i += 1) {
272 std::sort(pairs.begin(), pairs.end());
273 for (
int i = 0;
i < (std::min(
nbest, (
int) best_preds.size()));
i += 1) {
274 best_preds[
i] = pairs[
i].index;
283 unsigned long long int h =
g;
320 int lhist =
threadData[tid]->localHistories[
bi.getPC()];
321 int history_len =
threadData[tid]->localHistories.getLocalHistoryLength();
324 }
else if (lhist == ((1<<history_len)-1)) {
326 }
else if (lhist == (1<<(history_len-1))) {
328 }
else if (lhist == ((1<<(history_len-1))-1)) {
338 for (
int i = 0;
i <
specs.size();
i += 1) {
341 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
343 int counter =
threadData[tid]->tables[
i][hashed_idx];
348 int weight = spec.
coeff * ((spec.
width == 5) ?
351 int val = sign ? -weight : weight;
357 j < std::min(
nbest, (
int) best_preds.size());
360 if (best_preds[
j] ==
i) {
374 int max_weight)
const
387 if (counter < max_weight) {
395 if (counter < max_weight) {
417 bool correct = (
bi.yout >= 1) == taken;
419 int abs_yout = abs(
bi.yout);
425 for (
int i = 0;
i <
specs.size();
i += 1) {
428 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
430 int counter = tables[
i][hashed_idx];
431 int weight = spec.
coeff * ((spec.
width == 5) ?
433 if (sign) weight = -weight;
434 bool pred = weight >= 1;
444 for (
int i = 0;
i <
specs.size();
i += 1) {
451 bool do_train = !correct || (abs_yout <=
theta);
452 if (!do_train)
return;
462 if (correct && abs_yout <
theta) {
473 for (
int i = 0;
i <
specs.size();
i += 1) {
476 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
477 int counter = tables[
i][hashed_idx];
483 tables[
i][hashed_idx] = counter;
485 int weight = ((spec.
width == 5) ?
xlat4[counter] :
xlat[counter]);
497 if ((newyout >= 1) != taken) {
499 int round_counter = 0;
507 for (
int j = 0;
j <
specs.size();
j += 1) {
508 int i = (nrand +
j) %
specs.size();
510 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
511 int counter = tables[
i][hashed_idx];
514 int weight = ((spec.
width == 5) ?
516 int signed_weight = sign ? -weight : weight;
517 pout = newyout - signed_weight;
518 if ((pout >= 1) == taken) {
528 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
529 int counter = tables[
i][hashed_idx];
534 tables[
i][hashed_idx] = counter;
536 int weight = ((spec.
width == 5) ?
538 int signed_weight = sign ? -weight : weight;
539 int out = pout + signed_weight;
541 if ((out >= 1) != taken) {
557 bp_history = (
void *)
bi;
558 unsigned short int pc2 =
pc >> 2;
561 bool ab_new = (ghist_words[
i] >> (
blockSize - 1)) & 1;
562 ghist_words[
i] <<= 1;
563 ghist_words[
i] |= ab;
578 bp_history = (
void *)
bi;
580 bool use_static =
false;
583 unsigned int findex =
587 if (
f.alwaysNotTakenSoFar()) {
589 bi->prediction =
false;
591 }
else if (
f.alwaysTakenSoFar()) {
593 bi->prediction =
true;
603 bi->prediction =
false;
606 bi->prediction = (bestval >= 1);
608 bi->prediction = (
bi->yout >= 1);
612 return bi->prediction;
617 void *bp_history,
bool squashed,
628 if (
bi->isUnconditional()) {
633 bool do_train =
true;
636 int findex =
bi->getHashFilter(
threadData[tid]->last_ghist_bit) %
642 bool transition =
false;
643 if (
f.alwaysNotTakenSoFar() ||
f.alwaysTakenSoFar()) {
647 if (
f.alwaysNotTakenSoFar()) {
652 if (
f.alwaysTakenSoFar()) {
655 f.seenUntaken =
true;
664 if (
decay && transition &&
696 unsigned int target = corrTarget;
697 if (target < bi->getPC()) {
726 bool ab = hashed_taken;
727 assert(
threadData[tid]->ghist_words.size() > 0);
741 assert(
threadData[tid]->path_history.size() > 0);
750 threadData[tid]->updateAcyclic(hashed_taken,
bi->getHPC());
757 if (
bi->getHPC() % (
i + 2) == 0) {
771 for (
int i = 0;
i < blurrypath_histories.size();
i += 1)
773 if (blurrypath_histories[
i].size() > 0) {
774 unsigned int z =
bi->getPC() >>
i;
775 if (blurrypath_histories[
i][0] !=
z) {
776 memmove(&blurrypath_histories[
i][1],
777 &blurrypath_histories[
i][0],
778 sizeof(
unsigned int) *
779 (blurrypath_histories[
i].size() - 1));
780 blurrypath_histories[
i][0] =
z;
790 if (
bi->getHPC() % (
i + 2) == 0) {
795 threadData[tid]->mod_histories[
i][0] = hashed_taken;
809 threadData[tid]->localHistories.update(
bi->getPC(), hashed_taken);
Base class for branch operations.
Basically a wrapper class to hold both the branch predictor and the BTB.
const unsigned numThreads
Number of the threads for which the branch history is maintained.
static unsigned int hash(const std::vector< unsigned int short > &recency_stack, const std::vector< int > &table_sizes, unsigned short int pc, int l, int t)
std::vector< ThreadData * > threadData
std::vector< int > imli_counter_bits
const unsigned long long int imli_mask4
std::vector< std::vector< std::vector< bool > > > acyclic_bits
int computeOutput(ThreadID tid, MPPBranchInfo &bi)
Computes the output of the predictor for a given branch and the resulting best value in case the pred...
const unsigned int record_mask
bool lookup(ThreadID tid, Addr instPC, void *&bp_history) override
Looks up a given PC in the BP to see if it is taken or not taken.
std::vector< int > modpath_lengths
std::vector< std::vector< int > > blurrypath_bits
const bool speculative_update
void findBest(ThreadID tid, std::vector< int > &best_preds) const
Finds the best subset of features to use in case of a low-confidence branch, returns the result as an...
const int blockSize
Predictor parameters.
void btbUpdate(ThreadID tid, Addr branch_addr, void *&bp_history) override
If a branch is not taken, because the BTB address is invalid or missing, this function sets the appro...
static int xlat[]
Transfer function for 6-width tables.
bool doing_local
runtime values and data used to count the size in bits
MultiperspectivePerceptron(const MultiperspectivePerceptronParams ¶ms)
unsigned int getIndex(ThreadID tid, const MPPBranchInfo &bi, const HistorySpec &spec, int index) const
Get the position index of a predictor table.
std::vector< int > modhist_indices
virtual void createSpecs()=0
Creates the tables of the predictor.
void satIncDec(bool taken, bool &sign, int &c, int max_weight) const
Auxiliary function to increase a table counter depending on the direction of the branch.
std::vector< int > modhist_lengths
void computeBits(int num_filter_entries, int nlocal_histories, int local_history_length, bool ignore_path_size)
Computes the size in bits of the structures needed to keep track of the history and the predictor tab...
static int xlat4[]
Transfer function for 5-width tables.
std::vector< int > table_sizes
void init() override
init() is called after all C++ SimObjects have been created and all ports are connected.
void train(ThreadID tid, MPPBranchInfo &bi, bool taken)
Trains the branch predictor with the given branch and direction.
void setExtraBits(int bits)
Sets the starting number of storage bits to compute the number of table entries.
void uncondBranch(ThreadID tid, Addr pc, void *&bp_history) override
std::vector< HistorySpec * > specs
Predictor tables.
std::vector< int > modpath_indices
void update(ThreadID tid, Addr instPC, bool taken, void *bp_history, bool squashed, const StaticInstPtr &inst, Addr corrTarget) override
Updates the BP with taken/not taken information.
void squash(ThreadID tid, void *bp_history) override
const unsigned long long int recencypos_mask
const unsigned long long int imli_mask1
std::enable_if_t< std::is_integral_v< T >, T > random()
Use the SFINAE idiom to choose an implementation based on whether the type is integral or floating po...
constexpr T bits(T val, unsigned first, unsigned last)
Extract the bitfield from position 'first' to 'last' (inclusive) from 'val' and right justify it.
#define fatal_if(cond,...)
Conditional fatal macro that checks the supplied condition and only causes a fatal error if the condi...
const Params & params() const
Reference material can be found at the JEDEC website: UFS standard http://www.jedec....
int16_t ThreadID
Thread index/ID type.
uint64_t Addr
Address type This will probably be moved somewhere else in the near future.
bool operator<(const Time &l, const Time &r)
Entry of the branch filter.
bool seenTaken
Has this branch been taken at least once?
bool seenUntaken
Has this branch been not taken at least once?
Base class to implement the predictor tables.
virtual unsigned int getHash(ThreadID tid, Addr pc, Addr pc2, int t) const =0
Gets the hash to index the table, using the pc of the branch, and the index of the table.
const int width
Width of the table in bits
const double coeff
Coefficient of the feature, models the accuracy of the feature.
History data is kept for each thread.
std::vector< std::vector< unsigned int > > acyclic2_histories
std::vector< std::vector< short int > > tables
std::vector< std::vector< unsigned int > > blurrypath_histories
std::vector< std::vector< std::array< bool, 2 > > > sign_bits
std::vector< std::vector< bool > > mod_histories
std::vector< std::vector< bool > > acyclic_histories
std::vector< int > mpreds
std::vector< std::vector< unsigned short int > > modpath_histories
ThreadData(int num_filter, int n_local_histories, int local_history_length, int assoc, const std::vector< std::vector< int >> &blurrypath_bits, int path_length, int ghist_length, int block_size, const std::vector< std::vector< std::vector< bool >>> &acyclic_bits, const std::vector< int > &modhist_indices, const std::vector< int > &modhist_lengths, const std::vector< int > &modpath_indices, const std::vector< int > &modpath_lengths, const std::vector< int > &table_sizes, int n_sign_bits)