53#include "debug/Branch.hh"
63 {1,3,4,5,7,8,9,11,12,14,15,17,19,21,23,25,27,29,32,34,37,41,45,49,53,58,63,
67 {0,4,5,7,9,11,12,14,16,17,19,22,28,33,39,45,};
70 int n_local_histories,
int local_history_length,
int assoc,
96 int max_modhist_idx = -1;
98 max_modhist_idx = (max_modhist_idx < elem) ? elem : max_modhist_idx;
100 if (max_modhist_idx >= 0) {
107 int max_modpath_idx = -1;
109 max_modpath_idx = (max_modpath_idx < elem) ? elem : max_modpath_idx;
111 if (max_modpath_idx >= 0) {
131 const MultiperspectivePerceptronParams &
p) :
BPredUnit(
p),
161 for (
auto &spec :
specs) {
167 for (
auto &spec :
specs) {
168 spec->setBitRequirements();
170 const MultiperspectivePerceptronParams &
p =
171 static_cast<const MultiperspectivePerceptronParams &
>(
params());
174 p.local_history_length,
p.ignore_path_size);
178 p.num_local_histories,
179 p.local_history_length,
assoc,
190 int nlocal_histories,
int local_history_length,
bool ignore_path_size)
194 totalbits += imli_bits;
197 if (!ignore_path_size) {
204 if (!ignore_path_size) {
206 totalbits += 16 *
len;
209 totalbits +=
doing_local ? (nlocal_histories * local_history_length) : 0;
213 for (
auto &bve : bv) {
217 totalbits += num_filter_entries * 2;
220 for (
auto &abj : abi) {
221 for (
auto abk : abj) {
231 for (
int i = 0;
i <
specs.size();
i +=1) {
241 int table_size_bits = (remaining / (
specs.size()-num_sized));
242 for (
int i = 0;
i <
specs.size();
i += 1) {
245 int my_table_size = table_size_bits /
252 DPRINTF(
Branch,
"%d bits of metadata so far, %d left out of "
253 "%d total budget\n", totalbits, remaining,
budgetbits);
254 DPRINTF(
Branch,
"table size is %d bits, %d entries for 5 bit, %d entries "
255 "for 6 bit\n", table_size_bits,
275 return mpreds < bp.mpreds;
279 for (
int i = 0;
i < best_preds.size();
i += 1) {
283 std::sort(pairs.begin(), pairs.end());
284 for (
int i = 0;
i < (std::min(
nbest, (
int) best_preds.size()));
i += 1) {
285 best_preds[
i] = pairs[
i].index;
294 unsigned long long int h =
g;
331 int lhist =
threadData[tid]->localHistories[
bi.getPC()];
332 int history_len =
threadData[tid]->localHistories.getLocalHistoryLength();
335 }
else if (lhist == ((1<<history_len)-1)) {
337 }
else if (lhist == (1<<(history_len-1))) {
339 }
else if (lhist == ((1<<(history_len-1))-1)) {
349 for (
int i = 0;
i <
specs.size();
i += 1) {
352 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
354 int counter =
threadData[tid]->tables[
i][hashed_idx];
359 int weight = spec.
coeff * ((spec.
width == 5) ?
362 int val = sign ? -weight : weight;
368 j < std::min(
nbest, (
int) best_preds.size());
371 if (best_preds[j] ==
i) {
385 int max_weight)
const
398 if (counter < max_weight) {
406 if (counter < max_weight) {
428 bool correct = (
bi.yout >= 1) == taken;
430 int abs_yout = abs(
bi.yout);
436 for (
int i = 0;
i <
specs.size();
i += 1) {
439 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
441 int counter = tables[
i][hashed_idx];
442 int weight = spec.
coeff * ((spec.
width == 5) ?
444 if (sign) weight = -weight;
445 bool pred = weight >= 1;
455 for (
int i = 0;
i <
specs.size();
i += 1) {
462 bool do_train = !correct || (abs_yout <=
theta);
463 if (!do_train)
return;
473 if (correct && abs_yout <
theta) {
484 for (
int i = 0;
i <
specs.size();
i += 1) {
487 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
488 int counter = tables[
i][hashed_idx];
494 tables[
i][hashed_idx] = counter;
496 int weight = ((spec.
width == 5) ?
xlat4[counter] :
xlat[counter]);
508 if ((newyout >= 1) != taken) {
510 int round_counter = 0;
515 int nrand =
rng->random<
int>() %
specs.size();
518 for (
int j = 0; j <
specs.size(); j += 1) {
519 int i = (nrand + j) %
specs.size();
521 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
522 int counter = tables[
i][hashed_idx];
525 int weight = ((spec.
width == 5) ?
527 int signed_weight = sign ? -weight : weight;
528 pout = newyout - signed_weight;
529 if ((pout >= 1) == taken) {
539 unsigned int hashed_idx =
getIndex(tid,
bi, spec,
i);
540 int counter = tables[
i][hashed_idx];
545 tables[
i][hashed_idx] = counter;
547 int weight = ((spec.
width == 5) ?
549 int signed_weight = sign ? -weight : weight;
550 int out = pout + signed_weight;
552 if ((out >= 1) != taken) {
564 bool uncond,
bool taken,
Addr target,
567 assert(uncond || bp_history);
578 bp_history = (
void *)
bi;
579 unsigned short int pc2 =
pc >> 2;
582 bool ab_new = (ghist_words[
i] >> (
blockSize - 1)) & 1;
583 ghist_words[
i] <<= 1;
584 ghist_words[
i] |= ab;
599 bp_history = (
void *)
bi;
601 bool use_static =
false;
604 unsigned int findex =
608 if (
f.alwaysNotTakenSoFar()) {
610 bi->prediction =
false;
612 }
else if (
f.alwaysTakenSoFar()) {
614 bi->prediction =
true;
624 bi->prediction =
false;
627 bi->prediction = (bestval >= 1);
629 bi->prediction = (
bi->yout >= 1);
633 return bi->prediction;
638 void * &bp_history,
bool squashed,
649 if (
bi->isUnconditional()) {
651 bp_history =
nullptr;
655 bool do_train =
true;
658 int findex =
bi->getHashFilter(
threadData[tid]->last_ghist_bit) %
664 bool transition =
false;
665 if (
f.alwaysNotTakenSoFar() ||
f.alwaysTakenSoFar()) {
669 if (
f.alwaysNotTakenSoFar()) {
674 if (
f.alwaysTakenSoFar()) {
677 f.seenUntaken =
true;
686 if (
decay && transition &&
688 int rnd =
rng->random<
int>() %
718 if (target < bi->getPC()) {
747 bool ab = hashed_taken;
748 assert(
threadData[tid]->ghist_words.size() > 0);
762 assert(
threadData[tid]->path_history.size() > 0);
771 threadData[tid]->updateAcyclic(hashed_taken,
bi->getHPC());
778 if (
bi->getHPC() % (
i + 2) == 0) {
792 for (
int i = 0;
i < blurrypath_histories.size();
i += 1)
794 if (blurrypath_histories[
i].size() > 0) {
795 unsigned int z =
bi->getPC() >>
i;
796 if (blurrypath_histories[
i][0] !=
z) {
797 memmove(&blurrypath_histories[
i][1],
798 &blurrypath_histories[
i][0],
799 sizeof(
unsigned int) *
800 (blurrypath_histories[
i].size() - 1));
801 blurrypath_histories[
i][0] =
z;
811 if (
bi->getHPC() % (
i + 2) == 0) {
816 threadData[tid]->mod_histories[
i][0] = hashed_taken;
830 threadData[tid]->localHistories.update(
bi->getPC(), hashed_taken);
837 bp_history =
nullptr;
846 bp_history =
nullptr;
Base class for branch operations.
BPredUnit(const Params &p)
Branch Predictor Unit (BPU) interface functions.
const unsigned numThreads
Number of the threads for which the branch history is maintained.
static unsigned int hash(const std::vector< unsigned int short > &recency_stack, const std::vector< int > &table_sizes, unsigned short int pc, int l, int t)
std::vector< ThreadData * > threadData
std::vector< int > imli_counter_bits
const unsigned long long int imli_mask4
std::vector< std::vector< std::vector< bool > > > acyclic_bits
int computeOutput(ThreadID tid, MPPBranchInfo &bi)
Computes the output of the predictor for a given branch and the resulting best value in case the pred...
const unsigned int record_mask
std::vector< int > modpath_lengths
std::vector< std::vector< int > > blurrypath_bits
const bool speculative_update
void findBest(ThreadID tid, std::vector< int > &best_preds) const
Finds the best subset of features to use in case of a low-confidence branch, returns the result as an...
const int blockSize
Predictor parameters.
static int xlat[]
Transfer function for 6-width tables.
bool doing_local
runtime values and data used to count the size in bits
MultiperspectivePerceptron(const MultiperspectivePerceptronParams ¶ms)
unsigned int getIndex(ThreadID tid, const MPPBranchInfo &bi, const HistorySpec &spec, int index) const
Get the position index of a predictor table.
std::vector< int > modhist_indices
virtual void createSpecs()=0
Creates the tables of the predictor.
void satIncDec(bool taken, bool &sign, int &c, int max_weight) const
Auxiliary function to increase a table counter depending on the direction of the branch.
bool lookup(ThreadID tid, Addr branch_addr, void *&bp_history) override
Looks up a given conditional branch PC of in the BP to see if it is taken or not taken.
std::vector< int > modhist_lengths
void computeBits(int num_filter_entries, int nlocal_histories, int local_history_length, bool ignore_path_size)
Computes the size in bits of the structures needed to keep track of the history and the predictor tab...
static int xlat4[]
Transfer function for 5-width tables.
std::vector< int > table_sizes
void init() override
init() is called after all C++ SimObjects have been created and all ports are connected.
void train(ThreadID tid, MPPBranchInfo &bi, bool taken)
Trains the branch predictor with the given branch and direction.
void updateHistories(ThreadID tid, Addr pc, bool uncond, bool taken, Addr target, const StaticInstPtr &inst, void *&bp_history) override
Ones done with the prediction this function updates the path and global history.
void setExtraBits(int bits)
Sets the starting number of storage bits to compute the number of table entries.
std::vector< HistorySpec * > specs
Predictor tables.
std::vector< int > modpath_indices
const unsigned long long int recencypos_mask
const unsigned long long int imli_mask1
void squash(ThreadID tid, void *&bp_history) override
void update(ThreadID tid, Addr branch_addr, bool taken, void *&bp_history, bool squashed, const StaticInstPtr &inst, Addr corrTarget) override
Updates the BP with taken/not taken information.
constexpr T bits(T val, unsigned first, unsigned last)
Extract the bitfield from position 'first' to 'last' (inclusive) from 'val' and right justify it.
#define fatal_if(cond,...)
Conditional fatal macro that checks the supplied condition and only causes a fatal error if the condi...
const Params & params() const
Copyright (c) 2024 Arm Limited All rights reserved.
int16_t ThreadID
Thread index/ID type.
uint64_t Addr
Address type This will probably be moved somewhere else in the near future.
RefCountingPtr< StaticInst > StaticInstPtr
bool operator<(const Time &l, const Time &r)
Entry of the branch filter.
bool seenTaken
Has this branch been taken at least once?
bool seenUntaken
Has this branch been not taken at least once?
Base class to implement the predictor tables.
virtual unsigned int getHash(ThreadID tid, Addr pc, Addr pc2, int t) const =0
Gets the hash to index the table, using the pc of the branch, and the index of the table.
const int width
Width of the table in bits.
const double coeff
Coefficient of the feature, models the accuracy of the feature.
History data is kept for each thread.
std::vector< unsigned int > ghist_words
std::vector< std::vector< unsigned int > > acyclic2_histories
std::vector< std::vector< short int > > tables
std::vector< std::vector< unsigned int > > blurrypath_histories
ThreadData(int num_filter, int n_local_histories, int local_history_length, int assoc, const std::vector< std::vector< int > > &blurrypath_bits, int path_length, int ghist_length, int block_size, const std::vector< std::vector< std::vector< bool > > > &acyclic_bits, const std::vector< int > &modhist_indices, const std::vector< int > &modhist_lengths, const std::vector< int > &modpath_indices, const std::vector< int > &modpath_lengths, const std::vector< int > &table_sizes, int n_sign_bits)
std::vector< unsigned int > imli_counter
std::vector< FilterEntry > filterTable
std::vector< std::vector< std::array< bool, 2 > > > sign_bits
std::vector< std::vector< bool > > mod_histories
std::vector< std::vector< bool > > acyclic_histories
std::vector< int > mpreds
std::vector< unsigned short int > path_history
LocalHistories localHistories
std::vector< std::vector< unsigned short int > > modpath_histories
std::vector< unsigned int short > recency_stack