Switch to std::variant in configuration library.

This commit is contained in:
2022-02-23 15:57:32 +01:00
parent aed3d37699
commit 70f987f9e9
7 changed files with 299 additions and 265 deletions

View File

@@ -44,7 +44,6 @@
#include "memspec/MemSpec.h"
#include "util.h"
#include <memory>
#include <nlohmann/json.hpp>
#include <optional>
#include <string>
@@ -53,10 +52,10 @@
* To support polymorphic configurations, a Json "type" tag is used
* to determine the correct type before further parsing.
*
* To support optional values, std::pair is used. The first parameter is the value,
* the second parameter specifies if the value is valid.
* Replace with std::optional when this project is updated to C++17.
* Consider also switching to std::variant to achieve static polymorphism.
* To support optional values, std::optional is used. The default
* values will be provided by DRAMSys itself.
*
* To achieve static polymorphism, std::variant is used.
*/
namespace DRAMSysConfiguration

View File

@@ -38,7 +38,6 @@
#include "util.h"
#include <memory>
#include <nlohmann/json.hpp>
#include <optional>
#include <utility>

View File

@@ -35,6 +35,8 @@
#include "TraceSetup.h"
#include <variant>
namespace DRAMSysConfiguration
{
@@ -55,94 +57,112 @@ void to_json(json &j, const TraceSetup &c)
{
json initiator_j;
initiator_j["name"] = initiator->name;
initiator_j["clkMhz"] = initiator->clkMhz;
initiator_j["maxPendingReadRequests"] = initiator->maxPendingReadRequests;
initiator_j["maxPendingWriteRequests"] = initiator->maxPendingWriteRequests;
initiator_j["addLengthConverter"] = initiator->addLengthConverter;
if (const auto generator = dynamic_cast<TraceGenerator *>(initiator.get()))
{
initiator_j["type"] = "generator";
initiator_j["seed"] = generator->seed;
initiator_j["maxTransactions"] = generator->maxTransactions;
initiator_j["idleUntil"] = generator->idleUntil;
// When there are less than 2 states, flatten out the json.
if (generator->states.size() == 1)
std::visit(
[&initiator_j](auto &&initiator)
{
const auto &state = generator->states[0];
initiator_j["name"] = initiator.name;
initiator_j["clkMhz"] = initiator.clkMhz;
initiator_j["maxPendingReadRequests"] = initiator.maxPendingReadRequests;
initiator_j["maxPendingWriteRequests"] = initiator.maxPendingWriteRequests;
initiator_j["addLengthConverter"] = initiator.addLengthConverter;
if (const auto trafficState = dynamic_cast<TraceGeneratorTrafficState *>(state.get()))
using T = std::decay_t<decltype(initiator)>;
if constexpr (std::is_same_v<T, TraceGenerator>)
{
initiator_j["numRequests"] = trafficState->numRequests;
initiator_j["rwRatio"] = trafficState->rwRatio;
initiator_j["addressDistribution"] = trafficState->addressDistribution;
initiator_j["addressIncrement"] = trafficState->addressIncrement;
initiator_j["minAddress"] = trafficState->minAddress;
initiator_j["maxAddress"] = trafficState->maxAddress;
initiator_j["clksPerRequest"] = trafficState->clksPerRequest;
initiator_j["notify"] = trafficState->notify;
}
else if (const auto idleState = dynamic_cast<TraceGeneratorIdleState *>(state.get()))
{
initiator_j["idleClks"] = idleState->idleClks;
}
}
else
{
json states_j = json::array();
initiator_j["type"] = "generator";
initiator_j["seed"] = initiator.seed;
initiator_j["maxTransactions"] = initiator.maxTransactions;
initiator_j["idleUntil"] = initiator.idleUntil;
for (const auto &state : generator->states)
{
json state_j;
state_j["id"] = state.first;
if (const auto trafficState = dynamic_cast<TraceGeneratorTrafficState *>(state.second.get()))
// When there are less than 2 states, flatten out the json.
if (initiator.states.size() == 1)
{
state_j["numRequests"] = trafficState->numRequests;
state_j["rwRatio"] = trafficState->rwRatio;
state_j["addressDistribution"] = trafficState->addressDistribution;
state_j["addressIncrement"] = trafficState->addressIncrement;
state_j["minAddress"] = trafficState->minAddress;
state_j["maxAddress"] = trafficState->maxAddress;
state_j["clksPerRequest"] = trafficState->clksPerRequest;
state_j["notify"] = trafficState->notify;
std::visit(
[&initiator_j](auto &&state)
{
using U = std::decay_t<decltype(state)>;
if constexpr (std::is_same_v<U, TraceGeneratorTrafficState>)
{
initiator_j["numRequests"] = state.numRequests;
initiator_j["rwRatio"] = state.rwRatio;
initiator_j["addressDistribution"] = state.addressDistribution;
initiator_j["addressIncrement"] = state.addressIncrement;
initiator_j["minAddress"] = state.minAddress;
initiator_j["maxAddress"] = state.maxAddress;
initiator_j["clksPerRequest"] = state.clksPerRequest;
initiator_j["notify"] = state.notify;
}
else // if constexpr (std::is_same_v<U, TraceGeneratorIdleState>)
{
initiator_j["idleClks"] = state.idleClks;
}
},
initiator.states.at(0));
}
else if (const auto idleState = dynamic_cast<TraceGeneratorIdleState *>(state.second.get()))
else
{
state_j["idleClks"] = idleState->idleClks;
json states_j = json::array();
for (const auto &state : initiator.states)
{
json state_j;
state_j["id"] = state.first;
std::visit(
[&state_j](auto &&state)
{
using U = std::decay_t<decltype(state)>;
if constexpr (std::is_same_v<U, TraceGeneratorTrafficState>)
{
state_j["numRequests"] = state.numRequests;
state_j["rwRatio"] = state.rwRatio;
state_j["addressDistribution"] = state.addressDistribution;
state_j["addressIncrement"] = state.addressIncrement;
state_j["minAddress"] = state.minAddress;
state_j["maxAddress"] = state.maxAddress;
state_j["clksPerRequest"] = state.clksPerRequest;
state_j["notify"] = state.notify;
}
else // if constexpr (std::is_same_v<U, TraceGeneratorIdleState>)
{
state_j["idleClks"] = state.idleClks;
}
},
state.second);
remove_null_values(state_j);
states_j.insert(states_j.end(), state_j);
}
initiator_j["states"] = states_j;
json transitions_j = json::array();
for (const auto &transition : initiator.transitions)
{
json transition_j;
transition_j["from"] = transition.first;
transition_j["to"] = transition.second.to;
transition_j["propability"] = transition.second.propability;
remove_null_values(transition_j);
transitions_j.insert(transitions_j.end(), transition_j);
}
initiator_j["transitions"] = transitions_j;
}
remove_null_values(state_j);
states_j.insert(states_j.end(), state_j);
}
initiator_j["states"] = states_j;
json transitions_j = json::array();
for (const auto &transition : generator->transitions)
else if constexpr (std::is_same_v<T, TraceHammer>)
{
json transition_j;
transition_j["from"] = transition.first;
transition_j["to"] = transition.second.to;
transition_j["propability"] = transition.second.propability;
remove_null_values(transition_j);
transitions_j.insert(transitions_j.end(), transition_j);
initiator_j["type"] = "hammer";
initiator_j["numRequests"] = initiator.numRequests;
initiator_j["rowIncrement"] = initiator.rowIncrement;
}
initiator_j["transitions"] = transitions_j;
}
}
else if (const auto hammer = dynamic_cast<TraceHammer *>(initiator.get()))
{
initiator_j["type"] = "hammer";
initiator_j["numRequests"] = hammer->numRequests;
initiator_j["rowIncrement"] = hammer->rowIncrement;
}
else if (const auto player = dynamic_cast<TracePlayer *>(initiator.get()))
{
initiator_j["type"] = "player";
}
else // if constexpr (std::is_same_v<T, TracePlayer>)
{
initiator_j["type"] = "player";
}
},
initiator);
remove_null_values(initiator_j);
j.insert(j.end(), initiator_j);
@@ -156,51 +176,53 @@ void from_json(const json &j, TraceSetup &c)
// Default to Player, when not specified
TrafficInitiatorType type = initiator_j.value("type", TrafficInitiatorType::Player);
std::unique_ptr<TrafficInitiator> initiator;
std::variant<TracePlayer, TraceGenerator, TraceHammer> initiator;
if (type == TrafficInitiatorType::Player)
{
initiator = std::unique_ptr<TracePlayer>(new TracePlayer);
initiator = TracePlayer{};
}
else if (type == TrafficInitiatorType::Generator)
{
TraceGenerator *generator = new TraceGenerator;
TraceGenerator generator;
auto process_state = [](const json &state_j) -> std::pair<unsigned int, std::unique_ptr<TraceGeneratorState>>
auto process_state = [](const json &state_j)
-> std::pair<unsigned int, std::variant<TraceGeneratorIdleState, TraceGeneratorTrafficState>>
{
std::unique_ptr<TraceGeneratorState> state;
std::variant<TraceGeneratorIdleState, TraceGeneratorTrafficState> state;
if (state_j.contains("idleClks"))
{
// Idle state
auto idleState = new TraceGeneratorIdleState;
state_j.at("idleClks").get_to(idleState->idleClks);
TraceGeneratorIdleState idleState;
state_j.at("idleClks").get_to(idleState.idleClks);
state = std::unique_ptr<TraceGeneratorIdleState>(idleState);
state = std::move(idleState);
}
else
{
// Traffic state
auto trafficState = new TraceGeneratorTrafficState;
state_j.at("numRequests").get_to(trafficState->numRequests);
state_j.at("rwRatio").get_to(trafficState->rwRatio);
state_j.at("addressDistribution").get_to(trafficState->addressDistribution);
TraceGeneratorTrafficState trafficState;
state_j.at("numRequests").get_to(trafficState.numRequests);
state_j.at("rwRatio").get_to(trafficState.rwRatio);
state_j.at("addressDistribution").get_to(trafficState.addressDistribution);
if (state_j.contains("addressIncrement"))
state_j.at("addressIncrement").get_to(trafficState->addressIncrement);
state_j.at("addressIncrement").get_to(trafficState.addressIncrement);
if (state_j.contains("minAddress"))
state_j.at("minAddress").get_to(trafficState->minAddress);
state_j.at("minAddress").get_to(trafficState.minAddress);
if (state_j.contains("maxAddress"))
state_j.at("maxAddress").get_to(trafficState->maxAddress);
state_j.at("maxAddress").get_to(trafficState.maxAddress);
if (state_j.contains("clksPerRequest"))
state_j.at("clksPerRequest").get_to(trafficState->clksPerRequest);
state_j.at("clksPerRequest").get_to(trafficState.clksPerRequest);
if (state_j.contains("notify"))
state_j.at("notify").get_to(trafficState->notify);
state_j.at("notify").get_to(trafficState.notify);
state = std::unique_ptr<TraceGeneratorTrafficState>(trafficState);
state = std::move(trafficState);
}
// Default to 0
@@ -217,7 +239,7 @@ void from_json(const json &j, TraceSetup &c)
for (const auto &state_j : initiator_j.at("states"))
{
auto state = process_state(state_j);
generator->states[state.first] = std::move(state.second);
generator.states[state.first] = std::move(state.second);
}
for (const auto &transition_j : initiator_j.at("transitions"))
@@ -226,47 +248,52 @@ void from_json(const json &j, TraceSetup &c)
unsigned int from = transition_j.at("from");
transition.to = transition_j.at("to");
transition.propability = transition_j.at("propability");
generator->transitions.emplace(from, transition);
generator.transitions.emplace(from, transition);
}
}
else // Only one state will be created
{
auto state = process_state(initiator_j);
generator->states[state.first] = std::move(state.second);
generator.states[state.first] = std::move(state.second);
}
if (initiator_j.contains("seed"))
initiator_j.at("seed").get_to(generator->seed);
initiator_j.at("seed").get_to(generator.seed);
if (initiator_j.contains("maxTransactions"))
initiator_j.at("maxTransactions").get_to(generator->maxTransactions);
initiator_j.at("maxTransactions").get_to(generator.maxTransactions);
if (initiator_j.contains("idleUntil"))
initiator_j.at("idleUntil").get_to(generator->idleUntil);
initiator_j.at("idleUntil").get_to(generator.idleUntil);
initiator = std::unique_ptr<TraceGenerator>(generator);
initiator = generator;
}
else if (type == TrafficInitiatorType::Hammer)
{
TraceHammer *hammer = new TraceHammer;
TraceHammer hammer;
initiator_j.at("numRequests").get_to(hammer->numRequests);
initiator_j.at("rowIncrement").get_to(hammer->rowIncrement);
initiator_j.at("numRequests").get_to(hammer.numRequests);
initiator_j.at("rowIncrement").get_to(hammer.rowIncrement);
initiator = std::unique_ptr<TraceHammer>(hammer);
initiator = hammer;
}
initiator_j.at("name").get_to(initiator->name);
initiator_j.at("clkMhz").get_to(initiator->clkMhz);
std::visit(
[&initiator_j](auto &&initiator)
{
initiator_j.at("name").get_to(initiator.name);
initiator_j.at("clkMhz").get_to(initiator.clkMhz);
if (initiator_j.contains("maxPendingReadRequests"))
initiator_j.at("maxPendingReadRequests").get_to(initiator->maxPendingReadRequests);
if (initiator_j.contains("maxPendingReadRequests"))
initiator_j.at("maxPendingReadRequests").get_to(initiator.maxPendingReadRequests);
if (initiator_j.contains("maxPendingWriteRequests"))
initiator_j.at("maxPendingWriteRequests").get_to(initiator->maxPendingWriteRequests);
if (initiator_j.contains("maxPendingWriteRequests"))
initiator_j.at("maxPendingWriteRequests").get_to(initiator.maxPendingWriteRequests);
if (initiator_j.contains("addLengthConverter"))
initiator_j.at("addLengthConverter").get_to(initiator->addLengthConverter);
if (initiator_j.contains("addLengthConverter"))
initiator_j.at("addLengthConverter").get_to(initiator.addLengthConverter);
},
initiator);
c.initiators.emplace_back(std::move(initiator));
}

View File

@@ -38,9 +38,9 @@
#include "util.h"
#include <memory>
#include <nlohmann/json.hpp>
#include <optional>
#include <variant>
namespace DRAMSysConfiguration
{
@@ -117,7 +117,7 @@ struct TraceGenerator : public TrafficInitiator
{
std::optional<uint64_t> seed;
std::optional<uint64_t> maxTransactions;
std::map<unsigned int, std::unique_ptr<TraceGeneratorState>> states;
std::map<unsigned int, std::variant<TraceGeneratorIdleState, TraceGeneratorTrafficState>> states;
std::multimap<unsigned int, TraceGeneratorStateTransition> transitions;
std::optional<std::string> idleUntil;
};
@@ -130,7 +130,7 @@ struct TraceHammer : public TrafficInitiator
struct TraceSetup
{
std::vector<std::shared_ptr<TrafficInitiator>> initiators;
std::vector<std::variant<TracePlayer, TraceGenerator, TraceHammer>> initiators;
};
void to_json(json &j, const TraceSetup &c);

View File

@@ -94,89 +94,88 @@ DRAMSysConfiguration::ThermalConfig getThermalConfig()
true};
}
std::unique_ptr<DRAMSysConfiguration::TracePlayer> getTracePlayer()
DRAMSysConfiguration::TracePlayer getTracePlayer()
{
DRAMSysConfiguration::TracePlayer *player = new DRAMSysConfiguration::TracePlayer;
DRAMSysConfiguration::TracePlayer player;
player.clkMhz = 100;
player.name = "mytrace.stl";
player->clkMhz = 100;
player->name = "mytrace.stl";
return std::unique_ptr<DRAMSysConfiguration::TracePlayer>(player);
return player;
}
std::unique_ptr<DRAMSysConfiguration::TraceGenerator> getTraceGeneratorOneState()
DRAMSysConfiguration::TraceGenerator getTraceGeneratorOneState()
{
DRAMSysConfiguration::TraceGenerator *gen = new DRAMSysConfiguration::TraceGenerator;
DRAMSysConfiguration::TraceGenerator gen;
gen.clkMhz = 100;
gen.name = "MyTestGen";
gen->clkMhz = 100;
gen->name = "MyTestGen";
DRAMSysConfiguration::TraceGeneratorTrafficState state0;
state0.numRequests = 1000;
state0.rwRatio = 0.5;
state0.addressDistribution = DRAMSysConfiguration::AddressDistribution::Random;
state0.addressIncrement = {};
state0.minAddress = {};
state0.maxAddress = {};
state0.clksPerRequest = {};
auto state0 = new DRAMSysConfiguration::TraceGeneratorTrafficState;
state0->numRequests = 1000;
state0->rwRatio = 0.5;
state0->addressDistribution = DRAMSysConfiguration::AddressDistribution::Random;
state0->addressIncrement = {};
state0->minAddress = {};
state0->maxAddress = {};
state0->clksPerRequest = {};
gen.states.emplace(0, state0);
gen->states[0] = std::unique_ptr<DRAMSysConfiguration::TraceGeneratorTrafficState>(state0);
return std::unique_ptr<DRAMSysConfiguration::TraceGenerator>(gen);
return gen;
}
std::unique_ptr<DRAMSysConfiguration::TraceGenerator> getTraceGeneratorMultipleStates()
DRAMSysConfiguration::TraceGenerator getTraceGeneratorMultipleStates()
{
DRAMSysConfiguration::TraceGenerator *gen = new DRAMSysConfiguration::TraceGenerator;
DRAMSysConfiguration::TraceGenerator gen;
gen->clkMhz = 100;
gen->name = "MyTestGen";
gen->maxPendingReadRequests = 8;
gen.clkMhz = 100;
gen.name = "MyTestGen";
gen.maxPendingReadRequests = 8;
auto state0 = new DRAMSysConfiguration::TraceGeneratorTrafficState;
state0->numRequests = 1000;
state0->rwRatio = 0.5;
state0->addressDistribution = DRAMSysConfiguration::AddressDistribution::Sequential;
state0->addressIncrement = 256;
state0->minAddress = {};
state0->maxAddress = 1024;
state0->clksPerRequest = {};
DRAMSysConfiguration::TraceGeneratorTrafficState state0;
state0.numRequests = 1000;
state0.rwRatio = 0.5;
state0.addressDistribution = DRAMSysConfiguration::AddressDistribution::Sequential;
state0.addressIncrement = 256;
state0.minAddress = {};
state0.maxAddress = 1024;
state0.clksPerRequest = {};
auto state1 = new DRAMSysConfiguration::TraceGeneratorTrafficState;
state1->numRequests = 100;
state1->rwRatio = 0.75;
state1->addressDistribution = DRAMSysConfiguration::AddressDistribution::Sequential;
state1->addressIncrement = 512;
state1->minAddress = 1024;
state1->maxAddress = 2048;
state1->clksPerRequest = {};
DRAMSysConfiguration::TraceGeneratorTrafficState state1;
state1.numRequests = 100;
state1.rwRatio = 0.75;
state1.addressDistribution = DRAMSysConfiguration::AddressDistribution::Sequential;
state1.addressIncrement = 512;
state1.minAddress = 1024;
state1.maxAddress = 2048;
state1.clksPerRequest = {};
gen->states[0] = std::unique_ptr<DRAMSysConfiguration::TraceGeneratorTrafficState>(state0);
gen->states[1] = std::unique_ptr<DRAMSysConfiguration::TraceGeneratorTrafficState>(state1);
gen.states.emplace(0, state0);
gen.states.emplace(1, state1);
DRAMSysConfiguration::TraceGeneratorStateTransition transistion0{1, 1.0};
gen->transitions.emplace(0, transistion0);
gen.transitions.emplace(0, transistion0);
return std::unique_ptr<DRAMSysConfiguration::TraceGenerator>(gen);
return gen;
}
std::unique_ptr<DRAMSysConfiguration::TraceHammer> getTraceHammer()
DRAMSysConfiguration::TraceHammer getTraceHammer()
{
DRAMSysConfiguration::TraceHammer *hammer = new DRAMSysConfiguration::TraceHammer;
DRAMSysConfiguration::TraceHammer hammer;
hammer->clkMhz = 100;
hammer->name = "MyTestHammer";
hammer.clkMhz = 100;
hammer.name = "MyTestHammer";
hammer.numRequests = 4000;
hammer.rowIncrement = 2097152;
hammer->numRequests = 4000;
hammer->rowIncrement = 2097152;
return std::unique_ptr<DRAMSysConfiguration::TraceHammer>(hammer);
return hammer;
}
DRAMSysConfiguration::TraceSetup getTraceSetup()
{
std::vector<std::shared_ptr<DRAMSysConfiguration::TrafficInitiator>> initiators;
using namespace DRAMSysConfiguration;
std::vector<std::variant<TracePlayer, TraceGenerator, TraceHammer>> initiators;
initiators.emplace_back(getTracePlayer());
initiators.emplace_back(getTraceGeneratorOneState());
initiators.emplace_back(getTraceGeneratorMultipleStates());

View File

@@ -45,102 +45,107 @@
using namespace sc_core;
using namespace tlm;
TraceSetup::TraceSetup(const DRAMSysConfiguration::TraceSetup &traceSetup,
const std::string &pathToResources,
TraceSetup::TraceSetup(const DRAMSysConfiguration::TraceSetup &traceSetup, const std::string &pathToResources,
std::vector<std::unique_ptr<TrafficInitiator>> &players)
{
if (traceSetup.initiators.empty())
SC_REPORT_FATAL("TraceSetup", "No traffic initiators specified");
for (const auto &inititator : traceSetup.initiators)
for (const auto &initiator : traceSetup.initiators)
{
double frequencyMHz = inititator->clkMhz;
sc_time playerClk = sc_time(1.0 / frequencyMHz, SC_US);
std::visit(
[&](auto &&initiator)
{
std::string name = initiator.name;
double frequencyMHz = initiator.clkMhz;
sc_time playerClk = sc_time(1.0 / frequencyMHz, SC_US);
std::string name = inititator->name;
unsigned int maxPendingReadRequests = [=]() -> unsigned int
{
if (const auto &maxPendingReadRequests = initiator.maxPendingReadRequests)
return *maxPendingReadRequests;
else
return 0;
}();
unsigned int maxPendingReadRequests = [=]() -> unsigned int
{
if (const auto &maxPendingReadRequests = inititator->maxPendingReadRequests)
return *maxPendingReadRequests;
else
return 0;
}();
unsigned int maxPendingWriteRequests = [=]() -> unsigned int
{
if (const auto &maxPendingWriteRequests = initiator.maxPendingWriteRequests)
return *maxPendingWriteRequests;
else
return 0;
}();
unsigned int maxPendingWriteRequests = [=]() -> unsigned int
{
if (const auto &maxPendingWriteRequests = inititator->maxPendingWriteRequests)
return *maxPendingWriteRequests;
else
return 0;
}();
bool addLengthConverter = [=]() -> bool
{
if (const auto &addLengthConverter = initiator.addLengthConverter)
return *addLengthConverter;
else
return false;
}();
bool addLengthConverter = [=]() -> bool
{
if (const auto &addLengthConverter = inititator->addLengthConverter)
return *addLengthConverter;
else
return false;
}();
using T = std::decay_t<decltype(initiator)>;
if constexpr (std::is_same_v<T, DRAMSysConfiguration::TracePlayer>)
{
size_t pos = name.rfind('.');
if (pos == std::string::npos)
throw std::runtime_error("Name of the trace file does not contain a valid extension.");
if (std::dynamic_pointer_cast<DRAMSysConfiguration::TracePlayer>(inititator))
{
size_t pos = name.rfind('.');
if (pos == std::string::npos)
throw std::runtime_error("Name of the trace file does not contain a valid extension.");
// Get the extension and make it lower case
std::string ext = name.substr(pos + 1);
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
// Get the extension and make it lower case
std::string ext = name.substr(pos + 1);
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
std::stringstream stlFileStream;
stlFileStream << pathToResources << "/traces/" << name;
std::string stlFile = stlFileStream.str();
std::string moduleName = name;
std::stringstream stlFileStream;
stlFileStream << pathToResources << "/traces/" << name;
std::string stlFile = stlFileStream.str();
std::string moduleName = name;
// replace all '.' to '_'
std::replace(moduleName.begin(), moduleName.end(), '.', '_');
// replace all '.' to '_'
std::replace(moduleName.begin(), moduleName.end(), '.', '_');
StlPlayer *player;
if (ext == "stl")
player = new StlPlayer(moduleName.c_str(), stlFile, playerClk, maxPendingReadRequests,
maxPendingWriteRequests, addLengthConverter, this, false);
else if (ext == "rstl")
player = new StlPlayer(moduleName.c_str(), stlFile, playerClk, maxPendingReadRequests,
maxPendingWriteRequests, addLengthConverter, this, true);
else
throw std::runtime_error("Unsupported file extension in " + name);
StlPlayer *player;
if (ext == "stl")
player = new StlPlayer(moduleName.c_str(), stlFile, playerClk,
maxPendingReadRequests, maxPendingWriteRequests, addLengthConverter, this, false);
else if (ext == "rstl")
player = new StlPlayer(moduleName.c_str(), stlFile, playerClk,
maxPendingReadRequests, maxPendingWriteRequests, addLengthConverter, this, true);
else
throw std::runtime_error("Unsupported file extension in " + name);
players.push_back(std::unique_ptr<TrafficInitiator>(player));
totalTransactions += player->getNumberOfLines();
}
else if constexpr (std::is_same_v<T, DRAMSysConfiguration::TraceGenerator>)
{
TrafficGenerator *trafficGenerator = new TrafficGenerator(name.c_str(), initiator, this);
players.push_back(std::unique_ptr<TrafficInitiator>(trafficGenerator));
players.push_back(std::unique_ptr<TrafficInitiator>(player));
totalTransactions += player->getNumberOfLines();
}
else if (auto generator = std::dynamic_pointer_cast<DRAMSysConfiguration::TraceGenerator>(inititator))
{
TrafficGenerator *trafficGenerator = new TrafficGenerator(name.c_str(), *generator, this);
players.push_back(std::unique_ptr<TrafficInitiator>(trafficGenerator));
totalTransactions += trafficGenerator->getTotalTransactions();
}
else // if constexpr (std::is_same_v<T, DRAMSysConfiguration::TraceHammer>)
{
uint64_t numRequests = initiator.numRequests;
uint64_t rowIncrement = initiator.rowIncrement;
totalTransactions += trafficGenerator->getTotalTransactions();
}
else if (auto hammer = std::dynamic_pointer_cast<DRAMSysConfiguration::TraceHammer>(inititator))
{
uint64_t numRequests = hammer->numRequests;
uint64_t rowIncrement = hammer->rowIncrement;
players.push_back(std::unique_ptr<TrafficInitiator>(new TrafficGeneratorHammer(name.c_str(), *hammer, this)));
totalTransactions += numRequests;
}
players.push_back(
std::unique_ptr<TrafficInitiator>(new TrafficGeneratorHammer(name.c_str(), initiator, this)));
totalTransactions += numRequests;
}
},
initiator);
}
for (const auto &inititatorConf : traceSetup.initiators)
{
if (auto generatorConf = std::dynamic_pointer_cast<DRAMSysConfiguration::TraceGenerator>(inititatorConf))
if (auto generatorConf = std::get_if<DRAMSysConfiguration::TraceGenerator>(&inititatorConf))
{
if (const auto &idleUntil = generatorConf->idleUntil)
{
const std::string name = generatorConf->name;
auto listenerIt =
std::find_if(players.begin(), players.end(),
[&name](const std::unique_ptr<TrafficInitiator> &initiator) { return initiator->name() == name; });
auto listenerIt = std::find_if(players.begin(), players.end(),
[&name](const std::unique_ptr<TrafficInitiator> &initiator)
{ return initiator->name() == name; });
// Should be found
auto listener = dynamic_cast<TrafficGenerator *>(listenerIt->get());
@@ -205,7 +210,7 @@ void TraceSetup::loadBar(uint64_t x, uint64_t n, unsigned int w, unsigned int gr
if ((n < 100) || ((x != n) && (x % (n / 100 * granularity) != 0)))
return;
float ratio = x / (float) n;
float ratio = x / (float)n;
unsigned int c = (ratio * w);
float rest = (ratio * w) - c;
std::cout << std::setw(3) << round(ratio * 100) << "% |";

View File

@@ -37,6 +37,7 @@
*/
#include "TrafficGenerator.h"
#include "TraceSetup.h"
#include <limits>
@@ -103,7 +104,7 @@ TrafficGenerator::TrafficGenerator(const sc_module_name &name, const DRAMSysConf
// Perform checks for all states
for (const auto &state : conf.states)
{
if (auto trafficState = dynamic_cast<DRAMSysConfiguration::TraceGeneratorTrafficState *>(state.second.get()))
if (auto trafficState = std::get_if<DRAMSysConfiguration::TraceGeneratorTrafficState>(&state.second))
{
uint64_t minAddress = evaluateMinAddress(*trafficState);
uint64_t maxAddress = evaluateMaxAddress(*trafficState);
@@ -129,7 +130,8 @@ TrafficGenerator::TrafficGenerator(const sc_module_name &name, const DRAMSysConf
}
}
if (auto trafficState = dynamic_cast<DRAMSysConfiguration::TraceGeneratorTrafficState *>(conf.states.at(currentState).get()))
if (auto trafficState =
std::get_if<DRAMSysConfiguration::TraceGeneratorTrafficState>(&conf.states.at(currentState)))
{
uint64_t minAddress = evaluateMinAddress(*trafficState);
uint64_t maxAddress = evaluateMaxAddress(*trafficState);
@@ -182,7 +184,8 @@ void TrafficGenerator::calculateTransitions()
if (transitionFound)
{
if (auto trafficState = dynamic_cast<DRAMSysConfiguration::TraceGeneratorTrafficState *>(conf.states.at(state).get()))
if (auto trafficState =
std::get_if<DRAMSysConfiguration::TraceGeneratorTrafficState>(&conf.states.at(state)))
totalTransactions += trafficState->numRequests;
if (totalTransactions < maxTransactions)
@@ -221,7 +224,7 @@ uint64_t TrafficGenerator::getTotalTransactions() const
for (auto state : stateSequence)
{
if (auto trafficState = dynamic_cast<DRAMSysConfiguration::TraceGeneratorTrafficState *>(conf.states.at(state).get()))
if (auto trafficState = std::get_if<DRAMSysConfiguration::TraceGeneratorTrafficState>(&conf.states.at(state)))
totalTransactions += trafficState->numRequests;
}
@@ -256,13 +259,14 @@ void TrafficGenerator::transitionToNextState()
it.second.event.notify();
}
if (auto idleState = dynamic_cast<DRAMSysConfiguration::TraceGeneratorIdleState *>(conf.states.at(currentState).get()))
if (auto idleState = std::get_if<DRAMSysConfiguration::TraceGeneratorIdleState>(&conf.states.at(currentState)))
{
currentClksToIdle += idleState->idleClks;
transitionToNextState();
return;
}
else if (auto trafficState = dynamic_cast<DRAMSysConfiguration::TraceGeneratorTrafficState *>(conf.states.at(currentState).get()))
else if (auto trafficState =
std::get_if<DRAMSysConfiguration::TraceGeneratorTrafficState>(&conf.states.at(currentState)))
{
uint64_t minAddress = evaluateMinAddress(*trafficState);
uint64_t maxAddress = evaluateMaxAddress(*trafficState);
@@ -285,7 +289,8 @@ void TrafficGenerator::prepareNextPayload()
if (startEvent && transactionsSent == 0)
wait(*startEvent);
if (auto trafficState = dynamic_cast<DRAMSysConfiguration::TraceGeneratorTrafficState *>(conf.states.at(currentState).get()))
if (auto trafficState =
std::get_if<DRAMSysConfiguration::TraceGeneratorTrafficState>(&conf.states.at(currentState)))
{
if (transactionsSentInCurrentState >= trafficState->numRequests)
transitionToNextState();
@@ -303,10 +308,10 @@ void TrafficGenerator::payloadSent()
tlm::tlm_command TrafficGenerator::getNextCommand()
{
// An idle state should never reach this method.
auto state = static_cast<DRAMSysConfiguration::TraceGeneratorTrafficState *>(conf.states.at(currentState).get());
auto &state = std::get<DRAMSysConfiguration::TraceGeneratorTrafficState>(conf.states.at(currentState));
tlm_command command;
if (randomDistribution(randomGenerator) < state->rwRatio)
if (randomDistribution(randomGenerator) < state.rwRatio)
command = tlm::TLM_READ_COMMAND;
else
command = tlm::TLM_WRITE_COMMAND;
@@ -324,14 +329,14 @@ uint64_t TrafficGenerator::getNextAddress()
using DRAMSysConfiguration::AddressDistribution;
// An idle state should never reach this method.
auto state = static_cast<DRAMSysConfiguration::TraceGeneratorTrafficState *>(conf.states.at(currentState).get());
auto &state = std::get<DRAMSysConfiguration::TraceGeneratorTrafficState>(conf.states.at(currentState));
uint64_t minAddress = evaluateMinAddress(*state);
uint64_t maxAddress = evaluateMaxAddress(*state);
uint64_t minAddress = evaluateMinAddress(state);
uint64_t maxAddress = evaluateMaxAddress(state);
if (state->addressDistribution == AddressDistribution::Sequential)
if (state.addressDistribution == AddressDistribution::Sequential)
{
uint64_t addressIncrement = state->addressIncrement.value_or(defaultAddressIncrement);
uint64_t addressIncrement = state.addressIncrement.value_or(defaultAddressIncrement);
uint64_t address = currentAddress;
currentAddress += addressIncrement;
@@ -339,7 +344,7 @@ uint64_t TrafficGenerator::getNextAddress()
currentAddress = minAddress;
return address;
}
else if (state->addressDistribution == AddressDistribution::Random)
else if (state.addressDistribution == AddressDistribution::Random)
{
return randomAddressDistribution(randomGenerator);
}