arch-vega: Template MFMA instructions
templated - v_mfma_f64_16x16x4f64 added support for - v_mfma_f32_32x32x2f32 - v_mfma_f32_4x4x1_16b_f32 - v_mfma_f32_16x16x4f32 [formula for gprs needed](https://github.com/ROCm/amd_matrix_instruction_calculator) [formulas for register layouts and lanes used in computation](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf) Change-Id: I15d6c0a5865d58323ae8dbcb3f6dcb701a9ab3c7
This commit is contained in:
committed by
Matthew Poremba
parent
2b3beb92ff
commit
d5a734c252
@@ -3646,10 +3646,10 @@ namespace VegaISA
|
||||
&Decoder::decode_invalid,
|
||||
&Decoder::decode_invalid,
|
||||
&Decoder::decode_invalid,
|
||||
&Decoder::decode_OP_VOP3P__V_MFMA_F32_4X4X1_16B_F32,
|
||||
&Decoder::decode_invalid,
|
||||
&Decoder::decode_invalid,
|
||||
&Decoder::decode_invalid,
|
||||
&Decoder::decode_invalid,
|
||||
&Decoder::decode_OP_VOP3P__V_MFMA_F32_32X32X2F32,
|
||||
&Decoder::decode_OP_VOP3P__V_MFMA_F32_16X16X4F32,
|
||||
&Decoder::decode_invalid,
|
||||
&Decoder::decode_invalid,
|
||||
&Decoder::decode_invalid,
|
||||
@@ -13165,6 +13165,27 @@ namespace VegaISA
|
||||
return new Inst_VOP3P__V_DOT8_U32_U4(&iFmt->iFmt_VOP3P);
|
||||
}
|
||||
|
||||
GPUStaticInst*
|
||||
Decoder::decode_OP_VOP3P__V_MFMA_F32_32X32X2F32(MachInst iFmt)
|
||||
{
|
||||
return new Inst_VOP3P_MAI__V_MFMA_F32_32X32X2F32(
|
||||
&iFmt->iFmt_VOP3P_MAI);
|
||||
}
|
||||
|
||||
GPUStaticInst*
|
||||
Decoder::decode_OP_VOP3P__V_MFMA_F32_4X4X1_16B_F32(MachInst iFmt)
|
||||
{
|
||||
return new Inst_VOP3P_MAI__V_MFMA_F32_4X4X1_16B_F32(
|
||||
&iFmt->iFmt_VOP3P_MAI);
|
||||
}
|
||||
|
||||
GPUStaticInst*
|
||||
Decoder::decode_OP_VOP3P__V_MFMA_F32_16X16X4F32(MachInst iFmt)
|
||||
{
|
||||
return new Inst_VOP3P_MAI__V_MFMA_F32_16X16X4F32(
|
||||
&iFmt->iFmt_VOP3P_MAI);
|
||||
}
|
||||
|
||||
GPUStaticInst*
|
||||
Decoder::decode_OP_VOP3P__V_MFMA_I32_16X16X16I8(MachInst iFmt)
|
||||
{
|
||||
|
||||
@@ -1605,6 +1605,9 @@ namespace VegaISA
|
||||
GPUStaticInst* decode_OP_VOP3P__V_DOT4_U32_U8(MachInst);
|
||||
GPUStaticInst* decode_OP_VOP3P__V_DOT8_I32_I4(MachInst);
|
||||
GPUStaticInst* decode_OP_VOP3P__V_DOT8_U32_U4(MachInst);
|
||||
GPUStaticInst* decode_OP_VOP3P__V_MFMA_F32_32X32X2F32(MachInst);
|
||||
GPUStaticInst* decode_OP_VOP3P__V_MFMA_F32_4X4X1_16B_F32(MachInst);
|
||||
GPUStaticInst* decode_OP_VOP3P__V_MFMA_F32_16X16X4F32(MachInst);
|
||||
GPUStaticInst* decode_OP_VOP3P__V_MFMA_I32_16X16X16I8(MachInst);
|
||||
GPUStaticInst* decode_OP_VOP3P__V_ACCVGPR_READ(MachInst);
|
||||
GPUStaticInst* decode_OP_VOP3P__V_ACCVGPR_WRITE(MachInst);
|
||||
|
||||
@@ -32,6 +32,8 @@
|
||||
#ifndef __ARCH_VEGA_INSTS_INSTRUCTIONS_HH__
|
||||
#define __ARCH_VEGA_INSTS_INSTRUCTIONS_HH__
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "arch/amdgpu/vega/gpu_decoder.hh"
|
||||
#include "arch/amdgpu/vega/insts/gpu_static_inst.hh"
|
||||
#include "arch/amdgpu/vega/insts/op_encodings.hh"
|
||||
@@ -43915,78 +43917,234 @@ namespace VegaISA
|
||||
|
||||
void execute(GPUDynInstPtr) override;
|
||||
}; // Inst_VOP3P__V_PK_MOV_B32
|
||||
|
||||
//
|
||||
class Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8 : public Inst_VOP3P_MAI
|
||||
{
|
||||
public:
|
||||
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(InFmt_VOP3P_MAI*);
|
||||
~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8();
|
||||
public:
|
||||
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(InFmt_VOP3P_MAI *);
|
||||
~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8();
|
||||
|
||||
int
|
||||
getNumOperands() override
|
||||
{
|
||||
return numDstRegOperands() + numSrcRegOperands();
|
||||
} // getNumOperands
|
||||
int
|
||||
getNumOperands() override
|
||||
{
|
||||
return numDstRegOperands() + numSrcRegOperands();
|
||||
} // getNumOperands
|
||||
|
||||
int numDstRegOperands() override { return 1; }
|
||||
int numSrcRegOperands() override { return 3; }
|
||||
int numDstRegOperands() override { return 1; }
|
||||
int numSrcRegOperands() override { return 3; }
|
||||
|
||||
int
|
||||
getOperandSize(int opIdx) override
|
||||
{
|
||||
switch (opIdx) {
|
||||
case 0: // src0 "A"
|
||||
return 4;
|
||||
case 1: // src1 "B"
|
||||
return 4;
|
||||
case 2: // src2 "C"
|
||||
return 16;
|
||||
case 3: // dst
|
||||
return 16;
|
||||
default:
|
||||
fatal("op idx %i out of bounds\n", opIdx);
|
||||
return -1;
|
||||
}
|
||||
} // getOperandSize
|
||||
int
|
||||
getOperandSize(int opIdx) override
|
||||
{
|
||||
switch (opIdx) {
|
||||
case 0: // src0 "A"
|
||||
return 4;
|
||||
case 1: // src1 "B"
|
||||
return 4;
|
||||
case 2: // src2 "C"
|
||||
return 16;
|
||||
case 3: // dst
|
||||
return 16;
|
||||
default:
|
||||
fatal("op idx %i out of bounds\n", opIdx);
|
||||
return -1;
|
||||
}
|
||||
} // getOperandSize
|
||||
|
||||
void execute(GPUDynInstPtr) override;
|
||||
void execute(GPUDynInstPtr) override;
|
||||
};
|
||||
|
||||
class Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64 : public Inst_VOP3P_MAI
|
||||
template <const int _delta, const int M, const int N, const int K,
|
||||
const int B, typename T1, typename T2>
|
||||
class Inst_VOP3P_MAI__V_MFMA : public Inst_VOP3P_MAI
|
||||
{
|
||||
public:
|
||||
Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64(InFmt_VOP3P_MAI*);
|
||||
~Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64();
|
||||
|
||||
int
|
||||
getNumOperands() override
|
||||
{
|
||||
return numDstRegOperands() + numSrcRegOperands();
|
||||
} // getNumOperands
|
||||
private:
|
||||
static constexpr int gprs_a = M * K * B / 64, gprs_b = K * N * B / 64,
|
||||
gprs_c_d = M * N * B / 64;
|
||||
|
||||
int numDstRegOperands() override { return 1; }
|
||||
int numSrcRegOperands() override { return 3; }
|
||||
public:
|
||||
Inst_VOP3P_MAI__V_MFMA(InFmt_VOP3P_MAI *iFmt)
|
||||
: Inst_VOP3P_MAI(iFmt, (_delta == 2)
|
||||
? "v_mfma_f64_" + std::to_string(M) + "x" +
|
||||
std::to_string(N) + "x" +
|
||||
std::to_string(K) + "f64"
|
||||
: "v_mfma_f32_" + std::to_string(M) + "x" +
|
||||
std::to_string(N) + "x" +
|
||||
std::to_string(K) + "f32") {
|
||||
setFlag(ALU);
|
||||
}
|
||||
~Inst_VOP3P_MAI__V_MFMA() {}
|
||||
|
||||
int
|
||||
getOperandSize(int opIdx) override
|
||||
{
|
||||
switch (opIdx) {
|
||||
case 0: // src0 "A"
|
||||
return 8;
|
||||
case 1: // src1 "B"
|
||||
return 8;
|
||||
case 2: // src2 "C"
|
||||
return 32;
|
||||
case 3: // dst
|
||||
return 32;
|
||||
default:
|
||||
fatal("op idx %i out of bounds\n", opIdx);
|
||||
return -1;
|
||||
int getNumOperands() override {
|
||||
return numDstRegOperands() + numSrcRegOperands();
|
||||
} // getNumOperands
|
||||
|
||||
int numDstRegOperands() override { return 1; }
|
||||
int numSrcRegOperands() override { return 3; }
|
||||
|
||||
int getOperandSize(int opIdx) override {
|
||||
switch (opIdx) {
|
||||
case 0: // src0 "A"
|
||||
return 4*gprs_a;
|
||||
case 1: // src1 "B"
|
||||
return 4*gprs_b;
|
||||
case 2: // src2 "C"
|
||||
return 4*gprs_c_d;
|
||||
case 3: // dst
|
||||
return 4*gprs_c_d;
|
||||
default:
|
||||
fatal("op idx %i out of bounds\n", opIdx);
|
||||
return -1;
|
||||
}
|
||||
} // getOperandSize
|
||||
|
||||
void
|
||||
execute(GPUDynInstPtr gpuDynInst) override
|
||||
{
|
||||
|
||||
int acc_cd_off = 0;
|
||||
int acc_a_off = 0;
|
||||
int acc_b_off = 0;
|
||||
if (instData.ACC_CD) {
|
||||
acc_cd_off = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
if (extData.ACC) {
|
||||
int tmp_acc = extData.ACC;
|
||||
if (tmp_acc & 0x1) {
|
||||
acc_a_off = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
} // getOperandSize
|
||||
if (tmp_acc & 0x2) {
|
||||
acc_b_off = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
}
|
||||
|
||||
void execute(GPUDynInstPtr) override;
|
||||
alignas(T1) std::byte _src0[gprs_a*sizeof(T1)];
|
||||
alignas(T1) std::byte _src1[gprs_b*sizeof(T1)];
|
||||
alignas(T1) std::byte _src2[gprs_c_d*sizeof(T1)];
|
||||
alignas(T2) std::byte _vdst[gprs_c_d*sizeof(T1)];
|
||||
T1 *src0 = std::launder(reinterpret_cast<T1*>(&_src0));
|
||||
T1 *src1 = std::launder(reinterpret_cast<T1*>(&_src1));
|
||||
T1 *src2 = std::launder(reinterpret_cast<T1*>(&_src2));
|
||||
T2 *vdst = std::launder(reinterpret_cast<T2*>(&_vdst));
|
||||
|
||||
// Handling of src2 is a bit tricky. The operator[] overload cannot
|
||||
// be used for dword count > 2, and the dword count here is 4. Usually
|
||||
// src2 is a VGPR/AccGPR, but it might also be constant. In order to
|
||||
// use operator[] and handle constants, check for VGPR here and set
|
||||
// a delta for each of the src2 GPRs.
|
||||
int delta = isVectorReg(extData.SRC0) ? _delta : 0;
|
||||
for (int i = 0; i < gprs_a; i++) {
|
||||
new (&src0[i]) T1(gpuDynInst, extData.SRC0+acc_a_off+i*delta);
|
||||
src0[i].readSrc();
|
||||
}
|
||||
|
||||
delta = isVectorReg(extData.SRC1) ? _delta : 0;
|
||||
for (int i = 0; i < gprs_b; i++) {
|
||||
new (&src1[i]) T1(gpuDynInst, extData.SRC1+acc_b_off+i*delta);
|
||||
src1[i].readSrc();
|
||||
}
|
||||
|
||||
delta = isVectorReg(extData.SRC2) ? _delta : 0;
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
new (&src2[i]) T1(gpuDynInst, extData.SRC2+acc_cd_off+i*delta);
|
||||
src2[i].readSrc();
|
||||
}
|
||||
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
new (&vdst[i]) T2(gpuDynInst, instData.VDST+acc_cd_off+i*_delta);
|
||||
}
|
||||
|
||||
|
||||
// These values and meanings are described in the MI300 ISA manual:
|
||||
//
|
||||
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/
|
||||
// instruction-set-architectures/
|
||||
// amd-instinct-mi300-cdna3-instruction-set-architecture.pdf
|
||||
//
|
||||
// in section 7.1.4.2. In theory, only the M, N, K, and H values change
|
||||
// for each MFMA instruction and therefore this could be templated.
|
||||
|
||||
// Output layout
|
||||
constexpr int H = _delta == 2 ? 1 : 4;
|
||||
constexpr int B_I = std::ceil(64.0f / (N * M / H));
|
||||
constexpr int M_I = (64 / B_I) / N;
|
||||
constexpr int G = M / (H * M_I);
|
||||
|
||||
float result[M][N];
|
||||
|
||||
// Input layout
|
||||
constexpr int K_L = K / (64 / (M * B));
|
||||
|
||||
for (int block = 0; block < B; block++) {
|
||||
// Load src2 into result. src2 is row major
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
int item = (i % H) + H * (i/(H*M_I) + G * (block / B_I));
|
||||
int lane = j + N * ((i / H) % M_I + M_I * (block % B_I));
|
||||
|
||||
result[i][j] = src2[item][lane];
|
||||
}
|
||||
}
|
||||
|
||||
// Compute new result
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
for (int k = 0; k < K; ++k) {
|
||||
// src0 is column major, src1 is row major
|
||||
int lane_A = i + M * (block + B * (k / K_L));
|
||||
int lane_B = j + N * (block + B * (k / K_L));
|
||||
int item = k % K_L;
|
||||
result[i][j] +=
|
||||
src0[item][lane_A] * src1[item][lane_B];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
int item = (i % H) + H * (i/(H*M_I) + G * (block / B_I));
|
||||
int lane = j + N * ((i / H) % M_I + M_I * (block % B_I));
|
||||
|
||||
vdst[item][lane] = result[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < gprs_c_d; ++i) {
|
||||
vdst[i].write();
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < gprs_a; i++) {
|
||||
std::destroy_at(&src0[i]);
|
||||
}
|
||||
for (int i = 0; i < gprs_b; i++) {
|
||||
std::destroy_at(&src1[i]);
|
||||
}
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
std::destroy_at(&src2[i]);
|
||||
}
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
std::destroy_at(&vdst[i]);
|
||||
}
|
||||
} // execute
|
||||
};
|
||||
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_4X4X1_16B_F32 =
|
||||
Inst_VOP3P_MAI__V_MFMA<1, 4, 4, 1, 16, ConstVecOperandF32,
|
||||
VecOperandF32>;
|
||||
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_32X32X2F32 =
|
||||
Inst_VOP3P_MAI__V_MFMA<1, 32, 32, 2, 1, ConstVecOperandF32,
|
||||
VecOperandF32>;
|
||||
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_16X16X4F32 =
|
||||
Inst_VOP3P_MAI__V_MFMA<1, 16, 16, 4, 1, ConstVecOperandF32,
|
||||
VecOperandF32>;
|
||||
|
||||
using Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64 =
|
||||
Inst_VOP3P_MAI__V_MFMA<2, 16, 16, 4, 1, ConstVecOperandF64,
|
||||
VecOperandF64>;
|
||||
|
||||
} // namespace VegaISA
|
||||
} // namespace gem5
|
||||
|
||||
|
||||
@@ -168,112 +168,6 @@ namespace VegaISA
|
||||
vdst[3][i] = result[(i/16)*4+3][(i%16)];
|
||||
}
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vdst[i].write();
|
||||
}
|
||||
} // execute
|
||||
// --- Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64 class methods ---
|
||||
|
||||
Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64::
|
||||
Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64(InFmt_VOP3P_MAI *iFmt)
|
||||
: Inst_VOP3P_MAI(iFmt, "v_mfma_f64_16x16x4f64")
|
||||
{
|
||||
setFlag(ALU);
|
||||
} // Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64
|
||||
|
||||
Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64::
|
||||
~Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64()
|
||||
{
|
||||
} // ~Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64
|
||||
|
||||
// D(16x16F64) = A(16x4F64) x B(4x16F64) + C(16x16F64), 1 Blocks, 8
|
||||
// pass, srcA/srcB 2 VGPR, srcC/D 8 VGPR
|
||||
void
|
||||
Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64::execute(GPUDynInstPtr gpuDynInst)
|
||||
{
|
||||
// Accumulation register offsets for A, B, and C/D matrix.
|
||||
int a_offset = 0;
|
||||
int b_offset = 0;
|
||||
int cd_offset = 0;
|
||||
if (instData.ACC_CD) {
|
||||
cd_offset = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
if (extData.ACC) {
|
||||
if (extData.ACC & 0x1) {
|
||||
a_offset = gpuDynInst->wavefront()->accumOffset;
|
||||
} else if (extData.ACC & 0x2) {
|
||||
b_offset = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
}
|
||||
|
||||
// Handling of src2 is a bit tricky. The operator[] overload cannot
|
||||
// be used for dword count > 2, and the dword count here is 8. Usually
|
||||
// src2 is a VGPR/AccGPR, but it might also be constant. In order to
|
||||
// use operator[] and handle constants, check for VGPR here and set
|
||||
// a delta for each of the pairs of src2 GPRs.
|
||||
int delta = isVectorReg(extData.SRC2) ? 2 : 0;
|
||||
|
||||
ConstVecOperandF64 src0(gpuDynInst, extData.SRC0+a_offset);
|
||||
ConstVecOperandF64 src1(gpuDynInst, extData.SRC1+b_offset);
|
||||
ConstVecOperandF64 src2[4] = {
|
||||
ConstVecOperandF64(gpuDynInst, extData.SRC2+cd_offset),
|
||||
ConstVecOperandF64(gpuDynInst, extData.SRC2+cd_offset+1*delta),
|
||||
ConstVecOperandF64(gpuDynInst, extData.SRC2+cd_offset+2*delta),
|
||||
ConstVecOperandF64(gpuDynInst, extData.SRC2+cd_offset+3*delta),
|
||||
};
|
||||
|
||||
VecOperandF64 vdst[4] = {
|
||||
VecOperandF64(gpuDynInst, instData.VDST+cd_offset),
|
||||
VecOperandF64(gpuDynInst, instData.VDST+cd_offset+2),
|
||||
VecOperandF64(gpuDynInst, instData.VDST+cd_offset+4),
|
||||
VecOperandF64(gpuDynInst, instData.VDST+cd_offset+6),
|
||||
};
|
||||
|
||||
src0.readSrc();
|
||||
src1.readSrc();
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
src2[i].readSrc();
|
||||
}
|
||||
|
||||
double result[16][16];
|
||||
|
||||
// Load src2 into result. src2 is row major
|
||||
for (int i = 0; i < 64; ++i) {
|
||||
// src2[0] contains rows 0 - 3
|
||||
result[(i/16)][(i%16)] = src2[0][i];
|
||||
// src2[1] contains rows 4 - 7
|
||||
result[(i/16)+4][(i%16)] = src2[1][i];
|
||||
// src2[2] contains rows 8 - 11
|
||||
result[(i/16)+8][(i%16)] = src2[2][i];
|
||||
// src2[3] contains rows 12 - 15
|
||||
result[(i/16)+12][(i%16)] = src2[3][i];
|
||||
}
|
||||
|
||||
// Compute new result
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
// src0 is column major, src1 is row major
|
||||
int lane_A = 16*k + i;
|
||||
int lane_B = 16*k + j;
|
||||
result[i][j] += src0[lane_A] * src1[lane_B];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Put result in dest VGPRs
|
||||
for (int i = 0; i < 64; ++i) {
|
||||
// vdst[0] contains rows 0 - 3
|
||||
vdst[0][i] = result[(i/16)][(i%16)];
|
||||
// src2[1] contains rows 4 - 7
|
||||
vdst[1][i] = result[(i/16)+4][(i%16)];
|
||||
// src2[2] contains rows 8 - 11
|
||||
vdst[2][i] = result[(i/16)+8][(i%16)];
|
||||
// src2[3] contains rows 12 - 15
|
||||
vdst[3][i] = result[(i/16)+12][(i%16)];
|
||||
}
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vdst[i].write();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user