arch-riscv: Improve widening/narrowing vectors overlap check (#1331)

This PR improves the vector register groups overlap check in
widening/narrowing
instructions.

- Fix wrong illegal overlap condition between VS2 and VD vector register
groups.
- Also check VS1 vector register group for overlap with VD in
vector-vector
instructions.
- Parametrize widening/narrowing factors in overlap check function to
potentially
handle more cases.

Fixes issue #442.
This commit is contained in:
Ivana Mitrovic
2024-07-22 10:54:02 -07:00
committed by GitHub

View File

@@ -79,43 +79,79 @@ let {{
uint32_t ei = i + vtype_VLMAX(vtype, vlen, true) * this->microIdx;
''' + code
def wideningOpRegisterConstraintChecks(code):
return '''
const uint32_t num_microops = 1 << std::max<int64_t>(0, vtype_vlmul(machInst.vtype8) + 1);
if ((machInst.vd % alignToPowerOfTwo(num_microops)) != 0) {
def wideningOpRegisterConstraintChecks(code, src2_sew_mul, dest_sew_mul,
src1_is_vec):
def checkOverlap(vreg_name, vreg_emul):
check_code = '''
if ((({vreg_emul} < 0) && ({vreg_name} == VD)) ||
(({vreg_emul} >= 0) &&
({vreg_name} < VD + num_microops - (1 << {vreg_emul})) &&
(VD < {vreg_name} + (1 << {vreg_emul})))) {
// A destination vector register group can overlap a source
// vector register group if the destination EEW is greater than
// the source EEW, the source EMUL is at least 1, and the
// overlap is in the highest- numbered part of the destination
// register group.
std::string error =
csprintf("Unsupported overlap in {vreg_name} and VD for "
"Widening op");
return std::make_shared<IllegalInstFault>(error, machInst);
}
'''
check_code = check_code.replace("{vreg_name}", vreg_name)
check_code = check_code.replace("{vreg_emul}", vreg_emul)
return check_code
src2_sew_mul_bits = src2_sew_mul.bit_length() - 1
dest_sew_mul_bits = dest_sew_mul.bit_length() - 1
constraint_checks = '''
const uint32_t num_microops =
1 << std::max<int64_t>(0, vlmul + %d);
if ((machInst.vd %% alignToPowerOfTwo(num_microops)) != 0) {
std::string error =
csprintf("Unaligned Vd group in Widening op");
return std::make_shared<IllegalInstFault>(error, machInst);
}
if ((machInst.vs2 <= machInst.vd) && (machInst.vd < (machInst.vs2 + num_microops - 1))) {
// A destination vector register group can overlap a source vector
// register group if The destination EEW is greater than the source
// EEW, the source EMUL is at least 1, and the overlap is in the
// highest- numbered part of the destination register group.
''' % dest_sew_mul_bits
if src2_sew_mul_bits != dest_sew_mul_bits:
constraint_checks += (
"const int64_t vs2_emul = vlmul + %d;" % src2_sew_mul_bits
)
constraint_checks += checkOverlap("VS2", "vs2_emul")
if src1_is_vec:
constraint_checks += checkOverlap("VS1", "vlmul")
return constraint_checks + code
def narrowingOpRegisterConstraintChecks(code, src2_sew_mul, src1_is_vec):
def checkOverlap(vreg_name):
check_code = '''
if (({vreg_name} < VD) &&
(VD <= ({vreg_name} + num_microops - 1))) {
// A destination vector register group can overlap a source
// vector register group if the destination EEW is smaller than
// the source EEW and the overlap is in the lowest-numbered
// part of the source register group
std::string error =
csprintf("Unsupported overlap in Vs2 and Vd for Widening op");
csprintf("Unsupported overlap in {vreg_name} and VD for "
"Narrowing op");
return std::make_shared<IllegalInstFault>(error, machInst);
}
''' + code
def narrowingOpRegisterConstraintChecks(code):
return '''
const uint32_t num_microops = 1 << std::max<int64_t>(0, vtype_vlmul(machInst.vtype8) + 1);
if ((machInst.vs2 % alignToPowerOfTwo(num_microops)) != 0) {
'''
check_code = check_code.replace("{vreg_name}", vreg_name)
return check_code
src2_sew_mul_bits = src2_sew_mul.bit_length() - 1
constraint_checks = '''
const uint32_t num_microops =
1 << std::max<int64_t>(0, vlmul + %d);
if ((machInst.vs2 %% alignToPowerOfTwo(num_microops)) != 0) {
std::string error =
csprintf("Unaligned VS2 group in Narrowing op");
return std::make_shared<IllegalInstFault>(error, machInst);
}
if ((machInst.vs2 < machInst.vd) && (machInst.vd <= (VS2 + num_microops - 1))) {
// A destination vector register group can overlap a source vector
// register group The destination EEW is smaller than the source EEW
// and the overlap is in the lowest-numbered part of the source
// register group
std::string error =
csprintf("Unsupported overlap in Vs2 and Vd for Narrowing op");
return std::make_shared<IllegalInstFault>(error, machInst);
}
''' + code
''' % src2_sew_mul_bits
constraint_checks += checkOverlap("VS2")
if src1_is_vec:
constraint_checks += checkOverlap("VS1")
return constraint_checks + code
def fflags_wrapper(code):
return '''
@@ -321,18 +357,23 @@ def format VectorIntWideningFormat(code, category, *flags) {{
old_vd_idx = 2
dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]"
dest_sew_mul = 2
src1_reg_id = ""
src1_is_vec = False
if category in ["OPIVV", "OPMVV"]:
src1_reg_id = "vecRegClass[_machInst.vs1 + _microIdx / 2]"
src1_is_vec = True
elif category in ["OPIVX", "OPMVX"]:
src1_reg_id = "intRegClass[_machInst.rs1]"
else:
error("not supported category for VectorIntFormat: %s" % category)
src2_reg_id = ""
src2_sew_mul = 1
if inst_suffix in ["vv", "vx"]:
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx / 2]"
elif inst_suffix in ["wv", "wx"]:
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]"
src2_sew_mul = 2
set_dest_reg_idx = setDestWrapper(dest_reg_id)
@@ -355,7 +396,8 @@ def format VectorIntWideningFormat(code, category, *flags) {{
code = eiDeclarePrefix(code, widening=True)
code = loopWrapper(code)
code = wideningOpRegisterConstraintChecks(code)
code = wideningOpRegisterConstraintChecks(code, src2_sew_mul, dest_sew_mul,
src1_is_vec)
vm_decl_rd = ""
if v0_required:
@@ -402,8 +444,10 @@ def format VectorIntNarrowingFormat(code, category, *flags) {{
old_vd_idx = 2
dest_reg_id = "vecRegClass[_machInst.vd + _microIdx / 2]"
src1_is_vec = False
if category in ["OPIVV"]:
src1_reg_id = "vecRegClass[_machInst.vs1 + _microIdx / 2]"
src1_is_vec = True
elif category in ["OPIVX"]:
src1_reg_id = "intRegClass[_machInst.rs1]"
elif category == "OPIVI":
@@ -411,6 +455,7 @@ def format VectorIntNarrowingFormat(code, category, *flags) {{
else:
error("not supported category for VectorIntFormat: %s" % category)
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]"
src2_sew_mul = 2
set_dest_reg_idx = setDestWrapper(dest_reg_id)
set_src_reg_idx = ""
@@ -423,7 +468,7 @@ def format VectorIntNarrowingFormat(code, category, *flags) {{
code = maskCondWrapper(code)
code = eiDeclarePrefix(code, widening=True)
code = loopWrapper(code)
code = narrowingOpRegisterConstraintChecks(code)
code = narrowingOpRegisterConstraintChecks(code, src2_sew_mul, src1_is_vec)
vm_decl_rd = vmDeclAndReadData()
set_vlenb = setVlenb();
@@ -544,7 +589,6 @@ def format VectorGatherFormat(code, category, *flags) {{
else:
error("not supported category for VectorIntFormat: %s" % category)
src2_reg_id = "vecRegClass[_machInst.vs2 + vs2_idx]"
src2_reg_id = "vecRegClass[_machInst.vs2 + vs2_idx]"
# vtmp0 as dummy src reg to create dependency with pin vd micro
src3_reg_id = "vecRegClass[VecMemInternalReg0 + vd_idx]"
@@ -738,18 +782,23 @@ def format VectorFloatWideningFormat(code, category, *flags) {{
is_destructive_fused = iop.op_class == "SimdFloatMultAccOp"
dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]"
dest_sew_mul = 2
src1_reg_id = ""
src1_is_vec = False
if category in ["OPFVV"]:
src1_reg_id = "vecRegClass[_machInst.vs1 + _microIdx / 2]"
src1_is_vec = True
elif category in ["OPFVF"]:
src1_reg_id = "floatRegClass[_machInst.rs1]"
else:
error("not supported category for VectorFloatFormat: %s" % category)
src2_reg_id = ""
src2_sew_mul = 1
if inst_suffix in ["vv", "vf"]:
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx / 2]"
elif inst_suffix in ["wv", "wf"]:
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]"
src2_sew_mul = 2
set_dest_reg_idx = setDestWrapper(dest_reg_id)
@@ -773,7 +822,8 @@ def format VectorFloatWideningFormat(code, category, *flags) {{
code = loopWrapper(code)
code = fflags_wrapper(code)
code = wideningOpRegisterConstraintChecks(code)
code = wideningOpRegisterConstraintChecks(code, src2_sew_mul, dest_sew_mul,
src1_is_vec)
vm_decl_rd = ""
if v0_required:
@@ -877,6 +927,7 @@ def format VectorFloatNarrowingCvtFormat(code, category, *flags) {{
old_vd_idx = 1
dest_reg_id = "vecRegClass[_machInst.vd + _microIdx / 2]"
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]"
src2_sew_mul = 2
set_dest_reg_idx = setDestWrapper(dest_reg_id)
@@ -888,7 +939,7 @@ def format VectorFloatNarrowingCvtFormat(code, category, *flags) {{
code = eiDeclarePrefix(code, widening=True)
code = loopWrapper(code)
code = fflags_wrapper(code)
code = narrowingOpRegisterConstraintChecks(code)
code = narrowingOpRegisterConstraintChecks(code, src2_sew_mul, False)
vm_decl_rd = vmDeclAndReadData()