From a25d9a126f43f9ba7651a147991b62a32d919b00 Mon Sep 17 00:00:00 2001 From: Junshi Wang Date: Fri, 6 Sep 2024 13:32:45 +0800 Subject: [PATCH] arch-arm: Add recursive reduce in Neon instruction. FMAXV, FMINV, FMAXNMV, FMINNMV and ADDV instructions perform recursive reduction. Different reduction methods lie to different result when handle NaN values. Reuse the template of `twoRegAcrossInstX`. Add one more option `recursive` for recursive reduction. Change-Id: I69e690ce7668baee818542d3ea463f7a5f269a69 Reviewed-by: Giacomo Travaglini --- src/arch/arm/isa/insts/neon64.isa | 36 +++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/arch/arm/isa/insts/neon64.isa b/src/arch/arm/isa/insts/neon64.isa index a3b79be912..42fe379ba4 100644 --- a/src/arch/arm/isa/insts/neon64.isa +++ b/src/arch/arm/isa/insts/neon64.isa @@ -1,6 +1,6 @@ // -*- mode: c++ -*- -// Copyright (c) 2012-2013, 2015-2018, 2020 ARM Limited +// Copyright (c) 2012-2013, 2015-2018, 2020, 2024 ARM Limited // All rights reserved // // The license below extends only to copyright in the software and shall @@ -577,7 +577,7 @@ let {{ exec_output += NeonXExecDeclare.subst(substDict) def twoRegAcrossInstX(name, Name, opClass, types, rCount, op, - doubleDest=False, long=False): + doubleDest=False, long=False, recursive=False): global header_output, exec_output destPrefix = "Big" if long else "" eWalkCode = simd64EnabledCheckCode + ''' @@ -588,7 +588,25 @@ let {{ eWalkCode += ''' srcReg1.regs[%(reg)d] = htole(AA64FpOp1P%(reg)d_uw); ''' % { "reg" : reg } - eWalkCode += ''' + if recursive: + eWalkCode += ''' + RegVect tmpReg = srcReg1; + destReg.regs[0] = 0; + for (unsigned gap = 1; gap < eCount; gap = gap * 2) { + for (unsigned i = 0; i < eCount; i = i + gap * 2) { + unsigned src_id0 = i; + unsigned src_id1 = i + gap; + unsigned dst_id = i; + %(destPrefix)sElement destElem = letoh(tmpReg.elements[src_id0]); + %(destPrefix)sElement srcElem1 = letoh(tmpReg.elements[src_id1]); + %(op)s + tmpReg.elements[dst_id] = destElem; + } + } + destReg.elements[0] = htole(tmpReg.elements[0]); + ''' % { "op" : op, "destPrefix" : destPrefix } + else: + eWalkCode += ''' destReg.regs[0] = 0; %(destPrefix)sElement destElem = 0; for (unsigned i = 0; i < eCount; i++) { @@ -934,9 +952,9 @@ let {{ # Note: SimdAddOp can be a bit optimistic here addAcrossCode = "destElem += srcElem1;" twoRegAcrossInstX("addv", "AddvDX", "SimdAddOp", ("uint8_t", "uint16_t"), - 2, addAcrossCode) + 2, addAcrossCode, False, False, True) twoRegAcrossInstX("addv", "AddvQX", "SimdAddOp", smallUnsignedTypes, 4, - addAcrossCode) + addAcrossCode, False, False, True) # AND andCode = "destElem = srcElem1 & srcElem2;" threeEqualRegInstX("and", "AndDX", "SimdAluOp", ("uint64_t",), 2, andCode) @@ -1649,7 +1667,7 @@ let {{ fpAcrossOp = fpOp % "fplib%s(destElem, srcElem1, fpscr)" fmaxnmAcrossCode = fpAcrossOp % "MaxNum" twoRegAcrossInstX("fmaxnmv", "FmaxnmvQX", "SimdFloatCmpOp", ("uint32_t",), - 4, fmaxnmAcrossCode) + 4, fmaxnmAcrossCode, False, False, True) # FMAXP (scalar) twoRegPairwiseScInstX("fmaxp", "FmaxpScDX", "SimdFloatCmpOp", ("uint32_t",), 2, fmaxCode) @@ -1664,7 +1682,7 @@ let {{ # Note: SimdFloatCmpOp can be a bit optimistic here fmaxAcrossCode = fpAcrossOp % "Max" twoRegAcrossInstX("fmaxv", "FmaxvQX", "SimdFloatCmpOp", ("uint32_t",), 4, - fmaxAcrossCode) + fmaxAcrossCode, False, False, True) # FMIN fminCode = fpBinOp % "Min" threeEqualRegInstX("fmin", "FminDX", "SimdFloatCmpOp", smallFloatTypes, 2, @@ -1691,7 +1709,7 @@ let {{ # Note: SimdFloatCmpOp can be a bit optimistic here fminnmAcrossCode = fpAcrossOp % "MinNum" twoRegAcrossInstX("fminnmv", "FminnmvQX", "SimdFloatCmpOp", ("uint32_t",), - 4, fminnmAcrossCode) + 4, fminnmAcrossCode, False, False, True) # FMINP (scalar) twoRegPairwiseScInstX("fminp", "FminpScDX", "SimdFloatCmpOp", ("uint32_t",), 2, fminCode) @@ -1706,7 +1724,7 @@ let {{ # Note: SimdFloatCmpOp can be a bit optimistic here fminAcrossCode = fpAcrossOp % "Min" twoRegAcrossInstX("fminv", "FminvQX", "SimdFloatCmpOp", ("uint32_t",), 4, - fminAcrossCode) + fminAcrossCode, False, False, True) # FMLA (by element) fmlaCode = fpOp % ("fplibMulAdd(" "destElem, srcElem1, srcElem2, fpscr)")