diff --git a/src/arch/riscv/isa/formats/vector_arith.isa b/src/arch/riscv/isa/formats/vector_arith.isa index 3862b1de02..dc831f1b6d 100644 --- a/src/arch/riscv/isa/formats/vector_arith.isa +++ b/src/arch/riscv/isa/formats/vector_arith.isa @@ -79,43 +79,79 @@ let {{ uint32_t ei = i + vtype_VLMAX(vtype, vlen, true) * this->microIdx; ''' + code - def wideningOpRegisterConstraintChecks(code): - return ''' - const uint32_t num_microops = 1 << std::max(0, vtype_vlmul(machInst.vtype8) + 1); - if ((machInst.vd % alignToPowerOfTwo(num_microops)) != 0) { + def wideningOpRegisterConstraintChecks(code, src2_sew_mul, dest_sew_mul, + src1_is_vec): + def checkOverlap(vreg_name, vreg_emul): + check_code = ''' + if ((({vreg_emul} < 0) && ({vreg_name} == VD)) || + (({vreg_emul} >= 0) && + ({vreg_name} < VD + num_microops - (1 << {vreg_emul})) && + (VD < {vreg_name} + (1 << {vreg_emul})))) { + // A destination vector register group can overlap a source + // vector register group if the destination EEW is greater than + // the source EEW, the source EMUL is at least 1, and the + // overlap is in the highest- numbered part of the destination + // register group. + std::string error = + csprintf("Unsupported overlap in {vreg_name} and VD for " + "Widening op"); + return std::make_shared(error, machInst); + } + ''' + check_code = check_code.replace("{vreg_name}", vreg_name) + check_code = check_code.replace("{vreg_emul}", vreg_emul) + return check_code + src2_sew_mul_bits = src2_sew_mul.bit_length() - 1 + dest_sew_mul_bits = dest_sew_mul.bit_length() - 1 + constraint_checks = ''' + const uint32_t num_microops = + 1 << std::max(0, vlmul + %d); + if ((machInst.vd %% alignToPowerOfTwo(num_microops)) != 0) { std::string error = csprintf("Unaligned Vd group in Widening op"); return std::make_shared(error, machInst); } - if ((machInst.vs2 <= machInst.vd) && (machInst.vd < (machInst.vs2 + num_microops - 1))) { - // A destination vector register group can overlap a source vector - // register group if The destination EEW is greater than the source - // EEW, the source EMUL is at least 1, and the overlap is in the - // highest- numbered part of the destination register group. + ''' % dest_sew_mul_bits + if src2_sew_mul_bits != dest_sew_mul_bits: + constraint_checks += ( + "const int64_t vs2_emul = vlmul + %d;" % src2_sew_mul_bits + ) + constraint_checks += checkOverlap("VS2", "vs2_emul") + if src1_is_vec: + constraint_checks += checkOverlap("VS1", "vlmul") + return constraint_checks + code + + def narrowingOpRegisterConstraintChecks(code, src2_sew_mul, src1_is_vec): + def checkOverlap(vreg_name): + check_code = ''' + if (({vreg_name} < VD) && + (VD <= ({vreg_name} + num_microops - 1))) { + // A destination vector register group can overlap a source + // vector register group if the destination EEW is smaller than + // the source EEW and the overlap is in the lowest-numbered + // part of the source register group std::string error = - csprintf("Unsupported overlap in Vs2 and Vd for Widening op"); + csprintf("Unsupported overlap in {vreg_name} and VD for " + "Narrowing op"); return std::make_shared(error, machInst); } - ''' + code - - def narrowingOpRegisterConstraintChecks(code): - return ''' - const uint32_t num_microops = 1 << std::max(0, vtype_vlmul(machInst.vtype8) + 1); - if ((machInst.vs2 % alignToPowerOfTwo(num_microops)) != 0) { + ''' + check_code = check_code.replace("{vreg_name}", vreg_name) + return check_code + src2_sew_mul_bits = src2_sew_mul.bit_length() - 1 + constraint_checks = ''' + const uint32_t num_microops = + 1 << std::max(0, vlmul + %d); + if ((machInst.vs2 %% alignToPowerOfTwo(num_microops)) != 0) { std::string error = csprintf("Unaligned VS2 group in Narrowing op"); return std::make_shared(error, machInst); } - if ((machInst.vs2 < machInst.vd) && (machInst.vd <= (VS2 + num_microops - 1))) { - // A destination vector register group can overlap a source vector - // register group The destination EEW is smaller than the source EEW - // and the overlap is in the lowest-numbered part of the source - // register group - std::string error = - csprintf("Unsupported overlap in Vs2 and Vd for Narrowing op"); - return std::make_shared(error, machInst); - } - ''' + code + ''' % src2_sew_mul_bits + constraint_checks += checkOverlap("VS2") + if src1_is_vec: + constraint_checks += checkOverlap("VS1") + return constraint_checks + code def fflags_wrapper(code): return ''' @@ -321,18 +357,23 @@ def format VectorIntWideningFormat(code, category, *flags) {{ old_vd_idx = 2 dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]" + dest_sew_mul = 2 src1_reg_id = "" + src1_is_vec = False if category in ["OPIVV", "OPMVV"]: src1_reg_id = "vecRegClass[_machInst.vs1 + _microIdx / 2]" + src1_is_vec = True elif category in ["OPIVX", "OPMVX"]: src1_reg_id = "intRegClass[_machInst.rs1]" else: error("not supported category for VectorIntFormat: %s" % category) src2_reg_id = "" + src2_sew_mul = 1 if inst_suffix in ["vv", "vx"]: src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx / 2]" elif inst_suffix in ["wv", "wx"]: src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" + src2_sew_mul = 2 set_dest_reg_idx = setDestWrapper(dest_reg_id) @@ -355,7 +396,8 @@ def format VectorIntWideningFormat(code, category, *flags) {{ code = eiDeclarePrefix(code, widening=True) code = loopWrapper(code) - code = wideningOpRegisterConstraintChecks(code) + code = wideningOpRegisterConstraintChecks(code, src2_sew_mul, dest_sew_mul, + src1_is_vec) vm_decl_rd = "" if v0_required: @@ -402,8 +444,10 @@ def format VectorIntNarrowingFormat(code, category, *flags) {{ old_vd_idx = 2 dest_reg_id = "vecRegClass[_machInst.vd + _microIdx / 2]" + src1_is_vec = False if category in ["OPIVV"]: src1_reg_id = "vecRegClass[_machInst.vs1 + _microIdx / 2]" + src1_is_vec = True elif category in ["OPIVX"]: src1_reg_id = "intRegClass[_machInst.rs1]" elif category == "OPIVI": @@ -411,6 +455,7 @@ def format VectorIntNarrowingFormat(code, category, *flags) {{ else: error("not supported category for VectorIntFormat: %s" % category) src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" + src2_sew_mul = 2 set_dest_reg_idx = setDestWrapper(dest_reg_id) set_src_reg_idx = "" @@ -423,7 +468,7 @@ def format VectorIntNarrowingFormat(code, category, *flags) {{ code = maskCondWrapper(code) code = eiDeclarePrefix(code, widening=True) code = loopWrapper(code) - code = narrowingOpRegisterConstraintChecks(code) + code = narrowingOpRegisterConstraintChecks(code, src2_sew_mul, src1_is_vec) vm_decl_rd = vmDeclAndReadData() set_vlenb = setVlenb(); @@ -544,7 +589,6 @@ def format VectorGatherFormat(code, category, *flags) {{ else: error("not supported category for VectorIntFormat: %s" % category) src2_reg_id = "vecRegClass[_machInst.vs2 + vs2_idx]" - src2_reg_id = "vecRegClass[_machInst.vs2 + vs2_idx]" # vtmp0 as dummy src reg to create dependency with pin vd micro src3_reg_id = "vecRegClass[VecMemInternalReg0 + vd_idx]" @@ -738,18 +782,23 @@ def format VectorFloatWideningFormat(code, category, *flags) {{ is_destructive_fused = iop.op_class == "SimdFloatMultAccOp" dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]" + dest_sew_mul = 2 src1_reg_id = "" + src1_is_vec = False if category in ["OPFVV"]: src1_reg_id = "vecRegClass[_machInst.vs1 + _microIdx / 2]" + src1_is_vec = True elif category in ["OPFVF"]: src1_reg_id = "floatRegClass[_machInst.rs1]" else: error("not supported category for VectorFloatFormat: %s" % category) src2_reg_id = "" + src2_sew_mul = 1 if inst_suffix in ["vv", "vf"]: src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx / 2]" elif inst_suffix in ["wv", "wf"]: src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" + src2_sew_mul = 2 set_dest_reg_idx = setDestWrapper(dest_reg_id) @@ -773,7 +822,8 @@ def format VectorFloatWideningFormat(code, category, *flags) {{ code = loopWrapper(code) code = fflags_wrapper(code) - code = wideningOpRegisterConstraintChecks(code) + code = wideningOpRegisterConstraintChecks(code, src2_sew_mul, dest_sew_mul, + src1_is_vec) vm_decl_rd = "" if v0_required: @@ -877,6 +927,7 @@ def format VectorFloatNarrowingCvtFormat(code, category, *flags) {{ old_vd_idx = 1 dest_reg_id = "vecRegClass[_machInst.vd + _microIdx / 2]" src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]" + src2_sew_mul = 2 set_dest_reg_idx = setDestWrapper(dest_reg_id) @@ -888,7 +939,7 @@ def format VectorFloatNarrowingCvtFormat(code, category, *flags) {{ code = eiDeclarePrefix(code, widening=True) code = loopWrapper(code) code = fflags_wrapper(code) - code = narrowingOpRegisterConstraintChecks(code) + code = narrowingOpRegisterConstraintChecks(code, src2_sew_mul, False) vm_decl_rd = vmDeclAndReadData()