diff --git a/src/arch/amdgpu/common/dtype/mxfp_convert.hh b/src/arch/amdgpu/common/dtype/mxfp_convert.hh index 11cd726720..e7e3613af0 100644 --- a/src/arch/amdgpu/common/dtype/mxfp_convert.hh +++ b/src/arch/amdgpu/common/dtype/mxfp_convert.hh @@ -128,6 +128,9 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, mant |= (1 << sFMT::mbits); } + // Save the value for rounding so we don't need to recompute it. + uint32_t saved_mant = mant; + mant >>= (sFMT::mbits - dFMT::mbits); // Output became subnormal @@ -143,18 +146,19 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, out.mant = mant; // roundTiesToEven is the only required rounding mode for MXFP - // types. Here we take the original mantissa and check the final - // bit which is shifted out when converting the mantissa. If that - // value is one, then we should round up to the next representable - // number. If the value is one and all other discarded mantissa - // bits are zero, round towards the number which has an even (0) - // bit value in the least significant mantissa bit. + // types. Here we take the input mantissa and check the first + // three bits that were shifted out. These are called guard, + // round, and sticky bits. The value of these three bits combined + // are used to determine if we should round up or down. If the + // value is directly in between, we look at the final bit of the + // output mantissa with guard, round, sticky shifted out. If the + // value is one, round to nearest even by rounding down (set it to + // zero). // - // For denormals, the process is similar however we check the nth - // bit of the converted mantissa, where n is the absolute value of - // the converted exponent. If the value of |exp| is larger than - // the max exponent, round to zero. If it is exactly equal, always - // round up. + // For denormals, the process is similar, but we shift the input + // mantissa by 1 - exp more bits before setting the value of guard, + // round, sticky. Note that for denormals exp < 1 (i.e., shift + // value is always positive). // // If the number of destination and source format mantissa bits are // the same, the mantissa is unchanged. @@ -162,45 +166,53 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, && mode == roundTiesToEven) { bool round_up = false; - int check_shift = sFMT::mbits - dFMT::mbits - 1; - uint32_t check_mant = in.mant & mask(sFMT::mbits); + // Round using guard, round, sticky bits. We want to make sure + // there are three bits remaining. This is currently true for + // all conversion instructions. This would need to be revisited + // if there are f4 <-> f6 or f6 <-> f8 conversions. + assert((sFMT::mbits - dFMT::mbits) > 2); + + int check_shift = sFMT::mbits - dFMT::mbits - 3; + uint32_t check_mant = saved_mant; + + // Sticky bit is 1 if *any* of the N-2 bits that get shifted + // off are one. Being zero implies we are directly between two + // floating point values. + int sticky = (check_mant & mask(check_shift + 1)) != 0; check_mant >>= check_shift; - - // out.exp == 0 means subnormal - if (out.exp == 0) { - check_mant = in.mant >> (sFMT::mbits - dFMT::mbits); - - uint32_t max_exp = mask(dFMT::ebits); - if (-exp > max_exp) { - // if exp < -(1 << dFMT::ebits), result should be 0 - round_up = false; - } else if (-exp == max_exp) { - // if exp == -(1 << dFMT::ebits), round up - round_up = true; - } else { - // Use the |exp|'th bit to determine rounding - int check_bit = 1 << -exp; - round_up = (check_mant & check_bit); - } - } else { - round_up = (check_mant & 0x1); + if (exp < 1) { + int shift = 1 - exp; + check_mant >>= shift; } - // For roundTiesToEven, if we are exactly between two - // representable numbers, pick the one with an even least - // significant mantissa bit. We are exactly between when - // all of the discarded mantissa bits are 0 (i.e., !sticky). - int sticky = in.mant & mask(sFMT::mbits - dFMT::mbits); - if (round_up && !sticky) { - if (!(out.mant & 1)) { - round_up = false; + // Combine guard, round, sticky into one 3-bit value. If that + // value is < 0b100 we round down (truncate -- nothing to do), + // if it is > 0b100 we round up. If it is == 0b100, round to + // nearest even. + uint32_t check_test = check_mant & 0x7; + + // Add sticky to the 3-bit check value. + check_test += sticky; + + if (check_test > 0x4) { + round_up = true; + } else if (check_test == 0x4) { + // We are exactly between two FP values. Round to nearest + // even by looking at the last bit of output mantissa. + // If the last bit of the output mantissa is 1, round to + // nearest even (0 in last bit) which would simply be + // rounding down. The bit position of the last bit in this + // case is 0x8 since we kept three extra bits for guard, + // round, sticky. + if (check_mant & 0x8) { + out.mant -= 1; } } if (round_up) { if (out.mant == mask(dFMT::mbits)) { - // mantissa at max value, increment exponent if not inf + // Mantissa at max value, increment exponent if not inf if (out.exp != mask(dFMT::ebits)) { out.exp++; } @@ -324,7 +336,7 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, template int min_exp() { - return 1; + return 0; } template