From 62a2c09d4b9075fc98543faa4a1e0218df067703 Mon Sep 17 00:00:00 2001 From: Matthew Poremba Date: Sat, 10 Aug 2024 10:04:09 -0700 Subject: [PATCH] arch-vega: Rework rounding for microscaling conversions The current implementation does not correctly convert subnormal numbers (number that fill the underflow gap around zero in floating-point arithmetic). This commit reworks the rounding code to get correct results. First, the min_exp is set to 0 which allows for numbers to become subnormal when rounding. Second, the rounding code now uses something closer to "GRS" rounding (guard, round, sticky) which represent the first bit removed when rounding to a smaller type, the next second bit removed, and whether any of the other bits removed are one. More details can be found in the code comments. Change-Id: Idcd2f1e4383e4012fc3abf73b1f73c847d44f67b --- src/arch/amdgpu/common/dtype/mxfp_convert.hh | 96 +++++++++++--------- 1 file changed, 54 insertions(+), 42 deletions(-) diff --git a/src/arch/amdgpu/common/dtype/mxfp_convert.hh b/src/arch/amdgpu/common/dtype/mxfp_convert.hh index 11cd726720..e7e3613af0 100644 --- a/src/arch/amdgpu/common/dtype/mxfp_convert.hh +++ b/src/arch/amdgpu/common/dtype/mxfp_convert.hh @@ -128,6 +128,9 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, mant |= (1 << sFMT::mbits); } + // Save the value for rounding so we don't need to recompute it. + uint32_t saved_mant = mant; + mant >>= (sFMT::mbits - dFMT::mbits); // Output became subnormal @@ -143,18 +146,19 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, out.mant = mant; // roundTiesToEven is the only required rounding mode for MXFP - // types. Here we take the original mantissa and check the final - // bit which is shifted out when converting the mantissa. If that - // value is one, then we should round up to the next representable - // number. If the value is one and all other discarded mantissa - // bits are zero, round towards the number which has an even (0) - // bit value in the least significant mantissa bit. + // types. Here we take the input mantissa and check the first + // three bits that were shifted out. These are called guard, + // round, and sticky bits. The value of these three bits combined + // are used to determine if we should round up or down. If the + // value is directly in between, we look at the final bit of the + // output mantissa with guard, round, sticky shifted out. If the + // value is one, round to nearest even by rounding down (set it to + // zero). // - // For denormals, the process is similar however we check the nth - // bit of the converted mantissa, where n is the absolute value of - // the converted exponent. If the value of |exp| is larger than - // the max exponent, round to zero. If it is exactly equal, always - // round up. + // For denormals, the process is similar, but we shift the input + // mantissa by 1 - exp more bits before setting the value of guard, + // round, sticky. Note that for denormals exp < 1 (i.e., shift + // value is always positive). // // If the number of destination and source format mantissa bits are // the same, the mantissa is unchanged. @@ -162,45 +166,53 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, && mode == roundTiesToEven) { bool round_up = false; - int check_shift = sFMT::mbits - dFMT::mbits - 1; - uint32_t check_mant = in.mant & mask(sFMT::mbits); + // Round using guard, round, sticky bits. We want to make sure + // there are three bits remaining. This is currently true for + // all conversion instructions. This would need to be revisited + // if there are f4 <-> f6 or f6 <-> f8 conversions. + assert((sFMT::mbits - dFMT::mbits) > 2); + + int check_shift = sFMT::mbits - dFMT::mbits - 3; + uint32_t check_mant = saved_mant; + + // Sticky bit is 1 if *any* of the N-2 bits that get shifted + // off are one. Being zero implies we are directly between two + // floating point values. + int sticky = (check_mant & mask(check_shift + 1)) != 0; check_mant >>= check_shift; - - // out.exp == 0 means subnormal - if (out.exp == 0) { - check_mant = in.mant >> (sFMT::mbits - dFMT::mbits); - - uint32_t max_exp = mask(dFMT::ebits); - if (-exp > max_exp) { - // if exp < -(1 << dFMT::ebits), result should be 0 - round_up = false; - } else if (-exp == max_exp) { - // if exp == -(1 << dFMT::ebits), round up - round_up = true; - } else { - // Use the |exp|'th bit to determine rounding - int check_bit = 1 << -exp; - round_up = (check_mant & check_bit); - } - } else { - round_up = (check_mant & 0x1); + if (exp < 1) { + int shift = 1 - exp; + check_mant >>= shift; } - // For roundTiesToEven, if we are exactly between two - // representable numbers, pick the one with an even least - // significant mantissa bit. We are exactly between when - // all of the discarded mantissa bits are 0 (i.e., !sticky). - int sticky = in.mant & mask(sFMT::mbits - dFMT::mbits); - if (round_up && !sticky) { - if (!(out.mant & 1)) { - round_up = false; + // Combine guard, round, sticky into one 3-bit value. If that + // value is < 0b100 we round down (truncate -- nothing to do), + // if it is > 0b100 we round up. If it is == 0b100, round to + // nearest even. + uint32_t check_test = check_mant & 0x7; + + // Add sticky to the 3-bit check value. + check_test += sticky; + + if (check_test > 0x4) { + round_up = true; + } else if (check_test == 0x4) { + // We are exactly between two FP values. Round to nearest + // even by looking at the last bit of output mantissa. + // If the last bit of the output mantissa is 1, round to + // nearest even (0 in last bit) which would simply be + // rounding down. The bit position of the last bit in this + // case is 0x8 since we kept three extra bits for guard, + // round, sticky. + if (check_mant & 0x8) { + out.mant -= 1; } } if (round_up) { if (out.mant == mask(dFMT::mbits)) { - // mantissa at max value, increment exponent if not inf + // Mantissa at max value, increment exponent if not inf if (out.exp != mask(dFMT::ebits)) { out.exp++; } @@ -324,7 +336,7 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, template int min_exp() { - return 1; + return 0; } template