From 98e67c8610fd016bc4150a6ab12f1497a2524a50 Mon Sep 17 00:00:00 2001 From: Richard Cooper Date: Mon, 28 Sep 2020 17:50:52 +0100 Subject: [PATCH] arch-arm: Add support for Arm SVE Integer Matrix instructions. Add support for the Arm SVE Integer Matrix Multiply-Accumulate (SMMLA, USMMLA, UMMLA) instructions. Because the associated SUDOT and USDOT instructions have not yet been implemented, the SVE Feature ID register 0 (ID_AA64ZFR0_EL1) has not yet been updated to indicate support for SVE Int8 matrix multiplication instructions at this time. For more information please refer to the "ARM Architecture Reference Manual Supplement - The Scalable Vector Extension (SVE), for ARMv8-A" (https://developer.arm.com/architectures/cpu-architecture/a-profile/ docs/arm-architecture-reference-manual-supplement-armv8-a) Additional Contributors: Giacomo Travaglini Change-Id: Ia50e28fae03634cbe04b42a9900bab65a604817f Reviewed-by: Richard Cooper Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/70730 Maintainer: Andreas Sandberg Tested-by: kokoro Reviewed-by: Andreas Sandberg --- src/arch/arm/isa/formats/sve_2nd_level.isa | 43 ++++++++++++++++++++++ src/arch/arm/isa/formats/sve_top_level.isa | 11 ++++++ src/arch/arm/isa/insts/sve.isa | 16 ++++++++ 3 files changed, 70 insertions(+) diff --git a/src/arch/arm/isa/formats/sve_2nd_level.isa b/src/arch/arm/isa/formats/sve_2nd_level.isa index 3d211bc19b..4a44bab9b2 100644 --- a/src/arch/arm/isa/formats/sve_2nd_level.isa +++ b/src/arch/arm/isa/formats/sve_2nd_level.isa @@ -245,6 +245,33 @@ namespace Aarch64 return new Unknown64(machInst); } // decodeSveIntMulAdd + StaticInstPtr + decodeSveIntMatMulAdd(ExtMachInst machInst) + { + RegIndex zda = (RegIndex) (uint8_t) bits(machInst, 4, 0); + RegIndex zn = (RegIndex) (uint8_t) bits(machInst, 9, 5); + RegIndex zm = (RegIndex) (uint8_t) bits(machInst, 20, 16); + + uint8_t uns = bits(machInst, 23, 22); + + switch (uns) { + case 0x0: + return new SveSmmla( + machInst, zda, zn, zm); + case 0x2: + return new SveUsmmla( + machInst, zda, zn, zm); + case 0x3: + return new SveUmmla( + machInst, zda, zn, zm); + case 0x1: + default: + return new Unknown64(machInst); + } + + return new Unknown64(machInst); + } // decodeSveIntMatMulAdd + StaticInstPtr decodeSveShiftByImmPred0(ExtMachInst machInst) { @@ -3809,5 +3836,21 @@ namespace Aarch64 return new Unknown64(machInst); } // decodeSveMemStore + StaticInstPtr + decodeSveMisc(ExtMachInst machInst) { + switch(bits(machInst, 13, 10)) { + case 0b0110: { + return decodeSveIntMatMulAdd(machInst); + break; + } + default: { + return new Unknown64(machInst); + break; + } + } + return new Unknown64(machInst); + } // decodeSveMisc + + } // namespace Aarch64 }}; diff --git a/src/arch/arm/isa/formats/sve_top_level.isa b/src/arch/arm/isa/formats/sve_top_level.isa index 61f2f5ca6c..20a15a2971 100644 --- a/src/arch/arm/isa/formats/sve_top_level.isa +++ b/src/arch/arm/isa/formats/sve_top_level.isa @@ -44,6 +44,7 @@ namespace Aarch64 StaticInstPtr decodeSveShiftByImmPred(ExtMachInst machInst); StaticInstPtr decodeSveIntArithUnaryPred(ExtMachInst machInst); StaticInstPtr decodeSveIntMulAdd(ExtMachInst machInst); + StaticInstPtr decodeSveIntMatMulAdd(ExtMachInst machInst); StaticInstPtr decodeSveIntArithUnpred(ExtMachInst machInst); StaticInstPtr decodeSveIntLogUnpred(ExtMachInst machInst); StaticInstPtr decodeSveIndexGen(ExtMachInst machInst); @@ -94,6 +95,8 @@ namespace Aarch64 StaticInstPtr decodeSveMemContigLoad(ExtMachInst machInst); StaticInstPtr decodeSveMemGather64(ExtMachInst machInst); StaticInstPtr decodeSveMemStore(ExtMachInst machInst); + + StaticInstPtr decodeSveMisc(ExtMachInst machInst); } }}; @@ -104,6 +107,14 @@ namespace Aarch64 StaticInstPtr decodeSveInt(ExtMachInst machInst) { + if (bits(machInst, 31, 29) == 0b010) { + if (bits(machInst, 24) == 0b1 && + bits(machInst, 21) == 0b0 && + bits(machInst, 15, 14)==0b10) { + return decodeSveMisc(machInst); + } + } + uint8_t b_29_24_21 = (bits(machInst, 29) << 2) | (bits(machInst, 24) << 1) | bits(machInst, 21); diff --git a/src/arch/arm/isa/insts/sve.isa b/src/arch/arm/isa/insts/sve.isa index cbaa2b528a..6e8313bda5 100644 --- a/src/arch/arm/isa/insts/sve.isa +++ b/src/arch/arm/isa/insts/sve.isa @@ -4250,6 +4250,22 @@ let {{ sbclbCode = 'res = srcElem1 + ~(srcElem2) + carryIn;' sveTerInstUnpred('sbclb', 'Sbclb', 'VectorIntegerArithOp', unsignedTypes, sbclbCode, isTop=False, isAdd=False) + mmlaCode = ('destElem += srcElemA * srcElemB') + # SMMLA (vectors) + sveMatMulInst('smmla', 'Smmla', 'SimdMultAccOp', + (('int32_t', 'int8_t', 'int8_t'),), + numDestRows=2, numDestCols=2, K=8, + elt_mul_op=mmlaCode) + # USMMLA (vectors) + sveMatMulInst('usmmla', 'Usmmla', 'SimdMultAccOp', + (('int32_t', 'uint8_t', 'int8_t'),), + numDestRows=2, numDestCols=2, K=8, + elt_mul_op=mmlaCode) + # UMMLA (vectors) + sveMatMulInst('ummla', 'Ummla', 'SimdMultAccOp', + (('uint32_t', 'uint8_t', 'uint8_t'),), + numDestRows=2, numDestCols=2, K=8, + elt_mul_op=mmlaCode) # MOVPRFX (predicated) movCode = 'destElem = srcElem1;' sveUnaryInst('movprfx', 'MovprfxPredM', 'SimdMiscOp', unsignedTypes,