54#include "debug/Branch.hh"
59namespace branch_prediction
64 {1,3,4,5,7,8,9,11,12,14,15,17,19,21,23,25,27,29,32,34,37,41,45,49,53,58,63,
68 {0,4,5,7,9,11,12,14,16,17,19,22,28,33,39,45,};
71 int n_local_histories,
int local_history_length,
int assoc,
73 int ghist_length,
int block_size,
80 : filterTable(num_filters), acyclic_histories(acyclic_bits.size()),
81 acyclic2_histories(acyclic_bits.size()),
82 blurrypath_histories(blurrypath_bits.size()),
83 ghist_words(ghist_length/block_size+1, 0),
84 path_history(path_length, 0), imli_counter(4,0),
85 localHistories(n_local_histories, local_history_length),
86 recency_stack(assoc), last_ghist_bit(false), occupancy(0)
97 int max_modhist_idx = -1;
99 max_modhist_idx = (max_modhist_idx < elem) ? elem : max_modhist_idx;
101 if (max_modhist_idx >= 0) {
108 int max_modpath_idx = -1;
110 max_modpath_idx = (max_modpath_idx < elem) ? elem : max_modpath_idx;
112 if (max_modpath_idx >= 0) {
132 const MultiperspectivePerceptronParams &
p) :
BPredUnit(
p),
162 for (
auto &spec :
specs) {
168 for (
auto &spec :
specs) {
169 spec->setBitRequirements();
171 const MultiperspectivePerceptronParams &
p =
172 static_cast<const MultiperspectivePerceptronParams &
>(
params());
175 p.local_history_length,
p.ignore_path_size);
179 p.num_local_histories,
180 p.local_history_length,
assoc,
191 int nlocal_histories,
int local_history_length,
bool ignore_path_size)
195 totalbits += imli_bits;
198 if (!ignore_path_size) {
205 if (!ignore_path_size) {
207 totalbits += 16 *
len;
210 totalbits +=
doing_local ? (nlocal_histories * local_history_length) : 0;
214 for (
auto &bve : bv) {
218 totalbits += num_filter_entries * 2;
221 for (
auto &abj : abi) {
222 for (
auto abk : abj) {
232 for (
int i = 0;
i <
specs.size();
i +=1) {
242 int table_size_bits = (remaining / (
specs.size()-num_sized));
243 for (
int i = 0;
i <
specs.size();
i += 1) {
246 int my_table_size = table_size_bits /
253 DPRINTF(
Branch,
"%d bits of metadata so far, %d left out of "
254 "%d total budget\n", totalbits, remaining,
budgetbits);
255 DPRINTF(
Branch,
"table size is %d bits, %d entries for 5 bit, %d entries "
256 "for 6 bit\n", table_size_bits,
274 bool operator<(BestPair
const &bp)
const
276 return mpreds < bp.mpreds;
280 for (
int i = 0;
i < best_preds.size();
i += 1) {
284 std::sort(pairs.begin(), pairs.end());
285 for (
int i = 0;
i < (std::min(
nbest, (
int) best_preds.size()));
i += 1) {
286 best_preds[
i] = pairs[
i].index;
295 unsigned long long int h =
g;
332 int lhist =
threadData[tid]->localHistories[
bi.getPC()];
333 int history_len =
threadData[tid]->localHistories.getLocalHistoryLength();
336 }
else if (lhist == ((1<<history_len)-1)) {
338 }
else if (lhist == (1<<(history_len-1))) {
340 }
else if (lhist == ((1<<(history_len-1))-1)) {
350 for (
int i = 0;
i <
specs.size();
i += 1) {
353 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
355 int counter =
threadData[tid]->tables[
i][hashed_idx];
360 int weight = spec.
coeff * ((spec.
width == 5) ?
363 int val = sign ? -weight : weight;
369 j < std::min(
nbest, (
int) best_preds.size());
372 if (best_preds[j] ==
i) {
386 int max_weight)
const
399 if (counter < max_weight) {
407 if (counter < max_weight) {
429 bool correct = (
bi.yout >= 1) == taken;
431 int abs_yout = abs(
bi.yout);
437 for (
int i = 0;
i <
specs.size();
i += 1) {
440 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
442 int counter = tables[
i][hashed_idx];
443 int weight = spec.
coeff * ((spec.
width == 5) ?
445 if (sign) weight = -weight;
446 bool pred = weight >= 1;
456 for (
int i = 0;
i <
specs.size();
i += 1) {
463 bool do_train = !correct || (abs_yout <=
theta);
464 if (!do_train)
return;
474 if (correct && abs_yout <
theta) {
485 for (
int i = 0;
i <
specs.size();
i += 1) {
488 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
489 int counter = tables[
i][hashed_idx];
495 tables[
i][hashed_idx] = counter;
497 int weight = ((spec.
width == 5) ?
xlat4[counter] :
xlat[counter]);
509 if ((newyout >= 1) != taken) {
511 int round_counter = 0;
519 for (
int j = 0; j <
specs.size(); j += 1) {
520 int i = (nrand + j) %
specs.size();
522 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
523 int counter = tables[
i][hashed_idx];
526 int weight = ((spec.
width == 5) ?
528 int signed_weight = sign ? -weight : weight;
529 pout = newyout - signed_weight;
530 if ((pout >= 1) == taken) {
540 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
541 int counter = tables[
i][hashed_idx];
546 tables[
i][hashed_idx] = counter;
548 int weight = ((spec.
width == 5) ?
550 int signed_weight = sign ? -weight : weight;
551 int out = pout + signed_weight;
553 if ((out >= 1) != taken) {
565 bool uncond,
bool taken,
Addr target,
void * &bp_history)
567 assert(uncond || bp_history);
578 bp_history = (
void *)
bi;
579 unsigned short int pc2 =
pc >> 2;
582 bool ab_new = (ghist_words[
i] >> (
blockSize - 1)) & 1;
583 ghist_words[
i] <<= 1;
584 ghist_words[
i] |= ab;
599 bp_history = (
void *)
bi;
601 bool use_static =
false;
604 unsigned int findex =
608 if (
f.alwaysNotTakenSoFar()) {
610 bi->prediction =
false;
612 }
else if (
f.alwaysTakenSoFar()) {
614 bi->prediction =
true;
624 bi->prediction =
false;
627 bi->prediction = (bestval >= 1);
629 bi->prediction = (
bi->yout >= 1);
633 return bi->prediction;
638 void * &bp_history,
bool squashed,
649 if (
bi->isUnconditional()) {
651 bp_history =
nullptr;
655 bool do_train =
true;
664 bool transition =
false;
665 if (
f.alwaysNotTakenSoFar() ||
f.alwaysTakenSoFar()) {
669 if (
f.alwaysNotTakenSoFar()) {
674 if (
f.alwaysTakenSoFar()) {
677 f.seenUntaken =
true;
686 if (
decay && transition &&
718 if (target < bi->getPC()) {
747 bool ab = hashed_taken;
748 assert(
threadData[tid]->ghist_words.size() > 0);
762 assert(
threadData[tid]->path_history.size() > 0);
771 threadData[tid]->updateAcyclic(hashed_taken,
bi->getHPC());
778 if (
bi->getHPC() % (
i + 2) == 0) {
792 for (
int i = 0;
i < blurrypath_histories.size();
i += 1)
794 if (blurrypath_histories[
i].size() > 0) {
795 unsigned int z =
bi->getPC() >>
i;
796 if (blurrypath_histories[
i][0] !=
z) {
797 memmove(&blurrypath_histories[
i][1],
798 &blurrypath_histories[
i][0],
799 sizeof(
unsigned int) *
800 (blurrypath_histories[
i].size() - 1));
801 blurrypath_histories[
i][0] =
z;
811 if (
bi->getHPC() % (
i + 2) == 0) {
816 threadData[tid]->mod_histories[
i][0] = hashed_taken;
830 threadData[tid]->localHistories.update(
bi->getPC(), hashed_taken);
837 bp_history =
nullptr;
846 bp_history =
nullptr;
Base class for branch operations.
Basically a wrapper class to hold both the branch predictor and the BTB.
const unsigned numThreads
Number of the threads for which the branch history is maintained.
unsigned int getHashFilter(bool last_ghist_bit) const
static unsigned int hash(const std::vector< unsigned int short > &recency_stack, const std::vector< int > &table_sizes, unsigned short int pc, int l, int t)
std::vector< ThreadData * > threadData
std::vector< int > imli_counter_bits
const unsigned long long int imli_mask4
std::vector< std::vector< std::vector< bool > > > acyclic_bits
int computeOutput(ThreadID tid, MPPBranchInfo &bi)
Computes the output of the predictor for a given branch and the resulting best value in case the pred...
const unsigned int record_mask
std::vector< int > modpath_lengths
std::vector< std::vector< int > > blurrypath_bits
const bool speculative_update
void findBest(ThreadID tid, std::vector< int > &best_preds) const
Finds the best subset of features to use in case of a low-confidence branch, returns the result as an...
const int blockSize
Predictor parameters.
static int xlat[]
Transfer function for 6-width tables.
bool doing_local
runtime values and data used to count the size in bits
MultiperspectivePerceptron(const MultiperspectivePerceptronParams ¶ms)
unsigned int getIndex(ThreadID tid, const MPPBranchInfo &bi, const HistorySpec &spec, int index) const
Get the position index of a predictor table.
std::vector< int > modhist_indices
virtual void createSpecs()=0
Creates the tables of the predictor.
void satIncDec(bool taken, bool &sign, int &c, int max_weight) const
Auxiliary function to increase a table counter depending on the direction of the branch.
bool lookup(ThreadID tid, Addr branch_addr, void *&bp_history) override
Looks up a given conditional branch PC of in the BP to see if it is taken or not taken.
std::vector< int > modhist_lengths
void computeBits(int num_filter_entries, int nlocal_histories, int local_history_length, bool ignore_path_size)
Computes the size in bits of the structures needed to keep track of the history and the predictor tab...
static int xlat4[]
Transfer function for 5-width tables.
std::vector< int > table_sizes
void init() override
init() is called after all C++ SimObjects have been created and all ports are connected.
void update(ThreadID tid, Addr pc, bool taken, void *&bp_history, bool squashed, const StaticInstPtr &inst, Addr target) override
Updates the BP with taken/not taken information.
void train(ThreadID tid, MPPBranchInfo &bi, bool taken)
Trains the branch predictor with the given branch and direction.
void setExtraBits(int bits)
Sets the starting number of storage bits to compute the number of table entries.
std::vector< HistorySpec * > specs
Predictor tables.
std::vector< int > modpath_indices
const unsigned long long int recencypos_mask
void updateHistories(ThreadID tid, Addr pc, bool uncond, bool taken, Addr target, void *&bp_history) override
Ones done with the prediction this function updates the path and global history.
const unsigned long long int imli_mask1
void squash(ThreadID tid, void *&bp_history) override
std::enable_if_t< std::is_integral_v< T >, T > random()
Use the SFINAE idiom to choose an implementation based on whether the type is integral or floating po...
constexpr T bits(T val, unsigned first, unsigned last)
Extract the bitfield from position 'first' to 'last' (inclusive) from 'val' and right justify it.
#define fatal_if(cond,...)
Conditional fatal macro that checks the supplied condition and only causes a fatal error if the condi...
const Params & params() const
Copyright (c) 2024 - Pranith Kumar Copyright (c) 2020 Inria All rights reserved.
int16_t ThreadID
Thread index/ID type.
uint64_t Addr
Address type This will probably be moved somewhere else in the near future.
Entry of the branch filter.
bool seenTaken
Has this branch been taken at least once?
bool seenUntaken
Has this branch been not taken at least once?
Base class to implement the predictor tables.
virtual unsigned int getHash(ThreadID tid, Addr pc, Addr pc2, int t) const =0
Gets the hash to index the table, using the pc of the branch, and the index of the table.
const int width
Width of the table in bits
const double coeff
Coefficient of the feature, models the accuracy of the feature.
History data is kept for each thread.
std::vector< std::vector< unsigned int > > acyclic2_histories
std::vector< std::vector< short int > > tables
std::vector< std::vector< unsigned int > > blurrypath_histories
ThreadData(int num_filter, int n_local_histories, int local_history_length, int assoc, const std::vector< std::vector< int > > &blurrypath_bits, int path_length, int ghist_length, int block_size, const std::vector< std::vector< std::vector< bool > > > &acyclic_bits, const std::vector< int > &modhist_indices, const std::vector< int > &modhist_lengths, const std::vector< int > &modpath_indices, const std::vector< int > &modpath_lengths, const std::vector< int > &table_sizes, int n_sign_bits)
std::vector< std::vector< std::array< bool, 2 > > > sign_bits
std::vector< std::vector< bool > > mod_histories
std::vector< std::vector< bool > > acyclic_histories
std::vector< int > mpreds
std::vector< std::vector< unsigned short int > > modpath_histories