arch-vega: Various vega fixes to enable nanogpt (#950)

This PR fixes some issues observed that were needed to get nanogpt
working.
This commit is contained in:
Matthew Poremba
2024-03-21 21:11:44 -07:00
committed by GitHub
4 changed files with 134 additions and 21 deletions

View File

@@ -1238,6 +1238,7 @@ namespace VegaISA
&Decoder::decode_OPU_VOP3__V_CVT_PK_I16_I32,
&Decoder::decode_OPU_VOP3__V_PKNORM_I16_F16,
&Decoder::decode_OPU_VOP3__V_PKNORM_U16_F16,
&Decoder::decode_invalid,
&Decoder::decode_OPU_VOP3__V_ADD_I32,
&Decoder::decode_OPU_VOP3__V_SUB_I32,
&Decoder::decode_OPU_VOP3__V_ADD_I16,
@@ -1337,7 +1338,6 @@ namespace VegaISA
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_invalid,
&Decoder::decode_invalid
};
@@ -4217,8 +4217,7 @@ namespace VegaISA
GPUStaticInst*
Decoder::decode_OP_VOP2__V_XNOR_B32(MachInst iFmt)
{
fatal("Trying to decode instruction without a class\n");
return nullptr;
return new Inst_VOP2__V_XNOR_B32(&iFmt->iFmt_VOP2);
}
GPUStaticInst*

View File

@@ -8132,6 +8132,40 @@ namespace VegaISA
void execute(GPUDynInstPtr) override;
}; // Inst_VOP2__V_FMAC_F32
class Inst_VOP2__V_XNOR_B32 : public Inst_VOP2
{
public:
Inst_VOP2__V_XNOR_B32(InFmt_VOP2*);
~Inst_VOP2__V_XNOR_B32();
int
getNumOperands() override
{
return numDstRegOperands() + numSrcRegOperands();
} // getNumOperands
int numDstRegOperands() override { return 1; }
int numSrcRegOperands() override { return 2; }
int
getOperandSize(int opIdx) override
{
switch (opIdx) {
case 0: //src_0
return 4;
case 1: //src_1
return 4;
case 2: //vdst
return 4;
default:
fatal("op idx %i out of bounds\n", opIdx);
return -1;
}
} // getOperandSize
void execute(GPUDynInstPtr) override;
}; // Inst_VOP2__V_XNOR_B32
class Inst_VOP1__V_NOP : public Inst_VOP1
{
public:

View File

@@ -2167,9 +2167,9 @@ namespace VegaISA
Inst_VOP2__V_FMAC_F32::execute(GPUDynInstPtr gpuDynInst)
{
Wavefront *wf = gpuDynInst->wavefront();
ConstVecOperandU32 src0(gpuDynInst, instData.SRC0);
ConstVecOperandU32 src1(gpuDynInst, instData.VSRC1);
VecOperandU32 vdst(gpuDynInst, instData.VDST);
ConstVecOperandF32 src0(gpuDynInst, instData.SRC0);
ConstVecOperandF32 src1(gpuDynInst, instData.VSRC1);
VecOperandF32 vdst(gpuDynInst, instData.VDST);
src0.readSrc();
src1.read();
@@ -2181,6 +2181,40 @@ namespace VegaISA
}
}
vdst.write();
} // execute
// --- Inst_VOP2__V_XNOR_B32 class methods ---
Inst_VOP2__V_XNOR_B32::Inst_VOP2__V_XNOR_B32(InFmt_VOP2 *iFmt)
: Inst_VOP2(iFmt, "v_xnor_b32")
{
setFlag(ALU);
} // Inst_VOP2__V_XNOR_B32
Inst_VOP2__V_XNOR_B32::~Inst_VOP2__V_XNOR_B32()
{
} // ~Inst_VOP2__V_XNOR_B32
// --- description from .arch file ---
// D.u = S1.u - S0.u;
void
Inst_VOP2__V_XNOR_B32::execute(GPUDynInstPtr gpuDynInst)
{
Wavefront *wf = gpuDynInst->wavefront();
ConstVecOperandU32 src0(gpuDynInst, instData.SRC0);
ConstVecOperandU32 src1(gpuDynInst, instData.VSRC1);
VecOperandU32 vdst(gpuDynInst, instData.VDST);
src0.readSrc();
src1.read();
vdst.read();
for (int lane = 0; lane < NumVecElemPerVecReg; ++lane) {
if (wf->execMask(lane)) {
vdst[lane] = ~(src0[lane] ^ src1[lane]);
}
}
vdst.write();
} // execute
} // namespace VegaISA

View File

@@ -666,6 +666,9 @@ Inst_VOP3P__V_PK_FMA_F32::execute(GPUDynInstPtr gpuDynInst)
int opsel = instData.OPSEL;
int opsel_hi = extData.OPSEL_HI | (instData.OPSEL_HI2 << 2);
int neg = extData.NEG;
int neg_hi = instData.NEG_HI;
for (int lane = 0; lane < NumVecElemPerVecReg; ++lane) {
if (wf->execMask(lane)) {
uint32_t s0l = (opsel & 1) ? bits(src0[lane], 63, 32)
@@ -675,9 +678,15 @@ Inst_VOP3P__V_PK_FMA_F32::execute(GPUDynInstPtr gpuDynInst)
uint32_t s2l = (opsel & 4) ? bits(src2[lane], 63, 32)
: bits(src2[lane], 31, 0);
float dword1 = std::fma(*reinterpret_cast<float*>(&s0l),
*reinterpret_cast<float*>(&s1l),
*reinterpret_cast<float*>(&s2l));
float s0lf = *reinterpret_cast<float*>(&s0l);
float s1lf = *reinterpret_cast<float*>(&s1l);
float s2lf = *reinterpret_cast<float*>(&s2l);
if (neg & 1) s0lf = -s0lf;
if (neg & 1) s1lf = -s1lf;
if (neg & 1) s2lf = -s2lf;
float dword1 = std::fma(s0lf, s1lf, s2lf);
uint32_t s0h = (opsel_hi & 1) ? bits(src0[lane], 63, 32)
: bits(src0[lane], 31, 0);
@@ -686,9 +695,15 @@ Inst_VOP3P__V_PK_FMA_F32::execute(GPUDynInstPtr gpuDynInst)
uint32_t s2h = (opsel_hi & 4) ? bits(src2[lane], 63, 32)
: bits(src2[lane], 31, 0);
float dword2 = std::fma(*reinterpret_cast<float*>(&s0h),
*reinterpret_cast<float*>(&s1h),
*reinterpret_cast<float*>(&s2h));
float s0hf = *reinterpret_cast<float*>(&s0h);
float s1hf = *reinterpret_cast<float*>(&s1h);
float s2hf = *reinterpret_cast<float*>(&s2h);
if (neg_hi & 1) s0hf = -s0hf;
if (neg_hi & 1) s1hf = -s1hf;
if (neg_hi & 1) s2hf = -s2hf;
float dword2 = std::fma(s0hf, s1hf, s2hf);
uint32_t result1 = *reinterpret_cast<uint32_t*>(&dword1);
uint32_t result2 = *reinterpret_cast<uint32_t*>(&dword2);
@@ -731,6 +746,9 @@ Inst_VOP3P__V_PK_MUL_F32::execute(GPUDynInstPtr gpuDynInst)
int opsel = instData.OPSEL;
int opsel_hi = extData.OPSEL_HI;
int neg = extData.NEG;
int neg_hi = instData.NEG_HI;
for (int lane = 0; lane < NumVecElemPerVecReg; ++lane) {
if (wf->execMask(lane)) {
uint32_t lower_dword = (opsel & 1) ? bits(src0[lane], 63, 32)
@@ -738,16 +756,26 @@ Inst_VOP3P__V_PK_MUL_F32::execute(GPUDynInstPtr gpuDynInst)
uint32_t upper_dword = (opsel & 2) ? bits(src1[lane], 63, 32)
: bits(src1[lane], 31, 0);
float dword1 = *reinterpret_cast<float*>(&lower_dword)
* *reinterpret_cast<float*>(&upper_dword);
float ldwordf = *reinterpret_cast<float*>(&lower_dword);
float udwordf = *reinterpret_cast<float*>(&upper_dword);
if (neg & 1) ldwordf = -ldwordf;
if (neg & 2) udwordf = -udwordf;
float dword1 = ldwordf * udwordf;
lower_dword = (opsel_hi & 1) ? bits(src0[lane], 63, 32)
: bits(src0[lane], 31, 0);
upper_dword = (opsel_hi & 2) ? bits(src1[lane], 63, 32)
: bits(src1[lane], 31, 0);
float dword2 = *reinterpret_cast<float*>(&lower_dword)
* *reinterpret_cast<float*>(&upper_dword);
ldwordf = *reinterpret_cast<float*>(&lower_dword);
udwordf = *reinterpret_cast<float*>(&upper_dword);
if (neg_hi & 1) ldwordf = -ldwordf;
if (neg_hi & 2) udwordf = -udwordf;
float dword2 = ldwordf * udwordf;
uint32_t result1 = *reinterpret_cast<uint32_t*>(&dword1);
uint32_t result2 = *reinterpret_cast<uint32_t*>(&dword2);
@@ -787,9 +815,15 @@ Inst_VOP3P__V_PK_ADD_F32::execute(GPUDynInstPtr gpuDynInst)
src0.readSrc();
src1.readSrc();
panic_if(isSDWAInst(), "SDWA not supported for %s", _opcode);
panic_if(isDPPInst(), "DPP not supported for %s", _opcode);
int opsel = instData.OPSEL;
int opsel_hi = extData.OPSEL_HI;
int neg = extData.NEG;
int neg_hi = instData.NEG_HI;
for (int lane = 0; lane < NumVecElemPerVecReg; ++lane) {
if (wf->execMask(lane)) {
uint32_t lower_dword = (opsel & 1) ? bits(src0[lane], 63, 32)
@@ -797,16 +831,26 @@ Inst_VOP3P__V_PK_ADD_F32::execute(GPUDynInstPtr gpuDynInst)
uint32_t upper_dword = (opsel & 2) ? bits(src1[lane], 63, 32)
: bits(src1[lane], 31, 0);
float dword1 = *reinterpret_cast<float*>(&lower_dword)
+ *reinterpret_cast<float*>(&upper_dword);
float ldwordf = *reinterpret_cast<float*>(&lower_dword);
float udwordf = *reinterpret_cast<float*>(&upper_dword);
if (neg & 1) ldwordf = -ldwordf;
if (neg & 2) udwordf = -udwordf;
float dword1 = ldwordf + udwordf;
lower_dword = (opsel_hi & 1) ? bits(src0[lane], 63, 32)
: bits(src0[lane], 31, 0);
upper_dword = (opsel_hi & 2) ? bits(src1[lane], 63, 32)
: bits(src1[lane], 31, 0);
float dword2 = *reinterpret_cast<float*>(&lower_dword)
+ *reinterpret_cast<float*>(&upper_dword);
ldwordf = *reinterpret_cast<float*>(&lower_dword);
udwordf = *reinterpret_cast<float*>(&upper_dword);
if (neg_hi & 1) ldwordf = -ldwordf;
if (neg_hi & 2) udwordf = -udwordf;
float dword2 = ldwordf + udwordf;
uint32_t result1 = *reinterpret_cast<uint32_t*>(&dword1);
uint32_t result2 = *reinterpret_cast<uint32_t*>(&dword2);
@@ -845,9 +889,11 @@ Inst_VOP3P__V_PK_MOV_B32::execute(GPUDynInstPtr gpuDynInst)
// Only OPSEL[1:0] are used
// OPSEL[0] 0/1: Lower dest dword = lower/upper dword of src0
int opsel = instData.OPSEL;
warn_if(instData.NEG_HI || extData.NEG,
"Negative modifier undefined for %s", _opcode);
for (int lane = 0; lane < NumVecElemPerVecReg; ++lane) {
if (wf->execMask(lane)) {
// OPSEL[1] 0/1: Lower dest dword = lower/upper dword of src1