diff --git a/src/arch/amdgpu/vega/decoder.cc b/src/arch/amdgpu/vega/decoder.cc index 2850640af2..7d3d707c56 100644 --- a/src/arch/amdgpu/vega/decoder.cc +++ b/src/arch/amdgpu/vega/decoder.cc @@ -3664,7 +3664,7 @@ namespace VegaISA &Decoder::decode_invalid, &Decoder::decode_invalid, &Decoder::decode_invalid, - &Decoder::decode_invalid, + &Decoder::decode_OP_VOP3P__V_MFMA_I32_16X16X16I8, &Decoder::decode_invalid, &Decoder::decode_invalid, &Decoder::decode_invalid, @@ -13059,6 +13059,13 @@ namespace VegaISA return new Inst_VOP3P__V_PK_MOV_B32(&iFmt->iFmt_VOP3P); } + GPUStaticInst* + Decoder::decode_OP_VOP3P__V_MFMA_I32_16X16X16I8(MachInst iFmt) + { + return new Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8( + &iFmt->iFmt_VOP3P_MAI); + } + GPUStaticInst* Decoder::decode_OP_VOP3P__V_MFMA_F64_16X16X4F64(MachInst iFmt) { diff --git a/src/arch/amdgpu/vega/gpu_decoder.hh b/src/arch/amdgpu/vega/gpu_decoder.hh index 2a6f9370fb..11858f0375 100644 --- a/src/arch/amdgpu/vega/gpu_decoder.hh +++ b/src/arch/amdgpu/vega/gpu_decoder.hh @@ -1586,6 +1586,7 @@ namespace VegaISA GPUStaticInst* decode_OP_VOP3P__V_MAD_MIXLO_F16(MachInst); GPUStaticInst* decode_OP_VOP3P__V_MAD_MIXHI_F16(MachInst); GPUStaticInst* decode_OP_VOP3P__V_PK_MOV_B32(MachInst); + GPUStaticInst* decode_OP_VOP3P__V_MFMA_I32_16X16X16I8(MachInst); GPUStaticInst* decode_OP_VOP3P__V_MFMA_F64_16X16X4F64(MachInst); GPUStaticInst* subDecode_OPU_VOP3(MachInst); GPUStaticInst* subDecode_OP_DS(MachInst); diff --git a/src/arch/amdgpu/vega/insts/instructions.cc b/src/arch/amdgpu/vega/insts/instructions.cc index 855f91699f..f0472835dd 100644 --- a/src/arch/amdgpu/vega/insts/instructions.cc +++ b/src/arch/amdgpu/vega/insts/instructions.cc @@ -45953,6 +45953,129 @@ namespace VegaISA vdst.write(); } // execute + // --- Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8 class methods --- + + Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8:: + Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(InFmt_VOP3P_MAI *iFmt) + : Inst_VOP3P_MAI(iFmt, "v_mfma_i32_16x16x16i8") + { + setFlag(ALU); + } // Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8 + + Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8:: + ~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8() + { + } // ~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8 + + // D(16x16I32) = A(16x16I8) x B(16x16I8) + C(16x16I32), 1 Blocks, 8 + // pass, srcA/srcB 1 archVgpr, srcC/D 4 accVGPR + void + Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8::execute(GPUDynInstPtr gpuDynInst) + { + int acc_offset = 0; + if (instData.ACC_CD) { + warn("ACC_CD not yet implemented\n"); + } + + // int8 size allows for 4 elements per lane. At 16x16 this means 4 + // lanes per column (A matrix) / (B matrix). This whole matrix fits + // in one VGPR. The C matrix with size int32 requires 4 VGPRs. + // Handle the C matrix by using a delta. This is set to 1 normally to + // move to the next VGPR (1 dword away) and 0 if the input is a scalar + // reg (e.g., a constant). + int delta = isVectorReg(extData.SRC2) ? 1 : 0; + + // VecOperandI8 will read 8 bits and sign extend, so used U32 to read + // as "untyped" 32-bit values. + ConstVecOperandU32 src0(gpuDynInst, extData.SRC0); + ConstVecOperandU32 src1(gpuDynInst, extData.SRC1); + ConstVecOperandI32 src2a(gpuDynInst, extData.SRC2+acc_offset); + ConstVecOperandI32 src2b(gpuDynInst, extData.SRC2+acc_offset+1*delta); + ConstVecOperandI32 src2c(gpuDynInst, extData.SRC2+acc_offset+2*delta); + ConstVecOperandI32 src2d(gpuDynInst, extData.SRC2+acc_offset+3*delta); + + VecOperandI32 vdsta(gpuDynInst, instData.VDST+acc_offset); + VecOperandI32 vdstb(gpuDynInst, instData.VDST+acc_offset+1); + VecOperandI32 vdstc(gpuDynInst, instData.VDST+acc_offset+2); + VecOperandI32 vdstd(gpuDynInst, instData.VDST+acc_offset+3); + + src0.readSrc(); + src1.readSrc(); + src2a.readSrc(); + src2b.readSrc(); + src2c.readSrc(); + src2d.readSrc(); + + int32_t A[16][16]; + for (int i = 0; i < 64; ++i) { + // src0[0:15] contains columns 1 - 4 packed for rows 0 - 15, + // src0[16:31] contains columns 5 - 8 packed for rows 0 - 15, + // src0[32:47] contains columns 9 - 12 packed for rows 0 - 15, + // src0[48:63] contains columns 13 - 16 packed for rows 0 - 15, + int row = i % 16; + int start_col = (i / 16) * 4; + + A[row][start_col+0] = sext<8>(bits(src0[i], 7, 0)); + A[row][start_col+1] = sext<8>(bits(src0[i], 15, 8)); + A[row][start_col+2] = sext<8>(bits(src0[i], 23, 16)); + A[row][start_col+3] = sext<8>(bits(src0[i], 31, 24)); + } + + int32_t B[16][16]; + for (int i = 0; i < 64; ++i) { + // src1[0:15] contains rows 1 - 4 packed for columns 0 - 15 + // src1[16:31] contains rows 5 - 8 packed for columns 0 - 15 + // src1[32:47] contains rows 9 - 12 packed for columns 0 - 15 + // src1[48:63] contains rows 13 - 16 packed for columns 0 - 15 + int start_row = (i / 16) * 4; + int col = i % 16; + + B[start_row+0][col] = sext<8>(bits(src1[i], 7, 0)); + B[start_row+1][col] = sext<8>(bits(src1[i], 15, 8)); + B[start_row+2][col] = sext<8>(bits(src1[i], 23, 16)); + B[start_row+3][col] = sext<8>(bits(src1[i], 31, 24)); + } + + int32_t result[16][16]; + + // Load accumulation matrix C into result + for (int i = 0; i < 64; ++i) { + // src2a contains rows 0, 4, 8, 12 + result[(i/16)*4][(i%16)] = src2a[i]; + // src2b contains rows 1, 5, 9, 13 + result[(i/16)*4+1][(i%16)] = src2b[i]; + // src2c contains rows 2, 6, 10, 14 + result[(i/16)*4+2][(i%16)] = src2c[i]; + // src2d contains rows 3, 7, 11, 15 + result[(i/16)*4+3][(i%16)] = src2d[i]; + } + + // Compute new result - This is (obviously) not optimized + for (int i = 0; i < 16; ++i) { + for (int j = 0; j < 16; ++j) { + for (int k = 0; k < 16; ++k) { + result[i][j] += A[i][k] * B[k][j]; + } + } + } + + // Put result in dest VGPRs + for (int i = 0; i < 64; ++i) { + // vdsta contains rows 0, 4, 8, 12 + vdsta[i] = result[(i/16)*4][(i%16)]; + // vdstb contains rows 1, 5, 9, 13 + vdstb[i] = result[(i/16)*4+1][(i%16)]; + // vdstc contains rows 2, 6, 10, 14 + vdstc[i] = result[(i/16)*4+2][(i%16)]; + // vdstd contains rows 3, 7, 11, 15 + vdstd[i] = result[(i/16)*4+3][(i%16)]; + } + + vdsta.write(); + vdstb.write(); + vdstc.write(); + vdstd.write(); + } // execute // --- Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64 class methods --- Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64:: diff --git a/src/arch/amdgpu/vega/insts/instructions.hh b/src/arch/amdgpu/vega/insts/instructions.hh index 22423e14c6..c41569f193 100644 --- a/src/arch/amdgpu/vega/insts/instructions.hh +++ b/src/arch/amdgpu/vega/insts/instructions.hh @@ -43440,6 +43440,42 @@ namespace VegaISA void execute(GPUDynInstPtr) override; }; // Inst_VOP3P__V_PK_MOV_B32 + class Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8 : public Inst_VOP3P_MAI + { + public: + Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(InFmt_VOP3P_MAI*); + ~Inst_VOP3P_MAI__V_MFMA_I32_16X16X16I8(); + + int + getNumOperands() override + { + return numDstRegOperands() + numSrcRegOperands(); + } // getNumOperands + + int numDstRegOperands() override { return 1; } + int numSrcRegOperands() override { return 3; } + + int + getOperandSize(int opIdx) override + { + switch (opIdx) { + case 0: // src0 "A" + return 4; + case 1: // src1 "B" + return 4; + case 2: // src2 "C" + return 16; + case 3: // dst + return 16; + default: + fatal("op idx %i out of bounds\n", opIdx); + return -1; + } + } // getOperandSize + + void execute(GPUDynInstPtr) override; + }; + class Inst_VOP3P_MAI__V_MFMA_F64_16X16X4F64 : public Inst_VOP3P_MAI { public: