From 9de1443ebb8196950bcfdfc481ee0550feab0d7b Mon Sep 17 00:00:00 2001 From: Richard Cooper Date: Mon, 9 Nov 2020 19:05:09 +0000 Subject: [PATCH] arch-arm: Add support for Armv8.2-I8MM NEON extension. Add support for the Armv8.2-I8MM NEON extension. This provides the SUDOT and USDOT mixed-sign SIMD Dot Product instructions, as well as the SMMLA, UMMLA, and USMMLA SIMD Matrix Multiply-Accumulate instructions. For more information please refer to the Arm Architecture Reference Manual (https://developer.arm.com/documentation/ddi0487/latest/). Additional Contributors: Giacomo Travaglini Change-Id: I6fb9318f67cc9d2737079283e1a095630c4d2ad9 Reviewed-by: Richard Cooper Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/70737 Maintainer: Jason Lowe-Power Reviewed-by: Andreas Sandberg Reviewed-by: Jason Lowe-Power Maintainer: Andreas Sandberg Tested-by: kokoro --- src/arch/arm/isa/formats/neon64.isa | 51 ++++++++++++++++- src/arch/arm/isa/insts/neon64.isa | 88 +++++++++++++++++++++++++++++ src/arch/arm/process.cc | 3 + src/arch/arm/regs/misc.cc | 1 + src/arch/arm/regs/misc_types.hh | 1 + 5 files changed, 141 insertions(+), 3 deletions(-) diff --git a/src/arch/arm/isa/formats/neon64.isa b/src/arch/arm/isa/formats/neon64.isa index e083f6f25c..8d9b74dfa3 100644 --- a/src/arch/arm/isa/formats/neon64.isa +++ b/src/arch/arm/isa/formats/neon64.isa @@ -510,6 +510,7 @@ namespace Aarch64 decodeNeon3RegExtension(ExtMachInst machInst) { uint8_t q = bits(machInst, 30); + uint8_t u = bits(machInst, 29); uint8_t qu = bits(machInst, 30, 29); uint8_t size = bits(machInst, 23, 22); uint8_t opcode = bits(machInst, 15, 11); @@ -546,6 +547,20 @@ namespace Aarch64 default: return new Unknown64(machInst); } + case 0x13: + if (q) { + return new UsdotQX(machInst, vd, vn, vm); + } else { + return new UsdotDX(machInst, vd, vn, vm); + } + case 0x14: + if (u) { + return new UmmlaQX(machInst, vd, vn, vm); + } else { + return new SmmlaQX(machInst, vd, vn, vm); + } + case 0x15: + return new UsmmlaQX(machInst, vd, vn, vm); case 0x18: case 0x19: case 0x1a: @@ -1590,9 +1605,6 @@ namespace Aarch64 return decodeNeonSThreeImmHAndWReg( q, size, machInst, vd, vn, vm, index); - case 0xf: - return decodeNeonSThreeImmHAndWReg( - q, size, machInst, vd, vn, vm, index); case 0xe: switch (qu) { case 0b00: @@ -1610,6 +1622,39 @@ namespace Aarch64 default: return new Unknown64(machInst); } + case 0xf: + if (u) { + // Armv8.2-RDMA SQRDMLSH(elem) + return decodeNeonSThreeImmHAndWReg( + q, size, machInst, vd, vn, vm, index); + } else { + switch (size) { + case 0x0: + if (q) { + return new SudotElemQX( + machInst, vd, vn, vm_dp, index_dp); + } else { + return new SudotElemDX( + machInst, vd, vn, vm_dp, index_dp); + } + case 0x1: + // Armv8.2-BF16 BFDOT(elem) + return new Unknown64(machInst); + case 0x2: + if (q) { + return new UsdotElemQX( + machInst, vd, vn, vm_dp, index_dp); + } else { + return new UsdotElemDX( + machInst, vd, vn, vm_dp, index_dp); + } + case 0x3: + default: + // Armv8.2-BF16 BFMLALB(elem), BFMLALT(elem) + return new Unknown64(machInst); + } + } default: return new Unknown64(machInst); } diff --git a/src/arch/arm/isa/insts/neon64.isa b/src/arch/arm/isa/insts/neon64.isa index 53c0f112bf..6608f61688 100644 --- a/src/arch/arm/isa/insts/neon64.isa +++ b/src/arch/arm/isa/insts/neon64.isa @@ -1146,6 +1146,94 @@ let {{ # UDOT (element) intDotInst('udot', 'UdotElemDX', 'SimdAluOp', False, False, False, 2, True) intDotInst('udot', 'UdotElemQX', 'SimdAluOp', False, False, False, 4, True) + # SUDOT (element) + intDotInst('sudot', 'SudotElemDX', 'SimdAluOp', True, True, False, 2, True) + intDotInst('sudot', 'SudotElemQX', 'SimdAluOp', True, True, False, 4, True) + # USDOT (vector) + intDotInst('usdot', 'UsdotDX', 'SimdAluOp', True, False, True, 2, False) + intDotInst('usdot', 'UsdotQX', 'SimdAluOp', True, False, True, 4, False) + # USDOT (element) + intDotInst('usdot', 'UsdotElemDX', 'SimdAluOp', True, False, True, 2, True) + intDotInst('usdot', 'UsdotElemQX', 'SimdAluOp', True, False, True, 4, True) + + def intMatMulInst(name, Name, opClass, + destIsSigned, src1IsSigned, src2IsSigned): + destType = "int32_t" if destIsSigned else "uint32_t" + src1Type = "int8_t" if src1IsSigned else "uint8_t" + src2Type = "int8_t" if src2IsSigned else "uint8_t" + matMulCode = ''' + using Src1Element = %(src1Type)s; + using Src2Element = %(src2Type)s; + + // Neon MM instructions always generate four output elements + // from 16 pairs of source elements. + static_assert(sizeof(Element) == 4 * sizeof(Src1Element)); + static_assert(sizeof(Element) == 4 * sizeof(Src2Element)); + + // Extended source element types to avoid overflow of intermediate + // calculations. + using ExtendedSrc1Element = + typename vector_element_traits:: + extend_element::type; + using ExtendedSrc2Element = + typename vector_element_traits:: + extend_element::type; + + // Properties of the matrices + constexpr unsigned destMatSize = 2; // Dest Matrices are dim 2x2 + constexpr unsigned K = 8; // Src matrices are dim 2x8 & 8x2 + + constexpr unsigned eltsPerMatrix = destMatSize * destMatSize; + + Element destMat[eltsPerMatrix] = {0}; + for (unsigned j = 0; j < eltsPerMatrix; ++j) { + destMat[j] = letoh(destReg.elements[j]); + } + + Element src1MatPacked[eltsPerMatrix] = {0}; + Element src2MatPacked[eltsPerMatrix] = {0}; + for (unsigned j = 0; j < eltsPerMatrix; ++j) { + src1MatPacked[j] = letoh(srcReg1.elements[j]); + src2MatPacked[j] = letoh(srcReg2.elements[j]); + } + + Src1Element *src1Mat = + reinterpret_cast(&src1MatPacked); + Src2Element *src2Mat = + reinterpret_cast(&src2MatPacked); + + unsigned destEltIdx = 0; + for (unsigned rowIdx = 0; rowIdx < destMatSize; ++rowIdx) { + for (unsigned colIdx = 0; colIdx < destMatSize; ++colIdx) { + Element destElem = destMat[destEltIdx]; + for (unsigned k = 0; k < K; ++k) { + const ExtendedSrc1Element src1Elem = + static_cast + (src1Mat[K * rowIdx + k]); + const ExtendedSrc2Element src2Elem = + static_cast + (src2Mat[K * colIdx + k]); + + destElem += src1Elem * src2Elem; + } + destMat[destEltIdx++] = destElem; + } + } + + for (unsigned j = 0; j < eltsPerMatrix; ++j) { + destReg.elements[j] = htole(destMat[j]); + } + ''' % dict(src1Type=src1Type, src2Type=src2Type) + threeEqualRegInstX(name, Name, opClass, (destType,), 4, + matMulCode, readDest=True, byElem=False, + complex=True) + + # SMMLA + intMatMulInst('smmla', 'SmmlaQX', 'SimdMatMultAccOp', True, True, True) + # USMMLA + intMatMulInst('usmmla', 'UsmmlaQX', 'SimdMatMultAccOp', True, False, True) + # UMMLA + intMatMulInst('ummla', 'UmmlaQX', 'SimdMatMultAccOp', False, False, False) # CLS clsCode = ''' diff --git a/src/arch/arm/process.cc b/src/arch/arm/process.cc index b63567b6c3..9aa519fe36 100644 --- a/src/arch/arm/process.cc +++ b/src/arch/arm/process.cc @@ -320,6 +320,9 @@ ArmProcess64::armHwcapImpl2() const hwcap |= (isa_r0.ts >= 2) ? Arm_Flagm2 : Arm_None; hwcap |= (isa_r0.rndr >= 1) ? Arm_Rng : Arm_None; + const AA64ISAR1 isa_r1 = tc->readMiscReg(MISCREG_ID_AA64ISAR1_EL1); + hwcap |= (isa_r1.i8mm >= 1) ? Arm_I8mm : Arm_None; + const AA64ZFR0 zf_r0 = tc->readMiscReg(MISCREG_ID_AA64ZFR0_EL1); hwcap |= (zf_r0.f32mm >= 1) ? Arm_Svef32mm : Arm_None; hwcap |= (zf_r0.f64mm >= 1) ? Arm_Svef64mm : Arm_None; diff --git a/src/arch/arm/regs/misc.cc b/src/arch/arm/regs/misc.cc index ed15f25e69..dcb6e2b048 100644 --- a/src/arch/arm/regs/misc.cc +++ b/src/arch/arm/regs/misc.cc @@ -4005,6 +4005,7 @@ ISA::initializeMiscRegMetadata() InitReg(MISCREG_ID_AA64ISAR1_EL1) .reset([p,release=release](){ AA64ISAR1 isar1_el1 = p.id_aa64isar1_el1; + isar1_el1.i8mm = release->has(ArmExtension::FEAT_I8MM) ? 0x1 : 0x0; isar1_el1.apa = release->has(ArmExtension::FEAT_PAuth) ? 0x1 : 0x0; isar1_el1.jscvt = release->has(ArmExtension::FEAT_JSCVT) ? 0x1 : 0x0; isar1_el1.fcma = release->has(ArmExtension::FEAT_FCMA) ? 0x1 : 0x0; diff --git a/src/arch/arm/regs/misc_types.hh b/src/arch/arm/regs/misc_types.hh index b7a1207cf5..4bb234fd10 100644 --- a/src/arch/arm/regs/misc_types.hh +++ b/src/arch/arm/regs/misc_types.hh @@ -127,6 +127,7 @@ namespace ArmISA EndBitUnion(AA64ISAR0) BitUnion64(AA64ISAR1) + Bitfield<55, 52> i8mm; Bitfield<43, 40> specres; Bitfield<39, 36> sb; Bitfield<35, 32> frintts;