diff --git a/ext/softfloat/softfloat_types.h b/ext/softfloat/softfloat_types.h index af1888f9b9..5123cd39c6 100644 --- a/ext/softfloat/softfloat_types.h +++ b/ext/softfloat/softfloat_types.h @@ -47,6 +47,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | the types below may, if desired, be defined as aliases for the native types | (typically 'float' and 'double', and possibly 'long double'). *----------------------------------------------------------------------------*/ +typedef struct { uint8_t v; } float8_t; typedef struct { uint16_t v; } float16_t; typedef struct { uint32_t v; } float32_t; typedef struct { uint64_t v; } float64_t; diff --git a/src/arch/riscv/isa/formats/vector_arith.isa b/src/arch/riscv/isa/formats/vector_arith.isa index 3d8afc5bda..3939e05851 100644 --- a/src/arch/riscv/isa/formats/vector_arith.isa +++ b/src/arch/riscv/isa/formats/vector_arith.isa @@ -580,7 +580,7 @@ def format VectorFloatFormat(code, category, *flags) {{ Name, 'VectorArithMacroInst', {'code': code, - 'declare_varith_template': declareVArithTemplate(Name, 'float', 32)}, + 'declare_varith_template': declareVArithTemplate(Name, 'float', 16)}, flags ) inst_name, inst_suffix = name.split("_", maxsplit=1) @@ -621,7 +621,7 @@ def format VectorFloatFormat(code, category, *flags) {{ set_vlenb = setVlenb(); - varith_micro_declare = declareVArithTemplate(Name + "Micro", 'float', 32) + varith_micro_declare = declareVArithTemplate(Name + "Micro", 'float', 16) microiop = InstObjParams(name + "_micro", Name + "Micro", 'VectorArithMicroInst', @@ -650,7 +650,7 @@ def format VectorFloatCvtFormat(code, category, *flags) {{ Name, 'VectorArithMacroInst', {'code': code, - 'declare_varith_template': declareVArithTemplate(Name, 'float', 32)}, + 'declare_varith_template': declareVArithTemplate(Name, 'float', 16)}, flags ) @@ -674,7 +674,7 @@ def format VectorFloatCvtFormat(code, category, *flags) {{ set_vlenb = setVlenb(); - varith_micro_declare = declareVArithTemplate(Name + "Micro", 'float', 32) + varith_micro_declare = declareVArithTemplate(Name + "Micro", 'float', 16) microiop = InstObjParams(name + "_micro", Name + "Micro", 'VectorArithMicroInst', @@ -698,7 +698,7 @@ def format VectorFloatCvtFormat(code, category, *flags) {{ }}; def format VectorFloatWideningFormat(code, category, *flags) {{ - varith_macro_declare = declareVArithTemplate(Name, 'float', 32, 32) + varith_macro_declare = declareVArithTemplate(Name, 'float', 16, 32) iop = InstObjParams( name, Name, @@ -754,7 +754,7 @@ def format VectorFloatWideningFormat(code, category, *flags) {{ set_vlen = setVlen(); varith_micro_declare = declareVArithTemplate( - Name + "Micro", 'float', 32, 32) + Name + "Micro", 'float', 16, 32) microiop = InstObjParams(name + "_micro", Name + "Micro", 'VectorArithMicroInst', @@ -779,7 +779,7 @@ def format VectorFloatWideningFormat(code, category, *flags) {{ }}; def format VectorFloatWideningCvtFormat(code, category, *flags) {{ - varith_macro_declare = declareVArithTemplate(Name, 'float', 32, 32) + varith_macro_declare = declareVArithTemplate(Name, 'float', 8, 32) iop = InstObjParams( name, Name, @@ -811,7 +811,7 @@ def format VectorFloatWideningCvtFormat(code, category, *flags) {{ set_vlen = setVlen(); varith_micro_declare = declareVArithTemplate( - Name + "Micro", 'float', 32, 32) + Name + "Micro", 'float', 8, 32) microiop = InstObjParams(name + "_micro", Name + "Micro", 'VectorArithMicroInst', @@ -832,11 +832,11 @@ def format VectorFloatWideningCvtFormat(code, category, *flags) {{ VectorFloatMicroConstructor.subst(microiop) + \ VectorIntWideningMacroConstructor.subst(iop) exec_output = VectorFloatWideningMicroExecute.subst(microiop) - decode_block = VectorFloatWideningDecodeBlock.subst(iop) + decode_block = VectorFloatWideningAndNarrowingCvtDecodeBlock.subst(iop) }}; def format VectorFloatNarrowingCvtFormat(code, category, *flags) {{ - varith_macro_declare = declareVArithTemplate(Name, 'float', 32, 32) + varith_macro_declare = declareVArithTemplate(Name, 'float', 8, 32) iop = InstObjParams( name, Name, @@ -869,7 +869,7 @@ def format VectorFloatNarrowingCvtFormat(code, category, *flags) {{ set_vlen = setVlen(); varith_micro_declare = declareVArithTemplate( - Name + "Micro", 'float', 32, 32) + Name + "Micro", 'float', 8, 32) microiop = InstObjParams(name + "_micro", Name + "Micro", 'VectorArithMicroInst', @@ -890,7 +890,7 @@ def format VectorFloatNarrowingCvtFormat(code, category, *flags) {{ VectorFloatMicroConstructor.subst(microiop) + \ VectorIntWideningMacroConstructor.subst(iop) exec_output = VectorFloatNarrowingMicroExecute.subst(microiop) - decode_block = VectorFloatWideningDecodeBlock.subst(iop) + decode_block = VectorFloatWideningAndNarrowingCvtDecodeBlock.subst(iop) }}; def format VectorFloatMaskFormat(code, category, *flags) {{ @@ -898,7 +898,7 @@ def format VectorFloatMaskFormat(code, category, *flags) {{ Name, 'VectorArithMacroInst', {'code': code, - 'declare_varith_template': declareVArithTemplate(Name, 'float', 32)}, + 'declare_varith_template': declareVArithTemplate(Name, 'float', 16)}, flags ) dest_reg_id = "vecRegClass[VecMemInternalReg0 + _microIdx]" @@ -925,7 +925,7 @@ def format VectorFloatMaskFormat(code, category, *flags) {{ code = loopWrapper(code) code = fflags_wrapper(code) - varith_micro_declare = declareVArithTemplate(Name + "Micro", 'float', 32) + varith_micro_declare = declareVArithTemplate(Name + "Micro", 'float', 16) microiop = InstObjParams(name + "_micro", Name + "Micro", 'VectorArithMicroInst', @@ -1083,7 +1083,7 @@ def format VectorNonSplitFormat(code, category, *flags) {{ code = fflags_wrapper(code) if inst_name == "vfmv" : - varith_template = declareVArithTemplate(Name, 'float', 32) + varith_template = declareVArithTemplate(Name, 'float', 16) iop = InstObjParams(name, Name, 'VectorNonSplitInst', @@ -1217,7 +1217,7 @@ def format VectorReduceFloatFormat(code, category, *flags) {{ Name, 'VectorArithMacroInst', {'code': code, - 'declare_varith_template': declareVArithTemplate(Name, 'float', 32)}, + 'declare_varith_template': declareVArithTemplate(Name, 'float', 16)}, flags ) inst_name, inst_suffix = name.split("_", maxsplit=1) @@ -1243,7 +1243,7 @@ def format VectorReduceFloatFormat(code, category, *flags) {{ code = fflags_wrapper(code) - varith_micro_declare = declareVArithTemplate(Name + "Micro", 'float', 32) + varith_micro_declare = declareVArithTemplate(Name + "Micro", 'float', 16) microiop = InstObjParams(name + "_micro", Name + "Micro", 'VectorArithMicroInst', @@ -1269,7 +1269,7 @@ def format VectorReduceFloatFormat(code, category, *flags) {{ }}; def format VectorReduceFloatWideningFormat(code, category, *flags) {{ - varith_macro_declare = declareVArithTemplate(Name, 'float', 32, 32) + varith_macro_declare = declareVArithTemplate(Name, 'float', 16, 32) iop = InstObjParams( name, Name, @@ -1301,7 +1301,7 @@ def format VectorReduceFloatWideningFormat(code, category, *flags) {{ ''' varith_micro_declare = declareVArithTemplate( - Name + "Micro", 'float', 32, 32) + Name + "Micro", 'float', 16, 32) microiop = InstObjParams(name + "_micro", Name + "Micro", 'VectorArithMicroInst', @@ -1448,7 +1448,7 @@ def VectorSlideBase(name, Name, category, code, flags, macro_construtor, if decode_template is VectorIntDecodeBlock: varith_macro_declare = declareVArithTemplate(Name) elif decode_template is VectorFloatDecodeBlock: - varith_macro_declare = declareVArithTemplate(Name, 'float', 32) + varith_macro_declare = declareVArithTemplate(Name, 'float', 16) iop = InstObjParams( name, @@ -1491,7 +1491,7 @@ def VectorSlideBase(name, Name, category, code, flags, macro_construtor, varith_micro_declare = declareVArithTemplate(Name + "Micro") elif decode_template is VectorFloatDecodeBlock: varith_micro_declare = declareVArithTemplate( - Name + "Micro", 'float', 32) + Name + "Micro", 'float', 16) microiop = InstObjParams(name + "_micro", Name + "Micro", diff --git a/src/arch/riscv/isa/templates/vector_arith.isa b/src/arch/riscv/isa/templates/vector_arith.isa index 4a6b27cc2e..fcd7c58c83 100644 --- a/src/arch/riscv/isa/templates/vector_arith.isa +++ b/src/arch/riscv/isa/templates/vector_arith.isa @@ -665,6 +665,7 @@ Fault def template VectorFloatDecodeBlock {{ switch(machInst.vtype8.vsew) { +case 0b001: return new %(class_name)s(machInst, vlen); case 0b010: return new %(class_name)s(machInst, vlen); case 0b011: return new %(class_name)s(machInst, vlen); default: GEM5_UNREACHABLE; @@ -821,6 +822,19 @@ Fault def template VectorFloatWideningDecodeBlock {{ switch(machInst.vtype8.vsew) { +case 0b001: return new %(class_name)s(machInst, vlen); +case 0b010: return new %(class_name)s(machInst, vlen); +default: GEM5_UNREACHABLE; +} + +}}; + + +def template VectorFloatWideningAndNarrowingCvtDecodeBlock {{ + +switch(machInst.vtype8.vsew) { +case 0b000: return new %(class_name)s(machInst, vlen); +case 0b001: return new %(class_name)s(machInst, vlen); case 0b010: return new %(class_name)s(machInst, vlen); default: GEM5_UNREACHABLE; } @@ -1605,6 +1619,7 @@ Fault def template VectorFloatNonSplitDecodeBlock {{ switch(machInst.vtype8.vsew) { +case 0b001: return new %(class_name)s(machInst); case 0b010: return new %(class_name)s(machInst); case 0b011: return new %(class_name)s(machInst); default: GEM5_UNREACHABLE; diff --git a/src/arch/riscv/regs/float.hh b/src/arch/riscv/regs/float.hh index cca9e1be2f..0b4570bc71 100644 --- a/src/arch/riscv/regs/float.hh +++ b/src/arch/riscv/regs/float.hh @@ -211,6 +211,13 @@ const std::vector RegNames = { } // namespace float_reg +inline float16_t +fsgnj16(float16_t a, float16_t b, bool n, bool x) { + if (n) b.v = ~b.v; + else if (x) b.v = a.v ^ b.v; + return f16(insertBits(b.v, 14, 0, a.v)); +} + inline float32_t fsgnj32(float32_t a, float32_t b, bool n, bool x) { if (n) b.v = ~b.v; diff --git a/src/arch/riscv/utility.hh b/src/arch/riscv/utility.hh index 73cd7126ce..c6819d8bd7 100644 --- a/src/arch/riscv/utility.hh +++ b/src/arch/riscv/utility.hh @@ -75,10 +75,16 @@ template<> struct double_width { using type = int32_t; }; template<> struct double_width { using type = int64_t; }; template<> struct double_width { using type = __int128_t; }; template<> struct double_width { using type = float64_t;}; +template<> struct double_width { using type = float32_t;}; +template<> struct double_width { using type = float16_t;}; template struct double_widthf; template<> struct double_widthf { using type = float64_t;}; template<> struct double_widthf { using type = float64_t;}; +template<> struct double_widthf { using type = float32_t;}; +template<> struct double_widthf { using type = float32_t;}; +template<> struct double_widthf { using type = float16_t;}; +template<> struct double_widthf { using type = float16_t;}; template inline bool isquietnan(T val) @@ -324,6 +330,8 @@ ftype(IntType a) -> FloatType return f32(a); else if constexpr(std::is_same_v) return f64(a); + else if constexpr(std::is_same_v) + return f16(a); GEM5_UNREACHABLE; } @@ -336,6 +344,8 @@ ftype_freg(freg_t a) -> FloatType return f32(a); else if constexpr(std::is_same_v) return f64(a); + else if constexpr(std::is_same_v) + return f16(a); GEM5_UNREACHABLE; } @@ -346,6 +356,8 @@ fadd(FloatType a, FloatType b) return f32_add(a, b); else if constexpr(std::is_same_v) return f64_add(a, b); + else if constexpr(std::is_same_v) + return f16_add(a, b); GEM5_UNREACHABLE; } @@ -356,6 +368,8 @@ fsub(FloatType a, FloatType b) return f32_sub(a, b); else if constexpr(std::is_same_v) return f64_sub(a, b); + else if constexpr(std::is_same_v) + return f16_sub(a, b); GEM5_UNREACHABLE; } @@ -366,6 +380,8 @@ fmin(FloatType a, FloatType b) return f32_min(a, b); else if constexpr(std::is_same_v) return f64_min(a, b); + else if constexpr(std::is_same_v) + return f16_min(a, b); GEM5_UNREACHABLE; } @@ -376,6 +392,8 @@ fmax(FloatType a, FloatType b) return f32_max(a, b); else if constexpr(std::is_same_v) return f64_max(a, b); + else if constexpr(std::is_same_v) + return f16_max(a, b); GEM5_UNREACHABLE; } @@ -386,6 +404,8 @@ fdiv(FloatType a, FloatType b) return f32_div(a, b); else if constexpr(std::is_same_v) return f64_div(a, b); + else if constexpr(std::is_same_v) + return f16_div(a, b); GEM5_UNREACHABLE; } @@ -396,6 +416,8 @@ fmul(FloatType a, FloatType b) return f32_mul(a, b); else if constexpr(std::is_same_v) return f64_mul(a, b); + else if constexpr(std::is_same_v) + return f16_mul(a, b); GEM5_UNREACHABLE; } @@ -406,6 +428,8 @@ fsqrt(FloatType a) return f32_sqrt(a); else if constexpr(std::is_same_v) return f64_sqrt(a); + else if constexpr(std::is_same_v) + return f16_sqrt(a); GEM5_UNREACHABLE; } @@ -416,6 +440,8 @@ frsqrte7(FloatType a) return f32_rsqrte7(a); else if constexpr(std::is_same_v) return f64_rsqrte7(a); + else if constexpr(std::is_same_v) + return f16_rsqrte7(a); GEM5_UNREACHABLE; } @@ -426,6 +452,8 @@ frecip7(FloatType a) return f32_recip7(a); else if constexpr(std::is_same_v) return f64_recip7(a); + else if constexpr(std::is_same_v) + return f16_recip7(a); GEM5_UNREACHABLE; } @@ -436,6 +464,8 @@ fclassify(FloatType a) return f32(f32_classify(a)); else if constexpr(std::is_same_v) return f64(f64_classify(a)); + else if constexpr(std::is_same_v) + return f16(f16_classify(a)); GEM5_UNREACHABLE; } @@ -446,6 +476,8 @@ fsgnj(FloatType a, FloatType b, bool n, bool x) return fsgnj32(a, b, n, x); else if constexpr(std::is_same_v) return fsgnj64(a, b, n, x); + else if constexpr(std::is_same_v) + return fsgnj16(a, b, n, x); GEM5_UNREACHABLE; } @@ -456,6 +488,8 @@ fle(FloatType a, FloatType b) return f32_le(a, b); else if constexpr(std::is_same_v) return f64_le(a, b); + else if constexpr(std::is_same_v) + return f16_le(a, b); GEM5_UNREACHABLE; } @@ -466,6 +500,8 @@ feq(FloatType a, FloatType b) return f32_eq(a, b); else if constexpr(std::is_same_v) return f64_eq(a, b); + else if constexpr(std::is_same_v) + return f16_eq(a, b); GEM5_UNREACHABLE; } @@ -476,6 +512,8 @@ flt(FloatType a, FloatType b) return f32_lt(a, b); else if constexpr(std::is_same_v) return f64_lt(a, b); + else if constexpr(std::is_same_v) + return f16_lt(a, b); GEM5_UNREACHABLE; } @@ -486,6 +524,8 @@ fmadd(FloatType a, FloatType b, FloatType c) return f32_mulAdd(a, b, c); else if constexpr(std::is_same_v) return f64_mulAdd(a, b, c); + else if constexpr(std::is_same_v) + return f16_mulAdd(a, b, c); GEM5_UNREACHABLE; } @@ -496,6 +536,8 @@ fneg(FloatType a) return f32(a.v ^ uint32_t(mask(31, 31))); else if constexpr(std::is_same_v) return f64(a.v ^ mask(63, 63)); + else if constexpr(std::is_same_v) + return f16(a.v ^ uint16_t(mask(15, 15))); GEM5_UNREACHABLE; } @@ -504,6 +546,8 @@ fwiden(FT a) { if constexpr(std::is_same_v) return f32_to_f64(a); + else if constexpr(std::is_same_v) + return f16_to_f32(a); GEM5_UNREACHABLE; } @@ -514,6 +558,8 @@ f_to_ui(FloatType a, uint_fast8_t mode) return f32_to_ui32(a, mode, true); else if constexpr(std::is_same_v) return f64_to_ui64(a, mode, true); + else if constexpr(std::is_same_v) + return f16_to_ui16(a, mode, true); GEM5_UNREACHABLE; } @@ -525,6 +571,8 @@ f_to_wui(FloatType a, uint_fast8_t mode) { if constexpr(std::is_same_v) return f32_to_ui64(a, mode, true); + else if constexpr(std::is_same_v) + return f16_to_ui32(a, mode, true); GEM5_UNREACHABLE; } @@ -536,6 +584,10 @@ f_to_nui(FloatType a, uint_fast8_t mode) { if constexpr(std::is_same_v) return f64_to_ui32(a, mode, true); + else if constexpr(std::is_same_v) + return f32_to_ui16(a, mode, true); + else if constexpr(std::is_same_v) + return f16_to_ui8(a, mode, true); GEM5_UNREACHABLE; } @@ -546,6 +598,8 @@ f_to_i(FloatType a, uint_fast8_t mode) return (uint32_t)f32_to_i32(a, mode, true); else if constexpr(std::is_same_v) return (uint64_t)f64_to_i64(a, mode, true); + else if constexpr(std::is_same_v) + return (uint16_t)f16_to_i16(a, mode, true); GEM5_UNREACHABLE; } @@ -557,6 +611,8 @@ f_to_wi(FloatType a, uint_fast8_t mode) { if constexpr(std::is_same_v) return (uint64_t)f32_to_i64(a, mode, true); + else if constexpr(std::is_same_v) + return (uint32_t)f16_to_i32(a, mode, true); GEM5_UNREACHABLE; } @@ -568,6 +624,10 @@ f_to_ni(FloatType a, uint_fast8_t mode) { if constexpr(std::is_same_v) return (uint32_t)f64_to_i32(a, mode, true); + else if constexpr(std::is_same_v) + return (uint16_t)f32_to_i16(a, mode, true); + else if constexpr(std::is_same_v) + return (uint8_t)f16_to_i8(a, mode, true); GEM5_UNREACHABLE; } @@ -579,6 +639,8 @@ ui_to_f(IntType a) return ui32_to_f32(a); else if constexpr(std::is_same_v) return ui64_to_f64(a); + else if constexpr(std::is_same_v) + return ui32_to_f16(a); GEM5_UNREACHABLE; } @@ -590,6 +652,10 @@ ui_to_wf(IntType a) { if constexpr(std::is_same_v) return ui32_to_f64(a); + else if constexpr(std::is_same_v) + return ui32_to_f32(a); + else if constexpr(std::is_same_v) + return ui32_to_f16(a); GEM5_UNREACHABLE; } @@ -601,6 +667,8 @@ ui_to_nf(IntType a) { if constexpr(std::is_same_v) return ui64_to_f32(a); + else if constexpr(std::is_same_v) + return ui32_to_f16(a); GEM5_UNREACHABLE; } @@ -612,6 +680,8 @@ i_to_f(IntType a) return i32_to_f32((int32_t)a); else if constexpr(std::is_same_v) return i64_to_f64((int64_t)a); + else if constexpr(std::is_same_v) + return i32_to_f16((int16_t)a); GEM5_UNREACHABLE; } @@ -623,6 +693,10 @@ i_to_wf(IntType a) { if constexpr(std::is_same_v) return i32_to_f64((int32_t)a); + else if constexpr(std::is_same_v) + return i32_to_f32((int16_t)a); + else if constexpr(std::is_same_v) + return i32_to_f16((int8_t)a); GEM5_UNREACHABLE; } @@ -636,6 +710,8 @@ i_to_nf(IntType a) { if constexpr(std::is_same_v) return i64_to_f32(a); + else if constexpr(std::is_same_v) + return i32_to_f16(a); GEM5_UNREACHABLE; } @@ -647,6 +723,8 @@ f_to_wf(FloatType a) { if constexpr(std::is_same_v) return f32_to_f64(a); + else if constexpr(std::is_same_v) + return f16_to_f32(a); GEM5_UNREACHABLE; } @@ -658,6 +736,8 @@ f_to_nf(FloatType a) { if constexpr(std::is_same_v) return f64_to_f32(a); + else if constexpr(std::is_same_v) + return f32_to_f16(a); GEM5_UNREACHABLE; }