arch-vega: Template MFMA instructions

templated
- v_mfma_f64_16x16x4f64

added support for
- v_mfma_f32_32x32x2f32
- v_mfma_f32_4x4x1_16b_f32
- v_mfma_f32_16x16x4f32

[formula for gprs needed](https://github.com/ROCm/amd_matrix_instruction_calculator)

[formulas for register layouts and lanes used in computation](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf)

Change-Id: I15d6c0a5865d58323ae8dbcb3f6dcb701a9ab3c7
This commit is contained in:
Marco Kurzynski
2024-05-12 18:14:20 -05:00
committed by Matthew Poremba
parent 2b3beb92ff
commit d5a734c252
4 changed files with 242 additions and 166 deletions

View File

@@ -3646,10 +3646,10 @@ namespace VegaISA
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_OP_VOP3P__V_MFMA_F32_4X4X1_16B_F32,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_OP_VOP3P__V_MFMA_F32_32X32X2F32,
&Decoder::decode_OP_VOP3P__V_MFMA_F32_16X16X4F32,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
@@ -13165,6 +13165,27 @@ namespace VegaISA
return new Inst_VOP3P__V_DOT8_U32_U4(&iFmt->iFmt_VOP3P);
}
GPUStaticInst*
Decoder::decode_OP_VOP3P__V_MFMA_F32_32X32X2F32(MachInst iFmt)
{
return new Inst_VOP3P_MAI__V_MFMA_F32_32X32X2F32(
&iFmt->iFmt_VOP3P_MAI);
}
GPUStaticInst*
Decoder::decode_OP_VOP3P__V_MFMA_F32_4X4X1_16B_F32(MachInst iFmt)
{
return new Inst_VOP3P_MAI__V_MFMA_F32_4X4X1_16B_F32(
&iFmt->iFmt_VOP3P_MAI);
}
GPUStaticInst*
Decoder::decode_OP_VOP3P__V_MFMA_F32_16X16X4F32(MachInst iFmt)
{
return new Inst_VOP3P_MAI__V_MFMA_F32_16X16X4F32(
&iFmt->iFmt_VOP3P_MAI);
}
GPUStaticInst*
Decoder::decode_OP_VOP3P__V_MFMA_I32_16X16X16I8(MachInst iFmt)
{

View File

@@ -1605,6 +1605,9 @@ namespace VegaISA
GPUStaticInst* decode_OP_VOP3P__V_DOT4_U32_U8(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_DOT8_I32_I4(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_DOT8_U32_U4(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_MFMA_F32_32X32X2F32(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_MFMA_F32_4X4X1_16B_F32(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_MFMA_F32_16X16X4F32(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_MFMA_I32_16X16X16I8(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_ACCVGPR_READ(MachInst);
GPUStaticInst* decode_OP_VOP3P__V_ACCVGPR_WRITE(MachInst);

View File

@@ -32,6 +32,8 @@
#ifndef __ARCH_VEGA_INSTS_INSTRUCTIONS_HH__
#define __ARCH_VEGA_INSTS_INSTRUCTIONS_HH__
#include <type_traits>
#include "arch/amdgpu/vega/gpu_decoder.hh"
#include "arch/amdgpu/vega/insts/gpu_static_inst.hh"
#include "arch/amdgpu/vega/insts/op_encodings.hh"
@@ -43915,78 +43917,234 @@ namespace VegaISA
void execute(GPUDynInstPtr) override;
}; // Inst_VOP3P__V_PK_MOV_B32
//
class Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8 : public Inst_VOP3P_MAI
{
public:
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(InFmt_VOP3P_MAI*);
~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8();
public:
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(InFmt_VOP3P_MAI *);
~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8();
int
getNumOperands() override
{
return numDstRegOperands() + numSrcRegOperands();
} // getNumOperands
int
getNumOperands() override
{
return numDstRegOperands() + numSrcRegOperands();
} // getNumOperands
int numDstRegOperands() override { return 1; }
int numSrcRegOperands() override { return 3; }
int numDstRegOperands() override { return 1; }
int numSrcRegOperands() override { return 3; }
int
getOperandSize(int opIdx) override
{
switch (opIdx) {
case 0: // src0 "A"
return 4;
case 1: // src1 "B"
return 4;
case 2: // src2 "C"
return 16;
case 3: // dst
return 16;
default:
fatal("op idx %i out of bounds\n", opIdx);
return -1;
}
} // getOperandSize
int
getOperandSize(int opIdx) override
{
switch (opIdx) {
case 0: // src0 "A"
return 4;
case 1: // src1 "B"
return 4;
case 2: // src2 "C"
return 16;
case 3: // dst
return 16;
default:
fatal("op idx %i out of bounds\n", opIdx);
return -1;
}
} // getOperandSize
void execute(GPUDynInstPtr) override;
void execute(GPUDynInstPtr) override;
};
class Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64 : public Inst_VOP3P_MAI
template <const int _delta, const int M, const int N, const int K,
const int B, typename T1, typename T2>
class Inst_VOP3P_MAI__V_MFMA : public Inst_VOP3P_MAI
{
public:
Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64(InFmt_VOP3P_MAI*);
~Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64();
int
getNumOperands() override
{
return numDstRegOperands() + numSrcRegOperands();
} // getNumOperands
private:
static constexpr int gprs_a = M * K * B / 64, gprs_b = K * N * B / 64,
gprs_c_d = M * N * B / 64;
int numDstRegOperands() override { return 1; }
int numSrcRegOperands() override { return 3; }
public:
Inst_VOP3P_MAI__V_MFMA(InFmt_VOP3P_MAI *iFmt)
: Inst_VOP3P_MAI(iFmt, (_delta == 2)
? "v_mfma_f64_" + std::to_string(M) + "x" +
std::to_string(N) + "x" +
std::to_string(K) + "f64"
: "v_mfma_f32_" + std::to_string(M) + "x" +
std::to_string(N) + "x" +
std::to_string(K) + "f32") {
setFlag(ALU);
}
~Inst_VOP3P_MAI__V_MFMA() {}
int
getOperandSize(int opIdx) override
{
switch (opIdx) {
case 0: // src0 "A"
return 8;
case 1: // src1 "B"
return 8;
case 2: // src2 "C"
return 32;
case 3: // dst
return 32;
default:
fatal("op idx %i out of bounds\n", opIdx);
return -1;
int getNumOperands() override {
return numDstRegOperands() + numSrcRegOperands();
} // getNumOperands
int numDstRegOperands() override { return 1; }
int numSrcRegOperands() override { return 3; }
int getOperandSize(int opIdx) override {
switch (opIdx) {
case 0: // src0 "A"
return 4*gprs_a;
case 1: // src1 "B"
return 4*gprs_b;
case 2: // src2 "C"
return 4*gprs_c_d;
case 3: // dst
return 4*gprs_c_d;
default:
fatal("op idx %i out of bounds\n", opIdx);
return -1;
}
} // getOperandSize
void
execute(GPUDynInstPtr gpuDynInst) override
{
int acc_cd_off = 0;
int acc_a_off = 0;
int acc_b_off = 0;
if (instData.ACC_CD) {
acc_cd_off = gpuDynInst->wavefront()->accumOffset;
}
if (extData.ACC) {
int tmp_acc = extData.ACC;
if (tmp_acc & 0x1) {
acc_a_off = gpuDynInst->wavefront()->accumOffset;
}
} // getOperandSize
if (tmp_acc & 0x2) {
acc_b_off = gpuDynInst->wavefront()->accumOffset;
}
}
void execute(GPUDynInstPtr) override;
alignas(T1) std::byte _src0[gprs_a*sizeof(T1)];
alignas(T1) std::byte _src1[gprs_b*sizeof(T1)];
alignas(T1) std::byte _src2[gprs_c_d*sizeof(T1)];
alignas(T2) std::byte _vdst[gprs_c_d*sizeof(T1)];
T1 *src0 = std::launder(reinterpret_cast<T1*>(&_src0));
T1 *src1 = std::launder(reinterpret_cast<T1*>(&_src1));
T1 *src2 = std::launder(reinterpret_cast<T1*>(&_src2));
T2 *vdst = std::launder(reinterpret_cast<T2*>(&_vdst));
// Handling of src2 is a bit tricky. The operator[] overload cannot
// be used for dword count > 2, and the dword count here is 4. Usually
// src2 is a VGPR/AccGPR, but it might also be constant. In order to
// use operator[] and handle constants, check for VGPR here and set
// a delta for each of the src2 GPRs.
int delta = isVectorReg(extData.SRC0) ? _delta : 0;
for (int i = 0; i < gprs_a; i++) {
new (&src0[i]) T1(gpuDynInst, extData.SRC0+acc_a_off+i*delta);
src0[i].readSrc();
}
delta = isVectorReg(extData.SRC1) ? _delta : 0;
for (int i = 0; i < gprs_b; i++) {
new (&src1[i]) T1(gpuDynInst, extData.SRC1+acc_b_off+i*delta);
src1[i].readSrc();
}
delta = isVectorReg(extData.SRC2) ? _delta : 0;
for (int i = 0; i < gprs_c_d; i++) {
new (&src2[i]) T1(gpuDynInst, extData.SRC2+acc_cd_off+i*delta);
src2[i].readSrc();
}
for (int i = 0; i < gprs_c_d; i++) {
new (&vdst[i]) T2(gpuDynInst, instData.VDST+acc_cd_off+i*_delta);
}
// These values and meanings are described in the MI300 ISA manual:
//
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/
// instruction-set-architectures/
// amd-instinct-mi300-cdna3-instruction-set-architecture.pdf
//
// in section 7.1.4.2. In theory, only the M, N, K, and H values change
// for each MFMA instruction and therefore this could be templated.
// Output layout
constexpr int H = _delta == 2 ? 1 : 4;
constexpr int B_I = std::ceil(64.0f / (N * M / H));
constexpr int M_I = (64 / B_I) / N;
constexpr int G = M / (H * M_I);
float result[M][N];
// Input layout
constexpr int K_L = K / (64 / (M * B));
for (int block = 0; block < B; block++) {
// Load src2 into result. src2 is row major
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
int item = (i % H) + H * (i/(H*M_I) + G * (block / B_I));
int lane = j + N * ((i / H) % M_I + M_I * (block % B_I));
result[i][j] = src2[item][lane];
}
}
// Compute new result
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
for (int k = 0; k < K; ++k) {
// src0 is column major, src1 is row major
int lane_A = i + M * (block + B * (k / K_L));
int lane_B = j + N * (block + B * (k / K_L));
int item = k % K_L;
result[i][j] +=
src0[item][lane_A] * src1[item][lane_B];
}
}
}
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
int item = (i % H) + H * (i/(H*M_I) + G * (block / B_I));
int lane = j + N * ((i / H) % M_I + M_I * (block % B_I));
vdst[item][lane] = result[i][j];
}
}
for (int i = 0; i < gprs_c_d; ++i) {
vdst[i].write();
}
}
for (int i = 0; i < gprs_a; i++) {
std::destroy_at(&src0[i]);
}
for (int i = 0; i < gprs_b; i++) {
std::destroy_at(&src1[i]);
}
for (int i = 0; i < gprs_c_d; i++) {
std::destroy_at(&src2[i]);
}
for (int i = 0; i < gprs_c_d; i++) {
std::destroy_at(&vdst[i]);
}
} // execute
};
using Inst_VOP3P_MAI__V_MFMA_F32_4X4X1_16B_F32 =
Inst_VOP3P_MAI__V_MFMA<1, 4, 4, 1, 16, ConstVecOperandF32,
VecOperandF32>;
using Inst_VOP3P_MAI__V_MFMA_F32_32X32X2F32 =
Inst_VOP3P_MAI__V_MFMA<1, 32, 32, 2, 1, ConstVecOperandF32,
VecOperandF32>;
using Inst_VOP3P_MAI__V_MFMA_F32_16X16X4F32 =
Inst_VOP3P_MAI__V_MFMA<1, 16, 16, 4, 1, ConstVecOperandF32,
VecOperandF32>;
using Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64 =
Inst_VOP3P_MAI__V_MFMA<2, 16, 16, 4, 1, ConstVecOperandF64,
VecOperandF64>;
} // namespace VegaISA
} // namespace gem5

View File

@@ -168,112 +168,6 @@ namespace VegaISA
vdst[3][i] = result[(i/16)*4+3][(i%16)];
}
for (int i = 0; i < 4; ++i) {
vdst[i].write();
}
} // execute
// --- Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64 class methods ---
Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64::
Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64(InFmt_VOP3P_MAI *iFmt)
: Inst_VOP3P_MAI(iFmt, "v_mfma_f64_16x16x4f64")
{
setFlag(ALU);
} // Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64
Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64::
~Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64()
{
} // ~Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64
// D(16x16F64) = A(16x4F64) x B(4x16F64) + C(16x16F64), 1 Blocks, 8
// pass, srcA/srcB 2 VGPR, srcC/D 8 VGPR
void
Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64::execute(GPUDynInstPtr gpuDynInst)
{
// Accumulation register offsets for A, B, and C/D matrix.
int a_offset = 0;
int b_offset = 0;
int cd_offset = 0;
if (instData.ACC_CD) {
cd_offset = gpuDynInst->wavefront()->accumOffset;
}
if (extData.ACC) {
if (extData.ACC & 0x1) {
a_offset = gpuDynInst->wavefront()->accumOffset;
} else if (extData.ACC & 0x2) {
b_offset = gpuDynInst->wavefront()->accumOffset;
}
}
// Handling of src2 is a bit tricky. The operator[] overload cannot
// be used for dword count > 2, and the dword count here is 8. Usually
// src2 is a VGPR/AccGPR, but it might also be constant. In order to
// use operator[] and handle constants, check for VGPR here and set
// a delta for each of the pairs of src2 GPRs.
int delta = isVectorReg(extData.SRC2) ? 2 : 0;
ConstVecOperandF64 src0(gpuDynInst, extData.SRC0+a_offset);
ConstVecOperandF64 src1(gpuDynInst, extData.SRC1+b_offset);
ConstVecOperandF64 src2[4] = {
ConstVecOperandF64(gpuDynInst, extData.SRC2+cd_offset),
ConstVecOperandF64(gpuDynInst, extData.SRC2+cd_offset+1*delta),
ConstVecOperandF64(gpuDynInst, extData.SRC2+cd_offset+2*delta),
ConstVecOperandF64(gpuDynInst, extData.SRC2+cd_offset+3*delta),
};
VecOperandF64 vdst[4] = {
VecOperandF64(gpuDynInst, instData.VDST+cd_offset),
VecOperandF64(gpuDynInst, instData.VDST+cd_offset+2),
VecOperandF64(gpuDynInst, instData.VDST+cd_offset+4),
VecOperandF64(gpuDynInst, instData.VDST+cd_offset+6),
};
src0.readSrc();
src1.readSrc();
for (int i = 0; i < 4; ++i) {
src2[i].readSrc();
}
double result[16][16];
// Load src2 into result. src2 is row major
for (int i = 0; i < 64; ++i) {
// src2[0] contains rows 0 - 3
result[(i/16)][(i%16)] = src2[0][i];
// src2[1] contains rows 4 - 7
result[(i/16)+4][(i%16)] = src2[1][i];
// src2[2] contains rows 8 - 11
result[(i/16)+8][(i%16)] = src2[2][i];
// src2[3] contains rows 12 - 15
result[(i/16)+12][(i%16)] = src2[3][i];
}
// Compute new result
for (int i = 0; i < 16; ++i) {
for (int j = 0; j < 16; ++j) {
for (int k = 0; k < 4; ++k) {
// src0 is column major, src1 is row major
int lane_A = 16*k + i;
int lane_B = 16*k + j;
result[i][j] += src0[lane_A] * src1[lane_B];
}
}
}
// Put result in dest VGPRs
for (int i = 0; i < 64; ++i) {
// vdst[0] contains rows 0 - 3
vdst[0][i] = result[(i/16)][(i%16)];
// src2[1] contains rows 4 - 7
vdst[1][i] = result[(i/16)+4][(i%16)];
// src2[2] contains rows 8 - 11
vdst[2][i] = result[(i/16)+8][(i%16)];
// src2[3] contains rows 12 - 15
vdst[3][i] = result[(i/16)+12][(i%16)];
}
for (int i = 0; i < 4; ++i) {
vdst[i].write();
}