arch-vega: Update microscaling format scaling and denorm handling (#1451)
This PR has 3 commits: - Update scaling methods to scale by multiplication or division when upcasting or downcasting respectively. - Preserve the sign when a microscaling conversion results in NaN or infinity to match hardware. - Rework rounding to handle cases where conversion results in a denormal number in the output type so that the value is correct.
This commit is contained in:
@@ -125,20 +125,23 @@ class mxfp
|
||||
data = in.storage;
|
||||
}
|
||||
|
||||
// Used for upcasting
|
||||
void
|
||||
scale(const float& f)
|
||||
scaleMul(const float& f)
|
||||
{
|
||||
binary32 bfp;
|
||||
bfp.fp32 = f;
|
||||
int scale_val = bfp.exp - bfp.bias;
|
||||
int scale_val = bfp.exp;
|
||||
|
||||
// Scale value of 0xFF is NaN. Scaling by NaN returns NaN.
|
||||
// In this implementation, types without NaN define it as zero.
|
||||
// In this implementation, types without NaN define it as max().
|
||||
if (scale_val == 0xFF) {
|
||||
data = FMT::nan;
|
||||
return;
|
||||
}
|
||||
|
||||
scale_val -= bfp.bias;
|
||||
|
||||
FMT in = getFmt();
|
||||
int exp = in.exp;
|
||||
|
||||
@@ -153,27 +156,49 @@ class mxfp
|
||||
data = in.storage;
|
||||
}
|
||||
|
||||
// Used for downcasting
|
||||
void
|
||||
scaleDiv(const float& f)
|
||||
{
|
||||
binary32 bfp;
|
||||
bfp.fp32 = f;
|
||||
int scale_val = bfp.exp;
|
||||
|
||||
// Scale value of 0xFF is NaN. Scaling by NaN returns NaN.
|
||||
// In this implementation, types without NaN define it as max().
|
||||
if (scale_val == 0xFF) {
|
||||
data = FMT::nan;
|
||||
return;
|
||||
}
|
||||
|
||||
scale_val -= bfp.bias;
|
||||
|
||||
FMT in = getFmt();
|
||||
int exp = in.exp;
|
||||
|
||||
if (exp - scale_val > max_exp<FMT>()) {
|
||||
in.exp = max_exp<FMT>();
|
||||
} else if (exp - scale_val < min_exp<FMT>()) {
|
||||
in.exp = min_exp<FMT>();
|
||||
} else {
|
||||
in.exp = exp - scale_val;
|
||||
|
||||
// Output become denorm
|
||||
if (in.exp == 0) {
|
||||
uint32_t m = in.mant | 1 << FMT::mbits;
|
||||
m >>= 1;
|
||||
in.mant = m & mask(FMT::mbits);
|
||||
}
|
||||
}
|
||||
|
||||
data = in.storage;
|
||||
}
|
||||
|
||||
private:
|
||||
mxfpRoundingMode mode = roundTiesToEven;
|
||||
|
||||
uint32_t
|
||||
float_to_mxfp(float f)
|
||||
{
|
||||
if (std::isinf(f)) {
|
||||
assert(std::numeric_limits<FMT>::has_infinity);
|
||||
return FMT::inf;
|
||||
}
|
||||
|
||||
if (std::isnan(f)) {
|
||||
assert(std::numeric_limits<FMT>::has_quiet_NaN);
|
||||
return FMT::nan;
|
||||
}
|
||||
|
||||
return float_to_mxfp_nocheck(f);
|
||||
}
|
||||
|
||||
uint32_t
|
||||
float_to_mxfp_nocheck(float f)
|
||||
{
|
||||
binary32 in;
|
||||
in.fp32 = f;
|
||||
|
||||
@@ -86,15 +86,31 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
|
||||
// For types with no NaN return max value.
|
||||
if (std::numeric_limits<dFMT>::has_quiet_NaN) {
|
||||
out = std::numeric_limits<dFMT>::quiet_NaN();
|
||||
// Preserve sign bit
|
||||
if (in.storage & 0x80000000) {
|
||||
out.storage |= 0x80000000;
|
||||
}
|
||||
} else {
|
||||
out = std::numeric_limits<dFMT>::max();
|
||||
// Preserve sign bit
|
||||
if (in.storage & 0x80000000) {
|
||||
out.storage |= 0x80000000;
|
||||
}
|
||||
}
|
||||
} else if (std::isinf(in)) {
|
||||
// For types with no Inf return max value.
|
||||
if (std::numeric_limits<dFMT>::has_infinity) {
|
||||
out = std::numeric_limits<dFMT>::infinity();
|
||||
// Preserve sign bit
|
||||
if (in.storage & 0x80000000) {
|
||||
out.storage |= 0x80000000;
|
||||
}
|
||||
} else {
|
||||
out = std::numeric_limits<dFMT>::max();
|
||||
// Preserve sign bit
|
||||
if (in.storage & 0x80000000) {
|
||||
out.storage |= 0x80000000;
|
||||
}
|
||||
}
|
||||
} else if (in.mant == 0 && in.exp == 0) {
|
||||
// All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign.
|
||||
@@ -112,6 +128,9 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
|
||||
mant |= (1 << sFMT::mbits);
|
||||
}
|
||||
|
||||
// Save the value for rounding so we don't need to recompute it.
|
||||
uint32_t saved_mant = mant;
|
||||
|
||||
mant >>= (sFMT::mbits - dFMT::mbits);
|
||||
|
||||
// Output became subnormal
|
||||
@@ -127,18 +146,19 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
|
||||
out.mant = mant;
|
||||
|
||||
// roundTiesToEven is the only required rounding mode for MXFP
|
||||
// types. Here we take the original mantissa and check the final
|
||||
// bit which is shifted out when converting the mantissa. If that
|
||||
// value is one, then we should round up to the next representable
|
||||
// number. If the value is one and all other discarded mantissa
|
||||
// bits are zero, round towards the number which has an even (0)
|
||||
// bit value in the least significant mantissa bit.
|
||||
// types. Here we take the input mantissa and check the first
|
||||
// three bits that were shifted out. These are called guard,
|
||||
// round, and sticky bits. The value of these three bits combined
|
||||
// are used to determine if we should round up or down. If the
|
||||
// value is directly in between, we look at the final bit of the
|
||||
// output mantissa with guard, round, sticky shifted out. If the
|
||||
// value is one, round to nearest even by rounding down (set it to
|
||||
// zero).
|
||||
//
|
||||
// For denormals, the process is similar however we check the nth
|
||||
// bit of the converted mantissa, where n is the absolute value of
|
||||
// the converted exponent. If the value of |exp| is larger than
|
||||
// the max exponent, round to zero. If it is exactly equal, always
|
||||
// round up.
|
||||
// For denormals, the process is similar, but we shift the input
|
||||
// mantissa by 1 - exp more bits before setting the value of guard,
|
||||
// round, sticky. Note that for denormals exp < 1 (i.e., shift
|
||||
// value is always positive).
|
||||
//
|
||||
// If the number of destination and source format mantissa bits are
|
||||
// the same, the mantissa is unchanged.
|
||||
@@ -146,45 +166,53 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
|
||||
&& mode == roundTiesToEven) {
|
||||
bool round_up = false;
|
||||
|
||||
int check_shift = sFMT::mbits - dFMT::mbits - 1;
|
||||
uint32_t check_mant = in.mant & mask(sFMT::mbits);
|
||||
// Round using guard, round, sticky bits. We want to make sure
|
||||
// there are three bits remaining. This is currently true for
|
||||
// all conversion instructions. This would need to be revisited
|
||||
// if there are f4 <-> f6 or f6 <-> f8 conversions.
|
||||
assert((sFMT::mbits - dFMT::mbits) > 2);
|
||||
|
||||
int check_shift = sFMT::mbits - dFMT::mbits - 3;
|
||||
uint32_t check_mant = saved_mant;
|
||||
|
||||
// Sticky bit is 1 if *any* of the N-2 bits that get shifted
|
||||
// off are one. Being zero implies we are directly between two
|
||||
// floating point values.
|
||||
int sticky = (check_mant & mask(check_shift + 1)) != 0;
|
||||
|
||||
check_mant >>= check_shift;
|
||||
|
||||
// out.exp == 0 means subnormal
|
||||
if (out.exp == 0) {
|
||||
check_mant = in.mant >> (sFMT::mbits - dFMT::mbits);
|
||||
|
||||
uint32_t max_exp = mask(dFMT::ebits);
|
||||
if (-exp > max_exp) {
|
||||
// if exp < -(1 << dFMT::ebits), result should be 0
|
||||
round_up = false;
|
||||
} else if (-exp == max_exp) {
|
||||
// if exp == -(1 << dFMT::ebits), round up
|
||||
round_up = true;
|
||||
} else {
|
||||
// Use the |exp|'th bit to determine rounding
|
||||
int check_bit = 1 << -exp;
|
||||
round_up = (check_mant & check_bit);
|
||||
}
|
||||
} else {
|
||||
round_up = (check_mant & 0x1);
|
||||
if (exp < 1) {
|
||||
int shift = 1 - exp;
|
||||
check_mant >>= shift;
|
||||
}
|
||||
|
||||
// For roundTiesToEven, if we are exactly between two
|
||||
// representable numbers, pick the one with an even least
|
||||
// significant mantissa bit. We are exactly between when
|
||||
// all of the discarded mantissa bits are 0 (i.e., !sticky).
|
||||
int sticky = in.mant & mask(sFMT::mbits - dFMT::mbits);
|
||||
if (round_up && !sticky) {
|
||||
if (!(out.mant & 1)) {
|
||||
round_up = false;
|
||||
// Combine guard, round, sticky into one 3-bit value. If that
|
||||
// value is < 0b100 we round down (truncate -- nothing to do),
|
||||
// if it is > 0b100 we round up. If it is == 0b100, round to
|
||||
// nearest even.
|
||||
uint32_t check_test = check_mant & 0x7;
|
||||
|
||||
// Add sticky to the 3-bit check value.
|
||||
check_test += sticky;
|
||||
|
||||
if (check_test > 0x4) {
|
||||
round_up = true;
|
||||
} else if (check_test == 0x4) {
|
||||
// We are exactly between two FP values. Round to nearest
|
||||
// even by looking at the last bit of output mantissa.
|
||||
// If the last bit of the output mantissa is 1, round to
|
||||
// nearest even (0 in last bit) which would simply be
|
||||
// rounding down. The bit position of the last bit in this
|
||||
// case is 0x8 since we kept three extra bits for guard,
|
||||
// round, sticky.
|
||||
if (check_mant & 0x8) {
|
||||
out.mant -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (round_up) {
|
||||
if (out.mant == mask(dFMT::mbits)) {
|
||||
// mantissa at max value, increment exponent if not inf
|
||||
// Mantissa at max value, increment exponent if not inf
|
||||
if (out.exp != mask(dFMT::ebits)) {
|
||||
out.exp++;
|
||||
}
|
||||
@@ -243,15 +271,31 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
|
||||
// For types with no NaN return max value.
|
||||
if (std::numeric_limits<dFMT>::has_quiet_NaN) {
|
||||
out = std::numeric_limits<dFMT>::quiet_NaN();
|
||||
// Preserve sign bit
|
||||
if (in.storage & 0x80000000) {
|
||||
out.storage |= 0x80000000;
|
||||
}
|
||||
} else {
|
||||
out = std::numeric_limits<dFMT>::max();
|
||||
// Preserve sign bit
|
||||
if (in.storage & 0x80000000) {
|
||||
out.storage |= 0x80000000;
|
||||
}
|
||||
}
|
||||
} else if (std::isinf(in)) {
|
||||
// For types with no Inf return max value.
|
||||
if (std::numeric_limits<dFMT>::has_infinity) {
|
||||
out = std::numeric_limits<dFMT>::infinity();
|
||||
// Preserve sign bit
|
||||
if (in.storage & 0x80000000) {
|
||||
out.storage |= 0x80000000;
|
||||
}
|
||||
} else {
|
||||
out = std::numeric_limits<dFMT>::max();
|
||||
// Preserve sign bit
|
||||
if (in.storage & 0x80000000) {
|
||||
out.storage |= 0x80000000;
|
||||
}
|
||||
}
|
||||
} else if (in.mant == 0 && in.exp == 0) {
|
||||
// All MX formats FP32, and FP64 encode 0 as all zeros. Keep sign.
|
||||
@@ -292,7 +336,7 @@ dFMT convertMXFP(sFMT in, mxfpRoundingMode mode = roundTiesToEven,
|
||||
template<typename FMT>
|
||||
int min_exp()
|
||||
{
|
||||
return 1;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<typename FMT>
|
||||
|
||||
Reference in New Issue
Block a user