arch-arm: Add support for Armv8.2-I8MM NEON extension.
Add support for the Armv8.2-I8MM NEON extension. This provides the SUDOT and USDOT mixed-sign SIMD Dot Product instructions, as well as the SMMLA, UMMLA, and USMMLA SIMD Matrix Multiply-Accumulate instructions. For more information please refer to the Arm Architecture Reference Manual (https://developer.arm.com/documentation/ddi0487/latest/). Additional Contributors: Giacomo Travaglini Change-Id: I6fb9318f67cc9d2737079283e1a095630c4d2ad9 Reviewed-by: Richard Cooper <richard.cooper@arm.com> Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/70737 Maintainer: Jason Lowe-Power <power.jg@gmail.com> Reviewed-by: Andreas Sandberg <andreas.sandberg@arm.com> Reviewed-by: Jason Lowe-Power <power.jg@gmail.com> Maintainer: Andreas Sandberg <andreas.sandberg@arm.com> Tested-by: kokoro <noreply+kokoro@google.com>
This commit is contained in:
committed by
Bobby Bruce
parent
eb4f83b178
commit
9de1443ebb
@@ -510,6 +510,7 @@ namespace Aarch64
|
||||
decodeNeon3RegExtension(ExtMachInst machInst)
|
||||
{
|
||||
uint8_t q = bits(machInst, 30);
|
||||
uint8_t u = bits(machInst, 29);
|
||||
uint8_t qu = bits(machInst, 30, 29);
|
||||
uint8_t size = bits(machInst, 23, 22);
|
||||
uint8_t opcode = bits(machInst, 15, 11);
|
||||
@@ -546,6 +547,20 @@ namespace Aarch64
|
||||
default:
|
||||
return new Unknown64(machInst);
|
||||
}
|
||||
case 0x13:
|
||||
if (q) {
|
||||
return new UsdotQX<int32_t>(machInst, vd, vn, vm);
|
||||
} else {
|
||||
return new UsdotDX<int32_t>(machInst, vd, vn, vm);
|
||||
}
|
||||
case 0x14:
|
||||
if (u) {
|
||||
return new UmmlaQX<uint32_t>(machInst, vd, vn, vm);
|
||||
} else {
|
||||
return new SmmlaQX<int32_t>(machInst, vd, vn, vm);
|
||||
}
|
||||
case 0x15:
|
||||
return new UsmmlaQX<int32_t>(machInst, vd, vn, vm);
|
||||
case 0x18:
|
||||
case 0x19:
|
||||
case 0x1a:
|
||||
@@ -1590,9 +1605,6 @@ namespace Aarch64
|
||||
return decodeNeonSThreeImmHAndWReg<SqrdmulhElemDX,
|
||||
SqrdmulhElemQX>(
|
||||
q, size, machInst, vd, vn, vm, index);
|
||||
case 0xf:
|
||||
return decodeNeonSThreeImmHAndWReg<SqrdmlshElemDX, SqrdmlshElemQX>(
|
||||
q, size, machInst, vd, vn, vm, index);
|
||||
case 0xe:
|
||||
switch (qu) {
|
||||
case 0b00:
|
||||
@@ -1610,6 +1622,39 @@ namespace Aarch64
|
||||
default:
|
||||
return new Unknown64(machInst);
|
||||
}
|
||||
case 0xf:
|
||||
if (u) {
|
||||
// Armv8.2-RDMA SQRDMLSH(elem)
|
||||
return decodeNeonSThreeImmHAndWReg<SqrdmlshElemDX,
|
||||
SqrdmlshElemQX>(
|
||||
q, size, machInst, vd, vn, vm, index);
|
||||
} else {
|
||||
switch (size) {
|
||||
case 0x0:
|
||||
if (q) {
|
||||
return new SudotElemQX<int32_t>(
|
||||
machInst, vd, vn, vm_dp, index_dp);
|
||||
} else {
|
||||
return new SudotElemDX<int32_t>(
|
||||
machInst, vd, vn, vm_dp, index_dp);
|
||||
}
|
||||
case 0x1:
|
||||
// Armv8.2-BF16 BFDOT(elem)
|
||||
return new Unknown64(machInst);
|
||||
case 0x2:
|
||||
if (q) {
|
||||
return new UsdotElemQX<int32_t>(
|
||||
machInst, vd, vn, vm_dp, index_dp);
|
||||
} else {
|
||||
return new UsdotElemDX<int32_t>(
|
||||
machInst, vd, vn, vm_dp, index_dp);
|
||||
}
|
||||
case 0x3:
|
||||
default:
|
||||
// Armv8.2-BF16 BFMLALB(elem), BFMLALT(elem)
|
||||
return new Unknown64(machInst);
|
||||
}
|
||||
}
|
||||
default:
|
||||
return new Unknown64(machInst);
|
||||
}
|
||||
|
||||
@@ -1146,6 +1146,94 @@ let {{
|
||||
# UDOT (element)
|
||||
intDotInst('udot', 'UdotElemDX', 'SimdAluOp', False, False, False, 2, True)
|
||||
intDotInst('udot', 'UdotElemQX', 'SimdAluOp', False, False, False, 4, True)
|
||||
# SUDOT (element)
|
||||
intDotInst('sudot', 'SudotElemDX', 'SimdAluOp', True, True, False, 2, True)
|
||||
intDotInst('sudot', 'SudotElemQX', 'SimdAluOp', True, True, False, 4, True)
|
||||
# USDOT (vector)
|
||||
intDotInst('usdot', 'UsdotDX', 'SimdAluOp', True, False, True, 2, False)
|
||||
intDotInst('usdot', 'UsdotQX', 'SimdAluOp', True, False, True, 4, False)
|
||||
# USDOT (element)
|
||||
intDotInst('usdot', 'UsdotElemDX', 'SimdAluOp', True, False, True, 2, True)
|
||||
intDotInst('usdot', 'UsdotElemQX', 'SimdAluOp', True, False, True, 4, True)
|
||||
|
||||
def intMatMulInst(name, Name, opClass,
|
||||
destIsSigned, src1IsSigned, src2IsSigned):
|
||||
destType = "int32_t" if destIsSigned else "uint32_t"
|
||||
src1Type = "int8_t" if src1IsSigned else "uint8_t"
|
||||
src2Type = "int8_t" if src2IsSigned else "uint8_t"
|
||||
matMulCode = '''
|
||||
using Src1Element = %(src1Type)s;
|
||||
using Src2Element = %(src2Type)s;
|
||||
|
||||
// Neon MM instructions always generate four output elements
|
||||
// from 16 pairs of source elements.
|
||||
static_assert(sizeof(Element) == 4 * sizeof(Src1Element));
|
||||
static_assert(sizeof(Element) == 4 * sizeof(Src2Element));
|
||||
|
||||
// Extended source element types to avoid overflow of intermediate
|
||||
// calculations.
|
||||
using ExtendedSrc1Element =
|
||||
typename vector_element_traits::
|
||||
extend_element<Element, Src1Element>::type;
|
||||
using ExtendedSrc2Element =
|
||||
typename vector_element_traits::
|
||||
extend_element<Element, Src2Element>::type;
|
||||
|
||||
// Properties of the matrices
|
||||
constexpr unsigned destMatSize = 2; // Dest Matrices are dim 2x2
|
||||
constexpr unsigned K = 8; // Src matrices are dim 2x8 & 8x2
|
||||
|
||||
constexpr unsigned eltsPerMatrix = destMatSize * destMatSize;
|
||||
|
||||
Element destMat[eltsPerMatrix] = {0};
|
||||
for (unsigned j = 0; j < eltsPerMatrix; ++j) {
|
||||
destMat[j] = letoh(destReg.elements[j]);
|
||||
}
|
||||
|
||||
Element src1MatPacked[eltsPerMatrix] = {0};
|
||||
Element src2MatPacked[eltsPerMatrix] = {0};
|
||||
for (unsigned j = 0; j < eltsPerMatrix; ++j) {
|
||||
src1MatPacked[j] = letoh(srcReg1.elements[j]);
|
||||
src2MatPacked[j] = letoh(srcReg2.elements[j]);
|
||||
}
|
||||
|
||||
Src1Element *src1Mat =
|
||||
reinterpret_cast<Src1Element*>(&src1MatPacked);
|
||||
Src2Element *src2Mat =
|
||||
reinterpret_cast<Src2Element*>(&src2MatPacked);
|
||||
|
||||
unsigned destEltIdx = 0;
|
||||
for (unsigned rowIdx = 0; rowIdx < destMatSize; ++rowIdx) {
|
||||
for (unsigned colIdx = 0; colIdx < destMatSize; ++colIdx) {
|
||||
Element destElem = destMat[destEltIdx];
|
||||
for (unsigned k = 0; k < K; ++k) {
|
||||
const ExtendedSrc1Element src1Elem =
|
||||
static_cast<ExtendedSrc1Element>
|
||||
(src1Mat[K * rowIdx + k]);
|
||||
const ExtendedSrc2Element src2Elem =
|
||||
static_cast<ExtendedSrc2Element>
|
||||
(src2Mat[K * colIdx + k]);
|
||||
|
||||
destElem += src1Elem * src2Elem;
|
||||
}
|
||||
destMat[destEltIdx++] = destElem;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned j = 0; j < eltsPerMatrix; ++j) {
|
||||
destReg.elements[j] = htole(destMat[j]);
|
||||
}
|
||||
''' % dict(src1Type=src1Type, src2Type=src2Type)
|
||||
threeEqualRegInstX(name, Name, opClass, (destType,), 4,
|
||||
matMulCode, readDest=True, byElem=False,
|
||||
complex=True)
|
||||
|
||||
# SMMLA
|
||||
intMatMulInst('smmla', 'SmmlaQX', 'SimdMatMultAccOp', True, True, True)
|
||||
# USMMLA
|
||||
intMatMulInst('usmmla', 'UsmmlaQX', 'SimdMatMultAccOp', True, False, True)
|
||||
# UMMLA
|
||||
intMatMulInst('ummla', 'UmmlaQX', 'SimdMatMultAccOp', False, False, False)
|
||||
|
||||
# CLS
|
||||
clsCode = '''
|
||||
|
||||
@@ -320,6 +320,9 @@ ArmProcess64::armHwcapImpl2() const
|
||||
hwcap |= (isa_r0.ts >= 2) ? Arm_Flagm2 : Arm_None;
|
||||
hwcap |= (isa_r0.rndr >= 1) ? Arm_Rng : Arm_None;
|
||||
|
||||
const AA64ISAR1 isa_r1 = tc->readMiscReg(MISCREG_ID_AA64ISAR1_EL1);
|
||||
hwcap |= (isa_r1.i8mm >= 1) ? Arm_I8mm : Arm_None;
|
||||
|
||||
const AA64ZFR0 zf_r0 = tc->readMiscReg(MISCREG_ID_AA64ZFR0_EL1);
|
||||
hwcap |= (zf_r0.f32mm >= 1) ? Arm_Svef32mm : Arm_None;
|
||||
hwcap |= (zf_r0.f64mm >= 1) ? Arm_Svef64mm : Arm_None;
|
||||
|
||||
@@ -4005,6 +4005,7 @@ ISA::initializeMiscRegMetadata()
|
||||
InitReg(MISCREG_ID_AA64ISAR1_EL1)
|
||||
.reset([p,release=release](){
|
||||
AA64ISAR1 isar1_el1 = p.id_aa64isar1_el1;
|
||||
isar1_el1.i8mm = release->has(ArmExtension::FEAT_I8MM) ? 0x1 : 0x0;
|
||||
isar1_el1.apa = release->has(ArmExtension::FEAT_PAuth) ? 0x1 : 0x0;
|
||||
isar1_el1.jscvt = release->has(ArmExtension::FEAT_JSCVT) ? 0x1 : 0x0;
|
||||
isar1_el1.fcma = release->has(ArmExtension::FEAT_FCMA) ? 0x1 : 0x0;
|
||||
|
||||
@@ -127,6 +127,7 @@ namespace ArmISA
|
||||
EndBitUnion(AA64ISAR0)
|
||||
|
||||
BitUnion64(AA64ISAR1)
|
||||
Bitfield<55, 52> i8mm;
|
||||
Bitfield<43, 40> specres;
|
||||
Bitfield<39, 36> sb;
|
||||
Bitfield<35, 32> frintts;
|
||||
|
||||
Reference in New Issue
Block a user