arch-arm: Add support for Arm SVE Integer Matrix instructions.

Add support for the Arm SVE Integer Matrix Multiply-Accumulate
(SMMLA, USMMLA, UMMLA) instructions. Because the associated SUDOT and
USDOT instructions have not yet been implemented, the SVE Feature ID
register 0 (ID_AA64ZFR0_EL1) has not yet been updated to indicate
support for SVE Int8 matrix multiplication instructions at this time.

For more information please refer to the "ARM Architecture Reference
Manual Supplement - The Scalable Vector Extension (SVE), for ARMv8-A"
(https://developer.arm.com/architectures/cpu-architecture/a-profile/
docs/arm-architecture-reference-manual-supplement-armv8-a)

Additional Contributors: Giacomo Travaglini

Change-Id: Ia50e28fae03634cbe04b42a9900bab65a604817f
Reviewed-by: Richard Cooper <richard.cooper@arm.com>
Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/70730
Maintainer: Andreas Sandberg <andreas.sandberg@arm.com>
Tested-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Andreas Sandberg <andreas.sandberg@arm.com>
This commit is contained in:
Richard Cooper
2020-09-28 17:50:52 +01:00
committed by Bobby Bruce
parent 0f857873f9
commit 98e67c8610
3 changed files with 70 additions and 0 deletions

View File

@@ -245,6 +245,33 @@ namespace Aarch64
return new Unknown64(machInst);
} // decodeSveIntMulAdd
StaticInstPtr
decodeSveIntMatMulAdd(ExtMachInst machInst)
{
RegIndex zda = (RegIndex) (uint8_t) bits(machInst, 4, 0);
RegIndex zn = (RegIndex) (uint8_t) bits(machInst, 9, 5);
RegIndex zm = (RegIndex) (uint8_t) bits(machInst, 20, 16);
uint8_t uns = bits(machInst, 23, 22);
switch (uns) {
case 0x0:
return new SveSmmla<int32_t, int8_t, int8_t>(
machInst, zda, zn, zm);
case 0x2:
return new SveUsmmla<int32_t, uint8_t, int8_t>(
machInst, zda, zn, zm);
case 0x3:
return new SveUmmla<uint32_t, uint8_t, uint8_t>(
machInst, zda, zn, zm);
case 0x1:
default:
return new Unknown64(machInst);
}
return new Unknown64(machInst);
} // decodeSveIntMatMulAdd
StaticInstPtr
decodeSveShiftByImmPred0(ExtMachInst machInst)
{
@@ -3809,5 +3836,21 @@ namespace Aarch64
return new Unknown64(machInst);
} // decodeSveMemStore
StaticInstPtr
decodeSveMisc(ExtMachInst machInst) {
switch(bits(machInst, 13, 10)) {
case 0b0110: {
return decodeSveIntMatMulAdd(machInst);
break;
}
default: {
return new Unknown64(machInst);
break;
}
}
return new Unknown64(machInst);
} // decodeSveMisc
} // namespace Aarch64
}};

View File

@@ -44,6 +44,7 @@ namespace Aarch64
StaticInstPtr decodeSveShiftByImmPred(ExtMachInst machInst);
StaticInstPtr decodeSveIntArithUnaryPred(ExtMachInst machInst);
StaticInstPtr decodeSveIntMulAdd(ExtMachInst machInst);
StaticInstPtr decodeSveIntMatMulAdd(ExtMachInst machInst);
StaticInstPtr decodeSveIntArithUnpred(ExtMachInst machInst);
StaticInstPtr decodeSveIntLogUnpred(ExtMachInst machInst);
StaticInstPtr decodeSveIndexGen(ExtMachInst machInst);
@@ -94,6 +95,8 @@ namespace Aarch64
StaticInstPtr decodeSveMemContigLoad(ExtMachInst machInst);
StaticInstPtr decodeSveMemGather64(ExtMachInst machInst);
StaticInstPtr decodeSveMemStore(ExtMachInst machInst);
StaticInstPtr decodeSveMisc(ExtMachInst machInst);
}
}};
@@ -104,6 +107,14 @@ namespace Aarch64
StaticInstPtr
decodeSveInt(ExtMachInst machInst)
{
if (bits(machInst, 31, 29) == 0b010) {
if (bits(machInst, 24) == 0b1 &&
bits(machInst, 21) == 0b0 &&
bits(machInst, 15, 14)==0b10) {
return decodeSveMisc(machInst);
}
}
uint8_t b_29_24_21 = (bits(machInst, 29) << 2) |
(bits(machInst, 24) << 1) |
bits(machInst, 21);

View File

@@ -4250,6 +4250,22 @@ let {{
sbclbCode = 'res = srcElem1 + ~(srcElem2) + carryIn;'
sveTerInstUnpred('sbclb', 'Sbclb', 'VectorIntegerArithOp', unsignedTypes,
sbclbCode, isTop=False, isAdd=False)
mmlaCode = ('destElem += srcElemA * srcElemB')
# SMMLA (vectors)
sveMatMulInst('smmla', 'Smmla', 'SimdMultAccOp',
(('int32_t', 'int8_t', 'int8_t'),),
numDestRows=2, numDestCols=2, K=8,
elt_mul_op=mmlaCode)
# USMMLA (vectors)
sveMatMulInst('usmmla', 'Usmmla', 'SimdMultAccOp',
(('int32_t', 'uint8_t', 'int8_t'),),
numDestRows=2, numDestCols=2, K=8,
elt_mul_op=mmlaCode)
# UMMLA (vectors)
sveMatMulInst('ummla', 'Ummla', 'SimdMultAccOp',
(('uint32_t', 'uint8_t', 'uint8_t'),),
numDestRows=2, numDestCols=2, K=8,
elt_mul_op=mmlaCode)
# MOVPRFX (predicated)
movCode = 'destElem = srcElem1;'
sveUnaryInst('movprfx', 'MovprfxPredM', 'SimdMiscOp', unsignedTypes,