diff --git a/src/arch/arm/ArmISA.py b/src/arch/arm/ArmISA.py index ffe63ebb0a..8c1ee5ae42 100644 --- a/src/arch/arm/ArmISA.py +++ b/src/arch/arm/ArmISA.py @@ -57,6 +57,7 @@ class ArmDefaultSERelease(ArmRelease): "FEAT_F64MM", "FEAT_SVE", "FEAT_I8MM", + "FEAT_DOTPROD", # Armv8.3 "FEAT_FCMA", "FEAT_JSCVT", diff --git a/src/arch/arm/ArmSystem.py b/src/arch/arm/ArmSystem.py index c5c0f436a3..eaaf4b1cb3 100644 --- a/src/arch/arm/ArmSystem.py +++ b/src/arch/arm/ArmSystem.py @@ -81,6 +81,7 @@ class ArmExtension(ScopedEnum): "FEAT_F32MM", # Optional in Armv8.2 "FEAT_F64MM", # Optional in Armv8.2 "FEAT_I8MM", # Optional in Armv8.2 + "FEAT_DOTPROD", # Optional in Armv8.2 # Armv8.3 "FEAT_FCMA", "FEAT_JSCVT", @@ -169,6 +170,7 @@ class ArmDefaultRelease(Armv8): "FEAT_F32MM", "FEAT_F64MM", "FEAT_I8MM", + "FEAT_DOTPROD", # Armv8.3 "FEAT_FCMA", "FEAT_JSCVT", @@ -205,6 +207,7 @@ class Armv82(Armv81): "FEAT_F32MM", "FEAT_F64MM", "FEAT_I8MM", + "FEAT_DOTPROD", ] diff --git a/src/arch/arm/isa/formats/neon64.isa b/src/arch/arm/isa/formats/neon64.isa index 5cce0d7c23..e083f6f25c 100644 --- a/src/arch/arm/isa/formats/neon64.isa +++ b/src/arch/arm/isa/formats/neon64.isa @@ -510,6 +510,7 @@ namespace Aarch64 decodeNeon3RegExtension(ExtMachInst machInst) { uint8_t q = bits(machInst, 30); + uint8_t qu = bits(machInst, 30, 29); uint8_t size = bits(machInst, 23, 22); uint8_t opcode = bits(machInst, 15, 11); @@ -532,6 +533,19 @@ namespace Aarch64 else return decodeNeonSThreeHAndWReg( size, machInst, vd, vn, vm); + case 0x12: + switch (qu) { + case 0b00: + return new SdotDX(machInst, vd, vn, vm); + case 0b01: + return new UdotDX(machInst, vd, vn, vm); + case 0b10: + return new SdotQX(machInst, vd, vn, vm); + case 0b11: + return new UdotQX(machInst, vd, vn, vm); + default: + return new Unknown64(machInst); + } case 0x18: case 0x19: case 0x1a: @@ -1351,6 +1365,7 @@ namespace Aarch64 { uint8_t q = bits(machInst, 30); uint8_t u = bits(machInst, 29); + uint8_t qu = bits(machInst, 30, 29); uint8_t size = bits(machInst, 23, 22); uint8_t L = bits(machInst, 21); uint8_t M = bits(machInst, 20); @@ -1387,6 +1402,11 @@ namespace Aarch64 } RegIndex vm_fp = (RegIndex) (uint8_t) (vmh << 4 | vm_bf); + // Index and 2nd register operand for FEAT_DOTPROD and + // FEAT_I8MM instructions + uint8_t index_dp = (H << 1) | L; + RegIndex vm_dp = (RegIndex) (uint8_t) (M << 4 | vm_bf); + switch (opcode) { case 0x0: if (!u || (size == 0x0 || size == 0x3)) @@ -1573,6 +1593,23 @@ namespace Aarch64 case 0xf: return decodeNeonSThreeImmHAndWReg( q, size, machInst, vd, vn, vm, index); + case 0xe: + switch (qu) { + case 0b00: + return new SdotElemDX(machInst, + vd, vn, vm_dp, index_dp); + case 0b01: + return new UdotElemDX(machInst, + vd, vn, vm_dp, index_dp); + case 0b10: + return new SdotElemQX(machInst, + vd, vn, vm_dp, index_dp); + case 0b11: + return new UdotElemQX(machInst, + vd, vn, vm_dp, index_dp); + default: + return new Unknown64(machInst); + } default: return new Unknown64(machInst); } diff --git a/src/arch/arm/isa/insts/neon64.isa b/src/arch/arm/isa/insts/neon64.isa index 0da7f06ec3..53c0f112bf 100644 --- a/src/arch/arm/isa/insts/neon64.isa +++ b/src/arch/arm/isa/insts/neon64.isa @@ -1082,6 +1082,71 @@ let {{ complex=True) threeEqualRegInstX("fcmla", "FcmlaQX", "SimdFloatMultAccOp", floatTypes, 4, fcmla_vec, True, complex=True) + + def intDotInst(name, Name, opClass, + destIsSigned, src1IsSigned, src2IsSigned, + rCount, byElem): + destType = "int32_t" if destIsSigned else "uint32_t" + src1Type = "int8_t" if src1IsSigned else "uint8_t" + src2Type = "int8_t" if src2IsSigned else "uint8_t" + dotCode = ''' + using Src1Element = %(src1Type)s; + using Src2Element = %(src2Type)s; + + // Neon dot instructions always generate one output element + // from 4 pairs of source elements. + static_assert(sizeof(Element) == 4 * sizeof(Src1Element)); + static_assert(sizeof(Element) == 4 * sizeof(Src2Element)); + + // Extended source element types to avoid overflow of intermediate + // calculations. + using ExtendedSrc1Element = + typename vector_element_traits:: + extend_element::type; + using ExtendedSrc2Element = + typename vector_element_traits:: + extend_element::type; + + for (unsigned i = 0; i < eCount; ++i) { + Element src1ElemsPacked = letoh(srcReg1.elements[i]); + Element src2ElemsPacked = letoh(srcReg2.elements[%(src2Index)s]); + + Src1Element *src1Elems = + reinterpret_cast(&src1ElemsPacked); + Src2Element *src2Elems = + reinterpret_cast(&src2ElemsPacked); + + // Dot instructions accumulate into the dest reg + Element destElem = letoh(destReg.elements[i]); + + for (unsigned j = 0; j < 4; ++j) { + ExtendedSrc1Element src1Elem = + static_cast(src1Elems[j]); + ExtendedSrc2Element src2Elem = + static_cast(src2Elems[j]); + destElem += src1Elem * src2Elem; + } + destReg.elements[i] = htole(destElem); + } + ''' % dict(src1Type=src1Type, src2Type=src2Type, + src2Index="imm" if byElem else "i") + threeEqualRegInstX(name, Name, opClass, (destType,), rCount, + dotCode, readDest=True, byElem=byElem, + complex=True) + + # SDOT (vector) + intDotInst('sdot', 'SdotDX', 'SimdAluOp', True, True, True, 2, False) + intDotInst('sdot', 'SdotQX', 'SimdAluOp', True, True, True, 4, False) + # SDOT (element) + intDotInst('sdot', 'SdotElemDX', 'SimdAluOp', True, True, True, 2, True) + intDotInst('sdot', 'SdotElemQX', 'SimdAluOp', True, True, True, 4, True) + # UDOT (vector) + intDotInst('udot', 'UdotDX', 'SimdAluOp', False, False, False, 2, False) + intDotInst('udot', 'UdotQX', 'SimdAluOp', False, False, False, 4, False) + # UDOT (element) + intDotInst('udot', 'UdotElemDX', 'SimdAluOp', False, False, False, 2, True) + intDotInst('udot', 'UdotElemQX', 'SimdAluOp', False, False, False, 4, True) + # CLS clsCode = ''' unsigned count = 0; diff --git a/src/arch/arm/regs/misc.cc b/src/arch/arm/regs/misc.cc index b978044855..ed15f25e69 100644 --- a/src/arch/arm/regs/misc.cc +++ b/src/arch/arm/regs/misc.cc @@ -3988,6 +3988,7 @@ ISA::initializeMiscRegMetadata() isar0_el1.sha1 = 0; isar0_el1.aes = 0; } + isar0_el1.dp = release->has(ArmExtension::FEAT_DOTPROD) ? 0x1 : 0x0; isar0_el1.atomic = release->has(ArmExtension::FEAT_LSE) ? 0x2 : 0x0; isar0_el1.rdm = release->has(ArmExtension::FEAT_RDM) ? 0x1 : 0x0; isar0_el1.tme = release->has(ArmExtension::TME) ? 0x1 : 0x0;