diff --git a/src/gpu-compute/gpu_dyn_inst.cc b/src/gpu-compute/gpu_dyn_inst.cc index d4a6a8f447..c4a8e9085a 100644 --- a/src/gpu-compute/gpu_dyn_inst.cc +++ b/src/gpu-compute/gpu_dyn_inst.cc @@ -925,20 +925,14 @@ GPUDynInst::resolveFlatSegment(const VectorMask &mask) ComputeUnit *cu = wavefront()->computeUnit; if (wavefront()->gfxVersion == GfxVersion::gfx942) { - // Architected flat scratch base address in FLAT_SCRATCH registers - uint32_t fs_lo = cu->srf[simdId]->read( - VegaISA::REG_FLAT_SCRATCH_LO); - uint32_t fs_hi = cu->srf[simdId]->read( - VegaISA::REG_FLAT_SCRATCH_HI); - - Addr arch_flat_scratch = ((Addr)(fs_hi) << 32) | fs_lo; - + // Architected flat scratch base address is in a dedicated hardware + // register. for (int lane = 0; lane < cu->wfSize(); ++lane) { if (mask[lane]) { // The scratch base is added for other gfx versions, // otherwise this would simply add the register base. addr[lane] = addr[lane] - cu->shader->getScratchBase() - + arch_flat_scratch; + + wavefront()->archFlatScratchAddr; } } } else { diff --git a/src/gpu-compute/wavefront.cc b/src/gpu-compute/wavefront.cc index 1b94b13b6e..d14f8aee3c 100644 --- a/src/gpu-compute/wavefront.cc +++ b/src/gpu-compute/wavefront.cc @@ -384,14 +384,13 @@ Wavefront::initRegState(HSAQueueEntry *task, int wgSizeInWorkItems) // the FLAT_SCRATCH register pair to the scratch backing // memory: https://llvm.org/docs/AMDGPUUsage.html#flat-scratch if (task->gfxVersion() == GfxVersion::gfx942) { - Addr arch_flat_scratch = + archFlatScratchAddr = task->amdQueue.scratch_backing_memory_location; - computeUnit->srf[simdId]->write( - VegaISA::REG_FLAT_SCRATCH_HI, - bits(arch_flat_scratch, 63, 32)); - computeUnit->srf[simdId]->write( - VegaISA::REG_FLAT_SCRATCH_LO, - bits(arch_flat_scratch, 31, 0)); + + DPRINTF(GPUInitAbi, "CU%d: WF[%d][%d]: wave[%d] " + "Setting architected flat scratch = %x\n", + computeUnit->cu_id, simdId, wfSlotId, wfDynId, + archFlatScratchAddr); break; } diff --git a/src/gpu-compute/wavefront.hh b/src/gpu-compute/wavefront.hh index b7dff4617b..476393603b 100644 --- a/src/gpu-compute/wavefront.hh +++ b/src/gpu-compute/wavefront.hh @@ -205,6 +205,9 @@ class Wavefront : public SimObject // will live while the WF is executed uint32_t startSgprIndex; + // Architected flat scratch address for MI300+ + Addr archFlatScratchAddr = 0; + // Old value of destination gpr (for trace) std::vector oldVgpr; // Id of destination gpr (for trace)