arch-vega: Introduce two scaling methods for microscaling types

Currently there is only a scale() method which multiplies a microscaling
type by an int8 value. This should only be applied when upcasting to
a larger type after conversion to match hardware. When downcasting to a
smaller type, the scaling method should divide by the int8 value before
conversion.

This commit adds both scaling methods.

Change-Id: Ibafa8caa389cde4df609e536cd53bd2289959420
This commit is contained in:
Matthew Poremba
2024-08-10 09:58:03 -07:00
parent e980780efd
commit c1251f51c1

View File

@@ -125,20 +125,23 @@ class mxfp
data = in.storage;
}
// Used for upcasting
void
scale(const float& f)
scaleMul(const float& f)
{
binary32 bfp;
bfp.fp32 = f;
int scale_val = bfp.exp - bfp.bias;
int scale_val = bfp.exp;
// Scale value of 0xFF is NaN. Scaling by NaN returns NaN.
// In this implementation, types without NaN define it as zero.
// In this implementation, types without NaN define it as max().
if (scale_val == 0xFF) {
data = FMT::nan;
return;
}
scale_val -= bfp.bias;
FMT in = getFmt();
int exp = in.exp;
@@ -153,27 +156,49 @@ class mxfp
data = in.storage;
}
// Used for downcasting
void
scaleDiv(const float& f)
{
binary32 bfp;
bfp.fp32 = f;
int scale_val = bfp.exp;
// Scale value of 0xFF is NaN. Scaling by NaN returns NaN.
// In this implementation, types without NaN define it as max().
if (scale_val == 0xFF) {
data = FMT::nan;
return;
}
scale_val -= bfp.bias;
FMT in = getFmt();
int exp = in.exp;
if (exp - scale_val > max_exp<FMT>()) {
in.exp = max_exp<FMT>();
} else if (exp - scale_val < min_exp<FMT>()) {
in.exp = min_exp<FMT>();
} else {
in.exp = exp - scale_val;
// Output become denorm
if (in.exp == 0) {
uint32_t m = in.mant | 1 << FMT::mbits;
m >>= 1;
in.mant = m & mask(FMT::mbits);
}
}
data = in.storage;
}
private:
mxfpRoundingMode mode = roundTiesToEven;
uint32_t
float_to_mxfp(float f)
{
if (std::isinf(f)) {
assert(std::numeric_limits<FMT>::has_infinity);
return FMT::inf;
}
if (std::isnan(f)) {
assert(std::numeric_limits<FMT>::has_quiet_NaN);
return FMT::nan;
}
return float_to_mxfp_nocheck(f);
}
uint32_t
float_to_mxfp_nocheck(float f)
{
binary32 in;
in.fp32 = f;