gem5 v24.0.0.0
Loading...
Searching...
No Matches
mxfp_convert.hh
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Advanced Micro Devices, Inc.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice,
9 * this list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright notice,
12 * this list of conditions and the following disclaimer in the documentation
13 * and/or other materials provided with the distribution.
14 *
15 * 3. Neither the name of the copyright holder nor the names of its
16 * contributors may be used to endorse or promote products derived from this
17 * software without specific prior written permission.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29 * POSSIBILITY OF SUCH DAMAGE.
30 */
31
32#ifndef __ARCH_AMDGPU_COMMON_DTYPE_MXFP_CONVERT_HH__
33#define __ARCH_AMDGPU_COMMON_DTYPE_MXFP_CONVERT_HH__
34
35#include <cassert>
36
38#include "base/bitfield.hh"
39
40namespace gem5
41{
42
43namespace AMDGPU
44{
45
46// The various rounding modes for microscaling formats. roundTiesToEven must
47// be supported. Other rounding modes may be supported.
53
54// Conversion functions - For instructions that convert from one microscaling
55// format to another. We only need the conversion functions as there do not
56// appear to be any instructions yet which operate directly on the MX formats.
57//
58// in - An MXFP info struct type
59// mode - rounding mode
60// seed - input value for stochastic rounding function
61template<typename dFMT, typename sFMT>
63 uint32_t seed = 0)
64{
65 // We assume that *both* exponent and mantissa bits are both >= or <=
66 // the target type. Checkable at compile time.
67 //
68 // This is not necessarily a limitation, others just are not implemented.
69 // Figuring this out would be interesting for converting FP8 <-> BF8 for
70 // example. So far all GPU conversion instructions convert explicitly to
71 // a larger type from a smaller type or smaller to larger.
72 static_assert(((int(sFMT::mbits) >= int(dFMT::mbits)) &&
73 (int(sFMT::ebits) >= int(dFMT::ebits)))
74 || ((int(sFMT::mbits) <= int(dFMT::mbits)) &&
75 (int(sFMT::ebits) <= int(dFMT::ebits))));
76
77 dFMT out;
78 out.storage = 0;
79
80 if (int(sFMT::mbits) >= int(dFMT::mbits) &&
81 int(sFMT::ebits) >= int(dFMT::ebits)) {
82 // Input format is larger, truncate and round mantissa. MX formats
83 // are subnormal if exp == 0. Zero out exp in that case.
84
85 if (std::isnan(in)) {
86 // For types with no NaN return max value.
87 if (std::numeric_limits<dFMT>::has_quiet_NaN) {
88 out = std::numeric_limits<dFMT>::quiet_NaN();
89 } else {
90 out = std::numeric_limits<dFMT>::max();
91 }
92 } else if (std::isinf(in)) {
93 // For types with no Inf return max value.
94 if (std::numeric_limits<dFMT>::has_infinity) {
95 out = std::numeric_limits<dFMT>::infinity();
96 } else {
97 out = std::numeric_limits<dFMT>::max();
98 }
99 } else if (in.mant == 0 && in.exp == 0) {
100 // All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign.
101 out.mant = 0;
102 out.exp = 0;
103 out.sign = in.sign;
104 } else {
105 // Extra bits are needed for the mantissa conversion.
106 uint32_t mant = in.mant & mask(sFMT::mbits);
107 int32_t exp = in.exp - sFMT::bias + dFMT::bias;
108 out.sign = in.sign;
109
110 // Input is not subnormal, add the implicit 1 bit.
111 if (in.exp) {
112 mant |= (1 << sFMT::mbits);
113 }
114
115 mant >>= (sFMT::mbits - dFMT::mbits);
116
117 // Output became subnormal
118 if (exp < 1) {
119 int shift = 1 - exp;
120 mant >>= shift;
121 out.exp = 0;
122 } else {
123 out.exp = exp;
124 }
125
126 mant &= mask(dFMT::mbits);
127 out.mant = mant;
128
129 // roundTiesToEven is the only required rounding mode for MXFP
130 // types. Here we take the original mantissa and check the final
131 // bit which is shifted out when converting the mantissa. If that
132 // value is one, then we should round up to the next representable
133 // number. If the value is one and all other discarded mantissa
134 // bits are zero, round towards the number which has an even (0)
135 // bit value in the least significant mantissa bit.
136 //
137 // For denormals, the process is similar however we check the nth
138 // bit of the converted mantissa, where n is the absolute value of
139 // the converted exponent. If the value of |exp| is larger than
140 // the max exponent, round to zero. If it is exactly equal, always
141 // round up.
142 //
143 // If the number of destination and source format mantissa bits are
144 // the same, the mantissa is unchanged.
145 if (int(sFMT::mbits) > int(dFMT::mbits)
146 && mode == roundTiesToEven) {
147 bool round_up = false;
148
149 int check_shift = sFMT::mbits - dFMT::mbits - 1;
150 uint32_t check_mant = in.mant & mask(sFMT::mbits);
151
152 check_mant >>= check_shift;
153
154 // out.exp == 0 means subnormal
155 if (out.exp == 0) {
156 check_mant = in.mant >> (sFMT::mbits - dFMT::mbits);
157
158 uint32_t max_exp = mask(dFMT::ebits);
159 if (-exp > max_exp) {
160 // if exp < -(1 << dFMT::ebits), result should be 0
161 round_up = false;
162 } else if (-exp == max_exp) {
163 // if exp == -(1 << dFMT::ebits), round up
164 round_up = true;
165 } else {
166 // Use the |exp|'th bit to determine rounding
167 int check_bit = 1 << -exp;
168 round_up = (check_mant & check_bit);
169 }
170 } else {
171 round_up = (check_mant & 0x1);
172 }
173
174 // For roundTiesToEven, if we are exactly between two
175 // representable numbers, pick the one with an even least
176 // significant mantissa bit. We are exactly between when
177 // all of the discarded mantissa bits are 0 (i.e., !sticky).
178 int sticky = in.mant & mask(sFMT::mbits - dFMT::mbits);
179 if (round_up && !sticky) {
180 if (!(out.mant & 1)) {
181 round_up = false;
182 }
183 }
184
185 if (round_up) {
186 if (out.mant == mask(dFMT::mbits)) {
187 // mantissa at max value, increment exponent if not inf
188 if (out.exp != mask(dFMT::ebits)) {
189 out.exp++;
190 }
191 out.mant = 0;
192 } else {
193 out.mant++;
194 }
195 }
196 } else if (int(sFMT::mbits) > int(dFMT::mbits)
197 && mode == roundStochastic) {
198 // Use the discarded mantissa divided by the max mantissa of
199 // the source format to determine the probability of rounding
200 // up. An alternate implementation of this would be to get a
201 // random number and add that to the input mantissa. Then
202 // follow the normal rounding path above.
203 uint32_t discarded = in.mant & mask(sFMT::mbits - dFMT::mbits);
204 uint32_t max_mant = mask(sFMT::mbits);
205
206 float round_prob = float(discarded) / float(max_mant);
207
208 // Use a stochastic rounding function with the seed value to
209 // determine compare probability. This is implemented as a
210 // "Galois LFSR."
211 auto srFunc = [](uint32_t in) {
212 uint32_t bit = (in ^ (in >> 1) ^ (in >> 3) ^ (in >> 12));
213 return (in >> 1) | (bit << 15);
214 };
215
216 // Assume stochastic rounding returns up to max uint32_t.
217 // This will return an FP value between 0.0f and 1.0f.
218 float draw_prob = float(srFunc(seed))
219 / float(std::numeric_limits<uint32_t>::max());
220
221 // Round up if the number we drew is less than the rounding
222 // probability. E.g., if round_prob is 90% (0.9) we choose
223 // values 0.0f - 0.90f to round up.
224 if (round_prob >= draw_prob) {
225 if (out.mant == mask(dFMT::mbits)) {
226 // mantissa at max value, increment exponent if not inf
227 if (out.exp != mask(dFMT::ebits)) {
228 out.exp++;
229 }
230 out.mant = 0;
231 } else {
232 out.mant++;
233 }
234 }
235 }
236 }
237 } else if (int(sFMT::mbits) <= int(dFMT::mbits) &&
238 int(sFMT::ebits) <= int(dFMT::ebits)) {
239 // Input format is smaller. Extend mantissa / exponent and pad with 0.
240 // Should be the same for all non-stochastic rounding modes.
241
242 if (std::isnan(in)) {
243 // For types with no NaN return max value.
244 if (std::numeric_limits<dFMT>::has_quiet_NaN) {
245 out = std::numeric_limits<dFMT>::quiet_NaN();
246 } else {
247 out = std::numeric_limits<dFMT>::max();
248 }
249 } else if (std::isinf(in)) {
250 // For types with no Inf return max value.
251 if (std::numeric_limits<dFMT>::has_infinity) {
252 out = std::numeric_limits<dFMT>::infinity();
253 } else {
254 out = std::numeric_limits<dFMT>::max();
255 }
256 } else if (in.mant == 0 && in.exp == 0) {
257 // All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign.
258 out.mant = 0;
259 out.exp = 0;
260 out.sign = in.sign;
261 } else {
262 out.mant = in.mant << (dFMT::mbits - sFMT::mbits);
263 out.exp = in.exp + dFMT::bias - sFMT::bias;
264 out.sign = in.sign;
265
266 // Normalize input denormals
267 if (!in.exp && int(sFMT::ebits) != int(dFMT::ebits)) {
268 uint32_t m = out.mant;
269 if (m != 0) {
270 out.exp++;
271 while (!(m >> dFMT::mbits)) {
272 m <<= 1;
273 out.exp--;
274 }
275 out.mant = m & mask(dFMT::mbits);
276 }
277 } else if (!in.exp) {
278 // Exponent is the same, but output is not denorm, so add
279 // implicit 1. This is specific mainly to bf16 -> f32.
280 uint32_t m = out.mant;
281 m <<= 1;
282 out.mant = m & mask(dFMT::mbits);
283 }
284 }
285 } else {
286 assert(false);
287 }
288
289 return out;
290}
291
292template<typename FMT>
294{
295 return 1;
296}
297
298template<typename FMT>
300{
301 return (1 << FMT::ebits) - 1;
302}
303
304
305} // namespace AMDGPU
306
307} // namespace gem5
308
309#endif // __ARCH_AMDGPU_COMMON_DTYPE_MXFP_CONVERT_HH__
dFMT convertMXFP(sFMT in, mxfpRoundingMode mode=roundTiesToEven, uint32_t seed=0)
Bitfield< 3, 0 > mask
Definition pcstate.hh:63
Bitfield< 4, 0 > mode
Definition misc_types.hh:74
Bitfield< 0 > m
Bitfield< 6, 5 > shift
Definition types.hh:117
Copyright (c) 2024 - Pranith Kumar Copyright (c) 2020 Inria All rights reserved.
Definition binary32.hh:36
constexpr bool isinf(gem5::AMDGPU::fp16_e5m10_info a)
Definition fp16_e5m10.hh:78
constexpr bool isnan(gem5::AMDGPU::fp16_e5m10_info a)
Definition fp16_e5m10.hh:83

Generated on Tue Jun 18 2024 16:23:39 for gem5 by doxygen 1.11.0