From 99f58d37da334a92728618da3f353dae28567b63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sa=C3=BAl=20Adserias?= <33020671+saul44203@users.noreply.github.com> Date: Tue, 14 May 2024 14:47:00 +0200 Subject: [PATCH] arch-riscv: add agnostic opt to vector tail/mask for arith insts Change-Id: I693b5f3a6cc8a8f320be26b214fd9b359e541f14 --- src/arch/riscv/isa/decoder.isa | 30 ++-- src/arch/riscv/isa/formats/vector_arith.isa | 130 ++++++++---------- src/arch/riscv/isa/templates/vector_arith.isa | 92 ++++++++----- 3 files changed, 127 insertions(+), 125 deletions(-) diff --git a/src/arch/riscv/isa/decoder.isa b/src/arch/riscv/isa/decoder.isa index 98eedc102d..6644d6fa93 100644 --- a/src/arch/riscv/isa/decoder.isa +++ b/src/arch/riscv/isa/decoder.isa @@ -2745,10 +2745,10 @@ decode QUADRANT default Unknown::unknown() { if (this->vm || elem_mask(v0, ei)) { const uint64_t idx = Vs1_vu[i] - vs2_elems * vs2_idx; - auto res = (Vs1_vu[i] >= vlmax) ? 0 - : (idx < vs2_elems) ? Vs2_vu[idx] - : Vs3_vu[i]; - Vd_vu[i] = res; + if (Vs1_vu[i] >= vlmax) + Vd_vu[i] = 0; + else if (idx < vs2_elems) + Vd_vu[i] = Vs2_vu[idx]; } } }}, OPIVV, SimdMiscOp); @@ -2758,10 +2758,10 @@ decode QUADRANT default Unknown::unknown() { if (this->vm || elem_mask(v0, ei)) { const uint32_t idx = Vs1_uh[i + vs1_bias] - vs2_elems * vs2_idx; - auto res = (Vs1_uh[i + vs1_bias] >= vlmax) ? 0 - : (idx < vs2_elems) ? Vs2_vu[idx] - : Vs3_vu[i + vd_bias]; - Vd_vu[i + vd_bias] = res; + if (Vs1_uh[i + vs1_bias] >= vlmax) + Vd_vu[i + vd_bias] = 0; + else if (idx < vs2_elems) + Vd_vu[i + vd_bias] = Vs2_vu[idx]; } } }}, OPIVV, SimdMiscOp); @@ -3664,9 +3664,10 @@ decode QUADRANT default Unknown::unknown() { uint64_t zextImm = rvZext(SIMM5); if (this->vm || elem_mask(v0, ei)) { const uint64_t idx = zextImm - vs2_elems * vs2_idx; - Vd_vu[i] = (zextImm >= vlmax) ? 0 - : (idx < vs2_elems) ? Vs2_vu[idx] - : Vs3_vu[i]; + if (zextImm >= vlmax) + Vd_vu[i] = 0; + else if (idx < vs2_elems) + Vd_vu[i] = Vs2_vu[idx]; } } }}, OPIVI, SimdMiscOp); @@ -3999,9 +4000,10 @@ decode QUADRANT default Unknown::unknown() { uint64_t zextRs1 = rvZext(Rs1); if (this->vm || elem_mask(v0, ei)) { const uint64_t idx = zextRs1 - vs2_elems * vs2_idx; - Vd_vu[i] = (zextRs1 >= vlmax) ? 0 - : (idx < vs2_elems) ? Vs2_vu[idx] - : Vs3_vu[i]; + if (zextRs1 >= vlmax) + Vd_vu[i] = 0; + else if (idx < vs2_elems) + Vd_vu[i] = Vs2_vu[idx]; } } }}, OPIVX, SimdMiscOp); diff --git a/src/arch/riscv/isa/formats/vector_arith.isa b/src/arch/riscv/isa/formats/vector_arith.isa index 3b7a57e208..3862b1de02 100644 --- a/src/arch/riscv/isa/formats/vector_arith.isa +++ b/src/arch/riscv/isa/formats/vector_arith.isa @@ -31,12 +31,17 @@ let {{ def setVlen(): return "uint32_t vlen = VlenbBits * 8;\n" def setVlenb(): - return "uint32_t vlenb = VlenbBits;\n" + return "[[maybe_unused]] uint32_t vlenb = VlenbBits;\n" def setDestWrapper(destRegId): return "setDestRegIdx(_numDestRegs++, " + destRegId + ");\n" + \ "_numTypedDestRegs[VecRegClass]++;\n" def setSrcWrapper(srcRegId): return "setSrcRegIdx(_numSrcRegs++, " + srcRegId + ");\n" + def tailMaskCondSetSrcWrapper(setSrcRegCode): + return f''' + if (!_machInst.vtype8.vta || (!_machInst.vm && !_machInst.vtype8.vma)) + {setSrcRegCode} + ''' def setSrcVm(): return "if (!this->vm)\n" + \ " setSrcRegIdx(_numSrcRegs++, vecRegClass[0]);" @@ -164,6 +169,7 @@ def format VectorIntFormat(code, category, *flags) {{ v0_required = inst_name not in ["vmv"] mask_cond = v0_required and (inst_suffix not in ['vvm', 'vxm', 'vim']) need_elem_idx = mask_cond or code.find("ei") != -1 + is_destructive_fused = iop.op_class == "SimdMultAccOp" dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]" @@ -185,7 +191,6 @@ def format VectorIntFormat(code, category, *flags) {{ error("not supported category for VectorIntFormat: %s" % category) old_vd_idx = num_src_regs - src3_reg_id = "vecRegClass[_machInst.vd + _microIdx]" set_dest_reg_idx = setDestWrapper(dest_reg_id) @@ -193,7 +198,12 @@ def format VectorIntFormat(code, category, *flags) {{ if category != "OPIVI": set_src_reg_idx += setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(src3_reg_id) + + dest_set_src_reg_idx = setSrcWrapper(dest_reg_id) + if not is_destructive_fused: + dest_set_src_reg_idx = tailMaskCondSetSrcWrapper(dest_set_src_reg_idx) + set_src_reg_idx += dest_set_src_reg_idx + if v0_required: set_src_reg_idx += setSrcVm() @@ -247,17 +257,17 @@ def format VectorIntExtFormat(code, category, *flags) {{ inst_name, inst_suffix = name.split("_", maxsplit=1) ext_div = int(inst_suffix[-1]) - old_vd_idx = 1 dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]" src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx / " + \ str(ext_div) + "]" - src3_reg_id = "vecRegClass[_machInst.vs3 + _microIdx]" + + old_vd_idx = 1 set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = "" set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(src3_reg_id) + set_src_reg_idx += tailMaskCondSetSrcWrapper(setSrcWrapper(dest_reg_id)) set_src_reg_idx += setSrcVm() code = maskCondWrapper(code) @@ -307,6 +317,8 @@ def format VectorIntWideningFormat(code, category, *flags) {{ v0_required = True mask_cond = v0_required need_elem_idx = mask_cond or code.find("ei") != -1 + is_destructive_fused = iop.op_class == "SimdMultAccOp" + old_vd_idx = 2 dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]" src1_reg_id = "" @@ -321,14 +333,18 @@ def format VectorIntWideningFormat(code, category, *flags) {{ src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx / 2]" elif inst_suffix in ["wv", "wx"]: src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - src3_reg_id = "vecRegClass[_machInst.vs3 + _microIdx]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = "" set_src_reg_idx += setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(src3_reg_id) + + dest_set_src_reg_idx = setSrcWrapper(dest_reg_id) + if not is_destructive_fused: + dest_set_src_reg_idx = tailMaskCondSetSrcWrapper(dest_set_src_reg_idx) + set_src_reg_idx += dest_set_src_reg_idx + if v0_required: set_src_reg_idx += setSrcVm() @@ -395,14 +411,13 @@ def format VectorIntNarrowingFormat(code, category, *flags) {{ else: error("not supported category for VectorIntFormat: %s" % category) src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - old_dest_reg_id = "vecRegClass[_machInst.vs3 + _microIdx / 2]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = "" if category != "OPIVI": set_src_reg_idx += setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(old_dest_reg_id) + set_src_reg_idx += tailMaskCondSetSrcWrapper(setSrcWrapper(dest_reg_id)) set_src_reg_idx += setSrcVm() # code code = maskCondWrapper(code) @@ -452,7 +467,6 @@ def format VectorIntMaskFormat(code, category, *flags) {{ mask_cond = inst_name not in ['vmadc', 'vmsbc'] need_elem_idx = mask_cond or code.find("ei") != -1 - old_vd_idx = 2 dest_reg_id = "vecRegClass[VecMemInternalReg0 + _microIdx]" src1_reg_id = "" if category == "OPIVV": @@ -460,17 +474,15 @@ def format VectorIntMaskFormat(code, category, *flags) {{ elif category == "OPIVX": src1_reg_id = "intRegClass[_machInst.rs1]" elif category == "OPIVI": - old_vd_idx = 1 + pass else: error("not supported category for VectorIntFormat: %s" % category) src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - old_dest_reg_id = "vecRegClass[_machInst.vd]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = "" if category != "OPIVI": set_src_reg_idx += setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(old_dest_reg_id) if v0_required: set_src_reg_idx += setSrcVm() @@ -497,7 +509,6 @@ def format VectorIntMaskFormat(code, category, *flags) {{ 'set_vlenb': set_vlenb, 'set_vlen': set_vlen, 'vm_decl_rd': vm_decl_rd, - 'copy_old_vd': copyOldVd(old_vd_idx), 'declare_varith_template': declareVArithTemplate(Name + "Micro")}, flags) @@ -522,7 +533,6 @@ def format VectorGatherFormat(code, category, *flags) {{ 'code': code, 'declare_varith_template': declareGatherTemplate(Name, idx_type)}, flags) - old_vd_idx = 2 dest_reg_id = "vecRegClass[_machInst.vd + vd_idx]" src1_reg_id = "" if category in ["OPIVV"]: @@ -534,7 +544,10 @@ def format VectorGatherFormat(code, category, *flags) {{ else: error("not supported category for VectorIntFormat: %s" % category) src2_reg_id = "vecRegClass[_machInst.vs2 + vs2_idx]" - src3_reg_id = "vecRegClass[_machInst.vs3 + vd_idx]" + src2_reg_id = "vecRegClass[_machInst.vs2 + vs2_idx]" + + # vtmp0 as dummy src reg to create dependency with pin vd micro + src3_reg_id = "vecRegClass[VecMemInternalReg0 + vd_idx]" set_dest_reg_idx = setDestWrapper(dest_reg_id) @@ -562,7 +575,6 @@ def format VectorGatherFormat(code, category, *flags) {{ 'set_vlenb': set_vlenb, 'set_vlen': set_vlen, 'vm_decl_rd': vm_decl_rd, - 'copy_old_vd': copyOldVd(old_vd_idx), 'idx_type': idx_type, 'declare_varith_template': varith_micro_declare}, flags) @@ -575,7 +587,6 @@ def format VectorGatherFormat(code, category, *flags) {{ VectorGatherMacroConstructor.subst(iop) exec_output = VectorGatherMicroExecute.subst(microiop) decode_block = VectorGatherDecodeBlock.subst(iop) - }}; def format VectorFloatFormat(code, category, *flags) {{ @@ -591,6 +602,7 @@ def format VectorFloatFormat(code, category, *flags) {{ v0_required = inst_name not in ["vfmv"] mask_cond = v0_required and (inst_suffix not in ['vvm', 'vfm']) need_elem_idx = mask_cond or code.find("ei") != -1 + is_destructive_fused = iop.op_class == "SimdFloatMultAccOp" dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]" src1_reg_id = "" @@ -601,16 +613,21 @@ def format VectorFloatFormat(code, category, *flags) {{ else: error("not supported category for VectorFloatFormat: %s" % category) src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - src3_reg_id = "vecRegClass[_machInst.vs3 + _microIdx]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = "" set_src_reg_idx += setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(src3_reg_id) + + dest_set_src_reg_idx = setSrcWrapper(dest_reg_id) + if not is_destructive_fused: + dest_set_src_reg_idx = tailMaskCondSetSrcWrapper(dest_set_src_reg_idx) + set_src_reg_idx += dest_set_src_reg_idx + if v0_required: set_src_reg_idx += setSrcVm() + # code if mask_cond: code = maskCondWrapper(code) @@ -663,13 +680,12 @@ def format VectorFloatCvtFormat(code, category, *flags) {{ old_vd_idx = 1 dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]" src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - src3_reg_id = "vecRegClass[_machInst.vs3 + _microIdx]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = "" set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(src3_reg_id) + set_src_reg_idx += tailMaskCondSetSrcWrapper(setSrcWrapper(dest_reg_id)) set_src_reg_idx += setSrcVm() code = maskCondWrapper(code) code = eiDeclarePrefix(code) @@ -719,6 +735,7 @@ def format VectorFloatWideningFormat(code, category, *flags) {{ v0_required = True mask_cond = v0_required need_elem_idx = mask_cond or code.find("ei") != -1 + is_destructive_fused = iop.op_class == "SimdFloatMultAccOp" dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]" src1_reg_id = "" @@ -733,14 +750,18 @@ def format VectorFloatWideningFormat(code, category, *flags) {{ src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx / 2]" elif inst_suffix in ["wv", "wf"]: src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - src3_reg_id = "vecRegClass[_machInst.vs3 + _microIdx]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = "" set_src_reg_idx += setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(src3_reg_id) + + dest_set_src_reg_idx = setSrcWrapper(dest_reg_id) + if not is_destructive_fused: + dest_set_src_reg_idx = tailMaskCondSetSrcWrapper(dest_set_src_reg_idx) + set_src_reg_idx += dest_set_src_reg_idx + if v0_required: set_src_reg_idx += setSrcVm() @@ -800,13 +821,12 @@ def format VectorFloatWideningCvtFormat(code, category, *flags) {{ old_vd_idx = 1 dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]" src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx / 2]" - src3_reg_id = "vecRegClass[_machInst.vs3 + _microIdx]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = "" set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(src3_reg_id) + set_src_reg_idx += tailMaskCondSetSrcWrapper(setSrcWrapper(dest_reg_id)) set_src_reg_idx += setSrcVm() code = maskCondWrapper(code) code = eiDeclarePrefix(code, widening=True) @@ -857,13 +877,12 @@ def format VectorFloatNarrowingCvtFormat(code, category, *flags) {{ old_vd_idx = 1 dest_reg_id = "vecRegClass[_machInst.vd + _microIdx / 2]" src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - src3_reg_id = "vecRegClass[_machInst.vs3 + _microIdx / 2]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = "" set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(src3_reg_id) + set_src_reg_idx += tailMaskCondSetSrcWrapper(setSrcWrapper(dest_reg_id)) set_src_reg_idx += setSrcVm() code = maskCondWrapper(code) code = eiDeclarePrefix(code, widening=True) @@ -923,7 +942,6 @@ def format VectorFloatMaskFormat(code, category, *flags) {{ set_src_reg_idx = "" set_src_reg_idx += setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(old_dest_reg_id) set_src_reg_idx += setSrcVm() vm_decl_rd = vmDeclAndReadData() set_vlenb = setVlenb() @@ -944,7 +962,6 @@ def format VectorFloatMaskFormat(code, category, *flags) {{ 'set_vlenb': set_vlenb, 'set_vlen': set_vlen, 'vm_decl_rd': vm_decl_rd, - 'copy_old_vd': copyOldVd(2), 'declare_varith_template': varith_micro_declare}, flags) @@ -991,14 +1008,10 @@ def format ViotaFormat(code, category, *flags){{ inst_name, inst_suffix = name.split("_", maxsplit=1) dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]" src2_reg_id = "vecRegClass[_machInst.vs2]" - # The tail of vector mask inst should be treated as tail-agnostic. - # We treat it with tail-undisturbed policy, since - # the test suits only support undisturbed policy. - old_dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]" set_src_reg_idx = "" set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(old_dest_reg_id) + set_src_reg_idx += tailMaskCondSetSrcWrapper(setSrcWrapper(dest_reg_id)) set_src_reg_idx += setSrcVm() set_dest_reg_idx = setDestWrapper(dest_reg_id) vm_decl_rd = vmDeclAndReadData() @@ -1036,13 +1049,8 @@ def format Vector1Vs1VdMaskFormat(code, category, *flags){{ inst_name, inst_suffix = name.split("_", maxsplit=1) dest_reg_id = "vecRegClass[_machInst.vd]" src2_reg_id = "vecRegClass[_machInst.vs2]" - # The tail of vector mask inst should be treated as tail-agnostic. - # We treat it with tail-undisturbed policy, since - # the test suits only support undisturbed policy. - old_dest_reg_id = "vecRegClass[_machInst.vd]" set_src_reg_idx = "" set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(old_dest_reg_id) set_dest_reg_idx = setDestWrapper(dest_reg_id) vm_decl_rd = vmDeclAndReadData() set_vm_idx = setSrcVm() @@ -1056,7 +1064,6 @@ def format Vector1Vs1VdMaskFormat(code, category, *flags){{ 'set_vlenb': set_vlenb, 'vm_decl_rd': vm_decl_rd, 'set_vm_idx': set_vm_idx, - 'copy_old_vd': copyOldVd(1), 'declare_varith_template': declareVArithTemplate(Name, 'uint', 8, 8), }, flags) @@ -1130,23 +1137,15 @@ def format VectorNonSplitFormat(code, category, *flags) {{ def format VectorMaskFormat(code, category, *flags) {{ inst_name, inst_suffix = name.split("_", maxsplit=1) - old_vd_idx = 2 if category not in ["OPMVV"]: error("not supported category for VectorIntFormat: %s" % category) dest_reg_id = "vecRegClass[_machInst.vd]" src1_reg_id = "vecRegClass[_machInst.vs1]" src2_reg_id = "vecRegClass[_machInst.vs2]" - # The tail of vector mask inst should be treated as tail-agnostic. - # We treat it with tail-undisturbed policy, since - # the test suits only support undisturbed policy. - # TODO: remove it - old_dest_reg_id = "vecRegClass[_machInst.vd]" - set_src_reg_idx = "" set_src_reg_idx += setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(old_dest_reg_id) set_dest_reg_idx = setDestWrapper(dest_reg_id) @@ -1161,7 +1160,6 @@ def format VectorMaskFormat(code, category, *flags) {{ 'set_dest_reg_idx': set_dest_reg_idx, 'set_src_reg_idx': set_src_reg_idx, 'set_vlenb': set_vlenb, - 'copy_old_vd': copyOldVd(old_vd_idx), 'declare_varith_template': declareVArithTemplate(Name, 'uint', 8, 8) }, flags) @@ -1185,13 +1183,10 @@ def format VectorReduceIntFormat(code, category, *flags) {{ dest_reg_id = "vecRegClass[_machInst.vd]" src1_reg_id = "vecRegClass[_machInst.vs1]" src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - old_dest_reg_id = "vecRegClass[_machInst.vd]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - # Treat tail undisturbed/agnostic as the same - # We always need old rd as src vreg - set_src_reg_idx += setSrcWrapper(old_dest_reg_id) + set_src_reg_idx += tailMaskCondSetSrcWrapper(setSrcWrapper(dest_reg_id)) set_src_reg_idx += setSrcVm() vm_decl_rd = vmDeclAndReadData() set_vlenb = setVlenb() @@ -1238,13 +1233,10 @@ def format VectorReduceFloatFormat(code, category, *flags) {{ dest_reg_id = "vecRegClass[_machInst.vd]" src1_reg_id = "vecRegClass[_machInst.vs1]" src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - old_dest_reg_id = "vecRegClass[_machInst.vd]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - # Treat tail undisturbed/agnostic as the same - # We always need old rd as src vreg - set_src_reg_idx += setSrcWrapper(old_dest_reg_id) + set_src_reg_idx += tailMaskCondSetSrcWrapper(setSrcWrapper(dest_reg_id)) set_src_reg_idx += setSrcVm() vm_decl_rd = vmDeclAndReadData() set_vlenb = setVlenb() @@ -1296,13 +1288,10 @@ def format VectorReduceFloatWideningFormat(code, category, *flags) {{ dest_reg_id = "vecRegClass[_machInst.vd]" src1_reg_id = "vecRegClass[_machInst.vs1]" src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - old_dest_reg_id = "vecRegClass[_machInst.vd]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - # Treat tail undisturbed/agnostic as the same - # We always need old rd as src vreg - set_src_reg_idx += setSrcWrapper(old_dest_reg_id) + set_src_reg_idx += tailMaskCondSetSrcWrapper(setSrcWrapper(dest_reg_id)) set_src_reg_idx += setSrcVm() vm_decl_rd = vmDeclAndReadData() set_vlenb = setVlenb() @@ -1362,14 +1351,13 @@ def format VectorIntVxsatFormat(code, category, *flags) {{ else: error("not supported category for VectorIntVxsatFormat: %s" % category) src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - src3_reg_id = "vecRegClass[_machInst.vs3 + _microIdx]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = "" if category != "OPIVI": set_src_reg_idx += setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - set_src_reg_idx += setSrcWrapper(src3_reg_id) + set_src_reg_idx += tailMaskCondSetSrcWrapper(setSrcWrapper(dest_reg_id)) set_src_reg_idx += setSrcVm() vm_decl_rd = vmDeclAndReadData() @@ -1416,13 +1404,10 @@ def format VectorReduceIntWideningFormat(code, category, *flags) {{ dest_reg_id = "vecRegClass[_machInst.vd]" src1_reg_id = "vecRegClass[_machInst.vs1]" src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" - old_dest_reg_id = "vecRegClass[_machInst.vd]" set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = setSrcWrapper(src1_reg_id) set_src_reg_idx += setSrcWrapper(src2_reg_id) - # Treat tail undisturbed/agnostic as the same - # We always need old rd as src vreg - set_src_reg_idx += setSrcWrapper(old_dest_reg_id) + set_src_reg_idx += tailMaskCondSetSrcWrapper(setSrcWrapper(dest_reg_id)) set_src_reg_idx += setSrcVm() vm_decl_rd = vmDeclAndReadData() set_vlenb = setVlenb() @@ -1480,12 +1465,8 @@ def VectorSlideBase(name, Name, category, code, flags, macro_construtor, src1_ireg_id = "intRegClass[_machInst.rs1]" src1_freg_id = "floatRegClass[_machInst.rs1]" - # The tail of vector mask inst should be treated as tail-agnostic. - # We treat it with tail-undisturbed policy, since - # the test suits only support undisturbed policy. num_src_regs = 0 - old_dest_reg_id = "vecRegClass[_machInst.vd + vdIdx]" set_src_reg_idx = "" if category in ["OPIVX", "OPMVX"]: set_src_reg_idx += setSrcWrapper(src1_ireg_id) @@ -1495,8 +1476,6 @@ def VectorSlideBase(name, Name, category, code, flags, macro_construtor, num_src_regs += 1 set_src_reg_idx += setSrcWrapper(src2_reg_id) num_src_regs += 1 - old_vd_idx = num_src_regs - set_src_reg_idx += setSrcWrapper(old_dest_reg_id) set_dest_reg_idx = setDestWrapper(dest_reg_id) vm_decl_rd = vmDeclAndReadData() set_src_reg_idx += setSrcVm() @@ -1518,7 +1497,6 @@ def VectorSlideBase(name, Name, category, code, flags, macro_construtor, 'set_vlenb': set_vlenb, 'set_vlen': set_vlen, 'vm_decl_rd': vm_decl_rd, - 'copy_old_vd': copyOldVd(old_vd_idx), 'declare_varith_template': varith_micro_declare}, flags) diff --git a/src/arch/riscv/isa/templates/vector_arith.isa b/src/arch/riscv/isa/templates/vector_arith.isa index d000113a89..31d2c26147 100644 --- a/src/arch/riscv/isa/templates/vector_arith.isa +++ b/src/arch/riscv/isa/templates/vector_arith.isa @@ -31,12 +31,14 @@ output header {{ #define ASSIGN_VD_BIT(idx, bit) \ ((Vd[(idx)/8] & ~(1 << (idx)%8)) | ((bit) << (idx)%8)) -#define COPY_OLD_VD(idx) \ - [[maybe_unused]] RiscvISA::vreg_t old_vd; \ - [[maybe_unused]] decltype(Vd) old_Vd = nullptr; \ - xc->getRegOperand(this, (idx), &old_vd); \ - old_Vd = old_vd.as >(); \ - memcpy(Vd, old_Vd, vlenb); +#define COPY_OLD_VD(idx) \ + if (!machInst.vtype8.vta || (!machInst.vm && !machInst.vtype8.vma)) { \ + RiscvISA::vreg_t old_vd; \ + xc->getRegOperand(this, idx, &old_vd); \ + tmp_d0 = old_vd; \ + } else { \ + tmp_d0.set(0xff); \ + } \ #define VRM_REQUIRED \ uint_fast8_t frm = xc->readMiscReg(MISCREG_FRM); \ @@ -987,7 +989,7 @@ def template Vector1Vs1VdMaskDeclare {{ template class %(class_name)s : public %(base_class)s { private: - RegId srcRegIdxArr[3]; + RegId srcRegIdxArr[2]; RegId destRegIdxArr[1]; bool vm; public: @@ -1040,7 +1042,6 @@ Fault %(op_rd)s; %(set_vlenb)s; %(vm_decl_rd)s; - %(copy_old_vd)s; %(code)s; %(op_wb)s; return NoFault; @@ -1167,9 +1168,9 @@ template class %(class_name)s : public %(base_class)s { private: - // vs1(rs1), vs2, old_vd, v0 for *.vv[m] or *.vx[m] - // vs2, old_vd, v0 for *.vi[m] - RegId srcRegIdxArr[4]; + // vs1(rs1), vs2, v0 for *.vv[m] or *.vx[m] + // vs2, v0 for *.vi[m] + RegId srcRegIdxArr[3]; RegId destRegIdxArr[1]; bool vm; public: @@ -1228,7 +1229,6 @@ Fault %(set_vlenb)s; %(set_vlen)s; %(vm_decl_rd)s; - %(copy_old_vd)s; const uint32_t bit_offset = vlenb / sizeof(ElemType); const uint32_t offset = bit_offset * microIdx; @@ -1293,8 +1293,8 @@ template class %(class_name)s : public %(base_class)s { private: - // vs1(rs1), vs2, old_vd, v0 for *.vv or *.vf - RegId srcRegIdxArr[4]; + // vs1(rs1), vs2, v0 for *.vv or *.vf + RegId srcRegIdxArr[3]; RegId destRegIdxArr[1]; bool vm; public: @@ -1353,7 +1353,6 @@ Fault %(set_vlenb)s; %(set_vlen)s; %(vm_decl_rd)s; - %(copy_old_vd)s; const uint32_t bit_offset = vlenb / sizeof(ElemType); const uint32_t offset = bit_offset * microIdx; @@ -1470,7 +1469,7 @@ def template VectorMaskDeclare {{ template class %(class_name)s : public %(base_class)s { private: - RegId srcRegIdxArr[3]; + RegId srcRegIdxArr[2]; RegId destRegIdxArr[1]; public: %(class_name)s(ExtMachInst _machInst); @@ -1516,11 +1515,14 @@ Fault status.vs = VPUStatus::DIRTY; xc->setMiscReg(MISCREG_STATUS, status); + %(op_decl)s; %(op_rd)s; - // TODO: remove it %(set_vlenb)s; - %(copy_old_vd)s; + + // mask tails are always treated as agnostic: writting 1s + tmp_d0.set(0xff); + %(code)s; %(op_wb)s; @@ -1773,7 +1775,7 @@ Fault auto reduce_loop = [&, this](const auto& f, const auto* _, const auto* vs2) { - ElemType microop_result = this->microIdx != 0 ? old_Vd[0] : Vs1[0]; + ElemType microop_result = Vs1[0]; for (uint32_t i = 0; i < this->microVl; i++) { uint32_t ei = i + vtype_VLMAX(vtype, vlen, true) * this->microIdx; @@ -1822,8 +1824,6 @@ Fault %(vm_decl_rd)s; %(copy_old_vd)s; - Vd[0] = this->microIdx != 0 ? old_Vd[0] : Vs1[0]; - auto reduce_loop = [&, this](const auto& f, const auto* _, const auto* vs2) { vu tmp_val = Vd[0]; @@ -1874,8 +1874,6 @@ Fault %(vm_decl_rd)s; %(copy_old_vd)s; - Vd[0] = this->microIdx != 0 ? old_Vd[0] : Vs1[0]; - auto reduce_loop = [&, this](const auto& f, const auto* _, const auto* vs2) { vwu tmp_val = Vd[0]; @@ -1923,10 +1921,9 @@ template constexpr uint32_t vd_eewb = sizeof(ElemType); constexpr uint32_t vs2_eewb = sizeof(ElemType); constexpr uint32_t vs1_eewb = sizeof(IndexType); - constexpr bool vs1_split = vd_eewb > vs1_eewb; const int8_t lmul = vtype_vlmul(vtype); - const int8_t vs1_emul = lmul + - (vs1_split ? -(vs2_eewb / vs1_eewb) : vs1_eewb / vs2_eewb); + const int8_t vs1_emul = lmul + __builtin_ctz(vs1_eewb) + - __builtin_ctz(vs2_eewb); const uint8_t vs2_vregs = lmul < 0 ? 1 : 1 << lmul; const uint8_t vs1_vregs = vs1_emul < 0 ? 1 : 1 << vs1_emul; const uint8_t vd_vregs = vs2_vregs; @@ -1940,6 +1937,21 @@ template microop = new VectorNopMicroInst(_machInst); this->microops.push_back(microop); } + + uint32_t vd_vlmax = vlenb / vd_eewb; + uint32_t vs1_vlmax = vlenb / vs1_eewb; + for (uint32_t i = 0; i < ceil((float) this->vl / vd_vlmax); i++) { + uint32_t pinvd_micro_vl = (vd_vlmax*(i+1) <= remaining_vl) + ? vd_vlmax : remaining_vl; + uint8_t num_vd_pins = ceil((float) pinvd_micro_vl/vs1_vlmax)*vs2_vregs; + microop = new VPinVdMicroInst(machInst, i, num_vd_pins); + microop->setFlag(IsDelayedCommit); + this->microops.push_back(microop); + + remaining_vl -= pinvd_micro_vl; + } + + remaining_vl = this->vl; for (uint32_t i = 0; i < std::max(vs1_vregs, vd_vregs) && micro_vl > 0; i++) { for (uint8_t j = 0; j < vs2_vregs; j++) { @@ -1965,7 +1977,7 @@ template class %(class_name)s : public %(base_class)s { private: - // vs2, vs1, vd, vm + // vs2, vs1, vtmp0, vm RegId srcRegIdxArr[4]; RegId destRegIdxArr[1]; bool vm; @@ -2037,7 +2049,6 @@ Fault %(set_vlenb)s; %(set_vlen)s; %(vm_decl_rd)s; - %(copy_old_vd)s; const uint32_t vlmax = vtype_VLMAX(vtype,vlen); constexpr uint32_t vd_eewb = sizeof(ElemType); constexpr uint32_t vs1_eewb = sizeof(IndexType); @@ -2059,6 +2070,7 @@ Fault [[maybe_unused]] const uint32_t vd_bias = vd_elems * (vs1_idx % vd_split_num) / vd_split_num; + %(code)s; %(op_wb)s; @@ -2216,8 +2228,6 @@ Fault %(vm_decl_rd)s; %(copy_old_vd)s; - Vd[0] = this->microIdx != 0 ? old_Vd[0] : Vs1[0]; - auto reduce_loop = [&, this](const auto& f, const auto* _, const auto* vs2) { vwu tmp_val = Vd[0]; @@ -2271,6 +2281,13 @@ template microop = new VectorNopMicroInst(_machInst); this->microops.push_back(microop); } + + for (uint32_t i = 0; i < ceil((float) this->vl/micro_vlmax); i++) { + microop = new VPinVdMicroInst(machInst, i, i+1, true); + microop->setFlag(IsDelayedCommit); + this->microops.push_back(microop); + } + // Todo static filter out useless uop int micro_idx = 0; for (int i = 0; i < num_microops && micro_vl > 0; ++i) { @@ -2308,6 +2325,13 @@ template microop = new VectorNopMicroInst(_machInst); this->microops.push_back(microop); } + + for (uint32_t i = 0; i < ceil((float) this->vl / micro_vlmax); i++) { + microop = new VPinVdMicroInst(machInst, i, num_microops-i, false); + microop->setFlag(IsDelayedCommit); + this->microops.push_back(microop); + } + // Todo static filter out useless uop int micro_idx = 0; for (int i = 0; i < num_microops && micro_vl > 0; ++i) { @@ -2333,9 +2357,9 @@ template class %(class_name)s : public %(base_class)s { private: - // vs2, vs1, vs3(old_vd), vm for *.vv, *.vx - // vs2, (old_vd), vm for *.vi - RegId srcRegIdxArr[4]; + // vs2, vs1, vm for *.vv, *.vx + // vs2, vm for *.vi + RegId srcRegIdxArr[3]; RegId destRegIdxArr[1]; bool vm; public: @@ -2398,7 +2422,6 @@ Fault [[maybe_unused]]const uint32_t vlmax = vtype_VLMAX(vtype, vlen); %(vm_decl_rd)s; - %(copy_old_vd)s; %(code)s; %(op_wb)s; @@ -2439,7 +2462,6 @@ Fault [[maybe_unused]]const uint32_t vlmax = vtype_VLMAX(vtype, vlen); %(vm_decl_rd)s; - %(copy_old_vd)s; %(code)s; %(op_wb)s;