gem5 v24.1.0.1
Loading...
Searching...
No Matches
mxfp_convert.hh
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Advanced Micro Devices, Inc.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice,
9 * this list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright notice,
12 * this list of conditions and the following disclaimer in the documentation
13 * and/or other materials provided with the distribution.
14 *
15 * 3. Neither the name of the copyright holder nor the names of its
16 * contributors may be used to endorse or promote products derived from this
17 * software without specific prior written permission.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29 * POSSIBILITY OF SUCH DAMAGE.
30 */
31
32#ifndef __ARCH_AMDGPU_COMMON_DTYPE_MXFP_CONVERT_HH__
33#define __ARCH_AMDGPU_COMMON_DTYPE_MXFP_CONVERT_HH__
34
35#include <cassert>
36
38#include "base/bitfield.hh"
39
40namespace gem5
41{
42
43namespace AMDGPU
44{
45
46// The various rounding modes for microscaling formats. roundTiesToEven must
47// be supported. Other rounding modes may be supported.
53
54// Conversion functions - For instructions that convert from one microscaling
55// format to another. We only need the conversion functions as there do not
56// appear to be any instructions yet which operate directly on the MX formats.
57//
58// in - An MXFP info struct type
59// mode - rounding mode
60// seed - input value for stochastic rounding function
61template<typename dFMT, typename sFMT>
63 uint32_t seed = 0)
64{
65 // We assume that *both* exponent and mantissa bits are both >= or <=
66 // the target type. Checkable at compile time.
67 //
68 // This is not necessarily a limitation, others just are not implemented.
69 // Figuring this out would be interesting for converting FP8 <-> BF8 for
70 // example. So far all GPU conversion instructions convert explicitly to
71 // a larger type from a smaller type or smaller to larger.
72 static_assert(((int(sFMT::mbits) >= int(dFMT::mbits)) &&
73 (int(sFMT::ebits) >= int(dFMT::ebits)))
74 || ((int(sFMT::mbits) <= int(dFMT::mbits)) &&
75 (int(sFMT::ebits) <= int(dFMT::ebits))));
76
77 dFMT out;
78 out.storage = 0;
79
80 if (int(sFMT::mbits) >= int(dFMT::mbits) &&
81 int(sFMT::ebits) >= int(dFMT::ebits)) {
82 // Input format is larger, truncate and round mantissa. MX formats
83 // are subnormal if exp == 0. Zero out exp in that case.
84
85 if (std::isnan(in)) {
86 // For types with no NaN return max value.
87 if (std::numeric_limits<dFMT>::has_quiet_NaN) {
88 out = std::numeric_limits<dFMT>::quiet_NaN();
89 // Preserve sign bit
90 if (in.storage & 0x80000000) {
91 out.storage |= 0x80000000;
92 }
93 } else {
94 out = std::numeric_limits<dFMT>::max();
95 // Preserve sign bit
96 if (in.storage & 0x80000000) {
97 out.storage |= 0x80000000;
98 }
99 }
100 } else if (std::isinf(in)) {
101 // For types with no Inf return max value.
102 if (std::numeric_limits<dFMT>::has_infinity) {
103 out = std::numeric_limits<dFMT>::infinity();
104 // Preserve sign bit
105 if (in.storage & 0x80000000) {
106 out.storage |= 0x80000000;
107 }
108 } else {
109 out = std::numeric_limits<dFMT>::max();
110 // Preserve sign bit
111 if (in.storage & 0x80000000) {
112 out.storage |= 0x80000000;
113 }
114 }
115 } else if (in.mant == 0 && in.exp == 0) {
116 // All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign.
117 out.mant = 0;
118 out.exp = 0;
119 out.sign = in.sign;
120 } else {
121 // Extra bits are needed for the mantissa conversion.
122 uint32_t mant = in.mant & mask(sFMT::mbits);
123 int32_t exp = in.exp - sFMT::bias + dFMT::bias;
124 out.sign = in.sign;
125
126 // Input is not subnormal, add the implicit 1 bit.
127 if (in.exp) {
128 mant |= (1 << sFMT::mbits);
129 }
130
131 // Save the value for rounding so we don't need to recompute it.
132 uint32_t saved_mant = mant;
133
134 mant >>= (sFMT::mbits - dFMT::mbits);
135
136 // Output became subnormal
137 if (exp < 1) {
138 int shift = 1 - exp;
139 mant >>= shift;
140 out.exp = 0;
141 } else {
142 out.exp = exp;
143 }
144
145 mant &= mask(dFMT::mbits);
146 out.mant = mant;
147
148 // roundTiesToEven is the only required rounding mode for MXFP
149 // types. Here we take the input mantissa and check the first
150 // three bits that were shifted out. These are called guard,
151 // round, and sticky bits. The value of these three bits combined
152 // are used to determine if we should round up or down. If the
153 // value is directly in between, we look at the final bit of the
154 // output mantissa with guard, round, sticky shifted out. If the
155 // value is one, round to nearest even by rounding down (set it to
156 // zero).
157 //
158 // For denormals, the process is similar, but we shift the input
159 // mantissa by 1 - exp more bits before setting the value of guard,
160 // round, sticky. Note that for denormals exp < 1 (i.e., shift
161 // value is always positive).
162 //
163 // If the number of destination and source format mantissa bits are
164 // the same, the mantissa is unchanged.
165 if (int(sFMT::mbits) > int(dFMT::mbits)
166 && mode == roundTiesToEven) {
167 bool round_up = false;
168
169 // Round using guard, round, sticky bits. We want to make sure
170 // there are three bits remaining. This is currently true for
171 // all conversion instructions. This would need to be revisited
172 // if there are f4 <-> f6 or f6 <-> f8 conversions.
173 assert((sFMT::mbits - dFMT::mbits) > 2);
174
175 int check_shift = sFMT::mbits - dFMT::mbits - 3;
177
178 // Sticky bit is 1 if *any* of the N-2 bits that get shifted
179 // off are one. Being zero implies we are directly between two
180 // floating point values.
181 int sticky = (check_mant & mask(check_shift + 1)) != 0;
182
184 if (exp < 1) {
185 int shift = 1 - exp;
186 check_mant >>= shift;
187 }
188
189 // Combine guard, round, sticky into one 3-bit value. If that
190 // value is < 0b100 we round down (truncate -- nothing to do),
191 // if it is > 0b100 we round up. If it is == 0b100, round to
192 // nearest even.
194
195 // Add sticky to the 3-bit check value.
197
198 if (check_test > 0x4) {
199 round_up = true;
200 } else if (check_test == 0x4) {
201 // We are exactly between two FP values. Round to nearest
202 // even by looking at the last bit of output mantissa.
203 // If the last bit of the output mantissa is 1, round to
204 // nearest even (0 in last bit) which would simply be
205 // rounding down. The bit position of the last bit in this
206 // case is 0x8 since we kept three extra bits for guard,
207 // round, sticky.
208 if (check_mant & 0x8) {
209 out.mant -= 1;
210 }
211 }
212
213 if (round_up) {
214 if (out.mant == mask(dFMT::mbits)) {
215 // Mantissa at max value, increment exponent if not inf
216 if (out.exp != mask(dFMT::ebits)) {
217 out.exp++;
218 }
219 out.mant = 0;
220 } else {
221 out.mant++;
222 }
223 }
224 } else if (int(sFMT::mbits) > int(dFMT::mbits)
225 && mode == roundStochastic) {
226 // Use the discarded mantissa divided by the max mantissa of
227 // the source format to determine the probability of rounding
228 // up. An alternate implementation of this would be to get a
229 // random number and add that to the input mantissa. Then
230 // follow the normal rounding path above.
231 uint32_t discarded = in.mant & mask(sFMT::mbits - dFMT::mbits);
232 uint32_t max_mant = mask(sFMT::mbits);
233
235
236 // Use a stochastic rounding function with the seed value to
237 // determine compare probability. This is implemented as a
238 // "Galois LFSR."
239 auto srFunc = [](uint32_t in) {
240 uint32_t bit = (in ^ (in >> 1) ^ (in >> 3) ^ (in >> 12));
241 return (in >> 1) | (bit << 15);
242 };
243
244 // Assume stochastic rounding returns up to max uint32_t.
245 // This will return an FP value between 0.0f and 1.0f.
246 float draw_prob = float(srFunc(seed))
247 / float(std::numeric_limits<uint32_t>::max());
248
249 // Round up if the number we drew is less than the rounding
250 // probability. E.g., if round_prob is 90% (0.9) we choose
251 // values 0.0f - 0.90f to round up.
252 if (round_prob >= draw_prob) {
253 if (out.mant == mask(dFMT::mbits)) {
254 // mantissa at max value, increment exponent if not inf
255 if (out.exp != mask(dFMT::ebits)) {
256 out.exp++;
257 }
258 out.mant = 0;
259 } else {
260 out.mant++;
261 }
262 }
263 }
264 }
265 } else if (int(sFMT::mbits) <= int(dFMT::mbits) &&
266 int(sFMT::ebits) <= int(dFMT::ebits)) {
267 // Input format is smaller. Extend mantissa / exponent and pad with 0.
268 // Should be the same for all non-stochastic rounding modes.
269
270 if (std::isnan(in)) {
271 // For types with no NaN return max value.
272 if (std::numeric_limits<dFMT>::has_quiet_NaN) {
273 out = std::numeric_limits<dFMT>::quiet_NaN();
274 // Preserve sign bit
275 if (in.storage & 0x80000000) {
276 out.storage |= 0x80000000;
277 }
278 } else {
279 out = std::numeric_limits<dFMT>::max();
280 // Preserve sign bit
281 if (in.storage & 0x80000000) {
282 out.storage |= 0x80000000;
283 }
284 }
285 } else if (std::isinf(in)) {
286 // For types with no Inf return max value.
287 if (std::numeric_limits<dFMT>::has_infinity) {
288 out = std::numeric_limits<dFMT>::infinity();
289 // Preserve sign bit
290 if (in.storage & 0x80000000) {
291 out.storage |= 0x80000000;
292 }
293 } else {
294 out = std::numeric_limits<dFMT>::max();
295 // Preserve sign bit
296 if (in.storage & 0x80000000) {
297 out.storage |= 0x80000000;
298 }
299 }
300 } else if (in.mant == 0 && in.exp == 0) {
301 // All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign.
302 out.mant = 0;
303 out.exp = 0;
304 out.sign = in.sign;
305 } else {
306 out.mant = in.mant << (dFMT::mbits - sFMT::mbits);
307 out.exp = in.exp + dFMT::bias - sFMT::bias;
308 out.sign = in.sign;
309
310 // Normalize input denormals
311 if (!in.exp && int(sFMT::ebits) != int(dFMT::ebits)) {
312 uint32_t m = out.mant;
313 if (m != 0) {
314 out.exp++;
315 while (!(m >> dFMT::mbits)) {
316 m <<= 1;
317 out.exp--;
318 }
319 out.mant = m & mask(dFMT::mbits);
320 }
321 } else if (!in.exp) {
322 // Exponent is the same, but output is not denorm, so add
323 // implicit 1. This is specific mainly to bf16 -> f32.
324 uint32_t m = out.mant;
325 m <<= 1;
326 out.mant = m & mask(dFMT::mbits);
327 }
328 }
329 } else {
330 assert(false);
331 }
332
333 return out;
334}
335
336template<typename FMT>
338{
339 return 0;
340}
341
342template<typename FMT>
344{
345 return (1 << FMT::ebits) - 1;
346}
347
348
349} // namespace AMDGPU
350
351} // namespace gem5
352
353#endif // __ARCH_AMDGPU_COMMON_DTYPE_MXFP_CONVERT_HH__
dFMT convertMXFP(sFMT in, mxfpRoundingMode mode=roundTiesToEven, uint32_t seed=0)
Bitfield< 3, 0 > mask
Definition pcstate.hh:63
Bitfield< 4, 0 > mode
Definition misc_types.hh:74
Bitfield< 0 > m
Bitfield< 6, 5 > shift
Definition types.hh:117
Copyright (c) 2024 Arm Limited All rights reserved.
Definition binary32.hh:36
constexpr bool isinf(gem5::AMDGPU::fp16_e5m10_info a)
Definition fp16_e5m10.hh:78
constexpr bool isnan(gem5::AMDGPU::fp16_e5m10_info a)
Definition fp16_e5m10.hh:83

Generated on Mon Jan 13 2025 04:27:34 for gem5 by doxygen 1.9.8