arch-vega: MFMA templates for MXFP and INT8 types
The microscaling formats (MXFP) and INT8 types require additional size checks which are not needed for the current MFMA template. The size check is done using a constexpr method exclusive to the MXFP type, therefore create a special class for MXFP types. This is preferrable to attempting to shoehorn into the existing template as it helps with readability. Similar, INT8 requires a size check to determine number of elements per VGPR, but it not an MXFP type. Create a special template for that as well. This additionally implements all of the MFMA types which have test cases in the amd-lab-notes repository (https://github.com/amd/amd-lab-notes/). The implementations were tested using the applications in the matrix-cores subfolder and achieve L2 norms equivalent or better than MI200 hardware. Change-Id: Ia5ae89387149928905e7bcd25302ed3d1df6af38
This commit is contained in:
@@ -32,8 +32,10 @@
|
||||
#ifndef __ARCH_VEGA_INSTS_INSTRUCTIONS_HH__
|
||||
#define __ARCH_VEGA_INSTS_INSTRUCTIONS_HH__
|
||||
|
||||
#include <cstddef>
|
||||
#include <type_traits>
|
||||
|
||||
#include "arch/amdgpu/common/dtype/mxfp_types.hh"
|
||||
#include "arch/amdgpu/vega/gpu_decoder.hh"
|
||||
#include "arch/amdgpu/vega/insts/gpu_static_inst.hh"
|
||||
#include "arch/amdgpu/vega/insts/op_encodings.hh"
|
||||
@@ -43917,45 +43919,9 @@ namespace VegaISA
|
||||
|
||||
void execute(GPUDynInstPtr) override;
|
||||
}; // Inst_VOP3P__V_PK_MOV_B32
|
||||
//
|
||||
class Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8 : public Inst_VOP3P_MAI
|
||||
{
|
||||
public:
|
||||
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(InFmt_VOP3P_MAI *);
|
||||
~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8();
|
||||
|
||||
int
|
||||
getNumOperands() override
|
||||
{
|
||||
return numDstRegOperands() + numSrcRegOperands();
|
||||
} // getNumOperands
|
||||
|
||||
int numDstRegOperands() override { return 1; }
|
||||
int numSrcRegOperands() override { return 3; }
|
||||
|
||||
int
|
||||
getOperandSize(int opIdx) override
|
||||
{
|
||||
switch (opIdx) {
|
||||
case 0: // src0 "A"
|
||||
return 4;
|
||||
case 1: // src1 "B"
|
||||
return 4;
|
||||
case 2: // src2 "C"
|
||||
return 16;
|
||||
case 3: // dst
|
||||
return 16;
|
||||
default:
|
||||
fatal("op idx %i out of bounds\n", opIdx);
|
||||
return -1;
|
||||
}
|
||||
} // getOperandSize
|
||||
|
||||
void execute(GPUDynInstPtr) override;
|
||||
};
|
||||
|
||||
template <const int _delta, const int M, const int N, const int K,
|
||||
const int B, typename T1, typename T2>
|
||||
const int B, typename T1, typename T2, const char **MNEMONIC>
|
||||
class Inst_VOP3P_MAI__V_MFMA : public Inst_VOP3P_MAI
|
||||
{
|
||||
|
||||
@@ -43965,13 +43931,8 @@ namespace VegaISA
|
||||
|
||||
public:
|
||||
Inst_VOP3P_MAI__V_MFMA(InFmt_VOP3P_MAI *iFmt)
|
||||
: Inst_VOP3P_MAI(iFmt, (_delta == 2)
|
||||
? "v_mfma_f64_" + std::to_string(M) + "x" +
|
||||
std::to_string(N) + "x" +
|
||||
std::to_string(K) + "f64"
|
||||
: "v_mfma_f32_" + std::to_string(M) + "x" +
|
||||
std::to_string(N) + "x" +
|
||||
std::to_string(K) + "f32") {
|
||||
: Inst_VOP3P_MAI(iFmt, *MNEMONIC)
|
||||
{
|
||||
setFlag(ALU);
|
||||
}
|
||||
~Inst_VOP3P_MAI__V_MFMA() {}
|
||||
@@ -44002,7 +43963,6 @@ namespace VegaISA
|
||||
void
|
||||
execute(GPUDynInstPtr gpuDynInst) override
|
||||
{
|
||||
|
||||
int acc_cd_off = 0;
|
||||
int acc_a_off = 0;
|
||||
int acc_b_off = 0;
|
||||
@@ -44019,10 +43979,10 @@ namespace VegaISA
|
||||
}
|
||||
}
|
||||
|
||||
alignas(T1) std::byte _src0[gprs_a*sizeof(T1)];
|
||||
alignas(T1) std::byte _src1[gprs_b*sizeof(T1)];
|
||||
alignas(T1) std::byte _src2[gprs_c_d*sizeof(T1)];
|
||||
alignas(T2) std::byte _vdst[gprs_c_d*sizeof(T1)];
|
||||
alignas(T1) std::byte _src0[sizeof(T1) * gprs_a];
|
||||
alignas(T1) std::byte _src1[sizeof(T1) * gprs_b];
|
||||
alignas(T1) std::byte _src2[sizeof(T1) * gprs_c_d];
|
||||
alignas(T2) std::byte _vdst[sizeof(T2) * gprs_c_d];
|
||||
T1 *src0 = std::launder(reinterpret_cast<T1*>(&_src0));
|
||||
T1 *src1 = std::launder(reinterpret_cast<T1*>(&_src1));
|
||||
T1 *src2 = std::launder(reinterpret_cast<T1*>(&_src2));
|
||||
@@ -44055,7 +44015,6 @@ namespace VegaISA
|
||||
new (&vdst[i]) T2(gpuDynInst, instData.VDST+acc_cd_off+i*_delta);
|
||||
}
|
||||
|
||||
|
||||
// These values and meanings are described in the MI300 ISA manual:
|
||||
//
|
||||
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/
|
||||
@@ -44063,7 +44022,7 @@ namespace VegaISA
|
||||
// amd-instinct-mi300-cdna3-instruction-set-architecture.pdf
|
||||
//
|
||||
// in section 7.1.4.2. In theory, only the M, N, K, and H values change
|
||||
// for each MFMA instruction and therefore this could be templated.
|
||||
// for each MFMA instruction.
|
||||
|
||||
// Output layout
|
||||
constexpr int H = _delta == 2 ? 1 : 4;
|
||||
@@ -44109,11 +44068,12 @@ namespace VegaISA
|
||||
vdst[item][lane] = result[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < gprs_c_d; ++i) {
|
||||
vdst[i].write();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < gprs_c_d; ++i) {
|
||||
vdst[i].write();
|
||||
}
|
||||
|
||||
for (int i = 0; i < gprs_a; i++) {
|
||||
std::destroy_at(&src0[i]);
|
||||
}
|
||||
@@ -44129,21 +44089,487 @@ namespace VegaISA
|
||||
} // execute
|
||||
};
|
||||
|
||||
static const char *MNEM__V_MFMA_F32_4X4X1_16B_F32 =
|
||||
"v_mfma_f32_4x4x1_16b_f32";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_4X4X1_16B_F32 =
|
||||
Inst_VOP3P_MAI__V_MFMA<1, 4, 4, 1, 16, ConstVecOperandF32,
|
||||
VecOperandF32>;
|
||||
VecOperandF32, &MNEM__V_MFMA_F32_4X4X1_16B_F32>;
|
||||
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_32X32X2F32 =
|
||||
static const char *MNEM__V_MFMA_F32_32X32X1_2B_F32 =
|
||||
"v_mfma_f32_32x32x1_2b_f32";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_32X32X1_2B_F32 =
|
||||
Inst_VOP3P_MAI__V_MFMA<1, 32, 32, 1, 2, ConstVecOperandF32,
|
||||
VecOperandF32,
|
||||
&MNEM__V_MFMA_F32_32X32X1_2B_F32>;
|
||||
|
||||
static const char *MNEM__V_MFMA_F32_32X32X2_F32 =
|
||||
"v_mfma_f32_32x32x2_f32";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_32X32X2_F32 =
|
||||
Inst_VOP3P_MAI__V_MFMA<1, 32, 32, 2, 1, ConstVecOperandF32,
|
||||
VecOperandF32>;
|
||||
VecOperandF32, &MNEM__V_MFMA_F32_32X32X2_F32>;
|
||||
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_16X16X4F32 =
|
||||
static const char *MNEM__V_MFMA_F32_16X16X4_F32 =
|
||||
"v_mfma_f32_16x16x4_f32";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_16X16X4_F32 =
|
||||
Inst_VOP3P_MAI__V_MFMA<1, 16, 16, 4, 1, ConstVecOperandF32,
|
||||
VecOperandF32>;
|
||||
VecOperandF32, &MNEM__V_MFMA_F32_16X16X4_F32>;
|
||||
|
||||
using Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64 =
|
||||
static const char *MNEM__V_MFMA_F32_16X16X1_4B_F32 =
|
||||
"v_mfma_f32_16x16x1_4b_f32";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_16X16X1_4B_F32 =
|
||||
Inst_VOP3P_MAI__V_MFMA<1, 16, 16, 1, 4, ConstVecOperandF32,
|
||||
VecOperandF32,
|
||||
&MNEM__V_MFMA_F32_16X16X1_4B_F32>;
|
||||
|
||||
static const char *MNEM__V_MFMA_F64_4X4X4_4B_F64 =
|
||||
"v_mfma_f64_4x4x4_4b_f64";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F64_4X4X4_4B_F64 =
|
||||
Inst_VOP3P_MAI__V_MFMA<2, 4, 4, 4, 4, ConstVecOperandF64,
|
||||
VecOperandF64, &MNEM__V_MFMA_F64_4X4X4_4B_F64>;
|
||||
|
||||
static const char *MNEM__V_MFMA_F64_16X16X4_F64 =
|
||||
"v_mfma_f64_16x16x4_f64";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F64_16X16X4_F64 =
|
||||
Inst_VOP3P_MAI__V_MFMA<2, 16, 16, 4, 1, ConstVecOperandF64,
|
||||
VecOperandF64>;
|
||||
VecOperandF64, &MNEM__V_MFMA_F64_16X16X4_F64>;
|
||||
|
||||
|
||||
template <const int M, const int N, const int K,
|
||||
const int B, typename MXFPT, const char **MNEMONIC>
|
||||
class Inst_VOP3P_MAI__V_MFMA_MXFP : public Inst_VOP3P_MAI
|
||||
{
|
||||
|
||||
private:
|
||||
// Scale GPRs needed by elements / GPR (gpr_ratio)
|
||||
static constexpr int gpr_ratio = 32 / MXFPT::size();
|
||||
static constexpr int gprs_a = M * K * B / (64 * gpr_ratio);
|
||||
static constexpr int gprs_b = K * N * B / (64 * gpr_ratio);
|
||||
|
||||
// Always F32 which has an effective gpr_ratio of 1
|
||||
static constexpr int gprs_c_d = M * N * B / 64;
|
||||
|
||||
public:
|
||||
Inst_VOP3P_MAI__V_MFMA_MXFP(InFmt_VOP3P_MAI *iFmt)
|
||||
: Inst_VOP3P_MAI(iFmt, *MNEMONIC)
|
||||
{
|
||||
setFlag(ALU);
|
||||
}
|
||||
~Inst_VOP3P_MAI__V_MFMA_MXFP() {}
|
||||
|
||||
int getNumOperands() override {
|
||||
return numDstRegOperands() + numSrcRegOperands();
|
||||
} // getNumOperands
|
||||
|
||||
int numDstRegOperands() override { return 1; }
|
||||
int numSrcRegOperands() override { return 3; }
|
||||
|
||||
int getOperandSize(int opIdx) override {
|
||||
switch (opIdx) {
|
||||
case 0: // src0 "A"
|
||||
return 4*gprs_a;
|
||||
case 1: // src1 "B"
|
||||
return 4*gprs_b;
|
||||
case 2: // src2 "C"
|
||||
return 4*gprs_c_d;
|
||||
case 3: // dst
|
||||
return 4*gprs_c_d;
|
||||
default:
|
||||
fatal("op idx %i out of bounds\n", opIdx);
|
||||
return -1;
|
||||
}
|
||||
} // getOperandSize
|
||||
|
||||
void
|
||||
execute(GPUDynInstPtr gpuDynInst) override
|
||||
{
|
||||
int acc_cd_off = 0;
|
||||
int acc_a_off = 0;
|
||||
int acc_b_off = 0;
|
||||
if (instData.ACC_CD) {
|
||||
acc_cd_off = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
if (extData.ACC) {
|
||||
int tmp_acc = extData.ACC;
|
||||
if (tmp_acc & 0x1) {
|
||||
acc_a_off = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
if (tmp_acc & 0x2) {
|
||||
acc_b_off = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
}
|
||||
|
||||
// Read the MXFP types as U32 - Consider this "untyped."
|
||||
// A ConstVecOperand needs to be used for src2 as it could be an
|
||||
// inline constant. The Const version provides an operator[] overload
|
||||
// to read inline constants to each lane. The non-const type of src2
|
||||
// should be used for vdst to make it writeable.
|
||||
using T1 = ConstVecOperandU32;
|
||||
using T2 = ConstVecOperandF32;
|
||||
using T3 = VecOperandF32;
|
||||
|
||||
alignas(T1) std::byte _src0[sizeof(T1) * gprs_a];
|
||||
alignas(T1) std::byte _src1[sizeof(T1) * gprs_b];
|
||||
alignas(T2) std::byte _src2[sizeof(T2) * gprs_c_d];
|
||||
alignas(T3) std::byte _vdst[sizeof(T3) * gprs_c_d];
|
||||
T1 *src0 = std::launder(reinterpret_cast<T1*>(&_src0));
|
||||
T1 *src1 = std::launder(reinterpret_cast<T1*>(&_src1));
|
||||
T2 *src2 = std::launder(reinterpret_cast<T2*>(&_src2));
|
||||
T3 *vdst = std::launder(reinterpret_cast<T3*>(&_vdst));
|
||||
|
||||
// Handling of src2 is a bit tricky. The operator[] overload cannot
|
||||
// be used for dword count > 2, and the dword count here is 4. Usually
|
||||
// src2 is a VGPR/AccGPR, but it might also be constant. In order to
|
||||
// use operator[] and handle constants, check for VGPR here and set
|
||||
// a delta for each of the src2 GPRs.
|
||||
|
||||
int delta = isVectorReg(extData.SRC0) ? 1 : 0;
|
||||
for (int i = 0; i < gprs_a; i++) {
|
||||
new (&src0[i]) T1(gpuDynInst, extData.SRC0+acc_a_off+i*delta);
|
||||
src0[i].readSrc();
|
||||
}
|
||||
|
||||
delta = isVectorReg(extData.SRC1) ? 1 : 0;
|
||||
for (int i = 0; i < gprs_b; i++) {
|
||||
new (&src1[i]) T1(gpuDynInst, extData.SRC1+acc_b_off+i*delta);
|
||||
src1[i].readSrc();
|
||||
}
|
||||
|
||||
delta = isVectorReg(extData.SRC2) ? 1 : 0;
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
new (&src2[i]) T2(gpuDynInst, extData.SRC2+acc_cd_off+i*delta);
|
||||
src2[i].readSrc();
|
||||
}
|
||||
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
new (&vdst[i]) T3(gpuDynInst, instData.VDST+acc_cd_off+i);
|
||||
}
|
||||
|
||||
// These values and meanings are described in the MI300 ISA manual:
|
||||
//
|
||||
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/
|
||||
// instruction-set-architectures/
|
||||
// amd-instinct-mi300-cdna3-instruction-set-architecture.pdf
|
||||
//
|
||||
// in section 7.1.4.2. In theory, only the M, N, K, and H values change
|
||||
// for each MFMA instruction.
|
||||
|
||||
// Output layout
|
||||
constexpr int H = 4;
|
||||
constexpr int B_I = std::ceil(64.0f / (N * M / H));
|
||||
constexpr int M_I = (64 / B_I) / N;
|
||||
constexpr int G = M / (H * M_I);
|
||||
|
||||
float result[M][N];
|
||||
|
||||
// Input layout
|
||||
constexpr int K_L = K / (64 / (M * B));
|
||||
|
||||
for (int block = 0; block < B; block++) {
|
||||
// Load src2 into result. src2 is row major
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
int item = (i % H) + H * (i/(H*M_I) + G * (block / B_I));
|
||||
int lane = j + N * ((i / H) % M_I + M_I * (block % B_I));
|
||||
|
||||
result[i][j] = src2[item][lane];
|
||||
}
|
||||
}
|
||||
|
||||
// Compute new result
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
for (int k = 0; k < K; ++k) {
|
||||
// src0 is column major, src1 is row major
|
||||
int lane_A = i + M * (block + B * (k / K_L));
|
||||
int lane_B = j + N * (block + B * (k / K_L));
|
||||
int item = k % K_L;
|
||||
|
||||
PackedReg<K_L * MXFPT::size(), MXFPT::size()> A_elems;
|
||||
PackedReg<K_L * MXFPT::size(), MXFPT::size()> B_elems;
|
||||
|
||||
for (int i = 0; i < gprs_a; ++i) {
|
||||
A_elems.setDword(i, src0[i][lane_A]);
|
||||
}
|
||||
for (int i = 0; i < gprs_b; ++i) {
|
||||
B_elems.setDword(i, src1[i][lane_B]);
|
||||
}
|
||||
|
||||
MXFPT item_A(A_elems.getElem(item));
|
||||
MXFPT item_B(B_elems.getElem(item));
|
||||
|
||||
result[i][j] += item_A * item_B;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
int item = (i % H) + H * (i/(H*M_I) + G * (block / B_I));
|
||||
int lane = j + N * ((i / H) % M_I + M_I * (block % B_I));
|
||||
|
||||
vdst[item][lane] = result[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < gprs_c_d; ++i) {
|
||||
vdst[i].write();
|
||||
}
|
||||
|
||||
for (int i = 0; i < gprs_a; i++) {
|
||||
std::destroy_at(&src0[i]);
|
||||
}
|
||||
for (int i = 0; i < gprs_b; i++) {
|
||||
std::destroy_at(&src1[i]);
|
||||
}
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
std::destroy_at(&src2[i]);
|
||||
}
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
std::destroy_at(&vdst[i]);
|
||||
}
|
||||
} // execute
|
||||
};
|
||||
|
||||
|
||||
static const char *MNEM__V_MFMA_F32_16X16X16_F16 =
|
||||
"v_mfma_f32_16x16x16_f16";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_16X16X16_F16 =
|
||||
Inst_VOP3P_MAI__V_MFMA_MXFP<16, 16, 16, 1, AMDGPU::mxfloat16,
|
||||
&MNEM__V_MFMA_F32_16X16X16_F16>;
|
||||
|
||||
static const char *MNEM__V_MFMA_F32_16X16X4_4B_F16 =
|
||||
"v_mfma_f32_16x16x4_4b_f16";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_16X16X4_4B_F16 =
|
||||
Inst_VOP3P_MAI__V_MFMA_MXFP<16, 16, 4, 4, AMDGPU::mxfloat16,
|
||||
&MNEM__V_MFMA_F32_16X16X4_4B_F16>;
|
||||
|
||||
static const char *MNEM__V_MFMA_F32_32X32X4_2B_F16 =
|
||||
"v_mfma_f32_32x32x4_2b_f16";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_32X32X4_2B_F16 =
|
||||
Inst_VOP3P_MAI__V_MFMA_MXFP<32, 32, 4, 2, AMDGPU::mxfloat16,
|
||||
&MNEM__V_MFMA_F32_32X32X4_2B_F16>;
|
||||
|
||||
static const char *NMEM__V_MFMA_F32_32X32X8_F16 =
|
||||
"v_mfma_f32_32x32x8_f16";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_32X32X8_F16 =
|
||||
Inst_VOP3P_MAI__V_MFMA_MXFP<32, 32, 8, 1, AMDGPU::mxfloat16,
|
||||
&NMEM__V_MFMA_F32_32X32X8_F16>;
|
||||
|
||||
static const char *MNEM__V_MFMA_F32_4X4X4_16B_F16 =
|
||||
"v_mfma_f32_4x4x4_16b_f16";
|
||||
using Inst_VOP3P_MAI__V_MFMA_F32_4X4X4_16B_F16 =
|
||||
Inst_VOP3P_MAI__V_MFMA_MXFP<4, 4, 4, 16, AMDGPU::mxfloat16,
|
||||
&MNEM__V_MFMA_F32_4X4X4_16B_F16>;
|
||||
|
||||
|
||||
template <const int M, const int N, const int K,
|
||||
const int B, const char **MNEMONIC>
|
||||
class Inst_VOP3P_MAI__V_MFMA_I8 : public Inst_VOP3P_MAI
|
||||
{
|
||||
|
||||
private:
|
||||
// Only int8 exists at the moment, but make the type a parameter.
|
||||
using DT = int8_t;
|
||||
static constexpr int DT_bits = sizeof(DT) * 8;
|
||||
|
||||
// Scale GPRs needed by elements / GPR (gpr_ratio)
|
||||
static constexpr int gpr_ratio = 32 / DT_bits;
|
||||
static constexpr int gprs_a = M * K * B / (64 * gpr_ratio);
|
||||
static constexpr int gprs_b = K * N * B / (64 * gpr_ratio);
|
||||
|
||||
// Always F32 which has an effective gpr_ratio of 1
|
||||
static constexpr int gprs_c_d = M * N * B / 64;
|
||||
|
||||
public:
|
||||
Inst_VOP3P_MAI__V_MFMA_I8(InFmt_VOP3P_MAI *iFmt)
|
||||
: Inst_VOP3P_MAI(iFmt, *MNEMONIC)
|
||||
{
|
||||
setFlag(ALU);
|
||||
}
|
||||
~Inst_VOP3P_MAI__V_MFMA_I8() {}
|
||||
|
||||
int getNumOperands() override {
|
||||
return numDstRegOperands() + numSrcRegOperands();
|
||||
} // getNumOperands
|
||||
|
||||
int numDstRegOperands() override { return 1; }
|
||||
int numSrcRegOperands() override { return 3; }
|
||||
|
||||
int getOperandSize(int opIdx) override {
|
||||
switch (opIdx) {
|
||||
case 0: // src0 "A"
|
||||
return 4*gprs_a;
|
||||
case 1: // src1 "B"
|
||||
return 4*gprs_b;
|
||||
case 2: // src2 "C"
|
||||
return 4*gprs_c_d;
|
||||
case 3: // dst
|
||||
return 4*gprs_c_d;
|
||||
default:
|
||||
fatal("op idx %i out of bounds\n", opIdx);
|
||||
return -1;
|
||||
}
|
||||
} // getOperandSize
|
||||
|
||||
void
|
||||
execute(GPUDynInstPtr gpuDynInst) override
|
||||
{
|
||||
int acc_cd_off = 0;
|
||||
int acc_a_off = 0;
|
||||
int acc_b_off = 0;
|
||||
if (instData.ACC_CD) {
|
||||
acc_cd_off = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
if (extData.ACC) {
|
||||
int tmp_acc = extData.ACC;
|
||||
if (tmp_acc & 0x1) {
|
||||
acc_a_off = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
if (tmp_acc & 0x2) {
|
||||
acc_b_off = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
}
|
||||
|
||||
// Read the packed types as U32 - Consider this "untyped."
|
||||
// A ConstVecOperand needs to be used for src2 as it could be an
|
||||
// inline constant. The Const version provides an operator[] overload
|
||||
// to read inline constants to each lane. The non-const type of src2
|
||||
// should be used for vdst to make it writeable.
|
||||
using T1 = ConstVecOperandU32;
|
||||
using T2 = ConstVecOperandI32;
|
||||
using T3 = VecOperandI32;
|
||||
|
||||
alignas(T1) std::byte _src0[sizeof(T1) * gprs_a];
|
||||
alignas(T1) std::byte _src1[sizeof(T1) * gprs_b];
|
||||
alignas(T2) std::byte _src2[sizeof(T2) * gprs_c_d];
|
||||
alignas(T3) std::byte _vdst[sizeof(T3) * gprs_c_d];
|
||||
T1 *src0 = std::launder(reinterpret_cast<T1*>(&_src0));
|
||||
T1 *src1 = std::launder(reinterpret_cast<T1*>(&_src1));
|
||||
T2 *src2 = std::launder(reinterpret_cast<T2*>(&_src2));
|
||||
T3 *vdst = std::launder(reinterpret_cast<T3*>(&_vdst));
|
||||
|
||||
// Handling of src2 is a bit tricky. The operator[] overload cannot
|
||||
// be used for dword count > 2, and the dword count here is 4. Usually
|
||||
// src2 is a VGPR/AccGPR, but it might also be constant. In order to
|
||||
// use operator[] and handle constants, check for VGPR here and set
|
||||
// a delta for each of the src2 GPRs.
|
||||
|
||||
int delta = isVectorReg(extData.SRC0) ? 1 : 0;
|
||||
for (int i = 0; i < gprs_a; i++) {
|
||||
new (&src0[i]) T1(gpuDynInst, extData.SRC0+acc_a_off+i*delta);
|
||||
src0[i].readSrc();
|
||||
}
|
||||
|
||||
delta = isVectorReg(extData.SRC1) ? 1 : 0;
|
||||
for (int i = 0; i < gprs_b; i++) {
|
||||
new (&src1[i]) T1(gpuDynInst, extData.SRC1+acc_b_off+i*delta);
|
||||
src1[i].readSrc();
|
||||
}
|
||||
|
||||
delta = isVectorReg(extData.SRC2) ? 1 : 0;
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
new (&src2[i]) T2(gpuDynInst, extData.SRC2+acc_cd_off+i*delta);
|
||||
src2[i].readSrc();
|
||||
}
|
||||
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
new (&vdst[i]) T3(gpuDynInst, instData.VDST+acc_cd_off+i);
|
||||
}
|
||||
|
||||
// These values and meanings are described in the MI300 ISA manual:
|
||||
//
|
||||
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/
|
||||
// instruction-set-architectures/
|
||||
// amd-instinct-mi300-cdna3-instruction-set-architecture.pdf
|
||||
//
|
||||
// in section 7.1.4.2. In theory, only the M, N, K, and H values change
|
||||
// for each MFMA instruction.
|
||||
|
||||
// Output layout
|
||||
constexpr int H = 4;
|
||||
constexpr int B_I = std::ceil(64.0f / (N * M / H));
|
||||
constexpr int M_I = (64 / B_I) / N;
|
||||
constexpr int G = M / (H * M_I);
|
||||
|
||||
int32_t result[M][N];
|
||||
|
||||
// Input layout
|
||||
constexpr int K_L = K / (64 / (M * B));
|
||||
|
||||
for (int block = 0; block < B; block++) {
|
||||
// Load src2 into result. src2 is row major
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
int item = (i % H) + H * (i/(H*M_I) + G * (block / B_I));
|
||||
int lane = j + N * ((i / H) % M_I + M_I * (block % B_I));
|
||||
|
||||
result[i][j] = src2[item][lane];
|
||||
}
|
||||
}
|
||||
|
||||
// Compute new result
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
for (int k = 0; k < K; ++k) {
|
||||
// src0 is column major, src1 is row major
|
||||
int lane_A = i + M * (block + B * (k / K_L));
|
||||
int lane_B = j + N * (block + B * (k / K_L));
|
||||
int item = k % K_L;
|
||||
|
||||
PackedReg<K_L * DT_bits, DT_bits> A_elems;
|
||||
PackedReg<K_L * DT_bits, DT_bits> B_elems;
|
||||
|
||||
for (int i = 0; i < gprs_a; ++i) {
|
||||
A_elems.setDword(i, src0[i][lane_A]);
|
||||
}
|
||||
for (int i = 0; i < gprs_b; ++i) {
|
||||
B_elems.setDword(i, src1[i][lane_B]);
|
||||
}
|
||||
|
||||
DT item_A(A_elems.getElem(item));
|
||||
DT item_B(B_elems.getElem(item));
|
||||
|
||||
result[i][j] += int32_t(item_A) * int32_t(item_B);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
int item = (i % H) + H * (i/(H*M_I) + G * (block / B_I));
|
||||
int lane = j + N * ((i / H) % M_I + M_I * (block % B_I));
|
||||
|
||||
vdst[item][lane] = result[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < gprs_c_d; ++i) {
|
||||
vdst[i].write();
|
||||
}
|
||||
|
||||
for (int i = 0; i < gprs_a; i++) {
|
||||
std::destroy_at(&src0[i]);
|
||||
}
|
||||
for (int i = 0; i < gprs_b; i++) {
|
||||
std::destroy_at(&src1[i]);
|
||||
}
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
std::destroy_at(&src2[i]);
|
||||
}
|
||||
for (int i = 0; i < gprs_c_d; i++) {
|
||||
std::destroy_at(&vdst[i]);
|
||||
}
|
||||
} // execute
|
||||
};
|
||||
|
||||
static const char *MNEM__V_MFMA_I32_16X16X16_I8 =
|
||||
"v_mfma_i32_16x16x16_i8";
|
||||
using Inst_VOP3P_MAI__V_MFMA_I32_16X16X16_I8 =
|
||||
Inst_VOP3P_MAI__V_MFMA_I8<16, 16, 16, 1,
|
||||
&MNEM__V_MFMA_I32_16X16X16_I8>;
|
||||
|
||||
|
||||
class Inst_VOP3__V_CVT_PK_FP8_F32 : public Inst_VOP3A
|
||||
{
|
||||
|
||||
@@ -37,140 +37,5 @@ namespace gem5
|
||||
|
||||
namespace VegaISA
|
||||
{
|
||||
// --- Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8 class methods ---
|
||||
|
||||
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8::
|
||||
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(InFmt_VOP3P_MAI *iFmt)
|
||||
: Inst_VOP3P_MAI(iFmt, "v_mfma_i32_16x16x16i8")
|
||||
{
|
||||
setFlag(ALU);
|
||||
} // Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8
|
||||
|
||||
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8::
|
||||
~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8()
|
||||
{
|
||||
} // ~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8
|
||||
|
||||
// D(16x16I32) = A(16x16I8) x B(16x16I8) + C(16x16I32), 1 Blocks, 8
|
||||
// pass, srcA/srcB 1 archVgpr, srcC/D 4 accVGPR
|
||||
void
|
||||
Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8::execute(GPUDynInstPtr gpuDynInst)
|
||||
{
|
||||
// Accumulation register offsets for A, B, and C/D matrix.
|
||||
int a_offset = 0;
|
||||
int b_offset = 0;
|
||||
int cd_offset = 0;
|
||||
if (instData.ACC_CD) {
|
||||
cd_offset = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
if (extData.ACC) {
|
||||
if (extData.ACC & 0x1) {
|
||||
a_offset = gpuDynInst->wavefront()->accumOffset;
|
||||
} else if (extData.ACC & 0x2) {
|
||||
b_offset = gpuDynInst->wavefront()->accumOffset;
|
||||
}
|
||||
}
|
||||
|
||||
// int8 size allows for 4 elements per lane. At 16x16 this means 4
|
||||
// lanes per column (A matrix) / (B matrix). This whole matrix fits
|
||||
// in one VGPR. The C matrix with size int32 requires 4 VGPRs.
|
||||
// Handle the C matrix by using a delta. This is set to 1 normally to
|
||||
// move to the next VGPR (1 dword away) and 0 if the input is a scalar
|
||||
// reg (e.g., a constant).
|
||||
int delta = isVectorReg(extData.SRC2) ? 1 : 0;
|
||||
|
||||
// VecOperandI8 will read 8 bits and sign extend, so used U32 to read
|
||||
// as "untyped" 32-bit values.
|
||||
ConstVecOperandU32 src0(gpuDynInst, extData.SRC0+a_offset);
|
||||
ConstVecOperandU32 src1(gpuDynInst, extData.SRC1+b_offset);
|
||||
ConstVecOperandI32 src2[4] = {
|
||||
ConstVecOperandI32(gpuDynInst, extData.SRC2+cd_offset),
|
||||
ConstVecOperandI32(gpuDynInst, extData.SRC2+cd_offset+1*delta),
|
||||
ConstVecOperandI32(gpuDynInst, extData.SRC2+cd_offset+2*delta),
|
||||
ConstVecOperandI32(gpuDynInst, extData.SRC2+cd_offset+3*delta),
|
||||
};
|
||||
|
||||
VecOperandI32 vdst[4] = {
|
||||
VecOperandI32(gpuDynInst, instData.VDST+cd_offset),
|
||||
VecOperandI32(gpuDynInst, instData.VDST+cd_offset+1),
|
||||
VecOperandI32(gpuDynInst, instData.VDST+cd_offset+2),
|
||||
VecOperandI32(gpuDynInst, instData.VDST+cd_offset+3),
|
||||
};
|
||||
|
||||
src0.readSrc();
|
||||
src1.readSrc();
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
src2[i].readSrc();
|
||||
}
|
||||
|
||||
int32_t A[16][16];
|
||||
for (int i = 0; i < 64; ++i) {
|
||||
// src0[0:15] contains columns 1 - 4 packed for rows 0 - 15,
|
||||
// src0[16:31] contains columns 5 - 8 packed for rows 0 - 15,
|
||||
// src0[32:47] contains columns 9 - 12 packed for rows 0 - 15,
|
||||
// src0[48:63] contains columns 13 - 16 packed for rows 0 - 15,
|
||||
int row = i % 16;
|
||||
int start_col = (i / 16) * 4;
|
||||
|
||||
A[row][start_col+0] = sext<8>(bits(src0[i], 7, 0));
|
||||
A[row][start_col+1] = sext<8>(bits(src0[i], 15, 8));
|
||||
A[row][start_col+2] = sext<8>(bits(src0[i], 23, 16));
|
||||
A[row][start_col+3] = sext<8>(bits(src0[i], 31, 24));
|
||||
}
|
||||
|
||||
int32_t B[16][16];
|
||||
for (int i = 0; i < 64; ++i) {
|
||||
// src1[0:15] contains rows 1 - 4 packed for columns 0 - 15
|
||||
// src1[16:31] contains rows 5 - 8 packed for columns 0 - 15
|
||||
// src1[32:47] contains rows 9 - 12 packed for columns 0 - 15
|
||||
// src1[48:63] contains rows 13 - 16 packed for columns 0 - 15
|
||||
int start_row = (i / 16) * 4;
|
||||
int col = i % 16;
|
||||
|
||||
B[start_row+0][col] = sext<8>(bits(src1[i], 7, 0));
|
||||
B[start_row+1][col] = sext<8>(bits(src1[i], 15, 8));
|
||||
B[start_row+2][col] = sext<8>(bits(src1[i], 23, 16));
|
||||
B[start_row+3][col] = sext<8>(bits(src1[i], 31, 24));
|
||||
}
|
||||
|
||||
int32_t result[16][16];
|
||||
|
||||
// Load accumulation matrix C into result
|
||||
for (int i = 0; i < 64; ++i) {
|
||||
// src2[0] contains rows 0, 4, 8, 12
|
||||
result[(i/16)*4][(i%16)] = src2[0][i];
|
||||
// src2[1] contains rows 1, 5, 9, 13
|
||||
result[(i/16)*4+1][(i%16)] = src2[1][i];
|
||||
// src2[2] contains rows 2, 6, 10, 14
|
||||
result[(i/16)*4+2][(i%16)] = src2[2][i];
|
||||
// src2[3] contains rows 3, 7, 11, 15
|
||||
result[(i/16)*4+3][(i%16)] = src2[3][i];
|
||||
}
|
||||
|
||||
// Compute new result - This is (obviously) not optimized
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
result[i][j] += A[i][k] * B[k][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Put result in dest VGPRs
|
||||
for (int i = 0; i < 64; ++i) {
|
||||
// vdst[0] contains rows 0, 4, 8, 12
|
||||
vdst[0][i] = result[(i/16)*4][(i%16)];
|
||||
// vdst[1] contains rows 1, 5, 9, 13
|
||||
vdst[1][i] = result[(i/16)*4+1][(i%16)];
|
||||
// vdst[2] contains rows 2, 6, 10, 14
|
||||
vdst[2][i] = result[(i/16)*4+2][(i%16)];
|
||||
// vdst[3] contains rows 3, 7, 11, 15
|
||||
vdst[3][i] = result[(i/16)*4+3][(i%16)];
|
||||
}
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vdst[i].write();
|
||||
}
|
||||
} // execute
|
||||
} // namespace VegaISA
|
||||
} // namespace gem5
|
||||
|
||||
Reference in New Issue
Block a user