arch-vega: Implement v_mfma_i32_16x16x16i8

Tested using AMD labs notes examples located on github:

https://github.com/amd/amd-lab-notes/blob/release/matrix-cores/
    src/mfma_i32_16x16x16i8.cpp

Change-Id: Ib0e50162288528012b6d3395e1f629ebf12e8e54
This commit is contained in:
Matthew Poremba
2023-11-26 12:54:42 -06:00
parent cc75281802
commit 472c697d88
4 changed files with 168 additions and 1 deletions

View File

@@ -3664,7 +3664,7 @@ namespace VegaISA
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_OP_VOP3P__V_MFMA_I32_16X16X16I8,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
@@ -13059,6 +13059,13 @@ namespace VegaISA
return new Inst_VOP3P__V_PK_MOV_B32(&iFmt->iFmt_VOP3P);
}
GPUStaticInst*
Decoder::decode_OP_VOP3P__V_MFMA_I32_16X16X16I8(MachInst iFmt)
{
return new Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(
&iFmt->iFmt_VOP3P_MAI);
}
GPUStaticInst*
Decoder::decode_OP_VOP3P__V_MFMA_F64_16X16X4F64(MachInst iFmt)
{

View File

@@ -1586,6 +1586,7 @@ namespace VegaISA
GPUStaticInst* decode_OP_VOP3P__V_MAD_MIXLO_F16(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_MAD_MIXHI_F16(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_PK_MOV_B32(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_MFMA_I32_16X16X16I8(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_MFMA_F64_16X16X4F64(MachInst);
GPUStaticInst* subDecode_OPU_VOP3(MachInst);
GPUStaticInst* subDecode_OP_DS(MachInst);

View File

@@ -45953,6 +45953,129 @@ namespace VegaISA
vdst.write();
} // execute
// --- Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8 class methods ---
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8::
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(InFmt_VOP3P_MAI *iFmt)
: Inst_VOP3P_MAI(iFmt, "v_mfma_i32_16x16x16i8")
{
setFlag(ALU);
} // Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8::
~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8()
{
} // ~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8
// D(16x16I32) = A(16x16I8) x B(16x16I8) + C(16x16I32), 1 Blocks, 8
// pass, srcA/srcB 1 archVgpr, srcC/D 4 accVGPR
void
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8::execute(GPUDynInstPtr gpuDynInst)
{
int acc_offset = 0;
if (instData.ACC_CD) {
warn("ACC_CD not yet implemented\n");
}
// int8 size allows for 4 elements per lane. At 16x16 this means 4
// lanes per column (A matrix) / (B matrix). This whole matrix fits
// in one VGPR. The C matrix with size int32 requires 4 VGPRs.
// Handle the C matrix by using a delta. This is set to 1 normally to
// move to the next VGPR (1 dword away) and 0 if the input is a scalar
// reg (e.g., a constant).
int delta = isVectorReg(extData.SRC2) ? 1 : 0;
// VecOperandI8 will read 8 bits and sign extend, so used U32 to read
// as "untyped" 32-bit values.
ConstVecOperandU32 src0(gpuDynInst, extData.SRC0);
ConstVecOperandU32 src1(gpuDynInst, extData.SRC1);
ConstVecOperandI32 src2a(gpuDynInst, extData.SRC2+acc_offset);
ConstVecOperandI32 src2b(gpuDynInst, extData.SRC2+acc_offset+1*delta);
ConstVecOperandI32 src2c(gpuDynInst, extData.SRC2+acc_offset+2*delta);
ConstVecOperandI32 src2d(gpuDynInst, extData.SRC2+acc_offset+3*delta);
VecOperandI32 vdsta(gpuDynInst, instData.VDST+acc_offset);
VecOperandI32 vdstb(gpuDynInst, instData.VDST+acc_offset+1);
VecOperandI32 vdstc(gpuDynInst, instData.VDST+acc_offset+2);
VecOperandI32 vdstd(gpuDynInst, instData.VDST+acc_offset+3);
src0.readSrc();
src1.readSrc();
src2a.readSrc();
src2b.readSrc();
src2c.readSrc();
src2d.readSrc();
int32_t A[16][16];
for (int i = 0; i < 64; ++i) {
// src0[0:15] contains columns 1 - 4 packed for rows 0 - 15,
// src0[16:31] contains columns 5 - 8 packed for rows 0 - 15,
// src0[32:47] contains columns 9 - 12 packed for rows 0 - 15,
// src0[48:63] contains columns 13 - 16 packed for rows 0 - 15,
int row = i % 16;
int start_col = (i / 16) * 4;
A[row][start_col+0] = sext<8>(bits(src0[i], 7, 0));
A[row][start_col+1] = sext<8>(bits(src0[i], 15, 8));
A[row][start_col+2] = sext<8>(bits(src0[i], 23, 16));
A[row][start_col+3] = sext<8>(bits(src0[i], 31, 24));
}
int32_t B[16][16];
for (int i = 0; i < 64; ++i) {
// src1[0:15] contains rows 1 - 4 packed for columns 0 - 15
// src1[16:31] contains rows 5 - 8 packed for columns 0 - 15
// src1[32:47] contains rows 9 - 12 packed for columns 0 - 15
// src1[48:63] contains rows 13 - 16 packed for columns 0 - 15
int start_row = (i / 16) * 4;
int col = i % 16;
B[start_row+0][col] = sext<8>(bits(src1[i], 7, 0));
B[start_row+1][col] = sext<8>(bits(src1[i], 15, 8));
B[start_row+2][col] = sext<8>(bits(src1[i], 23, 16));
B[start_row+3][col] = sext<8>(bits(src1[i], 31, 24));
}
int32_t result[16][16];
// Load accumulation matrix C into result
for (int i = 0; i < 64; ++i) {
// src2a contains rows 0, 4, 8, 12
result[(i/16)*4][(i%16)] = src2a[i];
// src2b contains rows 1, 5, 9, 13
result[(i/16)*4+1][(i%16)] = src2b[i];
// src2c contains rows 2, 6, 10, 14
result[(i/16)*4+2][(i%16)] = src2c[i];
// src2d contains rows 3, 7, 11, 15
result[(i/16)*4+3][(i%16)] = src2d[i];
}
// Compute new result - This is (obviously) not optimized
for (int i = 0; i < 16; ++i) {
for (int j = 0; j < 16; ++j) {
for (int k = 0; k < 16; ++k) {
result[i][j] += A[i][k] * B[k][j];
}
}
}
// Put result in dest VGPRs
for (int i = 0; i < 64; ++i) {
// vdsta contains rows 0, 4, 8, 12
vdsta[i] = result[(i/16)*4][(i%16)];
// vdstb contains rows 1, 5, 9, 13
vdstb[i] = result[(i/16)*4+1][(i%16)];
// vdstc contains rows 2, 6, 10, 14
vdstc[i] = result[(i/16)*4+2][(i%16)];
// vdstd contains rows 3, 7, 11, 15
vdstd[i] = result[(i/16)*4+3][(i%16)];
}
vdsta.write();
vdstb.write();
vdstc.write();
vdstd.write();
} // execute
// --- Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64 class methods ---
Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64::

View File

@@ -43440,6 +43440,42 @@ namespace VegaISA
void execute(GPUDynInstPtr) override;
}; // Inst_VOP3P__V_PK_MOV_B32
class Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8 : public Inst_VOP3P_MAI
{
public:
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(InFmt_VOP3P_MAI*);
~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8();
int
getNumOperands() override
{
return numDstRegOperands() + numSrcRegOperands();
} // getNumOperands
int numDstRegOperands() override { return 1; }
int numSrcRegOperands() override { return 3; }
int
getOperandSize(int opIdx) override
{
switch (opIdx) {
case 0: // src0 "A"
return 4;
case 1: // src1 "B"
return 4;
case 2: // src2 "C"
return 16;
case 3: // dst
return 16;
default:
fatal("op idx %i out of bounds\n", opIdx);
return -1;
}
} // getOperandSize
void execute(GPUDynInstPtr) override;
};
class Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64 : public Inst_VOP3P_MAI
{
public: