gem5 [DEVELOP-FOR-25.1]
Loading...
Searching...
No Matches
mxfp_convert.hh
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Advanced Micro Devices, Inc.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice,
9 * this list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright notice,
12 * this list of conditions and the following disclaimer in the documentation
13 * and/or other materials provided with the distribution.
14 *
15 * 3. Neither the name of the copyright holder nor the names of its
16 * contributors may be used to endorse or promote products derived from this
17 * software without specific prior written permission.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29 * POSSIBILITY OF SUCH DAMAGE.
30 */
31
32#ifndef __ARCH_AMDGPU_COMMON_DTYPE_MXFP_CONVERT_HH__
33#define __ARCH_AMDGPU_COMMON_DTYPE_MXFP_CONVERT_HH__
34
35#include <cassert>
36
38#include "base/bitfield.hh"
39
40namespace gem5
41{
42
43namespace AMDGPU
44{
45
46// The various rounding modes for microscaling formats. roundTiesToEven must
47// be supported. Other rounding modes may be supported.
53
54// Conversion functions - For instructions that convert from one microscaling
55// format to another. We only need the conversion functions as there do not
56// appear to be any instructions yet which operate directly on the MX formats.
57//
58// in - An MXFP info struct type
59// mode - rounding mode
60// seed - input value for stochastic rounding function
61template<typename dFMT, typename sFMT>
63 uint32_t seed = 0)
64{
65 // We assume that *both* exponent and mantissa bits are both >= or <=
66 // the target type. Checkable at compile time.
67 //
68 // This is not necessarily a limitation, others just are not implemented.
69 // Figuring this out would be interesting for converting FP8 <-> BF8 for
70 // example. So far all GPU conversion instructions convert explicitly to
71 // a larger type from a smaller type or smaller to larger.
72 static_assert(((int(sFMT::mbits) >= int(dFMT::mbits)) &&
73 (int(sFMT::ebits) >= int(dFMT::ebits)))
74 || ((int(sFMT::mbits) <= int(dFMT::mbits)) &&
75 (int(sFMT::ebits) <= int(dFMT::ebits))));
76
77 dFMT out;
78 out.storage = 0;
79
80 // Optimize self-conversions which may happen due to copy constructors.
81 if (std::is_same_v<sFMT, dFMT>) {
82 out.storage = in.storage;
83 return out;
84 }
85
86 if constexpr (int(sFMT::mbits) >= int(dFMT::mbits) &&
87 int(sFMT::ebits) >= int(dFMT::ebits)) {
88 // Input format is larger, truncate and round mantissa. MX formats
89 // are subnormal if exp == 0. Zero out exp in that case.
90
91 if (std::isnan(in)) {
92 // For types with no NaN return max value.
93 if (std::numeric_limits<dFMT>::has_quiet_NaN) {
94 out = std::numeric_limits<dFMT>::quiet_NaN();
95 // Preserve sign bit
96 if (in.storage & 0x80000000) {
97 out.storage |= 0x80000000;
98 }
99 } else {
100 out = std::numeric_limits<dFMT>::max();
101 // Preserve sign bit
102 if (in.storage & 0x80000000) {
103 out.storage |= 0x80000000;
104 }
105 }
106 } else if (std::isinf(in)) {
107 // For types with no Inf return max value.
108 if (std::numeric_limits<dFMT>::has_infinity) {
109 out = std::numeric_limits<dFMT>::infinity();
110 // Preserve sign bit
111 if (in.storage & 0x80000000) {
112 out.storage |= 0x80000000;
113 }
114 } else {
115 out = std::numeric_limits<dFMT>::max();
116 // Preserve sign bit
117 if (in.storage & 0x80000000) {
118 out.storage |= 0x80000000;
119 }
120 }
121 } else if (in.mant == 0 && in.exp == 0) {
122 // All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign.
123 out.mant = 0;
124 out.exp = 0;
125 out.sign = in.sign;
126 } else {
127 // Extra bits are needed for the mantissa conversion.
128 uint32_t mant = in.mant & mask(sFMT::mbits);
129 int32_t exp = in.exp - sFMT::bias + dFMT::bias;
130 out.sign = in.sign;
131
132 // Input is not subnormal, add the implicit 1 bit.
133 if (in.exp) {
134 mant |= (1 << sFMT::mbits);
135 }
136
137 // Save the value for rounding so we don't need to recompute it.
138 uint32_t saved_mant = mant;
139
140 mant >>= (sFMT::mbits - dFMT::mbits);
141
142 // Output became subnormal
143 if (exp < 1 && in.exp) {
144 int shift = 1 - exp;
145 mant >>= shift;
146 out.exp = 0;
147 } else {
148 out.exp = exp;
149 }
150
151 mant &= mask(dFMT::mbits);
152 out.mant = mant;
153
154 // roundTiesToEven is the only required rounding mode for MXFP
155 // types. Here we take the input mantissa and check the first
156 // three bits that were shifted out. These are called guard,
157 // round, and sticky bits. The value of these three bits combined
158 // are used to determine if we should round up or down. If the
159 // value is directly in between, we look at the final bit of the
160 // output mantissa with guard, round, sticky shifted out. If the
161 // value is one, round to nearest even by rounding down (set it to
162 // zero).
163 //
164 // For denormals, the process is similar, but we shift the input
165 // mantissa by 1 - exp more bits before setting the value of guard,
166 // round, sticky. Note that for denormals exp < 1 (i.e., shift
167 // value is always positive).
168 //
169 // If the number of destination and source format mantissa bits are
170 // the same, the mantissa is unchanged.
171 if (int(sFMT::mbits) > int(dFMT::mbits)
172 && mode == roundTiesToEven) {
173 bool round_up = false;
174
175 // Round using guard, round, sticky bits. We want to make sure
176 // there are three bits remaining. This is currently true for
177 // all conversion instructions. This would need to be revisited
178 // if there are f4 <-> f6 or f6 <-> f8 conversions.
179 assert((sFMT::mbits - dFMT::mbits) > 2);
180
181 int check_shift = sFMT::mbits - dFMT::mbits - 3;
182 uint32_t check_mant = saved_mant;
183
184 // Sticky bit is 1 if *any* of the N-2 bits that get shifted
185 // off are one. Being zero implies we are directly between two
186 // floating point values.
187 int sticky = (check_mant & mask(check_shift + 1)) != 0;
188
189 check_mant >>= check_shift;
190 if (exp < 1) {
191 int shift = 1 - exp;
192 check_mant >>= shift;
193 }
194
195 // Combine guard, round, sticky into one 3-bit value. If that
196 // value is < 0b100 we round down (truncate -- nothing to do),
197 // if it is > 0b100 we round up. If it is == 0b100, round to
198 // nearest even.
199 uint32_t check_test = check_mant & 0x7;
200
201 // Add sticky to the 3-bit check value.
202 check_test += sticky;
203
204 if (check_test > 0x4) {
205 round_up = true;
206 }
207
208 if (round_up) {
209 if (out.mant == mask(dFMT::mbits)) {
210 // Mantissa at max value, increment exponent if not inf
211 if (out.exp != mask(dFMT::ebits)) {
212 out.exp++;
213 }
214 out.mant = 0;
215 } else {
216 out.mant++;
217 }
218 }
219 } else if (int(sFMT::mbits) > int(dFMT::mbits)
220 && mode == roundStochastic) {
221 // Use the discarded mantissa divided by the max mantissa of
222 // the source format to determine the probability of rounding
223 // up. An alternate implementation of this would be to get a
224 // random number and add that to the input mantissa. Then
225 // follow the normal rounding path above.
226 uint32_t discarded = in.mant & mask(sFMT::mbits - dFMT::mbits);
227 uint32_t max_mant = mask(sFMT::mbits);
228
229 float round_prob = float(discarded) / float(max_mant);
230
231 // Use a stochastic rounding function with the seed value to
232 // determine compare probability. This is implemented as a
233 // "Galois LFSR."
234 auto srFunc = [](uint32_t in) {
235 uint32_t bit = (in ^ (in >> 1) ^ (in >> 3) ^ (in >> 12));
236 return (in >> 1) | (bit << 15);
237 };
238
239 // Assume stochastic rounding returns up to max uint32_t.
240 // This will return an FP value between 0.0f and 1.0f.
241 float draw_prob = float(srFunc(seed))
242 / float(std::numeric_limits<uint32_t>::max());
243
244 // Round up if the number we drew is less than the rounding
245 // probability. E.g., if round_prob is 90% (0.9) we choose
246 // values 0.0f - 0.90f to round up.
247 if (round_prob <= draw_prob) {
248 if (out.mant == mask(dFMT::mbits)) {
249 // mantissa at max value, increment exponent if not inf
250 if (out.exp != mask(dFMT::ebits)) {
251 out.exp++;
252 }
253 out.mant = 0;
254 } else {
255 out.mant++;
256 }
257 }
258 }
259 }
260 } else if constexpr (int(sFMT::mbits) <= int(dFMT::mbits) &&
261 int(sFMT::ebits) <= int(dFMT::ebits)) {
262 // Input format is smaller. Extend mantissa / exponent and pad with 0.
263 // Should be the same for all non-stochastic rounding modes.
264
265 if (std::isnan(in)) {
266 // For types with no NaN return max value.
267 if (std::numeric_limits<dFMT>::has_quiet_NaN) {
268 out = std::numeric_limits<dFMT>::quiet_NaN();
269 // Preserve sign bit
270 if (in.storage & 0x80000000) {
271 out.storage |= 0x80000000;
272 }
273 } else {
274 out = std::numeric_limits<dFMT>::max();
275 // Preserve sign bit
276 if (in.storage & 0x80000000) {
277 out.storage |= 0x80000000;
278 }
279 }
280 } else if (std::isinf(in)) {
281 // For types with no Inf return max value.
282 if (std::numeric_limits<dFMT>::has_infinity) {
283 out = std::numeric_limits<dFMT>::infinity();
284 // Preserve sign bit
285 if (in.storage & 0x80000000) {
286 out.storage |= 0x80000000;
287 }
288 } else {
289 out = std::numeric_limits<dFMT>::max();
290 // Preserve sign bit
291 if (in.storage & 0x80000000) {
292 out.storage |= 0x80000000;
293 }
294 }
295 } else if (in.mant == 0 && in.exp == 0) {
296 // All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign.
297 out.mant = 0;
298 out.exp = 0;
299 out.sign = in.sign;
300 } else {
301 out.mant = in.mant << (dFMT::mbits - sFMT::mbits);
302 out.exp = in.exp + dFMT::bias - sFMT::bias;
303 out.sign = in.sign;
304
305 // Normalize input denormals. This only applies when exponents are
306 // different. Otherwise, the operation is simply a zero extend of
307 // mantissa (e.g., bf16 -> f32).
308 if (!in.exp && int(sFMT::ebits) != int(dFMT::ebits)) {
309 uint32_t m = out.mant;
310 if (m != 0) {
311 out.exp++;
312 while (!(m >> dFMT::mbits)) {
313 m <<= 1;
314 out.exp--;
315 }
316 out.mant = m & mask(dFMT::mbits);
317 }
318 }
319 }
320 } else {
321 assert(false);
322 }
323
324 return out;
325}
326
327template<typename FMT>
329{
330 return 0;
331}
332
333template<typename FMT>
335{
336 return (1 << FMT::ebits) - 1;
337}
338
339
340} // namespace AMDGPU
341
342} // namespace gem5
343
344#endif // __ARCH_AMDGPU_COMMON_DTYPE_MXFP_CONVERT_HH__
dFMT convertMXFP(sFMT in, mxfpRoundingMode mode=roundTiesToEven, uint32_t seed=0)
Bitfield< 3, 0 > mask
Definition pcstate.hh:63
Bitfield< 4, 0 > mode
Definition misc_types.hh:74
Bitfield< 0 > m
Bitfield< 6, 5 > shift
Definition types.hh:117
Copyright (c) 2024 Arm Limited All rights reserved.
Definition binary32.hh:36
constexpr bool isinf(gem5::AMDGPU::fp16_e5m10_info a)
Definition fp16_e5m10.hh:78
constexpr bool isnan(gem5::AMDGPU::fp16_e5m10_info a)
Definition fp16_e5m10.hh:83

Generated on Mon Oct 27 2025 04:12:28 for gem5 by doxygen 1.14.0