base: Find lsb set generalization and optimization (#76)

* base: Generalize findLsbSet to std::bitset<N>

* base: Split builtin and fallback implementations of findLsbSet

* base: Add more unit testing for findLsbSet

Change-Id: Id75dfb7d306c9a8228fa893798b1b867137465a9

---------

Co-authored-by: Gabriel Busnot <gabriel.busnot@arteris.com>
This commit is contained in:
Gabriel Busnot
2023-07-18 00:32:04 +02:00
committed by GitHub
parent f80015ea18
commit 6fb72d84e1
2 changed files with 91 additions and 26 deletions

View File

@@ -41,9 +41,12 @@
#ifndef __BASE_BITFIELD_HH__
#define __BASE_BITFIELD_HH__
#include <bitset>
#include <cassert>
#include <climits>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <type_traits>
namespace gem5
@@ -303,40 +306,84 @@ findMsbSet(uint64_t val)
return msb;
}
namespace {
template<typename T>
constexpr bool
hasBuiltinCtz() {
#if defined(__has_builtin) && __has_builtin(__builtin_ctz)
return sizeof(unsigned long long) >= sizeof(T);
#else
return false;
#endif
}
int
findLsbSetFallback(uint64_t val) {
// Create a mask with val's trailing zeros flipped to 1, lsb set flipped to
// 0 and the rest unchanged. This effectively is equivalent to doing -1.
// e.g.: 0101000 - 1 = 0100111
// ^^^^ ^^^^
auto mask = val - 1;
// This will create a mask of ones from lsb set to last bit
// e.g.: 0101000 ^ 0100111 = 00001111
// ^^^^ ^^^^
auto masked = val ^ mask;
// Shift that mask to that there is 1s only where there was 0s after the
// lsb set before
// e.g.: 00001111 >> 1 = 00000111 (val is 0101000 in the example)
auto ones = masked >> 1;
// Number of bit set is the lsb set. This operation should be optimized by
// the compiler without unsing intrinsics. This operation will become
// constexpr starting from C++23. In the meantime, that fallback should not
// be used much in favor of the constexpr intrinsic
return std::bitset<sizeof(ones) * CHAR_BIT>(ones).count();
}
}
/**
* Returns the bit position of the LSB that is set in the input
* That function will either use a builting that exploit a "count trailing
* zeros" instruction or use a bit-fidling algorithm explained bellow.
*
* @ingroup api_bitfield
*/
constexpr int
findLsbSet(uint64_t val)
findLsbSet(uint64_t val) {
if (val == 0) return 64;
if constexpr (hasBuiltinCtz<decltype(val)>()) {
return __builtin_ctzll(val);
} else {
return findLsbSetFallback(val);
}
}
template<size_t N>
constexpr int
findLsbSet(std::bitset<N> bs)
{
int lsb = 0;
if (!val)
return sizeof(val) * 8;
if (!bits(val, 31, 0)) {
lsb += 32;
val >>= 32;
if constexpr (N <= 64) {
return findLsbSet(bs.to_ullong());
} else {
if (bs.none()) return N;
// Mask of ones
constexpr std::bitset<N> mask(std::numeric_limits<uint64_t>::max());
// Is the lsb set in the rightmost 64 bits ?
auto nextQword{bs & mask};
int i{0};
while (nextQword.none()) {
// If no, shift by 64 bits and repeat
i += 64;
bs >>= 64;
nextQword = bs & mask;
}
// If yes, account for the bumber of 64-bit shifts and add the
// remaining using the uint64_t implementation. Store in intermediate
// variable to ensure valid conversion from ullong to uint64_t.
uint64_t remaining{nextQword.to_ullong()};
return i + findLsbSet(remaining);
}
if (!bits(val, 15, 0)) {
lsb += 16;
val >>= 16;
}
if (!bits(val, 7, 0)) {
lsb += 8;
val >>= 8;
}
if (!bits(val, 3, 0)) {
lsb += 4;
val >>= 4;
}
if (!bits(val, 1, 0)) {
lsb += 2;
val >>= 2;
}
if (!bits(val, 0, 0))
lsb += 1;
return lsb;
}
/**

View File

@@ -316,6 +316,7 @@ TEST(BitfieldTest, FindLsb)
{
uint64_t val = (1ULL << 63) + (1 << 1);
EXPECT_EQ(1, findLsbSet(val));
EXPECT_EQ(1, findLsbSetFallback(val));
}
TEST(BitfieldTest, FindLsbZero)
@@ -323,6 +324,23 @@ TEST(BitfieldTest, FindLsbZero)
EXPECT_EQ(64, findLsbSet(0));
}
TEST(BitfieldTest, FindLsbGeneralized)
{
static constexpr size_t N{1000};
std::bitset<N> bs{0};
EXPECT_EQ(findLsbSet(bs), N);
for (size_t i{0}; i < N ; ++i) {
bs = std::bitset<N>{1} << i;
ASSERT_EQ(findLsbSet(bs), i);
}
const auto leadingOne = std::bitset<N>{1} << (N-1);
for (size_t i{0}; i < N ; ++i) {
bs = leadingOne | (std::bitset<N>{1} << i);
ASSERT_EQ(findLsbSet(bs), i);
}
}
/*
* The following tests "popCount(X)". popCount counts the number of bits set to
* one.