arch-riscv: Generalize widening/narrowing vectors overlap check

As of now, the widening/narrowing vector register groups overlap check
always assumes a SEW multiplication factor equal to 2 (for either VD or
VS2). This commits aims at making this check more generic.

Change-Id: I4311fc3624cd324ccfdf2a1920a19efc85357120
This commit is contained in:
Tommaso Marinelli
2024-07-10 00:14:45 +02:00
parent 5b693fd8b6
commit a8b7e9727d

View File

@@ -79,20 +79,25 @@ let {{
uint32_t ei = i + vtype_VLMAX(vtype, vlen, true) * this->microIdx;
''' + code
def wideningOpRegisterConstraintChecks(code, src2_dw):
def wideningOpRegisterConstraintChecks(code, src2_sew_mul, dest_sew_mul):
src2_sew_mul_bits = src2_sew_mul.bit_length() - 1
dest_sew_mul_bits = dest_sew_mul.bit_length() - 1
constraint_checks = '''
const uint32_t num_microops = 1 << std::max<int64_t>(0, vlmul + 1);
if ((machInst.vd % alignToPowerOfTwo(num_microops)) != 0) {
const uint32_t num_microops =
1 << std::max<int64_t>(0, vlmul + %d);
if ((machInst.vd %% alignToPowerOfTwo(num_microops)) != 0) {
std::string error =
csprintf("Unaligned Vd group in Widening op");
return std::make_shared<IllegalInstFault>(error, machInst);
}
'''
if not src2_dw:
''' % dest_sew_mul_bits
if src2_sew_mul_bits != dest_sew_mul_bits:
constraint_checks += '''
if (((vlmul < 0) && (VS2 == VD)) ||
((vlmul >= 0) && (VS2 < VD + num_microops - (1 << vlmul)) &&
(VD < VS2 + (1 << vlmul)))) {
const int64_t vs2_emul = vlmul + %d;
if (((vs2_emul < 0) && (VS2 == VD)) ||
((vs2_emul >= 0) &&
(VS2 < VD + num_microops - (1 << vs2_emul)) &&
(VD < VS2 + (1 << vs2_emul)))) {
// A destination vector register group can overlap a source
// vector register group if the destination EEW is greater than
// the source EEW, the source EMUL is at least 1, and the
@@ -102,13 +107,15 @@ let {{
csprintf("Unsupported overlap in Vs2 and Vd for Widening op");
return std::make_shared<IllegalInstFault>(error, machInst);
}
'''
''' % src2_sew_mul_bits
return constraint_checks + code
def narrowingOpRegisterConstraintChecks(code):
def narrowingOpRegisterConstraintChecks(code, src2_sew_mul):
src2_sew_mul_bits = src2_sew_mul.bit_length() - 1
return '''
const uint32_t num_microops = 1 << std::max<int64_t>(0, vlmul + 1);
if ((machInst.vs2 % alignToPowerOfTwo(num_microops)) != 0) {
const uint32_t num_microops =
1 << std::max<int64_t>(0, vlmul + %d);
if ((machInst.vs2 %% alignToPowerOfTwo(num_microops)) != 0) {
std::string error =
csprintf("Unaligned VS2 group in Narrowing op");
return std::make_shared<IllegalInstFault>(error, machInst);
@@ -122,7 +129,7 @@ let {{
csprintf("Unsupported overlap in Vs2 and Vd for Narrowing op");
return std::make_shared<IllegalInstFault>(error, machInst);
}
''' + code
''' % src2_sew_mul_bits + code
def fflags_wrapper(code):
return '''
@@ -328,6 +335,7 @@ def format VectorIntWideningFormat(code, category, *flags) {{
old_vd_idx = 2
dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]"
dest_sew_mul = 2
src1_reg_id = ""
if category in ["OPIVV", "OPMVV"]:
src1_reg_id = "vecRegClass[_machInst.vs1 + _microIdx / 2]"
@@ -336,12 +344,12 @@ def format VectorIntWideningFormat(code, category, *flags) {{
else:
error("not supported category for VectorIntFormat: %s" % category)
src2_reg_id = ""
src2_dw = False
src2_sew_mul = 1
if inst_suffix in ["vv", "vx"]:
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx / 2]"
elif inst_suffix in ["wv", "wx"]:
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]"
src2_dw = True
src2_sew_mul = 2
set_dest_reg_idx = setDestWrapper(dest_reg_id)
@@ -364,7 +372,7 @@ def format VectorIntWideningFormat(code, category, *flags) {{
code = eiDeclarePrefix(code, widening=True)
code = loopWrapper(code)
code = wideningOpRegisterConstraintChecks(code, src2_dw)
code = wideningOpRegisterConstraintChecks(code, src2_sew_mul, dest_sew_mul)
vm_decl_rd = ""
if v0_required:
@@ -420,6 +428,7 @@ def format VectorIntNarrowingFormat(code, category, *flags) {{
else:
error("not supported category for VectorIntFormat: %s" % category)
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]"
src2_sew_mul = 2
set_dest_reg_idx = setDestWrapper(dest_reg_id)
set_src_reg_idx = ""
@@ -432,7 +441,7 @@ def format VectorIntNarrowingFormat(code, category, *flags) {{
code = maskCondWrapper(code)
code = eiDeclarePrefix(code, widening=True)
code = loopWrapper(code)
code = narrowingOpRegisterConstraintChecks(code)
code = narrowingOpRegisterConstraintChecks(code, src2_sew_mul)
vm_decl_rd = vmDeclAndReadData()
set_vlenb = setVlenb();
@@ -746,6 +755,7 @@ def format VectorFloatWideningFormat(code, category, *flags) {{
is_destructive_fused = iop.op_class == "SimdFloatMultAccOp"
dest_reg_id = "vecRegClass[_machInst.vd + _microIdx]"
dest_sew_mul = 2
src1_reg_id = ""
if category in ["OPFVV"]:
src1_reg_id = "vecRegClass[_machInst.vs1 + _microIdx / 2]"
@@ -754,12 +764,12 @@ def format VectorFloatWideningFormat(code, category, *flags) {{
else:
error("not supported category for VectorFloatFormat: %s" % category)
src2_reg_id = ""
src2_dw = False
src2_sew_mul = 1
if inst_suffix in ["vv", "vf"]:
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx / 2]"
elif inst_suffix in ["wv", "wf"]:
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]"
src2_dw = True
src2_sew_mul = 2
set_dest_reg_idx = setDestWrapper(dest_reg_id)
@@ -783,7 +793,7 @@ def format VectorFloatWideningFormat(code, category, *flags) {{
code = loopWrapper(code)
code = fflags_wrapper(code)
code = wideningOpRegisterConstraintChecks(code, src2_dw)
code = wideningOpRegisterConstraintChecks(code, src2_sew_mul, dest_sew_mul)
vm_decl_rd = ""
if v0_required:
@@ -887,6 +897,7 @@ def format VectorFloatNarrowingCvtFormat(code, category, *flags) {{
old_vd_idx = 1
dest_reg_id = "vecRegClass[_machInst.vd + _microIdx / 2]"
src2_reg_id = "vecRegClass[_machInst.vs2 + _microIdx]"
src2_sew_mul = 2
set_dest_reg_idx = setDestWrapper(dest_reg_id)
@@ -898,7 +909,7 @@ def format VectorFloatNarrowingCvtFormat(code, category, *flags) {{
code = eiDeclarePrefix(code, widening=True)
code = loopWrapper(code)
code = fflags_wrapper(code)
code = narrowingOpRegisterConstraintChecks(code)
code = narrowingOpRegisterConstraintChecks(code, src2_sew_mul)
vm_decl_rd = vmDeclAndReadData()