diff --git a/src/arch/amdgpu/common/dtype/mxfp.hh b/src/arch/amdgpu/common/dtype/mxfp.hh index d7edb32dbf..8a70b84012 100644 --- a/src/arch/amdgpu/common/dtype/mxfp.hh +++ b/src/arch/amdgpu/common/dtype/mxfp.hh @@ -125,20 +125,23 @@ class mxfp data = in.storage; } + // Used for upcasting void - scale(const float& f) + scaleMul(const float& f) { binary32 bfp; bfp.fp32 = f; - int scale_val = bfp.exp - bfp.bias; + int scale_val = bfp.exp; // Scale value of 0xFF is NaN. Scaling by NaN returns NaN. - // In this implementation, types without NaN define it as zero. + // In this implementation, types without NaN define it as max(). if (scale_val == 0xFF) { data = FMT::nan; return; } + scale_val -= bfp.bias; + FMT in = getFmt(); int exp = in.exp; @@ -153,27 +156,49 @@ class mxfp data = in.storage; } + // Used for downcasting + void + scaleDiv(const float& f) + { + binary32 bfp; + bfp.fp32 = f; + int scale_val = bfp.exp; + + // Scale value of 0xFF is NaN. Scaling by NaN returns NaN. + // In this implementation, types without NaN define it as max(). + if (scale_val == 0xFF) { + data = FMT::nan; + return; + } + + scale_val -= bfp.bias; + + FMT in = getFmt(); + int exp = in.exp; + + if (exp - scale_val > max_exp()) { + in.exp = max_exp(); + } else if (exp - scale_val < min_exp()) { + in.exp = min_exp(); + } else { + in.exp = exp - scale_val; + + // Output become denorm + if (in.exp == 0) { + uint32_t m = in.mant | 1 << FMT::mbits; + m >>= 1; + in.mant = m & mask(FMT::mbits); + } + } + + data = in.storage; + } + private: mxfpRoundingMode mode = roundTiesToEven; uint32_t float_to_mxfp(float f) - { - if (std::isinf(f)) { - assert(std::numeric_limits::has_infinity); - return FMT::inf; - } - - if (std::isnan(f)) { - assert(std::numeric_limits::has_quiet_NaN); - return FMT::nan; - } - - return float_to_mxfp_nocheck(f); - } - - uint32_t - float_to_mxfp_nocheck(float f) { binary32 in; in.fp32 = f; diff --git a/src/arch/amdgpu/common/dtype/mxfp_convert.hh b/src/arch/amdgpu/common/dtype/mxfp_convert.hh index 641d5f5732..e7e3613af0 100644 --- a/src/arch/amdgpu/common/dtype/mxfp_convert.hh +++ b/src/arch/amdgpu/common/dtype/mxfp_convert.hh @@ -86,15 +86,31 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, // For types with no NaN return max value. if (std::numeric_limits::has_quiet_NaN) { out = std::numeric_limits::quiet_NaN(); + // Preserve sign bit + if (in.storage & 0x80000000) { + out.storage |= 0x80000000; + } } else { out = std::numeric_limits::max(); + // Preserve sign bit + if (in.storage & 0x80000000) { + out.storage |= 0x80000000; + } } } else if (std::isinf(in)) { // For types with no Inf return max value. if (std::numeric_limits::has_infinity) { out = std::numeric_limits::infinity(); + // Preserve sign bit + if (in.storage & 0x80000000) { + out.storage |= 0x80000000; + } } else { out = std::numeric_limits::max(); + // Preserve sign bit + if (in.storage & 0x80000000) { + out.storage |= 0x80000000; + } } } else if (in.mant == 0 && in.exp == 0) { // All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign. @@ -112,6 +128,9 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, mant |= (1 << sFMT::mbits); } + // Save the value for rounding so we don't need to recompute it. + uint32_t saved_mant = mant; + mant >>= (sFMT::mbits - dFMT::mbits); // Output became subnormal @@ -127,18 +146,19 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, out.mant = mant; // roundTiesToEven is the only required rounding mode for MXFP - // types. Here we take the original mantissa and check the final - // bit which is shifted out when converting the mantissa. If that - // value is one, then we should round up to the next representable - // number. If the value is one and all other discarded mantissa - // bits are zero, round towards the number which has an even (0) - // bit value in the least significant mantissa bit. + // types. Here we take the input mantissa and check the first + // three bits that were shifted out. These are called guard, + // round, and sticky bits. The value of these three bits combined + // are used to determine if we should round up or down. If the + // value is directly in between, we look at the final bit of the + // output mantissa with guard, round, sticky shifted out. If the + // value is one, round to nearest even by rounding down (set it to + // zero). // - // For denormals, the process is similar however we check the nth - // bit of the converted mantissa, where n is the absolute value of - // the converted exponent. If the value of |exp| is larger than - // the max exponent, round to zero. If it is exactly equal, always - // round up. + // For denormals, the process is similar, but we shift the input + // mantissa by 1 - exp more bits before setting the value of guard, + // round, sticky. Note that for denormals exp < 1 (i.e., shift + // value is always positive). // // If the number of destination and source format mantissa bits are // the same, the mantissa is unchanged. @@ -146,45 +166,53 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, && mode == roundTiesToEven) { bool round_up = false; - int check_shift = sFMT::mbits - dFMT::mbits - 1; - uint32_t check_mant = in.mant & mask(sFMT::mbits); + // Round using guard, round, sticky bits. We want to make sure + // there are three bits remaining. This is currently true for + // all conversion instructions. This would need to be revisited + // if there are f4 <-> f6 or f6 <-> f8 conversions. + assert((sFMT::mbits - dFMT::mbits) > 2); + + int check_shift = sFMT::mbits - dFMT::mbits - 3; + uint32_t check_mant = saved_mant; + + // Sticky bit is 1 if *any* of the N-2 bits that get shifted + // off are one. Being zero implies we are directly between two + // floating point values. + int sticky = (check_mant & mask(check_shift + 1)) != 0; check_mant >>= check_shift; - - // out.exp == 0 means subnormal - if (out.exp == 0) { - check_mant = in.mant >> (sFMT::mbits - dFMT::mbits); - - uint32_t max_exp = mask(dFMT::ebits); - if (-exp > max_exp) { - // if exp < -(1 << dFMT::ebits), result should be 0 - round_up = false; - } else if (-exp == max_exp) { - // if exp == -(1 << dFMT::ebits), round up - round_up = true; - } else { - // Use the |exp|'th bit to determine rounding - int check_bit = 1 << -exp; - round_up = (check_mant & check_bit); - } - } else { - round_up = (check_mant & 0x1); + if (exp < 1) { + int shift = 1 - exp; + check_mant >>= shift; } - // For roundTiesToEven, if we are exactly between two - // representable numbers, pick the one with an even least - // significant mantissa bit. We are exactly between when - // all of the discarded mantissa bits are 0 (i.e., !sticky). - int sticky = in.mant & mask(sFMT::mbits - dFMT::mbits); - if (round_up && !sticky) { - if (!(out.mant & 1)) { - round_up = false; + // Combine guard, round, sticky into one 3-bit value. If that + // value is < 0b100 we round down (truncate -- nothing to do), + // if it is > 0b100 we round up. If it is == 0b100, round to + // nearest even. + uint32_t check_test = check_mant & 0x7; + + // Add sticky to the 3-bit check value. + check_test += sticky; + + if (check_test > 0x4) { + round_up = true; + } else if (check_test == 0x4) { + // We are exactly between two FP values. Round to nearest + // even by looking at the last bit of output mantissa. + // If the last bit of the output mantissa is 1, round to + // nearest even (0 in last bit) which would simply be + // rounding down. The bit position of the last bit in this + // case is 0x8 since we kept three extra bits for guard, + // round, sticky. + if (check_mant & 0x8) { + out.mant -= 1; } } if (round_up) { if (out.mant == mask(dFMT::mbits)) { - // mantissa at max value, increment exponent if not inf + // Mantissa at max value, increment exponent if not inf if (out.exp != mask(dFMT::ebits)) { out.exp++; } @@ -243,15 +271,31 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, // For types with no NaN return max value. if (std::numeric_limits::has_quiet_NaN) { out = std::numeric_limits::quiet_NaN(); + // Preserve sign bit + if (in.storage & 0x80000000) { + out.storage |= 0x80000000; + } } else { out = std::numeric_limits::max(); + // Preserve sign bit + if (in.storage & 0x80000000) { + out.storage |= 0x80000000; + } } } else if (std::isinf(in)) { // For types with no Inf return max value. if (std::numeric_limits::has_infinity) { out = std::numeric_limits::infinity(); + // Preserve sign bit + if (in.storage & 0x80000000) { + out.storage |= 0x80000000; + } } else { out = std::numeric_limits::max(); + // Preserve sign bit + if (in.storage & 0x80000000) { + out.storage |= 0x80000000; + } } } else if (in.mant == 0 && in.exp == 0) { // All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign. @@ -292,7 +336,7 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven, template int min_exp() { - return 1; + return 0; } template