53#include "debug/Branch.hh"
58namespace branch_prediction
63 {1,3,4,5,7,8,9,11,12,14,15,17,19,21,23,25,27,29,32,34,37,41,45,49,53,58,63,
67 {0,4,5,7,9,11,12,14,16,17,19,22,28,33,39,45,};
70 int n_local_histories,
int local_history_length,
int assoc,
72 int ghist_length,
int block_size,
79 : filterTable(num_filters), acyclic_histories(acyclic_bits.size()),
80 acyclic2_histories(acyclic_bits.size()),
81 blurrypath_histories(blurrypath_bits.size()),
82 ghist_words(ghist_length/block_size+1, 0),
83 path_history(path_length, 0), imli_counter(4,0),
84 localHistories(n_local_histories, local_history_length),
85 recency_stack(assoc), last_ghist_bit(false), occupancy(0)
96 int max_modhist_idx = -1;
98 max_modhist_idx = (max_modhist_idx < elem) ? elem : max_modhist_idx;
100 if (max_modhist_idx >= 0) {
107 int max_modpath_idx = -1;
109 max_modpath_idx = (max_modpath_idx < elem) ? elem : max_modpath_idx;
111 if (max_modpath_idx >= 0) {
131 const MultiperspectivePerceptronParams &
p) :
BPredUnit(
p),
161 for (
auto &spec :
specs) {
167 for (
auto &spec :
specs) {
168 spec->setBitRequirements();
170 const MultiperspectivePerceptronParams &
p =
171 static_cast<const MultiperspectivePerceptronParams &
>(
params());
174 p.local_history_length,
p.ignore_path_size);
178 p.num_local_histories,
179 p.local_history_length,
assoc,
190 int nlocal_histories,
int local_history_length,
bool ignore_path_size)
194 totalbits += imli_bits;
197 if (!ignore_path_size) {
204 if (!ignore_path_size) {
206 totalbits += 16 *
len;
209 totalbits +=
doing_local ? (nlocal_histories * local_history_length) : 0;
213 for (
auto &bve : bv) {
217 totalbits += num_filter_entries * 2;
220 for (
auto &abj : abi) {
221 for (
auto abk : abj) {
231 for (
int i = 0;
i <
specs.size();
i +=1) {
241 int table_size_bits = (remaining / (
specs.size()-num_sized));
242 for (
int i = 0;
i <
specs.size();
i += 1) {
245 int my_table_size = table_size_bits /
252 DPRINTF(
Branch,
"%d bits of metadata so far, %d left out of "
253 "%d total budget\n", totalbits, remaining,
budgetbits);
254 DPRINTF(
Branch,
"table size is %d bits, %d entries for 5 bit, %d entries "
255 "for 6 bit\n", table_size_bits,
273 bool operator<(BestPair
const &bp)
const
275 return mpreds < bp.mpreds;
279 for (
int i = 0;
i < best_preds.size();
i += 1) {
283 std::sort(pairs.begin(), pairs.end());
284 for (
int i = 0;
i < (std::min(
nbest, (
int) best_preds.size()));
i += 1) {
285 best_preds[
i] = pairs[
i].index;
294 unsigned long long int h =
g;
331 int lhist =
threadData[tid]->localHistories[
bi.getPC()];
332 int history_len =
threadData[tid]->localHistories.getLocalHistoryLength();
335 }
else if (lhist == ((1<<history_len)-1)) {
337 }
else if (lhist == (1<<(history_len-1))) {
339 }
else if (lhist == ((1<<(history_len-1))-1)) {
349 for (
int i = 0;
i <
specs.size();
i += 1) {
352 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
354 int counter =
threadData[tid]->tables[
i][hashed_idx];
359 int weight = spec.
coeff * ((spec.
width == 5) ?
362 int val = sign ? -weight : weight;
368 j < std::min(
nbest, (
int) best_preds.size());
371 if (best_preds[j] ==
i) {
385 int max_weight)
const
398 if (counter < max_weight) {
406 if (counter < max_weight) {
428 bool correct = (
bi.yout >= 1) == taken;
430 int abs_yout = abs(
bi.yout);
436 for (
int i = 0;
i <
specs.size();
i += 1) {
439 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
441 int counter = tables[
i][hashed_idx];
442 int weight = spec.
coeff * ((spec.
width == 5) ?
444 if (sign) weight = -weight;
445 bool pred = weight >= 1;
455 for (
int i = 0;
i <
specs.size();
i += 1) {
462 bool do_train = !correct || (abs_yout <=
theta);
463 if (!do_train)
return;
473 if (correct && abs_yout <
theta) {
484 for (
int i = 0;
i <
specs.size();
i += 1) {
487 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
488 int counter = tables[
i][hashed_idx];
494 tables[
i][hashed_idx] = counter;
496 int weight = ((spec.
width == 5) ?
xlat4[counter] :
xlat[counter]);
508 if ((newyout >= 1) != taken) {
510 int round_counter = 0;
515 int nrand =
rng->random<
int>() %
specs.size();
518 for (
int j = 0; j <
specs.size(); j += 1) {
519 int i = (nrand + j) %
specs.size();
521 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
522 int counter = tables[
i][hashed_idx];
525 int weight = ((spec.
width == 5) ?
527 int signed_weight = sign ? -weight : weight;
528 pout = newyout - signed_weight;
529 if ((pout >= 1) == taken) {
539 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
540 int counter = tables[
i][hashed_idx];
545 tables[
i][hashed_idx] = counter;
547 int weight = ((spec.
width == 5) ?
549 int signed_weight = sign ? -weight : weight;
550 int out = pout + signed_weight;
552 if ((out >= 1) != taken) {
564 bool uncond,
bool taken,
Addr target,
void * &bp_history)
566 assert(uncond || bp_history);
577 bp_history = (
void *)
bi;
578 unsigned short int pc2 =
pc >> 2;
581 bool ab_new = (ghist_words[
i] >> (
blockSize - 1)) & 1;
582 ghist_words[
i] <<= 1;
583 ghist_words[
i] |= ab;
598 bp_history = (
void *)
bi;
600 bool use_static =
false;
603 unsigned int findex =
607 if (
f.alwaysNotTakenSoFar()) {
609 bi->prediction =
false;
611 }
else if (
f.alwaysTakenSoFar()) {
613 bi->prediction =
true;
623 bi->prediction =
false;
626 bi->prediction = (bestval >= 1);
628 bi->prediction = (
bi->yout >= 1);
632 return bi->prediction;
637 void * &bp_history,
bool squashed,
648 if (
bi->isUnconditional()) {
650 bp_history =
nullptr;
654 bool do_train =
true;
663 bool transition =
false;
664 if (
f.alwaysNotTakenSoFar() ||
f.alwaysTakenSoFar()) {
668 if (
f.alwaysNotTakenSoFar()) {
673 if (
f.alwaysTakenSoFar()) {
676 f.seenUntaken =
true;
685 if (
decay && transition &&
687 int rnd =
rng->random<
int>() %
717 if (target < bi->getPC()) {
746 bool ab = hashed_taken;
747 assert(
threadData[tid]->ghist_words.size() > 0);
761 assert(
threadData[tid]->path_history.size() > 0);
770 threadData[tid]->updateAcyclic(hashed_taken,
bi->getHPC());
777 if (
bi->getHPC() % (
i + 2) == 0) {
791 for (
int i = 0;
i < blurrypath_histories.size();
i += 1)
793 if (blurrypath_histories[
i].size() > 0) {
794 unsigned int z =
bi->getPC() >>
i;
795 if (blurrypath_histories[
i][0] !=
z) {
796 memmove(&blurrypath_histories[
i][1],
797 &blurrypath_histories[
i][0],
798 sizeof(
unsigned int) *
799 (blurrypath_histories[
i].size() - 1));
800 blurrypath_histories[
i][0] =
z;
810 if (
bi->getHPC() % (
i + 2) == 0) {
815 threadData[tid]->mod_histories[
i][0] = hashed_taken;
829 threadData[tid]->localHistories.update(
bi->getPC(), hashed_taken);
836 bp_history =
nullptr;
845 bp_history =
nullptr;
Base class for branch operations.
Basically a wrapper class to hold both the branch predictor and the BTB.
const unsigned numThreads
Number of the threads for which the branch history is maintained.
unsigned int getHashFilter(bool last_ghist_bit) const
static unsigned int hash(const std::vector< unsigned int short > &recency_stack, const std::vector< int > &table_sizes, unsigned short int pc, int l, int t)
std::vector< ThreadData * > threadData
std::vector< int > imli_counter_bits
const unsigned long long int imli_mask4
std::vector< std::vector< std::vector< bool > > > acyclic_bits
int computeOutput(ThreadID tid, MPPBranchInfo &bi)
Computes the output of the predictor for a given branch and the resulting best value in case the pred...
const unsigned int record_mask
std::vector< int > modpath_lengths
std::vector< std::vector< int > > blurrypath_bits
const bool speculative_update
void findBest(ThreadID tid, std::vector< int > &best_preds) const
Finds the best subset of features to use in case of a low-confidence branch, returns the result as an...
const int blockSize
Predictor parameters.
static int xlat[]
Transfer function for 6-width tables.
bool doing_local
runtime values and data used to count the size in bits
MultiperspectivePerceptron(const MultiperspectivePerceptronParams ¶ms)
unsigned int getIndex(ThreadID tid, const MPPBranchInfo &bi, const HistorySpec &spec, int index) const
Get the position index of a predictor table.
std::vector< int > modhist_indices
virtual void createSpecs()=0
Creates the tables of the predictor.
void satIncDec(bool taken, bool &sign, int &c, int max_weight) const
Auxiliary function to increase a table counter depending on the direction of the branch.
bool lookup(ThreadID tid, Addr branch_addr, void *&bp_history) override
Looks up a given conditional branch PC of in the BP to see if it is taken or not taken.
std::vector< int > modhist_lengths
void computeBits(int num_filter_entries, int nlocal_histories, int local_history_length, bool ignore_path_size)
Computes the size in bits of the structures needed to keep track of the history and the predictor tab...
static int xlat4[]
Transfer function for 5-width tables.
std::vector< int > table_sizes
void init() override
init() is called after all C++ SimObjects have been created and all ports are connected.
void update(ThreadID tid, Addr pc, bool taken, void *&bp_history, bool squashed, const StaticInstPtr &inst, Addr target) override
Updates the BP with taken/not taken information.
void train(ThreadID tid, MPPBranchInfo &bi, bool taken)
Trains the branch predictor with the given branch and direction.
void setExtraBits(int bits)
Sets the starting number of storage bits to compute the number of table entries.
std::vector< HistorySpec * > specs
Predictor tables.
std::vector< int > modpath_indices
const unsigned long long int recencypos_mask
void updateHistories(ThreadID tid, Addr pc, bool uncond, bool taken, Addr target, void *&bp_history) override
Ones done with the prediction this function updates the path and global history.
const unsigned long long int imli_mask1
void squash(ThreadID tid, void *&bp_history) override
constexpr T bits(T val, unsigned first, unsigned last)
Extract the bitfield from position 'first' to 'last' (inclusive) from 'val' and right justify it.
#define fatal_if(cond,...)
Conditional fatal macro that checks the supplied condition and only causes a fatal error if the condi...
const Params & params() const
Copyright (c) 2024 Arm Limited All rights reserved.
int16_t ThreadID
Thread index/ID type.
uint64_t Addr
Address type This will probably be moved somewhere else in the near future.
Entry of the branch filter.
bool seenTaken
Has this branch been taken at least once?
bool seenUntaken
Has this branch been not taken at least once?
Base class to implement the predictor tables.
virtual unsigned int getHash(ThreadID tid, Addr pc, Addr pc2, int t) const =0
Gets the hash to index the table, using the pc of the branch, and the index of the table.
const int width
Width of the table in bits
const double coeff
Coefficient of the feature, models the accuracy of the feature.
History data is kept for each thread.
std::vector< std::vector< unsigned int > > acyclic2_histories
std::vector< std::vector< short int > > tables
std::vector< std::vector< unsigned int > > blurrypath_histories
ThreadData(int num_filter, int n_local_histories, int local_history_length, int assoc, const std::vector< std::vector< int > > &blurrypath_bits, int path_length, int ghist_length, int block_size, const std::vector< std::vector< std::vector< bool > > > &acyclic_bits, const std::vector< int > &modhist_indices, const std::vector< int > &modhist_lengths, const std::vector< int > &modpath_indices, const std::vector< int > &modpath_lengths, const std::vector< int > &table_sizes, int n_sign_bits)
std::vector< std::vector< std::array< bool, 2 > > > sign_bits
std::vector< std::vector< bool > > mod_histories
std::vector< std::vector< bool > > acyclic_histories
std::vector< int > mpreds
std::vector< std::vector< unsigned short int > > modpath_histories