diff --git a/src/base/sat_counter.hh b/src/base/sat_counter.hh index d257cdabf4..4849c2aa09 100644 --- a/src/base/sat_counter.hh +++ b/src/base/sat_counter.hh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Inria + * Copyright (c) 2019, 2020 Inria * All rights reserved. * * The license below extends only to copyright in the software and shall @@ -50,12 +50,15 @@ /** * Implements an n bit saturating counter and provides methods to * increment, decrement, and read it. + * + * @tparam T The type of the underlying counter container. */ -class SatCounter +template +class GenericSatCounter { public: /** The default constructor should never be used. */ - SatCounter() = delete; + GenericSatCounter() = delete; /** * Constructor for the counter. The explicit keyword is used to make @@ -68,11 +71,11 @@ class SatCounter * * @ingroup api_sat_counter */ - explicit SatCounter(unsigned bits, uint8_t initial_val = 0) - : initialVal(initial_val), maxVal((1 << bits) - 1), + explicit GenericSatCounter(unsigned bits, T initial_val = 0) + : initialVal(initial_val), maxVal((1ULL << bits) - 1), counter(initial_val) { - fatal_if(bits > 8*sizeof(uint8_t), + fatal_if(bits > 8*sizeof(T), "Number of bits exceeds counter size"); fatal_if(initial_val > maxVal, "Saturating counter's Initial value exceeds max value."); @@ -83,7 +86,7 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter(const SatCounter& other) + GenericSatCounter(const GenericSatCounter& other) : initialVal(other.initialVal), maxVal(other.maxVal), counter(other.counter) { @@ -94,9 +97,9 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter& operator=(const SatCounter& other) { + GenericSatCounter& operator=(const GenericSatCounter& other) { if (this != &other) { - SatCounter temp(other); + GenericSatCounter temp(other); this->swap(temp); } return *this; @@ -107,12 +110,12 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter(SatCounter&& other) + GenericSatCounter(GenericSatCounter&& other) { initialVal = other.initialVal; maxVal = other.maxVal; counter = other.counter; - SatCounter temp(0); + GenericSatCounter temp(0); other.swap(temp); } @@ -121,12 +124,12 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter& operator=(SatCounter&& other) { + GenericSatCounter& operator=(GenericSatCounter&& other) { if (this != &other) { initialVal = other.initialVal; maxVal = other.maxVal; counter = other.counter; - SatCounter temp(0); + GenericSatCounter temp(0); other.swap(temp); } return *this; @@ -141,7 +144,7 @@ class SatCounter * @ingroup api_sat_counter */ void - swap(SatCounter& other) + swap(GenericSatCounter& other) { std::swap(initialVal, other.initialVal); std::swap(maxVal, other.maxVal); @@ -153,7 +156,7 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter& + GenericSatCounter& operator++() { if (counter < maxVal) { @@ -167,10 +170,10 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter + GenericSatCounter operator++(int) { - SatCounter old_counter = *this; + GenericSatCounter old_counter = *this; ++*this; return old_counter; } @@ -180,7 +183,7 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter& + GenericSatCounter& operator--() { if (counter > 0) { @@ -194,10 +197,10 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter + GenericSatCounter operator--(int) { - SatCounter old_counter = *this; + GenericSatCounter old_counter = *this; --*this; return old_counter; } @@ -207,7 +210,7 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter& + GenericSatCounter& operator>>=(const int& shift) { assert(shift >= 0); @@ -220,7 +223,7 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter& + GenericSatCounter& operator<<=(const int& shift) { assert(shift >= 0); @@ -236,8 +239,8 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter& - operator+=(const int& value) + GenericSatCounter& + operator+=(const long long& value) { if (value >= 0) { if (maxVal - this->counter >= value) { @@ -256,8 +259,8 @@ class SatCounter * * @ingroup api_sat_counter */ - SatCounter& - operator-=(const int& value) + GenericSatCounter& + operator-=(const long long& value) { if (value >= 0) { if (this->counter > value) { @@ -276,7 +279,7 @@ class SatCounter * * @ingroup api_sat_counter */ - operator uint8_t() const { return counter; } + operator T() const { return counter; } /** * Reset the counter to its initial value. @@ -320,9 +323,21 @@ class SatCounter } private: - uint8_t initialVal; - uint8_t maxVal; - uint8_t counter; + T initialVal; + T maxVal; + T counter; }; +/** @ingroup api_sat_counter + * @{ + */ +typedef GenericSatCounter SatCounter8; +typedef GenericSatCounter SatCounter16; +typedef GenericSatCounter SatCounter32; +typedef GenericSatCounter SatCounter64; +/** @} */ + +[[deprecated("Use SatCounter8 (or variants) instead")]] +typedef SatCounter8 SatCounter; + #endif // __BASE_SAT_COUNTER_HH__ diff --git a/src/base/sat_counter.test.cc b/src/base/sat_counter.test.cc index 214b015774..4d400c0f79 100644 --- a/src/base/sat_counter.test.cc +++ b/src/base/sat_counter.test.cc @@ -360,3 +360,98 @@ TEST(SatCounterTest, NegativeAddSubAssignment) ASSERT_EQ(counter, value); } +/** Test max and min when using SatCounter16. */ +TEST(SatCounterTest, Size16) +{ + const uint16_t bits_16 = 9; + const uint16_t max_value_16 = (1 << bits_16) - 1; + SatCounter16 counter_16(bits_16); + + // Increasing + counter_16++; + ASSERT_EQ(counter_16, 1); + counter_16 <<= 1; + ASSERT_EQ(counter_16, 2); + counter_16 += 2 * max_value_16; + ASSERT_EQ(counter_16, max_value_16); + counter_16++; + ASSERT_EQ(counter_16, max_value_16); + counter_16 <<= 1; + ASSERT_EQ(counter_16, max_value_16); + + // Decreasing + counter_16--; + ASSERT_EQ(counter_16, max_value_16 - 1); + counter_16 >>= 1; + ASSERT_EQ(counter_16, (max_value_16 - 1) >> 1); + counter_16 -= 2 * max_value_16; + ASSERT_EQ(counter_16, 0); + counter_16--; + ASSERT_EQ(counter_16, 0); + counter_16 >>= 1; + ASSERT_EQ(counter_16, 0); +} + +/** Test max and min when using SatCounter32. */ +TEST(SatCounterTest, Size32) +{ + const uint32_t bits_32 = 17; + const uint32_t max_value_32 = (1 << bits_32) - 1; + SatCounter32 counter_32(bits_32); + + // Increasing + counter_32++; + ASSERT_EQ(counter_32, 1); + counter_32 <<= 1; + ASSERT_EQ(counter_32, 2); + counter_32 += 2 * max_value_32; + ASSERT_EQ(counter_32, max_value_32); + counter_32++; + ASSERT_EQ(counter_32, max_value_32); + counter_32 <<= 1; + ASSERT_EQ(counter_32, max_value_32); + + // Decreasing + counter_32--; + ASSERT_EQ(counter_32, max_value_32 - 1); + counter_32 >>= 1; + ASSERT_EQ(counter_32, (max_value_32 - 1) >> 1); + counter_32 -= 2 * max_value_32; + ASSERT_EQ(counter_32, 0); + counter_32--; + ASSERT_EQ(counter_32, 0); + counter_32 >>= 1; + ASSERT_EQ(counter_32, 0); +} + +/** Test max and min when using SatCounter64. */ +TEST(SatCounterTest, Size64) +{ + const uint64_t bits_64 = 33; + const uint64_t max_value_64 = (1ULL << bits_64) - 1; + SatCounter64 counter_64(bits_64); + + // Increasing + counter_64++; + ASSERT_EQ(counter_64, 1); + counter_64 <<= 1; + ASSERT_EQ(counter_64, 2); + counter_64 += max_value_64; + ASSERT_EQ(counter_64, max_value_64); + counter_64++; + ASSERT_EQ(counter_64, max_value_64); + counter_64 <<= 1; + ASSERT_EQ(counter_64, max_value_64); + + // Decreasing + counter_64--; + ASSERT_EQ(counter_64, max_value_64 - 1); + counter_64 >>= 1; + ASSERT_EQ(counter_64, (max_value_64 - 1) >> 1); + counter_64 -= max_value_64; + ASSERT_EQ(counter_64, 0); + counter_64--; + ASSERT_EQ(counter_64, 0); + counter_64 >>= 1; + ASSERT_EQ(counter_64, 0); +}