arch-vega: Update microscaling format scaling and denorm handling (#1451)

This PR has 3 commits:
- Update scaling methods to scale by multiplication or division when
upcasting or downcasting respectively.
- Preserve the sign when a microscaling conversion results in NaN or
infinity to match hardware.
- Rework rounding to handle cases where conversion results in a denormal
number in the output type so that the value is correct.
This commit is contained in:
Matthew Poremba
2024-08-12 07:00:26 -07:00
committed by GitHub
2 changed files with 130 additions and 61 deletions

View File

@@ -125,20 +125,23 @@ class mxfp
data = in.storage;
}
// Used for upcasting
void
scale(const float& f)
scaleMul(const float& f)
{
binary32 bfp;
bfp.fp32 = f;
int scale_val = bfp.exp - bfp.bias;
int scale_val = bfp.exp;
// Scale value of 0xFF is NaN. Scaling by NaN returns NaN.
// In this implementation, types without NaN define it as zero.
// In this implementation, types without NaN define it as max().
if (scale_val == 0xFF) {
data = FMT::nan;
return;
}
scale_val -= bfp.bias;
FMT in = getFmt();
int exp = in.exp;
@@ -153,27 +156,49 @@ class mxfp
data = in.storage;
}
// Used for downcasting
void
scaleDiv(const float& f)
{
binary32 bfp;
bfp.fp32 = f;
int scale_val = bfp.exp;
// Scale value of 0xFF is NaN. Scaling by NaN returns NaN.
// In this implementation, types without NaN define it as max().
if (scale_val == 0xFF) {
data = FMT::nan;
return;
}
scale_val -= bfp.bias;
FMT in = getFmt();
int exp = in.exp;
if (exp - scale_val > max_exp<FMT>()) {
in.exp = max_exp<FMT>();
} else if (exp - scale_val < min_exp<FMT>()) {
in.exp = min_exp<FMT>();
} else {
in.exp = exp - scale_val;
// Output become denorm
if (in.exp == 0) {
uint32_t m = in.mant | 1 << FMT::mbits;
m >>= 1;
in.mant = m & mask(FMT::mbits);
}
}
data = in.storage;
}
private:
mxfpRoundingMode mode = roundTiesToEven;
uint32_t
float_to_mxfp(float f)
{
if (std::isinf(f)) {
assert(std::numeric_limits<FMT>::has_infinity);
return FMT::inf;
}
if (std::isnan(f)) {
assert(std::numeric_limits<FMT>::has_quiet_NaN);
return FMT::nan;
}
return float_to_mxfp_nocheck(f);
}
uint32_t
float_to_mxfp_nocheck(float f)
{
binary32 in;
in.fp32 = f;

View File

@@ -86,15 +86,31 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
// For types with no NaN return max value.
if (std::numeric_limits<dFMT>::has_quiet_NaN) {
out = std::numeric_limits<dFMT>::quiet_NaN();
// Preserve sign bit
if (in.storage & 0x80000000) {
out.storage |= 0x80000000;
}
} else {
out = std::numeric_limits<dFMT>::max();
// Preserve sign bit
if (in.storage & 0x80000000) {
out.storage |= 0x80000000;
}
}
} else if (std::isinf(in)) {
// For types with no Inf return max value.
if (std::numeric_limits<dFMT>::has_infinity) {
out = std::numeric_limits<dFMT>::infinity();
// Preserve sign bit
if (in.storage & 0x80000000) {
out.storage |= 0x80000000;
}
} else {
out = std::numeric_limits<dFMT>::max();
// Preserve sign bit
if (in.storage & 0x80000000) {
out.storage |= 0x80000000;
}
}
} else if (in.mant == 0 && in.exp == 0) {
// All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign.
@@ -112,6 +128,9 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
mant |= (1 << sFMT::mbits);
}
// Save the value for rounding so we don't need to recompute it.
uint32_t saved_mant = mant;
mant >>= (sFMT::mbits - dFMT::mbits);
// Output became subnormal
@@ -127,18 +146,19 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
out.mant = mant;
// roundTiesToEven is the only required rounding mode for MXFP
// types. Here we take the original mantissa and check the final
// bit which is shifted out when converting the mantissa. If that
// value is one, then we should round up to the next representable
// number. If the value is one and all other discarded mantissa
// bits are zero, round towards the number which has an even (0)
// bit value in the least significant mantissa bit.
// types. Here we take the input mantissa and check the first
// three bits that were shifted out. These are called guard,
// round, and sticky bits. The value of these three bits combined
// are used to determine if we should round up or down. If the
// value is directly in between, we look at the final bit of the
// output mantissa with guard, round, sticky shifted out. If the
// value is one, round to nearest even by rounding down (set it to
// zero).
//
// For denormals, the process is similar however we check the nth
// bit of the converted mantissa, where n is the absolute value of
// the converted exponent. If the value of |exp| is larger than
// the max exponent, round to zero. If it is exactly equal, always
// round up.
// For denormals, the process is similar, but we shift the input
// mantissa by 1 - exp more bits before setting the value of guard,
// round, sticky. Note that for denormals exp < 1 (i.e., shift
// value is always positive).
//
// If the number of destination and source format mantissa bits are
// the same, the mantissa is unchanged.
@@ -146,45 +166,53 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
&& mode == roundTiesToEven) {
bool round_up = false;
int check_shift = sFMT::mbits - dFMT::mbits - 1;
uint32_t check_mant = in.mant & mask(sFMT::mbits);
// Round using guard, round, sticky bits. We want to make sure
// there are three bits remaining. This is currently true for
// all conversion instructions. This would need to be revisited
// if there are f4 <-> f6 or f6 <-> f8 conversions.
assert((sFMT::mbits - dFMT::mbits) > 2);
int check_shift = sFMT::mbits - dFMT::mbits - 3;
uint32_t check_mant = saved_mant;
// Sticky bit is 1 if *any* of the N-2 bits that get shifted
// off are one. Being zero implies we are directly between two
// floating point values.
int sticky = (check_mant & mask(check_shift + 1)) != 0;
check_mant >>= check_shift;
// out.exp == 0 means subnormal
if (out.exp == 0) {
check_mant = in.mant >> (sFMT::mbits - dFMT::mbits);
uint32_t max_exp = mask(dFMT::ebits);
if (-exp > max_exp) {
// if exp < -(1 << dFMT::ebits), result should be 0
round_up = false;
} else if (-exp == max_exp) {
// if exp == -(1 << dFMT::ebits), round up
round_up = true;
} else {
// Use the |exp|'th bit to determine rounding
int check_bit = 1 << -exp;
round_up = (check_mant & check_bit);
}
} else {
round_up = (check_mant & 0x1);
if (exp < 1) {
int shift = 1 - exp;
check_mant >>= shift;
}
// For roundTiesToEven, if we are exactly between two
// representable numbers, pick the one with an even least
// significant mantissa bit. We are exactly between when
// all of the discarded mantissa bits are 0 (i.e., !sticky).
int sticky = in.mant & mask(sFMT::mbits - dFMT::mbits);
if (round_up && !sticky) {
if (!(out.mant & 1)) {
round_up = false;
// Combine guard, round, sticky into one 3-bit value. If that
// value is < 0b100 we round down (truncate -- nothing to do),
// if it is > 0b100 we round up. If it is == 0b100, round to
// nearest even.
uint32_t check_test = check_mant & 0x7;
// Add sticky to the 3-bit check value.
check_test += sticky;
if (check_test > 0x4) {
round_up = true;
} else if (check_test == 0x4) {
// We are exactly between two FP values. Round to nearest
// even by looking at the last bit of output mantissa.
// If the last bit of the output mantissa is 1, round to
// nearest even (0 in last bit) which would simply be
// rounding down. The bit position of the last bit in this
// case is 0x8 since we kept three extra bits for guard,
// round, sticky.
if (check_mant & 0x8) {
out.mant -= 1;
}
}
if (round_up) {
if (out.mant == mask(dFMT::mbits)) {
// mantissa at max value, increment exponent if not inf
// Mantissa at max value, increment exponent if not inf
if (out.exp != mask(dFMT::ebits)) {
out.exp++;
}
@@ -243,15 +271,31 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
// For types with no NaN return max value.
if (std::numeric_limits<dFMT>::has_quiet_NaN) {
out = std::numeric_limits<dFMT>::quiet_NaN();
// Preserve sign bit
if (in.storage & 0x80000000) {
out.storage |= 0x80000000;
}
} else {
out = std::numeric_limits<dFMT>::max();
// Preserve sign bit
if (in.storage & 0x80000000) {
out.storage |= 0x80000000;
}
}
} else if (std::isinf(in)) {
// For types with no Inf return max value.
if (std::numeric_limits<dFMT>::has_infinity) {
out = std::numeric_limits<dFMT>::infinity();
// Preserve sign bit
if (in.storage & 0x80000000) {
out.storage |= 0x80000000;
}
} else {
out = std::numeric_limits<dFMT>::max();
// Preserve sign bit
if (in.storage & 0x80000000) {
out.storage |= 0x80000000;
}
}
} else if (in.mant == 0 && in.exp == 0) {
// All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign.
@@ -292,7 +336,7 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
template<typename FMT>
int min_exp()
{
return 1;
return 0;
}
template<typename FMT>