gem5 v24.1.0.1
Loading...
Searching...
No Matches
mxfp.hh
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Advanced Micro Devices, Inc.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice,
9 * this list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright notice,
12 * this list of conditions and the following disclaimer in the documentation
13 * and/or other materials provided with the distribution.
14 *
15 * 3. Neither the name of the copyright holder nor the names of its
16 * contributors may be used to endorse or promote products derived from this
17 * software without specific prior written permission.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29 * POSSIBILITY OF SUCH DAMAGE.
30 */
31
32#ifndef __ARCH_AMDGPU_COMMON_DTYPE_MXFP_HH__
33#define __ARCH_AMDGPU_COMMON_DTYPE_MXFP_HH__
34
35#include <cmath>
36#include <cstdint>
37#include <iostream>
38
40
41namespace gem5
42{
43
44namespace AMDGPU
45{
46
47// Base class for all microscaling types. The sizes of everything are
48// determined by the enum fields in the FMT struct. All of these share the
49// same operator overloads which convert to float before arithmetic and
50// convert back if assigned to a microscaling type.
51template<typename FMT>
52class mxfp
53{
54 public:
55 mxfp() = default;
57 {
59 }
60
61 // Set raw bits, used by gem5 to set a raw value read from VGPRs.
62 mxfp(const uint32_t& raw)
63 {
64 // The info unions end up being "left" aligned. For example, in FP4
65 // only the bits 31:28 are used. Shift the input by the storage size
66 // of 32 by the type size (sign + exponent + mantissa bits).
67 data = raw;
68 data <<= (32 - int(FMT::sbits) - int(FMT::ebits) - int(FMT::mbits));
69 }
70
71 mxfp(const mxfp& f)
72 {
73 FMT conv_out;
74 conv_out = convertMXFP<FMT, decltype(f.getFmt())>(f.getFmt());
75 data = conv_out.storage;
76 }
77
78 mxfp&
79 operator=(const float& f)
80 {
82 return *this;
83 }
84
85 mxfp&
87 {
88 FMT conv_out;
89 conv_out = convertMXFP<FMT, decltype(f.getFmt())>(f.getFmt());
90 data = conv_out.storage;
91 return *this;
92 }
93
94 operator float() const
95 {
96 binary32 out;
97 FMT in;
98 in.storage = data;
100
101 return out.fp32;
102 }
103
104 constexpr static int
106 {
107 return int(FMT::mbits) + int(FMT::ebits) + int(FMT::sbits);
108 }
109
110 // Intentionally use storage > size() so that a storage type is not needed
111 // as a template parameter.
113
114 FMT
115 getFmt() const
116 {
117 FMT out;
118 out.storage = data;
119 return out;
120 }
121
122 void
124 {
125 data = in.storage;
126 }
127
128 // Used for upcasting
129 void
130 scaleMul(const float& f)
131 {
133 bfp.fp32 = f;
134 int scale_val = bfp.exp;
135
136 // Scale value of 0xFF is NaN. Scaling by NaN returns NaN.
137 // In this implementation, types without NaN define it as max().
138 if (scale_val == 0xFF) {
139 data = FMT::nan;
140 return;
141 }
142
143 scale_val -= bfp.bias;
144
145 FMT in = getFmt();
146 int exp = in.exp;
147
148 if (exp + scale_val > max_exp<FMT>()) {
149 in.exp = max_exp<FMT>();
150 } else if (exp + scale_val < min_exp<FMT>()) {
151 in.exp = min_exp<FMT>();
152 } else {
153 in.exp = exp + scale_val;
154 }
155
156 data = in.storage;
157 }
158
159 // Used for downcasting
160 void
161 scaleDiv(const float& f)
162 {
164 bfp.fp32 = f;
165 int scale_val = bfp.exp;
166
167 // Scale value of 0xFF is NaN. Scaling by NaN returns NaN.
168 // In this implementation, types without NaN define it as max().
169 if (scale_val == 0xFF) {
170 data = FMT::nan;
171 return;
172 }
173
174 scale_val -= bfp.bias;
175
176 FMT in = getFmt();
177 int exp = in.exp;
178
179 if (exp - scale_val > max_exp<FMT>()) {
180 in.exp = max_exp<FMT>();
181 } else if (exp - scale_val < min_exp<FMT>()) {
182 in.exp = min_exp<FMT>();
183 } else {
184 in.exp = exp - scale_val;
185
186 // Output become denorm
187 if (in.exp == 0) {
188 uint32_t m = in.mant | 1 << FMT::mbits;
189 m >>= 1;
190 in.mant = m & mask(FMT::mbits);
191 }
192 }
193
194 data = in.storage;
195 }
196
197 private:
199
202 {
203 binary32 in;
204 in.fp32 = f;
205
206 FMT out;
207 out.storage = 0;
208
210
211 return out.storage;
212 }
213};
214
215// Unary operators
216template<typename T>
217inline T operator+(T a)
218{
219 return a;
220}
221
222template<typename T>
223inline T operator-(T a)
224{
225 // Flip sign bit
226 a.data ^= 0x80000000;
227 return a;
228}
229
230template<typename T>
231inline T operator++(T a)
232{
233 a = a + T(1.0f);
234 return a;
235}
236
237template<typename T>
238inline T operator--(T a)
239{
240 a = a - T(1.0f);
241 return a;
242}
243
244template<typename T>
245inline T operator++(T a, int)
246{
247 T original = a;
248 ++a;
249 return original;
250}
251
252template<typename T>
253inline T operator--(T a, int)
254{
255 T original = a;
256 --a;
257 return original;
258}
259
260// Math operators
261template<typename T>
262inline T operator+(T a, T b)
263{
264 return T(float(a) + float(b));
265}
266
267template<typename T>
268inline T operator-(T a, T b)
269{
270 return T(float(a) - float(b));
271}
272
273template<typename T>
274inline T operator*(T a, T b)
275{
276 return T(float(a) * float(b));
277}
278
279template<typename T>
280inline T operator/(T a, T b)
281{
282 return T(float(a) / float(b));
283}
284
285template<typename T>
286inline T operator+=(T &a, T b)
287{
288 a = a + b;
289 return a;
290}
291
292template<typename T>
293inline T operator-=(T &a, T b)
294{
295 a = a - b;
296 return a;
297}
298
299template<typename T>
300inline T operator*=(T &a, T b)
301{
302 a = a * b;
303 return a;
304}
305
306template<typename T>
307inline T operator/=(T &a, T b)
308{
309 a = a / b;
310 return a;
311}
312
313// Comparison operators
314template<typename T>
315inline bool operator<(T a, T b)
316{
317 return float(a) < float(b);
318}
319
320template<typename T>
321inline bool operator>(T a, T b)
322{
323 return float(a) > float(b);
324}
325
326template<typename T>
327inline bool operator<=(T a, T b)
328{
329 return float(a) <= float(b);
330}
331
332template<typename T>
333inline bool operator>=(T a, T b)
334{
335 return float(a) >= float(b);
336}
337
338template<typename T>
339inline bool operator==(T a, T b)
340{
341 return float(a) == float(b);
342}
343
344template<typename T>
345inline bool operator!=(T a, T b)
346{
347 return float(a) != float(b);
348}
349
350} // namespace AMDGPU
351
352} // namespace gem5
353
354#endif // __ARCH_AMDGPU_COMMON_DTYPE_MXFP_HH__
mxfp(const mxfp &f)
Definition mxfp.hh:71
FMT getFmt() const
Definition mxfp.hh:115
void scaleMul(const float &f)
Definition mxfp.hh:130
void setFmt(FMT in)
Definition mxfp.hh:123
mxfp & operator=(const mxfp &f)
Definition mxfp.hh:86
uint32_t float_to_mxfp(float f)
Definition mxfp.hh:201
static constexpr int size()
Definition mxfp.hh:105
mxfp(const uint32_t &raw)
Definition mxfp.hh:62
mxfpRoundingMode mode
Definition mxfp.hh:198
void scaleDiv(const float &f)
Definition mxfp.hh:161
mxfp & operator=(const float &f)
Definition mxfp.hh:79
mxfp(float f)
Definition mxfp.hh:56
uint32_t data
Definition mxfp.hh:112
T operator-=(T &a, T b)
Definition mxfp.hh:293
bool operator<(T a, T b)
Definition mxfp.hh:315
bool operator==(T a, T b)
Definition mxfp.hh:339
T operator*=(T &a, T b)
Definition mxfp.hh:300
T operator--(T a)
Definition mxfp.hh:238
T operator++(T a)
Definition mxfp.hh:231
T operator*(T a, T b)
Definition mxfp.hh:274
T operator-(T a)
Definition mxfp.hh:223
T operator+(T a)
Definition mxfp.hh:217
T operator+=(T &a, T b)
Definition mxfp.hh:286
dFMT convertMXFP(sFMT in, mxfpRoundingMode mode=roundTiesToEven, uint32_t seed=0)
bool operator!=(T a, T b)
Definition mxfp.hh:345
bool operator<=(T a, T b)
Definition mxfp.hh:327
T operator/(T a, T b)
Definition mxfp.hh:280
bool operator>=(T a, T b)
Definition mxfp.hh:333
bool operator>(T a, T b)
Definition mxfp.hh:321
T operator/=(T &a, T b)
Definition mxfp.hh:307
Bitfield< 3, 0 > mask
Definition pcstate.hh:63
Bitfield< 7 > b
Bitfield< 8 > a
Definition misc_types.hh:66
Bitfield< 6 > f
Definition misc_types.hh:68
Bitfield< 0 > m
Copyright (c) 2024 Arm Limited All rights reserved.
Definition binary32.hh:36

Generated on Mon Jan 13 2025 04:27:34 for gem5 by doxygen 1.9.8