From c694d8589f1023f6e565d83b1c636da6a9088bc0 Mon Sep 17 00:00:00 2001 From: Sascha Bischoff Date: Wed, 3 Aug 2022 17:10:29 +0100 Subject: [PATCH] arch-arm, cpu: Implement instructions added by FEAT_SME We add the full set of instructions added by Arm's FEAT_SME, with the exception of BMOPA/BMOPS which are BrainFloat16-based outer product instructions. These have been omitted due to the lack of support for BF16 in fplib - the software FP library used for the Arm ISA implementation. The SMEv1 specification can be found at the following location: https://developer.arm.com/documentation/ddi0616/latest Jira Issue: https://gem5.atlassian.net/browse/GEM5-1289 Change-Id: I4882ab452bfc48770419860f89f1f60c7af8aceb Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/64339 Reviewed-by: Giacomo Travaglini Tested-by: kokoro Maintainer: Giacomo Travaglini --- src/arch/arm/SConscript | 1 + src/arch/arm/insts/sme.cc | 183 +++++ src/arch/arm/insts/sme.hh | 229 ++++++ src/arch/arm/insts/sve.cc | 32 + src/arch/arm/insts/sve.hh | 41 + src/arch/arm/isa/formats/aarch64.isa | 37 +- src/arch/arm/isa/formats/formats.isa | 3 + src/arch/arm/isa/formats/sme.isa | 738 ++++++++++++++++++ src/arch/arm/isa/formats/sve_2nd_level.isa | 135 +++- src/arch/arm/isa/formats/sve_top_level.isa | 9 + src/arch/arm/isa/includes.isa | 1 + src/arch/arm/isa/insts/insts.isa | 3 + src/arch/arm/isa/insts/sme.isa | 821 +++++++++++++++++++++ src/arch/arm/isa/insts/sve.isa | 63 ++ src/arch/arm/isa/operands.isa | 5 + src/arch/arm/isa/templates/sme.isa | 773 +++++++++++++++++++ src/arch/arm/isa/templates/sve.isa | 53 ++ src/arch/arm/isa/templates/templates.isa | 3 + 18 files changed, 3103 insertions(+), 27 deletions(-) create mode 100644 src/arch/arm/insts/sme.cc create mode 100644 src/arch/arm/insts/sme.hh create mode 100644 src/arch/arm/isa/formats/sme.isa create mode 100644 src/arch/arm/isa/insts/sme.isa create mode 100644 src/arch/arm/isa/templates/sme.isa diff --git a/src/arch/arm/SConscript b/src/arch/arm/SConscript index 935f082c11..ee5efebf13 100644 --- a/src/arch/arm/SConscript +++ b/src/arch/arm/SConscript @@ -68,6 +68,7 @@ Source('insts/misc.cc', tags='arm isa') Source('insts/misc64.cc', tags='arm isa') Source('insts/pred_inst.cc', tags='arm isa') Source('insts/pseudo.cc', tags='arm isa') +Source('insts/sme.cc', tags='arm isa') Source('insts/static_inst.cc', tags='arm isa') Source('insts/sve.cc', tags='arm isa') Source('insts/sve_mem.cc', tags='arm isa') diff --git a/src/arch/arm/insts/sme.cc b/src/arch/arm/insts/sme.cc new file mode 100644 index 0000000000..305d332514 --- /dev/null +++ b/src/arch/arm/insts/sme.cc @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2022 ARM Limited + * All rights reserved + * + * The license below extends only to copyright in the software and shall + * not be construed as granting a license to any other intellectual + * property including but not limited to intellectual property relating + * to a hardware implementation of the functionality of the software + * licensed hereunder. You may use the software subject to the license + * terms below provided that you ensure that this notice is replicated + * unmodified and in its entirety in all distributions of the software, + * modified or unmodified, in source code or in binary form. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include "arch/arm/insts/sme.hh" + +namespace gem5 +{ + +namespace ArmISA +{ + +std::string +SmeAddOp::generateDisassembly(Addr pc, + const Loader::SymbolTable *symtab) const +{ + std::stringstream ss; + printMnemonic(ss, "", false); + ccprintf(ss, "#%d", imm); + ss << ", "; + printVecReg(ss, op1, true); + ss << ", "; + printVecPredReg(ss, gp1); + ss << ", "; + printVecPredReg(ss, gp2); + return ss.str(); +} + +std::string +SmeAddVlOp::generateDisassembly(Addr pc, + const Loader::SymbolTable *symtab) const +{ + std::stringstream ss; + printMnemonic(ss, "", false); + ss << ", "; + printVecReg(ss, dest); + ss << ", "; + printVecReg(ss, op1); + ss << ", "; + ccprintf(ss, "#%d", imm); + return ss.str(); +} + +std::string +SmeLd1xSt1xOp::generateDisassembly(Addr pc, + const Loader::SymbolTable *symtab) const +{ + std::stringstream ss; + printMnemonic(ss, "", false); + ccprintf(ss, "#%d", imm); + ss << ", "; + printIntReg(ss, op1); + ss << ", "; + printVecPredReg(ss, gp); + ss << ", "; + printIntReg(ss, op2); + ss << ", "; + printIntReg(ss, op3); + return ss.str(); +} + +std::string +SmeLdrStrOp::generateDisassembly(Addr pc, + const Loader::SymbolTable *symtab) const +{ + std::stringstream ss; + printMnemonic(ss, "", false); + ccprintf(ss, "#%d", imm); + ss << ", "; + printIntReg(ss, op1, true); + ss << ", "; + printIntReg(ss, op2, true); + return ss.str(); +} + +std::string +SmeMovExtractOp::generateDisassembly(Addr pc, + const Loader::SymbolTable *symtab) const +{ + std::stringstream ss; + printMnemonic(ss, "", false); + printVecReg(ss, op1, true); + ss << ", "; + ccprintf(ss, "#%d", imm); + ss << ", "; + printVecPredReg(ss, gp); + ss << ", "; + printIntReg(ss, op2); + return ss.str(); +} + +std::string +SmeMovInsertOp::generateDisassembly(Addr pc, + const Loader::SymbolTable *symtab) const +{ + std::stringstream ss; + printMnemonic(ss, "", false); + ccprintf(ss, "#%d", imm); + ss << ", "; + printVecReg(ss, op1, true); + ss << ", "; + printVecPredReg(ss, gp); + ss << ", "; + printIntReg(ss, op2); + return ss.str(); +} + +std::string +SmeOPOp::generateDisassembly(Addr pc, + const Loader::SymbolTable *symtab) const +{ + std::stringstream ss; + printMnemonic(ss, "", false); + ccprintf(ss, "#%d", imm); + ss << ", "; + printVecPredReg(ss, gp1); + ss << ", "; + printVecPredReg(ss, gp2); + ss << ", "; + printVecReg(ss, op1, true); + ss << ", "; + printVecReg(ss, op2, true); + return ss.str(); +} + +std::string +SmeRdsvlOp::generateDisassembly(Addr pc, + const Loader::SymbolTable *symtab) const +{ + std::stringstream ss; + printMnemonic(ss, "", false); + ss << ", "; + printVecReg(ss, dest); + ss << ", "; + ccprintf(ss, "#%d", imm); + return ss.str(); +} + +std::string +SmeZeroOp::generateDisassembly(Addr pc, + const Loader::SymbolTable *symtab) const +{ + std::stringstream ss; + ArmStaticInst::printMnemonic(ss, "", false); + ccprintf(ss, "#%d", imm); + return ss.str(); +} + +} // namespace ArmISA +} // namespace gem5 diff --git a/src/arch/arm/insts/sme.hh b/src/arch/arm/insts/sme.hh new file mode 100644 index 0000000000..d6cbdde5a7 --- /dev/null +++ b/src/arch/arm/insts/sme.hh @@ -0,0 +1,229 @@ +/* + * Copyright (c) 2022 ARM Limited + * All rights reserved + * + * The license below extends only to copyright in the software and shall + * not be construed as granting a license to any other intellectual + * property including but not limited to intellectual property relating + * to a hardware implementation of the functionality of the software + * licensed hereunder. You may use the software subject to the license + * terms below provided that you ensure that this notice is replicated + * unmodified and in its entirety in all distributions of the software, + * modified or unmodified, in source code or in binary form. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef __ARCH_ARM_INSTS_SME_HH__ +#define __ARCH_ARM_INSTS_SME_HH__ + +#include "arch/arm/insts/static_inst.hh" + +namespace gem5 +{ + +namespace ArmISA +{ + +// Used for SME ADDHA/ADDVA +class SmeAddOp : public ArmStaticInst +{ + protected: + uint64_t imm; + RegIndex op1; + RegIndex gp1; + RegIndex gp2; + + SmeAddOp(const char *mnem, ExtMachInst _machInst, + OpClass __opClass, uint64_t _imm, RegIndex _op1, + RegIndex _gp1, RegIndex _gp2) : + ArmStaticInst(mnem, _machInst, __opClass), + imm(_imm), op1(_op1), gp1(_gp1), gp2(_gp2) + {} + + std::string generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const override; +}; + +// Used for the SME ADDSPL/ADDSVL instructions +class SmeAddVlOp : public ArmStaticInst +{ + protected: + RegIndex dest; + RegIndex op1; + int8_t imm; + + SmeAddVlOp(const char *mnem, ExtMachInst _machInst, + OpClass __opClass, RegIndex _dest, RegIndex _op1, + int8_t _imm) : + ArmStaticInst(mnem, _machInst, __opClass), + dest(_dest), op1(_op1), imm(_imm) + {} + + std::string generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const override; +}; + +// Used for SME LD1x/ST1x instrucions +class SmeLd1xSt1xOp : public ArmStaticInst +{ + protected: + uint64_t imm; + RegIndex op1; + RegIndex gp; + RegIndex op2; + RegIndex op3; + bool V; + + SmeLd1xSt1xOp(const char *mnem, ExtMachInst _machInst, + OpClass __opClass, uint64_t _imm, RegIndex _op1, + RegIndex _gp, RegIndex _op2, + RegIndex _op3, bool _V) : + ArmStaticInst(mnem, _machInst, __opClass), + imm(_imm), op1(_op1), gp(_gp), op2(_op2), op3(_op3), V(_V) + {} + + std::string generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const override; +}; + +// Used for SME LDR/STR instructions +class SmeLdrStrOp : public ArmStaticInst +{ + protected: + uint64_t imm; + RegIndex op1; + RegIndex op2; + + SmeLdrStrOp(const char *mnem, ExtMachInst _machInst, + OpClass __opClass, uint64_t _imm, RegIndex _op1, + RegIndex _op2) : + ArmStaticInst(mnem, _machInst, __opClass), + imm(_imm), op1(_op1), op2(_op2) + {} + + std::string generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const override; +}; + +// Used for SME MOVA (Tile to Vector) +class SmeMovExtractOp : public ArmStaticInst +{ + protected: + RegIndex op1; + uint8_t imm; + RegIndex gp; + RegIndex op2; + bool v; + + SmeMovExtractOp(const char *mnem, ExtMachInst _machInst, + OpClass __opClass, RegIndex _op1, uint8_t _imm, + RegIndex _gp, RegIndex _op2, bool _v) : + ArmStaticInst(mnem, _machInst, __opClass), + op1(_op1), imm(_imm), gp(_gp), op2(_op2), v(_v) + {} + + std::string generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const override; +}; + +// Used for SME MOVA (Vector to Tile) +class SmeMovInsertOp : public ArmStaticInst +{ + protected: + uint8_t imm; + RegIndex op1; + RegIndex gp; + RegIndex op2; + bool v; + + SmeMovInsertOp(const char *mnem, ExtMachInst _machInst, + OpClass __opClass, uint8_t _imm, RegIndex _op1, + RegIndex _gp, RegIndex _op2, bool _v) : + ArmStaticInst(mnem, _machInst, __opClass), + imm(_imm), op1(_op1), gp(_gp), op2(_op2), v(_v) + {} + + std::string generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const override; +}; + +// Used for SME output product instructions +class SmeOPOp : public ArmStaticInst +{ + protected: + uint64_t imm; + RegIndex op1; + RegIndex gp1; + RegIndex gp2; + RegIndex op2; + + SmeOPOp(const char *mnem, ExtMachInst _machInst, OpClass __opClass, + uint64_t _imm, RegIndex _op1, RegIndex _gp1, + RegIndex _gp2, RegIndex _op2) : + ArmStaticInst(mnem, _machInst, __opClass), + imm(_imm), op1(_op1), gp1(_gp1), gp2(_gp2), op2(_op2) + {} + + std::string generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const override; +}; + +// Used for the SME RDSVL instruction +class SmeRdsvlOp : public ArmStaticInst +{ + protected: + RegIndex dest; + int8_t imm; + + SmeRdsvlOp(const char *mnem, ExtMachInst _machInst, + OpClass __opClass, RegIndex _dest, int8_t _imm) : + ArmStaticInst(mnem, _machInst, __opClass), + dest(_dest), imm(_imm) + {} + + std::string generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const override; +}; + +// Used for SME ZERO +class SmeZeroOp : public ArmStaticInst +{ + protected: + uint8_t imm; + + SmeZeroOp(const char *mnem, ExtMachInst _machInst, + OpClass __opClass, uint8_t _imm) : + ArmStaticInst(mnem, _machInst, __opClass), + imm(_imm) + {} + + std::string generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const override; +}; + +} // namespace ArmISA +} // namespace gem5 + +#endif // __ARCH_ARM_INSTS_SME_HH__ diff --git a/src/arch/arm/insts/sve.cc b/src/arch/arm/insts/sve.cc index 9a525b195d..9d9c2bcb1c 100644 --- a/src/arch/arm/insts/sve.cc +++ b/src/arch/arm/insts/sve.cc @@ -161,6 +161,24 @@ SveWhileOp::generateDisassembly( return ss.str(); } +std::string +SvePselOp::generateDisassembly(Addr pc, + const Loader::SymbolTable *symtab) const +{ + std::stringstream ss; + printMnemonic(ss, "", false); + printVecPredReg(ss, dest); + ss << ", "; + printVecPredReg(ss, op1); + ss << ", "; + printVecPredReg(ss, gp); + ss << ", "; + printIntReg(ss, op2); + ss << ", "; + ccprintf(ss, "#%d", imm); + return ss.str(); +} + std::string SveCompTermOp::generateDisassembly( Addr pc, const loader::SymbolTable *symtab) const @@ -831,6 +849,20 @@ SveComplexIdxOp::generateDisassembly( return ss.str(); } +std::string +SveClampOp::generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const +{ + std::stringstream ss; + printMnemonic(ss, "", false); + printVecReg(ss, dest, true); + ss << ", "; + printVecReg(ss, op1, true); + ss << ", "; + printVecReg(ss, op2, true); + return ss.str(); +} + std::string sveDisasmPredCountImm(uint8_t imm) { diff --git a/src/arch/arm/insts/sve.hh b/src/arch/arm/insts/sve.hh index f9939e1f22..63a59d493a 100644 --- a/src/arch/arm/insts/sve.hh +++ b/src/arch/arm/insts/sve.hh @@ -180,6 +180,28 @@ class SveWhileOp : public ArmStaticInst Addr pc, const loader::SymbolTable *symtab) const override; }; +/// Psel predicate selection SVE instruction. +class SvePselOp : public ArmStaticInst +{ + protected: + RegIndex dest; + RegIndex op1; + RegIndex gp; + RegIndex op2; + uint64_t imm; + + SvePselOp(const char *mnem, ExtMachInst _machInst, + OpClass __opClass, RegIndex _dest, + RegIndex _op1, RegIndex _gp, + RegIndex _op2, uint64_t _imm) : + ArmStaticInst(mnem, _machInst, __opClass), + dest(_dest), op1(_op1), gp(_gp), op2(_op2), imm(_imm) + {} + + std::string generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const override; +}; + /// Compare and terminate loop SVE instruction. class SveCompTermOp : public ArmStaticInst { @@ -951,6 +973,25 @@ class SveComplexIdxOp : public ArmStaticInst Addr pc, const loader::SymbolTable *symtab) const override; }; +// SVE2 SCLAMP/UCLAMP instructions +class SveClampOp : public ArmStaticInst +{ + protected: + RegIndex dest; + RegIndex op1; + RegIndex op2; + + SveClampOp(const char *mnem, ExtMachInst _machInst, + OpClass __opClass, RegIndex _dest, + RegIndex _op1, RegIndex _op2) : + ArmStaticInst(mnem, _machInst, __opClass), + dest(_dest), op1(_op1), op2(_op2) + {} + + std::string generateDisassembly( + Addr pc, const Loader::SymbolTable *symtab) const override; +}; + /// Returns the symbolic name associated with pattern `imm` for PTRUE(S) /// instructions. diff --git a/src/arch/arm/isa/formats/aarch64.isa b/src/arch/arm/isa/formats/aarch64.isa index 37eb995bfd..2fd28f8209 100644 --- a/src/arch/arm/isa/formats/aarch64.isa +++ b/src/arch/arm/isa/formats/aarch64.isa @@ -436,6 +436,9 @@ namespace Aarch64 // SP return new MsrImm64( machInst, MISCREG_SPSEL, crm); + case 0x1b: + // SVE SVCR - SMSTART/SMSTOP + return decodeSmeMgmt(machInst); case 0x1e: // DAIFSet return new MsrImmDAIFSet64( @@ -3073,20 +3076,30 @@ def format Aarch64() {{ using namespace Aarch64; if (bits(machInst, 27) == 0x0) { if (bits(machInst, 28) == 0x0) { - if (bits(machInst, 26, 25) != 0x2) { - return new Unknown64(machInst); - } - if (bits(machInst, 31) == 0x0) { - switch (bits(machInst, 30, 29)) { - case 0x0: - case 0x1: - case 0x2: - return decodeSveInt(machInst); - case 0x3: - return decodeSveFp(machInst); + if (bits(machInst, 26) == 0x1) { + if (bits(machInst, 31) == 0x0) { + if (bits(machInst, 25) == 0x1) { + return new Unknown64(machInst); + } + switch (bits(machInst, 30, 29)) { + case 0x0: + case 0x1: + case 0x2: + return decodeSveInt(machInst); + case 0x3: + return decodeSveFp(machInst); + } + } else { + return decodeSveMem(machInst); } } else { - return decodeSveMem(machInst); + if ((bits(machInst, 25) == 0x0) && \ + (bits(machInst, 31) == 0x1)) { + // bit 31:25=1xx0000 + return decodeSmeInst(machInst); + } else { + return new Unknown64(machInst); + } } } else if (bits(machInst, 26) == 0) // bit 28:26=100 diff --git a/src/arch/arm/isa/formats/formats.isa b/src/arch/arm/isa/formats/formats.isa index 5ef65966af..0a1f8f8ce2 100644 --- a/src/arch/arm/isa/formats/formats.isa +++ b/src/arch/arm/isa/formats/formats.isa @@ -52,6 +52,9 @@ ##include "sve_top_level.isa" ##include "sve_2nd_level.isa" +//Include support for decoding SME instructions (AArch64-only) +##include "sme.isa" + //Include support for predicated instructions ##include "pred.isa" diff --git a/src/arch/arm/isa/formats/sme.isa b/src/arch/arm/isa/formats/sme.isa new file mode 100644 index 0000000000..ac75d151ce --- /dev/null +++ b/src/arch/arm/isa/formats/sme.isa @@ -0,0 +1,738 @@ +// Copyright (c) 2022 ARM Limited +// All rights reserved +// +// The license below extends only to copyright in the software and shall +// not be construed as granting a license to any other intellectual +// property including but not limited to intellectual property relating +// to a hardware implementation of the functionality of the software +// licensed hereunder. You may use the software subject to the license +// terms below provided that you ensure that this notice is replicated +// unmodified and in its entirety in all distributions of the software, +// modified or unmodified, in source code or in binary form. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer; +// redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution; +// neither the name of the copyright holders nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +output header {{ +namespace Aarch64 +{ + StaticInstPtr decodeSmeMgmt(ExtMachInst); + StaticInstPtr decodeSmeInst(ExtMachInst); + + StaticInstPtr decodeSmeOp32(ExtMachInst); + StaticInstPtr decodeSmeOpFp32(ExtMachInst); + StaticInstPtr decodeSmeOpBf16(ExtMachInst); + StaticInstPtr decodeSmeOpFp16(ExtMachInst); + StaticInstPtr decodeSmeOpInt8(ExtMachInst); + + StaticInstPtr decodeSmeOp64(ExtMachInst); + StaticInstPtr decodeSmeOpFp64(ExtMachInst); + StaticInstPtr decodeSmeOpInt16(ExtMachInst); + + StaticInstPtr decodeSmeMovaInsert(ExtMachInst); + StaticInstPtr decodeSmeMovaExtract(ExtMachInst); + + StaticInstPtr decodeSmeMisc(ExtMachInst); + StaticInstPtr decodeSmeZero(ExtMachInst); + + StaticInstPtr decodeSmeAddArray(ExtMachInst); + StaticInstPtr decodeSmeAddhv(ExtMachInst); + + StaticInstPtr decodeSmeMemory(ExtMachInst); + StaticInstPtr decodeSmeLoad(ExtMachInst); + StaticInstPtr decodeSmeStore(ExtMachInst); + StaticInstPtr decodeSmeLoadStoreArray(ExtMachInst); + StaticInstPtr decodeSmeLoadQuadWord(ExtMachInst); + StaticInstPtr decodeSmeStoreQuadWord(ExtMachInst); +} +}}; + +output decoder {{ +namespace Aarch64 +{ + // NOTE: This is called from a different decode tree (aarch64.isa). + // For neatness and clarity we keep the code here order to keep all + // SME things together. + StaticInstPtr + decodeSmeMgmt(ExtMachInst machInst) + { + const uint8_t imm = (uint8_t)bits(machInst, 10, 8); + + if (bits(machInst, 8)) { + return new SmeSmstart(machInst, imm); + } else { + return new SmeSmstop(machInst, imm); + } + } + + StaticInstPtr + decodeSmeInst(ExtMachInst machInst) + { + // Starting point for decoding: bits 31:25=1xx0000 + + const uint8_t op0 = (uint8_t)bits(machInst, 30, 29); + const uint8_t op1 = (uint8_t)bits(machInst, 24, 19); + const uint8_t op2 = (uint8_t)bits(machInst, 17); + const uint8_t op3 = (uint8_t)bits(machInst, 4, 2); + + if ((op0 & 0b10) == 0b00) { + if ((op1 & 0b011000) == 0b010000) { + if ((op3 & 0b001) == 0b000) { + return decodeSmeOp32(machInst); + } + } + + if ((op1 & 0b011000) == 0b011000) { + if ((op3 & 0b010) == 0b000) { + return decodeSmeOp64(machInst); + } + } + } + + if (op0 == 0b10) { + if ((op1 & 0b100111) == 0b000000) { + if (op2 == 0b0) { + if ((op3 & 0b100) == 0b000) { + return decodeSmeMovaInsert(machInst); + } + } + + if (op2 ==0b1) { + return decodeSmeMovaExtract(machInst); + } + } + + if ((op1 & 0b100111) == 0b000001) { + return decodeSmeMisc(machInst); + } + + if ((op1 & 0b100111) == 0b000010) { + if ((op3 & 0b010) == 0b000) { + return decodeSmeAddArray(machInst); + } + } + } + + if (op0 == 0b11) { + return decodeSmeMemory(machInst); + } + + // We should not get here + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeOp32(ExtMachInst machInst) + { + const uint8_t op0 = (uint8_t)bits(machInst, 29); + const uint8_t op1 = (uint8_t)bits(machInst, 24); + const uint8_t op2 = (uint8_t)bits(machInst, 21); + const uint8_t op3 = (uint8_t)bits(machInst, 3); + + if (op0 == 0) { + if (op1 == 0) { + if (op2 == 0) { + if (op3 == 0) { + return decodeSmeOpFp32(machInst); + } + } + } + + if (op1 == 1) { + if (op2 == 0) { + if (op3 == 0) { + return decodeSmeOpBf16(machInst); + } + } + + if (op2 == 1) { + if (op3 == 0) { + return decodeSmeOpFp16(machInst); + } + } + } + } + + if (op0 == 1) { + if (op3 == 0) { + return decodeSmeOpInt8(machInst); + } + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeOpFp32(ExtMachInst machInst) + { + const uint32_t S = (uint32_t)bits(machInst, 4, 4); + + const RegIndex Zm = (RegIndex)(uint32_t)(bits(machInst, 20, 16)); + const RegIndex Zn = (RegIndex)(uint32_t)(bits(machInst, 9, 5)); + const RegIndex Pn = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + const RegIndex Pm = (RegIndex)(uint32_t)(bits(machInst, 15, 13)); + const RegIndex ZAda = (RegIndex)(uint32_t)(bits(machInst, 1, 0)); + + if (S == 0) { + return new SmeFmopa(machInst, ZAda, Zn, + Pn, Pm, Zm); + } else { + return new SmeFmops(machInst, ZAda, Zn, + Pn, Pm, Zm); + } + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeOpBf16(ExtMachInst machInst) + { + // The following code is functionally correct for decode, but + // remains commented out as the current gem5 fplib implementation + // doesn't support BF16, and hence the instructions themselves + // remain unimplemented. Once these have been implemented, this code + // can be safely uncommented to enable decode for the two BF16 Outer + // Product instructions added by FEAT_SME. + + // const uint32_t S = (uint32_t)bits(machInst, 4, 4); + + // const RegIndex Zm = (RegIndex)(uint32_t)( + // bits(machInst, 20, 16)); + // const RegIndex Zn = (RegIndex)(uint32_t)( + // bits(machInst, 9, 5)); + // const RegIndex Pn = (RegIndex)(uint32_t)( + // bits(machInst, 12, 10)); + // const RegIndex Pm = (RegIndex)(uint32_t)( + // bits(machInst, 15, 13)); + // const RegIndex ZAda = (RegIndex)(uint32_t)( + // bits(machInst, 1, 0)); + + // if (S == 0) { + // return new SmeBmopa(machInst); + // } else { + // return new SmeBmops(machInst); + // } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeOpFp16(ExtMachInst machInst) + { + const uint32_t S = (uint32_t)bits(machInst, 4, 4); + + const RegIndex Zm = (RegIndex)(uint32_t)(bits(machInst, 20, 16)); + const RegIndex Zn = (RegIndex)(uint32_t)(bits(machInst, 9, 5)); + const RegIndex Pn = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + const RegIndex Pm = (RegIndex)(uint32_t)(bits(machInst, 15, 13)); + const RegIndex ZAda = (RegIndex)(uint32_t)(bits(machInst, 1, 0)); + + if (S == 0) { + return new SmeFmopaWidening(machInst, ZAda, Zn, + Pn, Pm, Zm); + } else { + return new SmeFmopsWidening(machInst, ZAda, Zn, + Pn, Pm, Zm); + } + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeOpInt8(ExtMachInst machInst) + { + const uint32_t u0 = (uint32_t)bits(machInst, 24); + const uint32_t u1 = (uint32_t)bits(machInst, 21); + const uint32_t S = (uint32_t)bits(machInst, 4); + + const RegIndex Zm = (RegIndex)(uint32_t)(bits(machInst, 20, 16)); + const RegIndex Zn = (RegIndex)(uint32_t)(bits(machInst, 9, 5)); + const RegIndex Pn = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + const RegIndex Pm = (RegIndex)(uint32_t)(bits(machInst, 15, 13)); + const RegIndex ZAda = (RegIndex)(uint32_t)(bits(machInst, 1, 0)); + + if (u0 == 0) { + if (u1 == 0) { + if (S == 0) { + return new SmeSmopa( + machInst, ZAda, Zn, Pn, Pm, Zm); + } else { + return new SmeSmops( + machInst, ZAda, Zn, Pn, Pm, Zm); + } + } else { + if (S == 0) { + return new SmeSumopa( + machInst, ZAda, Zn, Pn, Pm, Zm); + } else { + return new SmeSumops( + machInst, ZAda, Zn, Pn, Pm, Zm); + } + } + } else { + if (u1 == 0) { + if (S == 0) { + return new SmeUsmopa( + machInst, ZAda, Zn, Pn, Pm, Zm); + } else { + return new SmeUsmops( + machInst, ZAda, Zn, Pn, Pm, Zm); + } + } else { + if (S == 0) { + return new SmeUmopa( + machInst, ZAda, Zn, Pn, Pm, Zm); + } else { + return new SmeUmops( + machInst, ZAda, Zn, Pn, Pm, Zm); + } + } + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeOp64(ExtMachInst machInst) + { + const uint8_t op0 = (uint8_t)bits(machInst, 29); + const uint8_t op1 = (uint8_t)bits(machInst, 24); + const uint8_t op2 = (uint8_t)bits(machInst, 21); + + if (op0 == 0) { + if (op1 == 0) { + if (op2 == 0) { + return decodeSmeOpFp64(machInst); + } + } + } + + if (op0 == 1) { + return decodeSmeOpInt16(machInst); + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeOpFp64(ExtMachInst machInst) + { + const uint32_t S = (uint32_t)bits(machInst, 4, 4); + + const RegIndex Zm = (RegIndex)(uint32_t)(bits(machInst, 20, 16)); + const RegIndex Zn = (RegIndex)(uint32_t)(bits(machInst, 9, 5)); + const RegIndex Pn = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + const RegIndex Pm = (RegIndex)(uint32_t)(bits(machInst, 15, 13)); + const RegIndex ZAda = (RegIndex)(uint32_t)(bits(machInst, 2, 0)); + + if (S == 0) { + return new SmeFmopa(machInst, ZAda, Zn, + Pn, Pm, Zm); + } else { + return new SmeFmops(machInst, ZAda, Zn, + Pn, Pm, Zm); + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeOpInt16(ExtMachInst machInst) + { + const uint32_t u0 = (uint32_t)bits(machInst, 24); + const uint32_t u1 = (uint32_t)bits(machInst, 21); + const uint32_t S = (uint32_t)bits(machInst, 4); + + const RegIndex Zm = (RegIndex)(uint32_t)(bits(machInst, 20, 16)); + const RegIndex Zn = (RegIndex)(uint32_t)(bits(machInst, 9, 5)); + const RegIndex Pn = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + const RegIndex Pm = (RegIndex)(uint32_t)(bits(machInst, 15, 13)); + const RegIndex ZAda = (RegIndex)(uint32_t)(bits(machInst, 2, 0)); + + if (u0 == 0) { + if (u1 == 0) { + if (S == 0) { + return new SmeSmopa( + machInst, ZAda, Zn, Pn, Pm, Zm); + } else { + return new SmeSmops( + machInst, ZAda, Zn, Pn, Pm, Zm); + } + } else { + if (S == 0) { + return new SmeSumopa( + machInst, ZAda, Zn, Pn, Pm, Zm); + } else { + return new SmeSumops( + machInst, ZAda, Zn, Pn, Pm, Zm); + } + } + } else { + if (u1 == 0) { + if (S == 0) { + return new SmeUsmopa( + machInst, ZAda, Zn, Pn, Pm, Zm); + } else { + return new SmeUsmops( + machInst, ZAda, Zn, Pn, Pm, Zm); + } + } else { + if (S == 0) { + return new SmeUmopa( + machInst, ZAda, Zn, Pn, Pm, Zm); + } else { + return new SmeUmops( + machInst, ZAda, Zn, Pn, Pm, Zm); + } + } + } + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeMovaInsert(ExtMachInst machInst) + { + const uint8_t op0 = (uint8_t)bits(machInst, 18); + + if (op0 == 1) { + return new Unknown64(machInst); + } + + const uint32_t size = (uint32_t)bits(machInst, 23, 22); + const uint32_t Q = (uint32_t)bits(machInst, 16, 16); + + const RegIndex Zn = (RegIndex)(uint32_t)(bits(machInst, 9, 5)); + const RegIndex Ws = (RegIndex)(uint32_t)( + bits(machInst, 14, 13) + 12); + const RegIndex Pg = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + const RegIndex ZAd_imm = (RegIndex)(uint32_t)( + bits(machInst, 3, 0)); + const bool V = (bool)bits(machInst, 15); + + if (Q == 0) { + switch (size) { + case 0b00: + return new SmeMovaInsert(machInst, ZAd_imm, + Zn, Pg, Ws, V); + case 0b01: + return new SmeMovaInsert(machInst, ZAd_imm, + Zn, Pg, Ws, V); + case 0b10: + return new SmeMovaInsert(machInst, ZAd_imm, + Zn, Pg, Ws, V); + case 0b11: + return new SmeMovaInsert(machInst, ZAd_imm, + Zn, Pg, Ws, V); + default: + break; + } + } + + if ((Q == 1) && (size == 0b11)) { + return new SmeMovaInsert<__uint128_t>(machInst, ZAd_imm, + Zn, Pg, Ws, V); + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeMovaExtract(ExtMachInst machInst) + { + const uint8_t op0 = (uint8_t)bits(machInst, 18); + const uint8_t op1 = (uint8_t)bits(machInst, 9); + + if ((op0 == 1) || (op1 == 1)) { + return new Unknown64(machInst); + } + + const uint32_t size = (uint32_t)bits(machInst, 23, 22); + const uint32_t Q = (uint32_t)bits(machInst, 16, 16); + + const RegIndex Zd = (RegIndex)(uint32_t)(bits(machInst, 4, 0)); + const RegIndex Ws = (RegIndex)(uint32_t)( + bits(machInst, 14, 13) + 12); + const RegIndex Pg = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + const RegIndex ZAn_imm = (RegIndex)(uint32_t)( + bits(machInst, 8, 5)); + const bool V = (bool)bits(machInst, 15); + + if (Q == 0) { + switch (size) { + case 0b00: + return new SmeMovaExtract(machInst, Zd, + ZAn_imm, Pg, Ws, V); + case 0b01: + return new SmeMovaExtract(machInst, Zd, + ZAn_imm, Pg, Ws, V); + case 0b10: + return new SmeMovaExtract(machInst, Zd, + ZAn_imm, Pg, Ws, V); + case 0b11: + return new SmeMovaExtract(machInst, Zd, + ZAn_imm, Pg, Ws, V); + default: + break; + } + } + + if ((Q == 1) && (size == 0b11)) { + return new SmeMovaExtract<__uint128_t>(machInst, Zd, + ZAn_imm, Pg, Ws, V); + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeMisc(ExtMachInst machInst) + { + const uint32_t op0 = (uint32_t)bits(machInst, 23, 22); + const uint32_t op1 = (uint32_t)bits(machInst, 18, 8); + + if (op0 == 0b00) { + if (op1 == 0b00000000000) { + return decodeSmeZero(machInst); + } + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeZero(ExtMachInst machInst) + { + const uint8_t imm8 = (uint8_t)bits(machInst, 7, 0); + + return new SmeZero(machInst, imm8); + } + + StaticInstPtr + decodeSmeAddArray(ExtMachInst machInst) + { + const uint32_t op0 = (uint32_t)bits(machInst, 23); + const uint32_t op1 = (uint32_t)bits(machInst, 18, 17); + const uint32_t op2 = (uint32_t)bits(machInst, 4); + + if (op0 == 1) { + if (op1 == 0b00) { + if (op2 == 0) { + return decodeSmeAddhv(machInst); + } + } + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeAddhv(ExtMachInst machInst) + { + const uint32_t V = (uint32_t)bits(machInst, 16, 16); + const uint32_t op = (uint32_t)bits(machInst, 22, 22); + const uint32_t op2 = (uint32_t)bits(machInst, 2, 0); + + const RegIndex Zn = (RegIndex)(uint32_t)(bits(machInst, 9, 5)); + const RegIndex Pn = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + const RegIndex Pm = (RegIndex)(uint32_t)(bits(machInst, 15, 13)); + const RegIndex ZAda = (RegIndex)(uint32_t)(bits(machInst, 2, 0)); + + if (op == 0) { // 32-bit + if (V == 0) { + if ((op2 & 0b100) == 0b000) { + return new SmeAddha(machInst, ZAda, Zn, Pn, Pm); + } + } else { + if ((op2 & 0b100) == 0b000) { + return new SmeAddva(machInst, ZAda, Zn, Pn, Pm); + } + } + } else { + if (V == 0) { + return new SmeAddha(machInst, ZAda, Zn, Pn, Pm); + } else { + return new SmeAddva(machInst, ZAda, Zn, Pn, Pm); + } + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeMemory(ExtMachInst machInst) + { + const uint8_t op0 = (uint8_t)bits(machInst, 24, 21); + const uint8_t op1 = (uint8_t)bits(machInst, 20, 15); + const uint8_t op2 = (uint8_t)bits(machInst, 12, 10); + const uint8_t op3 = (uint8_t)bits(machInst, 4); + + if ((op0 & 0b1001) == 0b0000) { + if (op3 == 0b0) { + return decodeSmeLoad(machInst); + } + } + + if ((op0 & 0b1001) == 0b0001) { + if (op3 == 0b0) { + return decodeSmeStore(machInst); + } + } + + if ((op0 & 0b1110) == 0b1000) { + if (op1 == 0b000000) { + if (op2 == 0b000) { + if (op3 == 0b0) { + return decodeSmeLoadStoreArray(machInst); + } + } + } + } + + if (op0 == 0b1110) { + if (op3 == 0b0) { + return decodeSmeLoadQuadWord(machInst); + } + } + + if (op0 == 0b1111) { + if (op3 == 0b0) { + return decodeSmeStoreQuadWord(machInst); + } + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeLoad(ExtMachInst machInst) + { + const uint8_t msz = (uint8_t)bits(machInst, 23, 22); + const bool V = (bool)bits(machInst, 15); + + const RegIndex Rn = makeSP( + (RegIndex)(uint32_t)bits(machInst, 9, 5)); + const RegIndex Rm = (RegIndex)(uint32_t)(bits(machInst, 20, 16)); + const RegIndex Rs = (RegIndex)(uint32_t)( + bits(machInst, 14, 13) + 12); + const uint32_t ZAt_imm = (uint32_t)bits(machInst, 3, 0); + const RegIndex Pg = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + + switch(msz) + { + case 0b00: + return new SmeLd1b(machInst, ZAt_imm, Rn, Pg, Rs, Rm, V); + case 0b01: + return new SmeLd1h(machInst, ZAt_imm, Rn, Pg, Rs, Rm, V); + case 0b10: + return new SmeLd1w(machInst, ZAt_imm, Rn, Pg, Rs, Rm, V); + case 0b11: + return new SmeLd1d(machInst, ZAt_imm, Rn, Pg, Rs, Rm, V); + default: + break; + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeStore(ExtMachInst machInst) + { + const uint8_t msz = (uint8_t)bits(machInst, 23, 22); + const bool V = (bool)bits(machInst, 15); + + const RegIndex Rn = makeSP( + (RegIndex)(uint32_t)bits(machInst, 9, 5)); + const RegIndex Rm = (RegIndex)(uint32_t)(bits(machInst, 20, 16)); + const RegIndex Rs = (RegIndex)(uint32_t)( + bits(machInst, 14, 13) + 12); + const uint32_t ZAt_imm = (uint32_t)bits(machInst, 3, 0); + const RegIndex Pg = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + + switch(msz) + { + case 0b00: + return new SmeSt1b(machInst, ZAt_imm, Rn, Pg, Rs, Rm, V); + case 0b01: + return new SmeSt1h(machInst, ZAt_imm, Rn, Pg, Rs, Rm, V); + case 0b10: + return new SmeSt1w(machInst, ZAt_imm, Rn, Pg, Rs, Rm, V); + case 0b11: + return new SmeSt1d(machInst, ZAt_imm, Rn, Pg, Rs, Rm, V); + default: + break; + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeLoadStoreArray(ExtMachInst machInst) + { + const uint8_t op = (uint8_t)bits(machInst, 21); + + const RegIndex Rn = makeSP( + (RegIndex)(uint32_t)bits(machInst, 9, 5)); + const RegIndex Rv = (RegIndex)(uint32_t)( + bits(machInst, 14, 13) + 12); + const uint32_t imm4 = (uint32_t)bits(machInst, 3, 0); + + if (op == 0) { + return new SmeLdr(machInst, imm4, Rn, Rv); + } else { + return new SmeStr(machInst, imm4, Rn, Rv); + } + + return new Unknown64(machInst); + } + + StaticInstPtr + decodeSmeLoadQuadWord(ExtMachInst machInst) + { + const bool V = (bool)bits(machInst, 15); + + const RegIndex Rn = makeSP( + (RegIndex)(uint32_t)bits(machInst, 9, 5)); + const RegIndex Rm = (RegIndex)(uint32_t)(bits(machInst, 20, 16)); + const RegIndex Rs = (RegIndex)(uint32_t)( + bits(machInst, 14, 13) + 12); + const uint32_t ZAt = (uint32_t)bits(machInst, 3, 0); + const RegIndex Pg = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + + return new SmeLd1q<__uint128_t>(machInst, ZAt, Rn, Pg, Rs, Rm, V); + } + + StaticInstPtr + decodeSmeStoreQuadWord(ExtMachInst machInst) + { + const bool V = (bool)bits(machInst, 15); + + const RegIndex Rn = makeSP( + (RegIndex)(uint32_t)bits(machInst, 9, 5)); + const RegIndex Rm = (RegIndex)(uint32_t)(bits(machInst, 20, 16)); + const RegIndex Rs = (RegIndex)(uint32_t)( + bits(machInst, 14, 13) + 12); + const uint32_t ZAt = (uint32_t)bits(machInst, 3, 0); + const RegIndex Pg = (RegIndex)(uint32_t)(bits(machInst, 12, 10)); + + return new SmeSt1q<__uint128_t>(machInst, ZAt, Rn, Pg, Rs, Rm, V); + } +} +}}; diff --git a/src/arch/arm/isa/formats/sve_2nd_level.isa b/src/arch/arm/isa/formats/sve_2nd_level.isa index cbd5466b82..2ee3817445 100644 --- a/src/arch/arm/isa/formats/sve_2nd_level.isa +++ b/src/arch/arm/isa/formats/sve_2nd_level.isa @@ -605,22 +605,43 @@ namespace Aarch64 { uint8_t b23_22 = bits(machInst, 23, 22); uint8_t b11 = bits(machInst, 11); - if ((b23_22 & 0x2) == 0x0 && b11 == 0x0) { - RegIndex rd = makeSP( - (RegIndex) (uint8_t) bits(machInst, 4, 0)); - RegIndex rn = makeSP( - (RegIndex) (uint8_t) bits(machInst, 20, 16)); - uint64_t imm = sext<6>(bits(machInst, 10, 5)); - if ((b23_22 & 0x1) == 0x0) { - return new AddvlXImm(machInst, rd, rn, imm); - } else { - return new AddplXImm(machInst, rd, rn, imm); + if (b11 == 0x0) { + if ((b23_22 & 0x2) == 0x0) { + RegIndex rd = makeSP( + (RegIndex) (uint8_t) bits(machInst, 4, 0)); + RegIndex rn = makeSP( + (RegIndex) (uint8_t) bits(machInst, 20, 16)); + uint64_t imm = sext<6>(bits(machInst, 10, 5)); + if ((b23_22 & 0x1) == 0x0) { + return new AddvlXImm(machInst, rd, rn, imm); + } else { + return new AddplXImm(machInst, rd, rn, imm); + } + } else if (b23_22 == 0x2) { + RegIndex rd = (RegIndex) (uint8_t) bits(machInst, 4, 0); + uint64_t imm = sext<6>(bits(machInst, 10, 5)); + if (bits(machInst, 20, 16) == 0x1f) { + return new SveRdvl(machInst, rd, imm); + } } - } else if (b23_22 == 0x2 && b11 == 0x0) { - RegIndex rd = (RegIndex) (uint8_t) bits(machInst, 4, 0); - uint64_t imm = sext<6>(bits(machInst, 10, 5)); - if (bits(machInst, 20, 16) == 0x1f) { - return new SveRdvl(machInst, rd, imm); + } else { // b11 == 1 + if ((b23_22 & 0x2) == 0x0) { + RegIndex rd = makeSP( + (RegIndex) (uint8_t) bits(machInst, 4, 0)); + RegIndex rn = makeSP( + (RegIndex) (uint8_t) bits(machInst, 20, 16)); + uint64_t imm = sext<6>(bits(machInst, 10, 5)); + if ((b23_22 & 0x1) == 0x0) { + return new SmeAddsvl(machInst, rd, rn, imm); + } else { + return new SmeAddspl(machInst, rd, rn, imm); + } + } else if (b23_22 == 0x2) { + RegIndex rd = (RegIndex) (uint8_t) bits(machInst, 4, 0); + uint64_t imm = sext<6>(bits(machInst, 10, 5)); + if (bits(machInst, 20, 16) == 0x1f) { + return new SmeRdsvl(machInst, rd, imm); + } } } return new Unknown64(machInst); @@ -1201,6 +1222,18 @@ namespace Aarch64 zdn, zm, pg); } break; + case 0xE: + if(!b13) { + unsigned size = (unsigned) bits(machInst, 23, 22); + RegIndex pg = (RegIndex)(uint8_t) bits(machInst, 12, 10); + RegIndex zn = (RegIndex)(uint8_t) bits(machInst, 9, 5); + RegIndex zd = (RegIndex)(uint8_t) bits(machInst, 4, 0); + + if (size == 0b00) { + return new SveRevd<__uint128_t>(machInst, zd, zn, pg); + } + } + break; } switch (bits(machInst, 20, 17)) { case 0x0: @@ -1951,6 +1984,36 @@ namespace Aarch64 return new Unknown64(machInst); } // decodeSveIntCmpSca + StaticInstPtr + decodeSvePsel(ExtMachInst machInst) + { + RegIndex Pd = (RegIndex)(uint8_t)bits(machInst, 3, 0); + RegIndex Pn = (RegIndex)(uint8_t)bits(machInst, 8, 5); + RegIndex Pg = (RegIndex)(uint8_t)bits(machInst, 13, 10); + RegIndex Rm = (RegIndex)(0b01100 + + (uint8_t)bits(machInst, 17, 16)); + uint8_t imm = (uint8_t)bits(machInst, 20, 18); + imm += (uint8_t)bits(machInst, 23, 22) << 3; + + const uint8_t size = imm & 0xF; + + if (size == 0) { + return new Unknown64(machInst); + } + + if (size & 0b0001) { + return new SvePsel(machInst, Pd, Pn, Pg, Rm, imm >> 1); + } else if (size & 0b0010) { + return new SvePsel(machInst, Pd, Pn, Pg, Rm, imm >> 2); + } else if (size & 0b0100) { + return new SvePsel(machInst, Pd, Pn, Pg, Rm, imm >> 3); + } else if (size & 0b1000) { + return new SvePsel(machInst, Pd, Pn, Pg, Rm, imm >> 4); + } + + return new Unknown64(machInst); + } // decodeSvePsel + StaticInstPtr decodeSveIntWideImmUnpred0(ExtMachInst machInst) { @@ -2106,6 +2169,48 @@ namespace Aarch64 return new Unknown64(machInst); } // decodeSveIntWideImmUnpred + StaticInstPtr + decodeSveClamp(ExtMachInst machInst) + { + RegIndex zda = (RegIndex)(uint8_t)bits(machInst, 4, 0); + RegIndex zn = (RegIndex)(uint8_t)bits(machInst, 9, 5); + RegIndex zm = (RegIndex)(uint8_t)bits(machInst, 20, 16); + + switch(bits(machInst, 10)) { + case 0: + switch(bits(machInst, 23, 22)) { + case 0x0: + return new SveSclamp(machInst, zm, zn, zda); + case 0x1: + return new SveSclamp(machInst, zm, zn, zda); + case 0x2: + return new SveSclamp(machInst, zm, zn, zda); + case 0x3: + return new SveSclamp(machInst, zm, zn, zda); + default: + break; + } + break; + case 1: + switch(bits(machInst, 23, 22)) { + case 0x0: + return new SveUclamp(machInst, zm, zn, zda); + case 0x1: + return new SveUclamp(machInst, zm, zn, zda); + case 0x2: + return new SveUclamp(machInst, zm, zn, zda); + case 0x3: + return new SveUclamp(machInst, zm, zn, zda); + default: + break; + } + default: + break; + } + + return new Unknown64(machInst); + } + StaticInstPtr decodeSveMultiplyAddUnpred(ExtMachInst machInst) { diff --git a/src/arch/arm/isa/formats/sve_top_level.isa b/src/arch/arm/isa/formats/sve_top_level.isa index 803029a2a4..155ec1c42f 100644 --- a/src/arch/arm/isa/formats/sve_top_level.isa +++ b/src/arch/arm/isa/formats/sve_top_level.isa @@ -66,7 +66,9 @@ namespace Aarch64 StaticInstPtr decodeSvePredGen(ExtMachInst machInst); StaticInstPtr decodeSvePredCount(ExtMachInst machInst); StaticInstPtr decodeSveIntCmpSca(ExtMachInst machInst); + StaticInstPtr decodeSvePsel(ExtMachInst machInst); StaticInstPtr decodeSveIntWideImmUnpred(ExtMachInst machInst); + StaticInstPtr decodeSveClamp(ExtMachInst machInst); StaticInstPtr decodeSveMultiplyAddUnpred(ExtMachInst machInst); StaticInstPtr decodeSveMultiplyIndexed(ExtMachInst machInst); @@ -107,6 +109,9 @@ namespace Aarch64 case 0x0: { if (bits(machInst, 14)) { + if (bits(machInst, 15, 11) == 0b11000) { + return decodeSveClamp(machInst); + } return decodeSveIntMulAdd(machInst); } else { uint8_t b_15_13 = (bits(machInst, 15) << 1) | @@ -210,10 +215,14 @@ namespace Aarch64 case 0x7: { uint8_t b_15_14 = bits(machInst, 15, 14); + uint8_t b_4 = bits(machInst, 4, 4); switch (b_15_14) { case 0x0: return decodeSveIntCmpSca(machInst); case 0x1: + if (b_4 == 0) { + return decodeSvePsel(machInst); + } return new Unknown64(machInst); case 0x2: return decodeSvePredCount(machInst); diff --git a/src/arch/arm/isa/includes.isa b/src/arch/arm/isa/includes.isa index 386af4e05d..e2534a6728 100644 --- a/src/arch/arm/isa/includes.isa +++ b/src/arch/arm/isa/includes.isa @@ -61,6 +61,7 @@ output header {{ #include "arch/arm/insts/neon64_mem.hh" #include "arch/arm/insts/pred_inst.hh" #include "arch/arm/insts/pseudo.hh" +#include "arch/arm/insts/sme.hh" #include "arch/arm/insts/static_inst.hh" #include "arch/arm/insts/sve.hh" #include "arch/arm/insts/sve_mem.hh" diff --git a/src/arch/arm/isa/insts/insts.isa b/src/arch/arm/isa/insts/insts.isa index 0697ca49d2..cdc162f2b5 100644 --- a/src/arch/arm/isa/insts/insts.isa +++ b/src/arch/arm/isa/insts/insts.isa @@ -105,6 +105,9 @@ split decoder; ##include "sve.isa" ##include "sve_mem.isa" +//SME +##include "sme.isa" + //m5 Pseudo-ops ##include "m5ops.isa" diff --git a/src/arch/arm/isa/insts/sme.isa b/src/arch/arm/isa/insts/sme.isa new file mode 100644 index 0000000000..b9f6115432 --- /dev/null +++ b/src/arch/arm/isa/insts/sme.isa @@ -0,0 +1,821 @@ +// Copyright (c) 2022 ARM Limited +// All rights reserved +// +// The license below extends only to copyright in the software and shall +// not be construed as granting a license to any other intellectual +// property including but not limited to intellectual property relating +// to a hardware implementation of the functionality of the software +// licensed hereunder. You may use the software subject to the license +// terms below provided that you ensure that this notice is replicated +// unmodified and in its entirety in all distributions of the software, +// modified or unmodified, in source code or in binary form. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer; +// redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution; +// neither the name of the copyright holders nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// @file Definition of SME instructions. + +let {{ + + header_output = "" + decoder_output = "" + exec_output = "" + + def smeAddInst(name, Name, opClass, types, op): + global header_output, decoder_output, exec_output + code = smEnCheckCode + smeZaWrite + ''' + // imm stores the tile index + // op1 is the source SVE vector register + // gp1 is the row predecate register + // gp2 is the column predecate register + + unsigned eCount = ArmStaticInst::getCurSmeVecLen( + xc->tcBase()); + + uint8_t tile_index = imm & 0x7; + + // View the tile as the correct data type, extract the sub-tile + auto tile = getTile(ZA, tile_index); + ''' + code += op + + iop = InstObjParams(name, "Sme" + Name, "SmeAddOp", + {'code': code, 'op_class': opClass}, + ['IsNonSpeculative']) + header_output += SmeAddDeclare.subst(iop) + exec_output += SmeTemplatedExecute.subst(iop) + + for type in types: + substDict = {'targs' : type, + 'class_name' : 'Sme' + Name} + exec_output += SmeOpExecDeclare.subst(substDict) + + def smeAddVlInst(name, Name, opClass, op): + global header_output, decoder_output, exec_output + code = smEnCheckCodeNoPstate + ''' + // dest is the 64-bit destination register + // op1 is the 64-bit source register + // imm is a signed multiplier + ''' + code += op + + iop = InstObjParams(name, "Sme" + Name, "SmeAddVlOp", + {'code': code, 'op_class': opClass}, + ['IsNonSpeculative']) + header_output += SmeAddVlDeclare.subst(iop) + exec_output += SmeExecute.subst(iop) + + def smeLd1xInst(name, Name, opClass, types): + global header_output, decoder_output, exec_output + code = smEnCheckCode + smeZaWrite + ''' + // imm stores the tile number as well as the vector offset. The + // size of the fields changes based on the data type being used. + // XOp1 stores Rn + // GpOp stores the governing predicate register + // WOp2 stores Rs - the vector index register + // XOp3 stores Rm - the offset register (applied to Rn) + + + unsigned eCount = ArmStaticInst::getCurSmeVecLen( + xc->tcBase()); + + uint8_t offset = imm & (0xf >> (findMsbSet(sizeof(TPElem)))); + M5_VAR_USED uint8_t tile_idx = + imm >> (4 - findMsbSet(sizeof(TPElem))); + M5_VAR_USED uint8_t vec_idx = (WOp2 + offset) % eCount; + + // Calculate the address + M5_VAR_USED Addr EA = XOp1 + XOp3 * sizeof(TPElem); + + // Calculate the read predicate. One boolean per byte, + // initialised to all true. + auto rdEn = std::vector(eCount * sizeof(TPElem), true); + for (int i = 0; i < eCount; ++i) { + if (GpOp_x[i]) { + continue; + } + + // Mark each byte of the corresponding elem as false + for (int j = 0; j < sizeof(TPElem); ++j) { + rdEn[i * sizeof(TPElem) + j] = false; + } + } + ''' + + zaWriteCode = ''' + // Here we write the data we just got from memory to the tile: + if (V) { + auto col = getTileVSlice(ZA, tile_idx, vec_idx); + for(int i = 0; i < eCount; ++i) { + col[i] = GpOp_x[i] ? data[i] : 0; + } + } else { + auto row = getTileHSlice(ZA, tile_idx, vec_idx); + for(int i = 0; i < eCount; ++i) { + row[i] = GpOp_x[i] ? data[i] : 0; + } + } + ''' + + iop = InstObjParams(name, "Sme" + Name, "SmeLd1xSt1xOp", + {'code': code, 'za_write': zaWriteCode, + 'op_class': opClass}, ['IsLoad', + 'IsNonSpeculative']) + header_output += SmeLd1xDeclare.subst(iop) + exec_output += SmeLd1xExecute.subst(iop) + exec_output += SmeLd1xInitiateAcc.subst(iop) + exec_output += SmeLd1xCompleteAcc.subst(iop) + for type in types: + substDict = {'targs' : type, + 'class_name' : 'Sme' + Name} + exec_output += SmeLd1xExecDeclare.subst(substDict) + + def smeLdrInst(name, Name, opClass): + global header_output, decoder_output, exec_output + code = smEnCheckCodeNoSM + smeZaWrite + ''' + // imm stores the vector offset. We do not have a tile number as + // we target the whole accumulator array. + // imm also stores the offset applied to the base memory access + // register. + // Op1 stores Rn, which is the base memory access register + // Op2 stores Rv, which is the vector select register + + + unsigned eCount = ArmStaticInst::getCurSmeVecLen( + xc->tcBase()); + + M5_VAR_USED uint8_t vec_index = (WOp2 + imm) % eCount; + + // Calculate the address + M5_VAR_USED Addr EA = XOp1 + imm; + ''' + + iop = InstObjParams(name, "Sme" + Name, "SmeLdrStrOp", + {'code': code, 'op_class': opClass}, + ['IsLoad', 'IsNonSpeculative']) + header_output += SmeLdrDeclare.subst(iop) + exec_output += SmeLdrExecute.subst(iop) + exec_output += SmeLdrInitiateAcc.subst(iop) + exec_output += SmeLdrCompleteAcc.subst(iop) + + def smeMovaExtractInst(name, Name, opClass, types): + global header_output, decoder_output, exec_output + code = smEnCheckCode + ''' + // imm stores the tile index + // op1 is the source SVE vector register + // gp is the governing predecate register + // op2 is the slice index register + // v is the row/col select immediate - true for column accesses + + unsigned eCount = ArmStaticInst::getCurSmeVecLen( + xc->tcBase()); + + uint8_t offset = imm & (0xf >> (findMsbSet(sizeof(TPElem)))); + uint8_t tile_idx = imm >> (4 - findMsbSet(sizeof(TPElem))); + + uint32_t vec_idx = (WOp2 + offset) % eCount; + + if (!v) { // Horizontal (row) access + auto row = getTileHSlice(ZA, tile_idx, vec_idx); + for (int i = 0; i < eCount; ++i) { + if (!GpOp_x[i]) { + continue; + } + + AA64FpOp1_x[i] = row[i]; + } + } else { // Vertical (column) access + auto col = getTileVSlice(ZA, tile_idx, vec_idx); + for (int i = 0; i < eCount; ++i) { + if (!GpOp_x[i]) { + continue; + } + + AA64FpOp1_x[i] = col[i]; + } + } + ''' + + iop = InstObjParams(name, "Sme" + Name, "SmeMovExtractOp", + {'code': code, 'op_class': opClass}, + ['IsNonSpeculative']) + header_output += SmeMovaExtractDeclare.subst(iop) + exec_output += SmeTemplatedExecute.subst(iop) + + for type in types: + substDict = {'targs' : type, + 'class_name' : 'Sme' + Name} + exec_output += SmeOpExecDeclare.subst(substDict) + + def smeMovaInsertInst(name, Name, opClass, types): + global header_output, decoder_output, exec_output + code = smEnCheckCode + smeZaWrite + ''' + // imm stores the tile index + // op1 is the source SVE vector register + // gp is the governing predecate register + // op2 is the slice index register + // v is the row/col select immediate - true for column accesses + + unsigned eCount = ArmStaticInst::getCurSmeVecLen( + xc->tcBase()); + + uint8_t offset = imm & (0xf >> (findMsbSet(sizeof(TPElem)))); + uint8_t tile_idx = imm >> (4 - findMsbSet(sizeof(TPElem))); + + uint32_t vec_idx = (WOp2 + offset) % eCount; + + if (!v) { // Horizontal (row) access + auto row = getTileHSlice(ZA, tile_idx, vec_idx); + for (int i = 0; i < eCount; ++i) { + if (!GpOp_x[i]) { + continue; + } + + row[i] = AA64FpOp1_x[i]; + } + } else { // Vertical (column) access + auto col = getTileVSlice(ZA, tile_idx, vec_idx); + for (int i = 0; i < eCount; ++i) { + if (!GpOp_x[i]) { + continue; + } + + col[i] = AA64FpOp1_x[i]; + } + } + ''' + + iop = InstObjParams(name, "Sme" + Name, "SmeMovInsertOp", + {'code': code, 'op_class': opClass}, + ['IsNonSpeculative']) + header_output += SmeMovaInsertDeclare.subst(iop) + exec_output += SmeTemplatedExecute.subst(iop) + + for type in types: + substDict = {'targs' : type, + 'class_name' : 'Sme' + Name} + exec_output += SmeOpExecDeclare.subst(substDict) + + def smeMsrInst(name, Name, opClass, op): + global header_output, decoder_output, exec_output + code = ''' + if (FullSystem) { + fault = this->checkSmeAccess(xc->tcBase(), Cpsr, Cpacr64); + if (fault != NoFault) { + return fault; + } + } + ''' + op + + iop = InstObjParams(name, "Sme" + Name, "ImmOp64", + {'code': code, 'op_class': opClass}, + ['IsNonSpeculative', 'IsSerializeAfter']) + header_output += SMEMgmtDeclare.subst(iop) + exec_output += SmeExecute.subst(iop) + + def smeFPOPInst(name, Name, opClass, srcTypes, dstTypes, op): + global header_output, decoder_output, exec_output + code = smEnCheckCode + smeZaWrite + ''' + // imm stores the tile index + // op1 is the first SVE vector register + // gp1 is the predecate register corresponding to the first + // SVE vector register + // gp2 is the predecate register corresponding to the second + // SVE vector register + // op2 is the second SVE vector register + + unsigned eCount = ArmStaticInst::getCurSmeVecLen( + xc->tcBase()); + ''' + code += op + + iop = InstObjParams(name, "Sme" + Name, "SmeOPOp", + {'code': code, 'op_class': opClass}, + ['IsNonSpeculative']) + header_output += SmeFPOPDeclare.subst(iop) + exec_output += SmeDualTemplatedExecute.subst(iop) + for src, dst in zip(srcTypes, dstTypes): + substDict = {'targs' : "{}, {}".format(src, dst), + 'class_name' : 'Sme' + Name} + exec_output += SmeOpExecDeclare.subst(substDict) + + def smeIntOPInst(name, Name, opClass, src1Types, src2Types, dstTypes, op): + global header_output, decoder_output, exec_output + code = smEnCheckCode + smeZaWrite + ''' + // imm stores the tile index + // op1 is the first SVE vector register + // gp1 is the predecate register corresponding to the first + // SVE vector register + // gp2 is the predecate register corresponding to the second + // SVE vector register + // op2 is the second SVE vector register + + unsigned eCount = ArmStaticInst::getCurSmeVecLen( + xc->tcBase()); + ''' + code += op + + iop = InstObjParams(name, "Sme" + Name, "SmeOPOp", + {'code': code, 'op_class': opClass}, + ['IsNonSpeculative']) + header_output += SmeIntOPDeclare.subst(iop) + exec_output += SmeTripleTemplatedExecute.subst(iop) + for src1, src2, dst in zip(src1Types, src2Types, dstTypes): + substDict = {'targs' : "{}, {}, {}".format(src1, src2, dst), + 'class_name' : 'Sme' + Name} + exec_output += SmeOpExecDeclare.subst(substDict) + + def smeRdsvlInst(name, Name, opClass): + global header_output, decoder_output, exec_output + code = smEnCheckCodeNoPstate + ''' + // dest is the 64-bit destination register + // imm is a signed multiplier + + unsigned eCount = ArmStaticInst::getCurSmeVecLen( + xc->tcBase()); + + Dest64 = eCount * imm; + ''' + + iop = InstObjParams(name, "Sme" + Name, "SmeRdsvlOp", + {'code': code, 'op_class': opClass}, + ['IsNonSpeculative']) + header_output += SmeRdsvlDeclare.subst(iop) + exec_output += SmeExecute.subst(iop) + + def smeSt1xInst(name, Name, opClass, types): + global header_output, decoder_output, exec_output + code = smEnCheckCode + ''' + // imm stores the tile number as well as the vector offset. The + // size of the fields changes based on the data type being used. + // XOp1 stores Rn + // GpOp stores the governing predicate register + // WOp2 stores Rs - the vector index register + // XOp3 stores Rm - the offset register (applied to Rn) + + + unsigned eCount = ArmStaticInst::getCurSmeVecLen( + xc->tcBase()); + + uint8_t offset = imm & (0xf >> (findMsbSet(sizeof(TPElem)))); + M5_VAR_USED uint8_t tile_idx = + imm >> (4 - findMsbSet(sizeof(TPElem))); + M5_VAR_USED uint8_t vec_idx = (WOp2 + offset) % eCount; + + // Calculate the address + M5_VAR_USED Addr EA = XOp1 + XOp3 * sizeof(TPElem); + + // Calculate the write predicate. One boolean per byte, + // initialised to all true. + auto wrEn = std::vector(eCount * sizeof(TPElem), true); + for (int i = 0; i < eCount; ++i) { + if (GpOp_x[i]) { + continue; + } + + // Mark each byte of the corresponding elem as false + for (int j = 0; j < sizeof(TPElem); ++j) { + wrEn[i * sizeof(TPElem) + j] = false; + } + } + + // Extract the data to be stored from the tile. We don't worry + // about the predicate here as that's already handled by wrEn. + TPElem data[MaxSmeVecLenInBytes / sizeof(TPElem)]; + if(V) { + auto col = getTileVSlice(ZA, tile_idx, vec_idx); + for (int i = 0; i < eCount; ++i) { + data[i] = col[i]; + } + } else { + auto row = getTileHSlice(ZA, tile_idx, vec_idx); + for (int i = 0; i < eCount; ++i) { + data[i] = row[i]; + } + } + ''' + + iop = InstObjParams(name, "Sme" + Name, "SmeLd1xSt1xOp", + {'code': code, 'op_class': opClass}, + ['IsStore', 'IsNonSpeculative']) + header_output += SmeSt1xDeclare.subst(iop) + exec_output += SmeSt1xExecute.subst(iop) + exec_output += SmeSt1xInitiateAcc.subst(iop) + exec_output += SmeSt1xCompleteAcc.subst(iop) + for type in types: + substDict = {'targs' : type, + 'class_name' : 'Sme' + Name} + exec_output += SmeSt1xExecDeclare.subst(substDict) + + def smeStrInst(name, Name, opClass): + global header_output, decoder_output, exec_output + code = smEnCheckCodeNoSM + ''' + // imm stores the vector offset. We do not have a tile number + // as we target the whole accumulator array. + // imm also stores the offset applied to the base memory access + // register. + // Op1 stores Rn, which is the base memory access register + // Op2 stores Rv, which is the vector select register + + + unsigned eCount = ArmStaticInst::getCurSmeVecLen( + xc->tcBase()); + + uint8_t vec_index = (WOp2 + imm) % eCount; + + auto row = getTileHSlice(ZA, 0, vec_index); + + // Calculate the address + M5_VAR_USED Addr EA = XOp1 + imm; + + uint8_t data[MaxSmeVecLenInBytes]; + + // Update data which will then by used to store the row to memory + for (int i = 0; i < eCount; ++i) { + data[i] = row[i]; + } + ''' + + iop = InstObjParams(name, "Sme" + Name, "SmeLdrStrOp", + {'code': code, 'op_class': opClass}, + ['IsStore', 'IsNonSpeculative']) + header_output += SmeStrDeclare.subst(iop) + exec_output += SmeStrExecute.subst(iop) + exec_output += SmeStrInitiateAcc.subst(iop) + exec_output += SmeStrCompleteAcc.subst(iop) + + def smeZeroInst(name, Name, opClass, types): + global header_output, decoder_output, exec_output + code = smEnCheckCodeNoSM + smeZaWrite + ''' + // When zeroing tiles, we use 64-bit elements. This means + // that we have up to eight subtiles to clear in the ZA tile. + + ZA = ZA; + + for (int i = 0; i < 8; ++i) { + if (((imm >> i) & 0x1) == 0x1) { + getTile(ZA, i).zero(); + } + }''' + + iop = InstObjParams(name, "Sme" + Name, "SmeZeroOp", + {'code': code, 'op_class': opClass}, + ['IsNonSpeculative']) + header_output += SmeZeroDeclare.subst(iop) + exec_output += SmeTemplatedExecute.subst(iop) + + for type in types: + substDict = {'targs' : type, + 'class_name' : 'Sme' + Name} + exec_output += SmeOpExecDeclare.subst(substDict) + + # ADDHA + addCode = ''' + for (int col = 0; col < eCount; ++col) { + TPElem val = AA64FpOp1_x[col]; + + for (int row = 0; row < eCount; ++row) { + if (!(GpOp1_x[row] && GpOp2_x[col])) { + continue; + } + + tile[col][row] += val; + } + } + ''' + smeAddInst('addha', "Addha", "SimdAddOp", ['int32_t', 'int64_t'], addCode) + + # ADDSPL + addSplCode = ''' + Dest64 = imm * ArmStaticInst::getCurSmeVecLen(xc->tcBase()); + // Divide down to get the predicate length in bytes + Dest64 /= 8; + Dest64 += XOp1; + ''' + smeAddVlInst('addspl', "Addspl", "SimdAddOp", addSplCode) + + # ADDSVL + addSvlCode = ''' + Dest64 = imm * ArmStaticInst::getCurSmeVecLen(xc->tcBase()); + Dest64 += XOp1; + ''' + smeAddVlInst('addsvl', "Addsvl", "SimdAddOp", addSvlCode) + + # ADDVA + addCode = ''' + for (int row = 0; row < eCount; ++row) { + TPElem val = AA64FpOp1_x[row]; + + for (int col = 0; col < eCount; ++col) { + if (!(GpOp1_x[row] && GpOp2_x[col])) { + continue; + } + + tile[col][row] += val; + } + } + ''' + smeAddInst('addva', "Addva", "SimdAddOp", ['int32_t', 'int64_t'], addCode) + + # BFMOPA + # BFMOPS + + # FMOPA (non-widening) + fmopxCode = ''' + auto tile = getTile(ZA, imm); + FPSCR fpscr = (FPSCR) Fpscr; + + for (int j = 0; j < eCount; ++j) { + if (!GpOp1_xd[j]) { + continue; + } + + TPDElem val1 = AA64FpOp1_xd[j]; + + for (int i = 0; i < eCount; ++i) { + if (!GpOp2_xd[i]) { + continue; + } + + TPDElem val2 = AA64FpOp2_xd[i]; + + #if %s + val2 = fplibNeg(val2); + #endif + + TPDElem res = fplibMul(val1, val2, fpscr); + + tile[j][i] = fplibAdd(tile[j][i], + res, fpscr); + } + } + ''' + smeFPOPInst('fmopa', 'Fmopa', 'MatrixOPOp', ['uint32_t', 'uint64_t'], + ['uint32_t', 'uint64_t'], fmopxCode % "0") + + # FMOPA (widening) + wideningFmopxCode = ''' + auto tile = getTile(ZA, imm); + FPSCR fpscr = (FPSCR) Fpscr; + + for (int j = 0; j < eCount; ++j) { + if (!GpOp1_xd[j]) { + continue; + } + for (int i = 0; i < eCount; ++i) { + if (!GpOp2_xd[i]) { + continue; + } + + for (int k = 0; k < 2; ++k) { + TPSElem temp1 = (AA64FpOp1_xd[j] >> (16 * k)) & 0xFFFF; + TPSElem temp2 = (AA64FpOp2_xd[j] >> (16 * k)) & 0xFFFF; + TPDElem val1 = fplibConvert(temp1, + FPCRRounding(fpscr), fpscr); + TPDElem val2 = fplibConvert(temp2, + FPCRRounding(fpscr), fpscr); + + #if %s + val2 = fplibNeg(val2); + #endif + + TPDElem res = fplibMul(val1, val2, fpscr); + tile[j][i] = fplibAdd(tile[j][i], res, fpscr); + } + } + } + ''' + smeFPOPInst('fmopa', 'FmopaWidening', 'MatrixOPOp', + ['uint16_t'], ['uint32_t'], wideningFmopxCode % "0") + + # FMOPS (non-widening) + smeFPOPInst('fmops', 'Fmops', 'MatrixOPOp', ['uint32_t', 'uint64_t'], + ['uint32_t', 'uint64_t'], fmopxCode % "1") + + # FMOPS (widening) + smeFPOPInst('fmops', 'FmopsWidening', 'MatrixOPOp', + ['uint16_t'], ['uint32_t'], wideningFmopxCode % "1") + + # LD1B + smeLd1xInst('ld1b', 'Ld1b', 'MemReadOp', ['uint8_t']) + + # LD1D + smeLd1xInst('ld1d', 'Ld1d', 'MemReadOp', ['uint64_t']) + + # LD1H + smeLd1xInst('ld1h', 'Ld1h', 'MemReadOp', ['uint16_t']) + + # LD1Q + smeLd1xInst('ld1q', 'Ld1q', 'MemReadOp', ['__uint128_t']) + + # LD1W + smeLd1xInst('ld1w', 'Ld1w', 'MemReadOp', ['uint32_t']) + + # LDR + smeLdrInst("ldr", "Ldr", 'MemReadOp') + + # MOV (tile to vector) - ALIAS; see MOVA + # MOV (vector to tile) - ALIAS; see MOVA + # MOVA (tile to vector) + smeMovaExtractInst("mova", "MovaExtract", 'MatrixMovOp', + ["uint8_t", "uint16_t", "uint32_t", "uint64_t", + "__uint128_t"]) + + # MOVA (vector to tile) + smeMovaInsertInst("mova", "MovaInsert", 'MatrixMovOp', + ["uint8_t", "uint16_t", "uint32_t", "uint64_t", + "__uint128_t"]) + + # RDSVL + smeRdsvlInst('rdsvl', 'Rdsvl', 'SimdAddOp') + + # SMOPA + intMopxCode = ''' + auto tile = getTile(ZA, imm); + + size_t shift = 8 * sizeof(TPS1Elem); + size_t mask = (1 << shift) - 1; + + for (int j = 0; j < eCount; ++j) { + for (int i = 0; i < eCount; ++i) { + for (int k = 0; k < 4; ++k) { + if (!GpOp1_xs1[4 * j + k]) { + continue; + } + + if (!GpOp2_xs2[4 * i + k]) { + continue; + } + + TPS1Elem temp1 = + (TPS1Elem)(AA64FpOp1_xd[j] >> (shift * k)) & mask; + TPS2Elem temp2 = + (TPS2Elem)(AA64FpOp2_xd[i] >> (shift * k)) & mask; + + tile[j][i] %s= (TPDElem)temp1 * (TPDElem)temp2; + } + } + } + ''' + smeIntOPInst('smopa', 'Smopa', 'MatrixOPOp', ['int8_t', 'int16_t'], + ['int8_t', 'int16_t'], ['int32_t', 'int64_t'], + intMopxCode % "+") + + # SMOPS + smeIntOPInst('smops', 'Smops', 'MatrixOPOp', ['int8_t', 'int16_t'], + ['int8_t', 'int16_t'], ['int32_t', 'int64_t'], + intMopxCode % "-") + + # SMSTART + smstartSmstopCode = ''' + // Bit 0 of imm determines if we are setting or clearing + // (smstart vs smstop) + // Bit 1 means that we are applying this to SM + // Bit 2 means that we are applying this to ZA + bool new_state = imm & 0x1; + bool sm_affected = imm & 0x2; + bool za_affected = imm & 0x4; + bool old_sm_state = Svcr & 0x1; + bool old_za_state = Svcr & 0x2; + + bool sm_changed = sm_affected && old_sm_state != new_state; + bool za_changed = za_affected && old_za_state != new_state; + + if (sm_changed) { + // We need to zero the SVE Z, P, FFR registers on SM change. Also, + // set FPSR to a default value. Note that we use the max SVE len + // instead of the actual vector length. + // + // For the Z, P registers we are directly setting these to zero + // without going through the ISA parser (which generates the + // dependencies) as otherwise the O3 CPU can deadlock when there + // are too few free physical registers. We therefore rely on this + // instruction being a barrier (IsSerialiseAfter). + + // Z Registers, including special and interleave registers + ArmISA::VecRegContainer zeroed_z_reg; + zeroed_z_reg.zero(); + + for (int reg_idx = 0; reg_idx < NumVecRegs; ++reg_idx) { + auto reg_id = ArmISA::vecRegClass[reg_idx]; + xc->tcBase()->setReg(reg_id, &zeroed_z_reg); + } + + // P Registers, including the FFR + ArmISA::VecPredRegContainer zeroed_p_reg; + zeroed_p_reg.reset(); + + for (int reg_idx = 0; reg_idx < NumVecPredRegs; ++reg_idx) { + auto reg_id = ArmISA::vecPredRegClass[reg_idx]; + xc->tcBase()->setReg(reg_id, &zeroed_p_reg); + } + + // FPSR + Fpsr = 0x0800009f; + } + + if (za_changed) { + // ZA write + ZA = ZA; + ZA.zero(); + } + + // Now that we've handled the zeroing of the appropriate registers, + // we update the pstate accordingly. + + if (sm_changed) { + if (new_state == 1) { + Svcr = Svcr | 0x1; // Set SM + } else { + Svcr = Svcr & ~(uint64_t)0x1; // Clear SM + } + } + + if (za_changed) { + if (new_state == 1) { + Svcr = Svcr | 0x2; // Set ZA + } else { + Svcr = Svcr & ~(uint64_t)0x2; // Clear ZA + } + } + ''' + + smeMsrInst('smstart', 'Smstart', 'IntAluOp', + smstartSmstopCode) + + # SMSTOP + smeMsrInst('smstop', 'Smstop', 'IntAluOp', + smstartSmstopCode) + + # ST1B + smeSt1xInst('st1b', 'St1b', 'MemWriteOp', ['uint8_t']) + + # ST1D + smeSt1xInst('st1d', 'St1d', 'MemWriteOp', ['uint64_t']) + + # ST1H + smeSt1xInst('st1h', 'St1h', 'MemWriteOp', ['uint16_t']) + + # ST1Q + smeSt1xInst('st1q', 'St1q', 'MemWriteOp', ['__uint128_t']) + + # ST1W + smeSt1xInst('st1w', 'St1w', 'MemWriteOp', ['uint32_t']) + + # STR + smeStrInst("str", "Str", "MemWriteOp") + + # SUMOPA + smeIntOPInst('sumopa', 'Sumopa', 'MatrixOPOp', ['int8_t', 'int16_t'], + ['uint8_t', 'uint16_t'], ['int32_t', 'int64_t'], + intMopxCode % "+") + + # SUMOPS + smeIntOPInst('sumops', 'Sumops', 'MatrixOPOp', ['int8_t', 'int16_t'], + ['uint8_t', 'uint16_t'], ['int32_t', 'int64_t'], + intMopxCode % "-") + + # UMOPA + smeIntOPInst('umopa', 'Umopa', 'MatrixOPOp', ['uint8_t', 'uint16_t'], + ['uint8_t', 'uint16_t'], ['int32_t', 'int64_t'], + intMopxCode % "+") + + # UMOPS + smeIntOPInst('umops', 'Umops', 'MatrixOPOp', ['uint8_t', 'uint16_t'], + ['uint8_t', 'uint16_t'], ['int32_t', 'int64_t'], + intMopxCode % "-") + + # USMOPA + smeIntOPInst('usmopa', 'Usmopa', 'MatrixOPOp', ['uint8_t', 'uint16_t'], + ['int8_t', 'int16_t'], ['int32_t', 'int64_t'], + intMopxCode % "+") + + # USMOPS + smeIntOPInst('usmops', 'Usmops', 'MatrixOPOp', ['uint8_t', 'uint16_t'], + ['int8_t', 'int16_t'], ['int32_t', 'int64_t'], + intMopxCode % "-") + + # ZERO + smeZeroInst("zero", "Zero", "MatrixOp", ["uint64_t"]) + +}}; diff --git a/src/arch/arm/isa/insts/sve.isa b/src/arch/arm/isa/insts/sve.isa index 7cb733100f..97d4ec7e56 100644 --- a/src/arch/arm/isa/insts/sve.isa +++ b/src/arch/arm/isa/insts/sve.isa @@ -1310,6 +1310,34 @@ let {{ substDict = {'targs' : type, 'class_name' : 'Sve' + Name} exec_output += SveOpExecDeclare.subst(substDict); + # Generates definition for SVE psel predicate selection instructions + def svePselInst(name, Name, opClass, types): + global header_output, exec_output, decoders + code = sveEnabledCheckCode + ''' + unsigned eCount = ArmStaticInst::getCurSveVecLen( + xc->tcBase()); + + uint8_t index = ((uint32_t)Op2 + imm) % eCount; + + bool copy = POp1_x[index]; + if (copy) { + for (int i = 0; i < eCount; ++i) { + PDest_x[i] = GpOp_x[i]; + } + } else { + for (int i = 0; i < eCount; ++i) { + PDest_x[i] = false; + } + } + ''' + iop = ArmInstObjParams(name, 'Sve' + Name, 'SvePselOp', + {'code': code, 'op_class': opClass}, []) + header_output += SvePselOpDeclare.subst(iop) + exec_output += SveOpExecute.subst(iop) + for type in types: + substDict = {'targs' : type, 'class_name' : 'Sve' + Name} + exec_output += SveOpExecDeclare.subst(substDict); + # Generate definition for SVE compare & terminate instructions def sveCompTermInst(name, Name, opClass, types, op): global header_output, exec_output, decoders @@ -3096,6 +3124,31 @@ let {{ 'class_name' : 'Sve' + Name} exec_output += SveOpExecDeclare.subst(substDict) + # Generate definitions for clamp to min/max instructions + def sveClampInst(name, Name, opClass, types, + decoder = 'Generic'): + global header_output, exec_output, decoders + code = sveEnabledCheckCode + ''' + unsigned eCount = ArmStaticInst::getCurSveVecLen( + xc->tcBase()); + + for (int i = 0 ; i < eCount ; ++i) { + if (AA64FpDestMerge_x[i] < AA64FpOp2_x[i]) { + AA64FpDest_x[i] = AA64FpOp2_x[i]; + } else if (AA64FpDestMerge_x[i] > AA64FpOp1_x[i]) { + AA64FpDest_x[i] = AA64FpOp1_x[i]; + } + } + ''' + iop = ArmInstObjParams(name, 'Sve' + Name, 'SveClampOp', + {'code': code, 'op_class': opClass}, []) + header_output += SveClampOpDeclare.subst(iop) + exec_output += SveOpExecute.subst(iop) + for type in types: + substDict = {'targs' : type, + 'class_name' : 'Sve' + Name} + exec_output += SveOpExecDeclare.subst(substDict) + fpTypes = ('uint16_t', 'uint32_t', 'uint64_t') signedTypes = ('int8_t', 'int16_t', 'int32_t', 'int64_t') unsignedTypes = ('uint8_t', 'uint16_t', 'uint32_t', 'uint64_t') @@ -4071,6 +4124,8 @@ let {{ svePFirstInst('pfirst', 'Pfirst', 'SimdPredAluOp') # PNEXT svePNextInst('pnext', 'Pnext', 'SimdPredAluOp', unsignedTypes) + # PSEL + svePselInst('psel', 'Psel', 'SimdPredAluOp', unsignedTypes) # PTEST svePredTestInst('ptest', 'Ptest', 'SimdPredAluOp') # PTRUE @@ -4140,6 +4195,10 @@ let {{ ['uint16_t', 'uint32_t', 'uint64_t'], revCode % {'revtype' : 'uint8_t'}, predType=PredType.MERGE, srcRegType=SrcRegType.Vector, decoder='Generic') + # REVD + sveUnaryInst('revd', 'Revd', 'SimdAluOp', ['__uint128_t'], + revCode % {'revtype' : 'uint64_t'}, predType=PredType.MERGE, + srcRegType=SrcRegType.Vector, decoder='Generic') # REVH sveUnaryInst('revh', 'Revh', 'SimdAluOp', ['uint32_t', 'uint64_t'], revCode % {'revtype' : 'uint16_t'}, predType=PredType.MERGE, @@ -4160,6 +4219,8 @@ let {{ sveWideningAssocReducInst('saddv', 'Saddv', 'SimdReduceAddOp', ['int8_t, int64_t', 'int16_t, int64_t', 'int32_t, int64_t'], addvCode, '0') + # SCLAMP + sveClampInst('sclamp', 'Sclamp', 'SimdAluOp', signedTypes) # SCVTF scvtfCode = fpOp % ('fplibFixedToFP(' 'sext(srcElem1), 0,' @@ -4545,6 +4606,8 @@ let {{ ['uint8_t, uint64_t', 'uint16_t, uint64_t', 'uint32_t, uint64_t', 'uint64_t, uint64_t'], addvCode, '0') + # UCLAMP + sveClampInst('uclamp', 'Uclamp', 'SimdAluOp', unsignedTypes) # UCVTF ucvtfCode = fpOp % ('fplibFixedToFP(srcElem1, 0, true,' ' FPCRRounding(fpscr), fpscr)') diff --git a/src/arch/arm/isa/operands.isa b/src/arch/arm/isa/operands.isa index 5919ae974e..24a0af9155 100644 --- a/src/arch/arm/isa/operands.isa +++ b/src/arch/arm/isa/operands.isa @@ -57,6 +57,8 @@ def operand_types {{ # For operations that are implemented as a template 'x' : 'TPElem', 'xs' : 'TPSElem', + 'xs1' : 'TPS1Elem', + 'xs2' : 'TPS2Elem', 'xd' : 'TPDElem', 'pc' : 'ArmISA::VecPredRegContainer', 'pb' : 'uint8_t' @@ -451,6 +453,8 @@ def operands {{ # Predicate register operands 'GpOp': VecPredReg('gp'), + 'GpOp1': VecPredReg('gp1'), + 'GpOp2': VecPredReg('gp2'), 'POp1': VecPredReg('op1'), 'POp2': VecPredReg('op2'), 'PDest': VecPredReg('dest'), @@ -496,6 +500,7 @@ def operands {{ 'LLSCLock': CntrlRegNC('MISCREG_LOCKFLAG'), 'Dczid' : CntrlRegNC('MISCREG_DCZID_EL0'), 'PendingDvm': CntrlRegNC('MISCREG_TLBINEEDSYNC'), + 'Svcr' : CntrlReg('MISCREG_SVCR'), #Register fields for microops 'URa' : IntReg('ura'), diff --git a/src/arch/arm/isa/templates/sme.isa b/src/arch/arm/isa/templates/sme.isa new file mode 100644 index 0000000000..1bec2a3a71 --- /dev/null +++ b/src/arch/arm/isa/templates/sme.isa @@ -0,0 +1,773 @@ +// Copyright (c) 2022 ARM Limited +// All rights reserved +// +// The license below extends only to copyright in the software and shall +// not be construed as granting a license to any other intellectual +// property including but not limited to intellectual property relating +// to a hardware implementation of the functionality of the software +// licensed hereunder. You may use the software subject to the license +// terms below provided that you ensure that this notice is replicated +// unmodified and in its entirety in all distributions of the software, +// modified or unmodified, in source code or in binary form. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer; +// redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution; +// neither the name of the copyright holders nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// @file Definition of SME instruction templates. + +let {{ + # All SME instructions should be checking if Streaming Mode is + # enabled in the PSTATE. The following call checks both the SME and + # the FP enable flags in the relevant registers depending on the + # current EL. + smEnCheckCodeNoPstate = ''' + if (FullSystem) { + fault = this->checkSmeEnabled(xc->tcBase(), Cpsr, Cpacr64); + if (fault != NoFault) { + return fault; + } + } + ''' + + smPreamble = ''' + CPSR cpsr = (CPSR) Cpsr; + ExceptionLevel target_el = (ExceptionLevel) (uint8_t) cpsr.el; + if (target_el == EL0) { + target_el = EL1; + } + ''' + + smCheckCode = ''' + // Check streaming mode first + if ((Svcr & 1) != 0b1) { + fault = smeAccessTrap(target_el, 0b10); + return fault; + } + ''' + + zaCheckCode = ''' + // Check if ZA is enabled + if ((Svcr & 2) >> 1 != 0b1) { + fault = smeAccessTrap(target_el, 0b11); + return fault; + } + ''' + + # If streaming mode is disabled or ZA is disabled we trap + smEnCheckCode = smPreamble + smCheckCode + zaCheckCode + \ + smEnCheckCodeNoPstate + + # If ZA is disabled we trap + smEnCheckCodeNoSM = smPreamble + zaCheckCode + smEnCheckCodeNoPstate + + # If streaming mode is disabled we trap + smEnCheckCodeNoZA = smPreamble + smCheckCode + smEnCheckCodeNoPstate + + smeZaWrite = ''' + // Force the ISA parser to see the access to ZA as a write, + // not a read. + ZA = ZA; + ''' +}}; + +def template SmeAddDeclare {{ + template + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, uint64_t imm, + RegIndex op1, RegIndex gp1, + RegIndex gp2) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + imm, op1, gp1, gp2) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + }; +}}; + +def template SmeAddVlDeclare {{ + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, + RegIndex dest, RegIndex op1, + int8_t imm) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + dest, op1, imm) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + }; +}}; + +def template SmeLd1xDeclare {{ + template + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, uint64_t imm, + RegIndex op1, RegIndex mpop1, + RegIndex op2, RegIndex op3, + bool V) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + imm, op1, mpop1, op2, op3, V) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + Fault initiateAcc(ExecContext *, trace::InstRecord *) const override; + Fault completeAcc(PacketPtr, ExecContext *, + trace::InstRecord *) const override; + }; +}}; + +def template SmeLd1xExecute {{ + template + Fault %(class_name)s::execute(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + Request::Flags flags = 0; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + // We need a buffer in which to store the data: + TPElem data[MaxSmeVecLenInBytes / sizeof(TPElem)]; + + if (fault == NoFault) { + // The size of the access is controlled by the type of data, and + // the number of elements. + fault = xc->readMem(EA, (uint8_t*)data, eCount * sizeof(TPElem), + flags, rdEn); + } + + if (fault == NoFault) { + %(za_write)s + + // Write back the changes to the actual tile + %(op_wb)s; + } + + return fault; + } +}}; + +def template SmeLd1xInitiateAcc {{ + template + Fault %(class_name)s::initiateAcc(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + Request::Flags flags = 0; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + if (fault == NoFault) { + fault = xc->initiateMemRead(EA, eCount * sizeof(TPElem), + flags, rdEn); + } + + return fault; + } +}}; + +def template SmeLd1xCompleteAcc {{ + template + Fault %(class_name)s::completeAcc(PacketPtr pkt, ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + // The O3 CPU will call this with a NULL-pointer if the access was + // disabled. Just return. + if (pkt == NULL) { + return fault; + } + + if (fault == NoFault) { + // We need a buffer in which to store the data: + TPElem data[MaxSmeVecLenInBytes / sizeof(TPElem)]; + + // The size for the amount of data returned here should + // have been set in initiateAcc. + memcpy((uint8_t*)data, pkt->getPtr(), pkt->getSize()); + + %(za_write)s + + // Write back the changes to the tile + %(op_wb)s; + } + return fault; + } +}}; + +def template SmeLd1xExecDeclare {{ + template + Fault %(class_name)s<%(targs)s>::execute( + ExecContext *, trace::InstRecord *) const; + template + Fault %(class_name)s<%(targs)s>::initiateAcc( + ExecContext *, trace::InstRecord *) const; + template + Fault %(class_name)s<%(targs)s>::completeAcc( + PacketPtr, ExecContext *, trace::InstRecord *) const; +}}; + +def template SmeLdrDeclare {{ + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, uint64_t imm, + RegIndex op1, RegIndex op2) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + imm, op1, op2) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + Fault initiateAcc(ExecContext *, trace::InstRecord *) const override; + Fault completeAcc(PacketPtr, ExecContext *, + trace::InstRecord *) const override; + }; +}}; + +def template SmeLdrExecute {{ + Fault %(class_name)s::execute(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + Request::Flags flags = 0; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + auto rdEn = std::vector(eCount, true); + + // We need a buffer in which to store the data: + uint8_t data[MaxSmeVecLenInBytes]; + + if (fault == NoFault) { + fault = xc->readMem(EA, (uint8_t*)data, eCount, flags, rdEn); + } + + if (fault == NoFault) { + auto row = getTileHSlice(ZA, 0, vec_index); + for (int i = 0; i < eCount; ++i) { + row[i] = data[i]; + } + + %(op_wb)s; + } + + return fault; + } +}}; + +def template SmeLdrInitiateAcc {{ + Fault %(class_name)s::initiateAcc(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + Request::Flags flags = 0; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + auto rdEn = std::vector(eCount, true); + + if (fault == NoFault) { + fault = xc->initiateMemRead(EA, eCount, flags, rdEn); + } + + return fault; + } +}}; + +def template SmeLdrCompleteAcc {{ + Fault %(class_name)s::completeAcc(PacketPtr pkt, ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + // The O3 CPU will call this with a NULL-pointer if the access was + // disabled. Just return. + if (pkt == NULL) { + return fault; + } + + if (fault == NoFault) { + // Get the data out of the packet + auto row = getTileHSlice(ZA, 0, vec_index); + for (int i = 0; i < eCount; ++i) { + row[i] = pkt->getPtr()[i]; + } + + %(op_wb)s; + } + + return fault; + } +}}; + +def template SMEMgmtDeclare {{ + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, uint64_t imm) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, imm) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + }; +}}; + +def template SmeMovaExtractDeclare {{ + template + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, RegIndex op1, + uint8_t imm, RegIndex gp, + RegIndex op2, bool v) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + op1, imm, gp, op2, v) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + }; +}}; + +def template SmeMovaInsertDeclare {{ + template + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, uint8_t imm, + RegIndex op1, RegIndex gp, + RegIndex op2, bool v) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + imm, op1, gp, op2, v) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + }; +}}; + +def template SmeFPOPDeclare {{ + template + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, uint64_t imm, + RegIndex op1, RegIndex gp1, + RegIndex gp2, RegIndex op2) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + imm, op1, gp1, gp2, op2) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + }; +}}; + +def template SmeIntOPDeclare {{ + template + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, uint64_t imm, + RegIndex op1, RegIndex gp1, + RegIndex gp2, RegIndex op2) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + imm, op1, gp1, gp2, op2) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + }; +}}; + +def template SmeRdsvlDeclare {{ + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, + RegIndex dest, int8_t imm) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + dest, imm) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + }; +}}; + +def template SmeSt1xDeclare {{ + template + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, uint64_t imm, + RegIndex op1, RegIndex mpop1, + RegIndex op2, RegIndex op3, bool V) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + imm, op1, mpop1, op2, op3, V) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + Fault initiateAcc(ExecContext *, trace::InstRecord *) const override; + Fault completeAcc(PacketPtr, ExecContext *, + trace::InstRecord *) const override; + }; +}}; + +def template SmeSt1xExecute {{ + template + Fault %(class_name)s::execute(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + Request::Flags flags = 0; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + if (fault == NoFault) { + fault = xc->writeMem((uint8_t*)data, eCount * sizeof(TPElem), EA, + flags, NULL, wrEn); + } + + return fault; + } +}}; + +def template SmeSt1xInitiateAcc {{ + template + Fault %(class_name)s::initiateAcc(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + Request::Flags flags = 0; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + if (fault == NoFault) { + fault = xc->writeMem((uint8_t*)data, eCount * sizeof(TPElem), EA, + flags, NULL, wrEn); + } + + return fault; + } +}}; + +def template SmeSt1xCompleteAcc {{ + template + Fault %(class_name)s::completeAcc(PacketPtr pkt, ExecContext *xc, + trace::InstRecord *traceData) const + { + return NoFault; + } +}}; + +def template SmeStrDeclare {{ + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, uint64_t imm, + RegIndex op1, RegIndex op2) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + imm, op1, op2) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + Fault initiateAcc(ExecContext *, trace::InstRecord *) const override; + Fault completeAcc(PacketPtr, ExecContext *, + trace::InstRecord *) const override; + }; +}}; + +def template SmeStrExecute {{ + Fault %(class_name)s::execute(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + Request::Flags flags = 0; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + + if (fault == NoFault) { + auto wrEn = std::vector(eCount, true); + fault = xc->writeMem((uint8_t*)data, eCount, EA, + flags, NULL, wrEn); + } + + return fault; + } +}}; + +def template SmeStrInitiateAcc {{ + Fault %(class_name)s::initiateAcc(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + Request::Flags flags = 0; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + if (fault == NoFault) { + auto wrEn = std::vector(eCount, true); + fault = xc->writeMem((uint8_t*)data, eCount, EA, + flags, NULL, wrEn); + } + + return fault; + } +}}; + +def template SmeStrCompleteAcc {{ + Fault %(class_name)s::completeAcc(PacketPtr pkt, ExecContext *xc, + trace::InstRecord *traceData) const + { + // TODO-SME: Can this fail? + return NoFault; + } +}}; + +def template SmeSt1xExecDeclare {{ + template + Fault %(class_name)s<%(targs)s>::execute( + ExecContext *, trace::InstRecord *) const; + template + Fault %(class_name)s<%(targs)s>::initiateAcc( + ExecContext *, trace::InstRecord *) const; + template + Fault %(class_name)s<%(targs)s>::completeAcc( + PacketPtr, ExecContext *, trace::InstRecord *) const; +}}; + +def template SmeZeroDeclare {{ + template + class %(class_name)s : public %(base_class)s + { + private: + %(reg_idx_arr_decl)s; + + public: + /// Constructor. + %(class_name)s(ExtMachInst machInst, uint8_t imm) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, imm) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; + }; +}}; + +def template SmeExecute {{ + Fault + %(class_name)s::execute(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + if (fault == NoFault) { + %(op_wb)s; + } + + return fault; + } +}}; + +def template SmeTemplatedExecute {{ + template + Fault + %(class_name)s::execute(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + if (fault == NoFault) { + %(op_wb)s; + } + + return fault; + } +}}; + +def template SmeDualTemplatedExecute {{ + template + Fault + %(class_name)s::execute(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + if (fault == NoFault) { + %(op_wb)s; + } + + return fault; + } +}}; + +def template SmeTripleTemplatedExecute {{ + template + Fault + %(class_name)s::execute(ExecContext *xc, + trace::InstRecord *traceData) const + { + Fault fault = NoFault; + + %(op_decl)s; + %(op_rd)s; + %(code)s; + + if (fault == NoFault) { + %(op_wb)s; + } + + return fault; + } +}}; + +def template SmeOpExecDeclare {{ + template + Fault %(class_name)s<%(targs)s>::execute( + ExecContext *, trace::InstRecord *) const; +}}; diff --git a/src/arch/arm/isa/templates/sve.isa b/src/arch/arm/isa/templates/sve.isa index fc38a2b979..9260441c2c 100644 --- a/src/arch/arm/isa/templates/sve.isa +++ b/src/arch/arm/isa/templates/sve.isa @@ -800,6 +800,33 @@ class %(class_name)s : public %(base_class)s }; }}; +def template SvePselOpDeclare {{ +template +class %(class_name)s : public %(base_class)s +{ + private: + %(reg_idx_arr_decl)s; + + protected: + typedef _Element Element; + typedef _Element TPElem; + + public: + %(class_name)s(ExtMachInst machInst, + RegIndex dest, RegIndex op1, + RegIndex gp, RegIndex op2, + uint64_t imm) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + dest, op1, gp, op2, imm) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; +}; +}}; + def template SveCompTermOpDeclare {{ template class %(class_name)s : public %(base_class)s @@ -1170,6 +1197,32 @@ class %(class_name)s : public %(base_class)s }; }}; +def template SveClampOpDeclare {{ +template +class %(class_name)s : public %(base_class)s +{ + private: + %(reg_idx_arr_decl)s; + + protected: + typedef _Element Element; + typedef _Element TPElem; + + public: + // Constructor + %(class_name)s(ExtMachInst machInst, + RegIndex dest, RegIndex op1, RegIndex op2) + : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s, + dest, op1, op2) + { + %(set_reg_idx_arr)s; + %(constructor)s; + } + + Fault execute(ExecContext *, trace::InstRecord *) const override; +}; +}}; + def template SveWideningOpExecute {{ template Fault diff --git a/src/arch/arm/isa/templates/templates.isa b/src/arch/arm/isa/templates/templates.isa index 0b4abfcce4..047cd1ef79 100644 --- a/src/arch/arm/isa/templates/templates.isa +++ b/src/arch/arm/isa/templates/templates.isa @@ -82,3 +82,6 @@ //Templates for SVE instructions ##include "sve.isa" ##include "sve_mem.isa" + +//Templates for SME instructions +##include "sme.isa"