arch-arm: Add support for Arm SVE Integer Matrix instructions.
Add support for the Arm SVE Integer Matrix Multiply-Accumulate (SMMLA, USMMLA, UMMLA) instructions. Because the associated SUDOT and USDOT instructions have not yet been implemented, the SVE Feature ID register 0 (ID_AA64ZFR0_EL1) has not yet been updated to indicate support for SVE Int8 matrix multiplication instructions at this time. For more information please refer to the "ARM Architecture Reference Manual Supplement - The Scalable Vector Extension (SVE), for ARMv8-A" (https://developer.arm.com/architectures/cpu-architecture/a-profile/ docs/arm-architecture-reference-manual-supplement-armv8-a) Additional Contributors: Giacomo Travaglini Change-Id: Ia50e28fae03634cbe04b42a9900bab65a604817f Reviewed-by: Richard Cooper <richard.cooper@arm.com> Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/70730 Maintainer: Andreas Sandberg <andreas.sandberg@arm.com> Tested-by: kokoro <noreply+kokoro@google.com> Reviewed-by: Andreas Sandberg <andreas.sandberg@arm.com>
This commit is contained in:
committed by
Bobby Bruce
parent
0f857873f9
commit
98e67c8610
@@ -245,6 +245,33 @@ namespace Aarch64
|
||||
return new Unknown64(machInst);
|
||||
} // decodeSveIntMulAdd
|
||||
|
||||
StaticInstPtr
|
||||
decodeSveIntMatMulAdd(ExtMachInst machInst)
|
||||
{
|
||||
RegIndex zda = (RegIndex) (uint8_t) bits(machInst, 4, 0);
|
||||
RegIndex zn = (RegIndex) (uint8_t) bits(machInst, 9, 5);
|
||||
RegIndex zm = (RegIndex) (uint8_t) bits(machInst, 20, 16);
|
||||
|
||||
uint8_t uns = bits(machInst, 23, 22);
|
||||
|
||||
switch (uns) {
|
||||
case 0x0:
|
||||
return new SveSmmla<int32_t, int8_t, int8_t>(
|
||||
machInst, zda, zn, zm);
|
||||
case 0x2:
|
||||
return new SveUsmmla<int32_t, uint8_t, int8_t>(
|
||||
machInst, zda, zn, zm);
|
||||
case 0x3:
|
||||
return new SveUmmla<uint32_t, uint8_t, uint8_t>(
|
||||
machInst, zda, zn, zm);
|
||||
case 0x1:
|
||||
default:
|
||||
return new Unknown64(machInst);
|
||||
}
|
||||
|
||||
return new Unknown64(machInst);
|
||||
} // decodeSveIntMatMulAdd
|
||||
|
||||
StaticInstPtr
|
||||
decodeSveShiftByImmPred0(ExtMachInst machInst)
|
||||
{
|
||||
@@ -3809,5 +3836,21 @@ namespace Aarch64
|
||||
return new Unknown64(machInst);
|
||||
} // decodeSveMemStore
|
||||
|
||||
StaticInstPtr
|
||||
decodeSveMisc(ExtMachInst machInst) {
|
||||
switch(bits(machInst, 13, 10)) {
|
||||
case 0b0110: {
|
||||
return decodeSveIntMatMulAdd(machInst);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
return new Unknown64(machInst);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return new Unknown64(machInst);
|
||||
} // decodeSveMisc
|
||||
|
||||
|
||||
} // namespace Aarch64
|
||||
}};
|
||||
|
||||
@@ -44,6 +44,7 @@ namespace Aarch64
|
||||
StaticInstPtr decodeSveShiftByImmPred(ExtMachInst machInst);
|
||||
StaticInstPtr decodeSveIntArithUnaryPred(ExtMachInst machInst);
|
||||
StaticInstPtr decodeSveIntMulAdd(ExtMachInst machInst);
|
||||
StaticInstPtr decodeSveIntMatMulAdd(ExtMachInst machInst);
|
||||
StaticInstPtr decodeSveIntArithUnpred(ExtMachInst machInst);
|
||||
StaticInstPtr decodeSveIntLogUnpred(ExtMachInst machInst);
|
||||
StaticInstPtr decodeSveIndexGen(ExtMachInst machInst);
|
||||
@@ -94,6 +95,8 @@ namespace Aarch64
|
||||
StaticInstPtr decodeSveMemContigLoad(ExtMachInst machInst);
|
||||
StaticInstPtr decodeSveMemGather64(ExtMachInst machInst);
|
||||
StaticInstPtr decodeSveMemStore(ExtMachInst machInst);
|
||||
|
||||
StaticInstPtr decodeSveMisc(ExtMachInst machInst);
|
||||
}
|
||||
}};
|
||||
|
||||
@@ -104,6 +107,14 @@ namespace Aarch64
|
||||
StaticInstPtr
|
||||
decodeSveInt(ExtMachInst machInst)
|
||||
{
|
||||
if (bits(machInst, 31, 29) == 0b010) {
|
||||
if (bits(machInst, 24) == 0b1 &&
|
||||
bits(machInst, 21) == 0b0 &&
|
||||
bits(machInst, 15, 14)==0b10) {
|
||||
return decodeSveMisc(machInst);
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t b_29_24_21 = (bits(machInst, 29) << 2) |
|
||||
(bits(machInst, 24) << 1) |
|
||||
bits(machInst, 21);
|
||||
|
||||
@@ -4250,6 +4250,22 @@ let {{
|
||||
sbclbCode = 'res = srcElem1 + ~(srcElem2) + carryIn;'
|
||||
sveTerInstUnpred('sbclb', 'Sbclb', 'VectorIntegerArithOp', unsignedTypes,
|
||||
sbclbCode, isTop=False, isAdd=False)
|
||||
mmlaCode = ('destElem += srcElemA * srcElemB')
|
||||
# SMMLA (vectors)
|
||||
sveMatMulInst('smmla', 'Smmla', 'SimdMultAccOp',
|
||||
(('int32_t', 'int8_t', 'int8_t'),),
|
||||
numDestRows=2, numDestCols=2, K=8,
|
||||
elt_mul_op=mmlaCode)
|
||||
# USMMLA (vectors)
|
||||
sveMatMulInst('usmmla', 'Usmmla', 'SimdMultAccOp',
|
||||
(('int32_t', 'uint8_t', 'int8_t'),),
|
||||
numDestRows=2, numDestCols=2, K=8,
|
||||
elt_mul_op=mmlaCode)
|
||||
# UMMLA (vectors)
|
||||
sveMatMulInst('ummla', 'Ummla', 'SimdMultAccOp',
|
||||
(('uint32_t', 'uint8_t', 'uint8_t'),),
|
||||
numDestRows=2, numDestCols=2, K=8,
|
||||
elt_mul_op=mmlaCode)
|
||||
# MOVPRFX (predicated)
|
||||
movCode = 'destElem = srcElem1;'
|
||||
sveUnaryInst('movprfx', 'MovprfxPredM', 'SimdMiscOp', unsignedTypes,
|
||||
|
||||
Reference in New Issue
Block a user