From 70f987f9e906c6453a99139ae462df20d52b84fc Mon Sep 17 00:00:00 2001 From: Derek Christ Date: Wed, 23 Feb 2022 15:57:32 +0100 Subject: [PATCH] Switch to std::variant in configuration library. --- .../src/common/configuration/Configuration.h | 9 +- .../src/common/configuration/McConfig.h | 1 - .../src/common/configuration/TraceSetup.cpp | 255 ++++++++++-------- .../src/common/configuration/TraceSetup.h | 6 +- .../common/configuration/tests/simpletest.cpp | 109 ++++---- DRAMSys/simulator/TraceSetup.cpp | 149 +++++----- DRAMSys/simulator/TrafficGenerator.cpp | 35 +-- 7 files changed, 299 insertions(+), 265 deletions(-) diff --git a/DRAMSys/library/src/common/configuration/Configuration.h b/DRAMSys/library/src/common/configuration/Configuration.h index 61091aa1..657f11be 100644 --- a/DRAMSys/library/src/common/configuration/Configuration.h +++ b/DRAMSys/library/src/common/configuration/Configuration.h @@ -44,7 +44,6 @@ #include "memspec/MemSpec.h" #include "util.h" -#include #include #include #include @@ -53,10 +52,10 @@ * To support polymorphic configurations, a Json "type" tag is used * to determine the correct type before further parsing. * - * To support optional values, std::pair is used. The first parameter is the value, - * the second parameter specifies if the value is valid. - * Replace with std::optional when this project is updated to C++17. - * Consider also switching to std::variant to achieve static polymorphism. + * To support optional values, std::optional is used. The default + * values will be provided by DRAMSys itself. + * + * To achieve static polymorphism, std::variant is used. */ namespace DRAMSysConfiguration diff --git a/DRAMSys/library/src/common/configuration/McConfig.h b/DRAMSys/library/src/common/configuration/McConfig.h index ee1b51d8..051c05e4 100644 --- a/DRAMSys/library/src/common/configuration/McConfig.h +++ b/DRAMSys/library/src/common/configuration/McConfig.h @@ -38,7 +38,6 @@ #include "util.h" -#include #include #include #include diff --git a/DRAMSys/library/src/common/configuration/TraceSetup.cpp b/DRAMSys/library/src/common/configuration/TraceSetup.cpp index 19241e33..1e25b4a4 100644 --- a/DRAMSys/library/src/common/configuration/TraceSetup.cpp +++ b/DRAMSys/library/src/common/configuration/TraceSetup.cpp @@ -35,6 +35,8 @@ #include "TraceSetup.h" +#include + namespace DRAMSysConfiguration { @@ -55,94 +57,112 @@ void to_json(json &j, const TraceSetup &c) { json initiator_j; - initiator_j["name"] = initiator->name; - initiator_j["clkMhz"] = initiator->clkMhz; - initiator_j["maxPendingReadRequests"] = initiator->maxPendingReadRequests; - initiator_j["maxPendingWriteRequests"] = initiator->maxPendingWriteRequests; - initiator_j["addLengthConverter"] = initiator->addLengthConverter; - - if (const auto generator = dynamic_cast(initiator.get())) - { - initiator_j["type"] = "generator"; - initiator_j["seed"] = generator->seed; - initiator_j["maxTransactions"] = generator->maxTransactions; - initiator_j["idleUntil"] = generator->idleUntil; - - // When there are less than 2 states, flatten out the json. - if (generator->states.size() == 1) + std::visit( + [&initiator_j](auto &&initiator) { - const auto &state = generator->states[0]; + initiator_j["name"] = initiator.name; + initiator_j["clkMhz"] = initiator.clkMhz; + initiator_j["maxPendingReadRequests"] = initiator.maxPendingReadRequests; + initiator_j["maxPendingWriteRequests"] = initiator.maxPendingWriteRequests; + initiator_j["addLengthConverter"] = initiator.addLengthConverter; - if (const auto trafficState = dynamic_cast(state.get())) + using T = std::decay_t; + if constexpr (std::is_same_v) { - initiator_j["numRequests"] = trafficState->numRequests; - initiator_j["rwRatio"] = trafficState->rwRatio; - initiator_j["addressDistribution"] = trafficState->addressDistribution; - initiator_j["addressIncrement"] = trafficState->addressIncrement; - initiator_j["minAddress"] = trafficState->minAddress; - initiator_j["maxAddress"] = trafficState->maxAddress; - initiator_j["clksPerRequest"] = trafficState->clksPerRequest; - initiator_j["notify"] = trafficState->notify; - } - else if (const auto idleState = dynamic_cast(state.get())) - { - initiator_j["idleClks"] = idleState->idleClks; - } - } - else - { - json states_j = json::array(); + initiator_j["type"] = "generator"; + initiator_j["seed"] = initiator.seed; + initiator_j["maxTransactions"] = initiator.maxTransactions; + initiator_j["idleUntil"] = initiator.idleUntil; - for (const auto &state : generator->states) - { - json state_j; - state_j["id"] = state.first; - - if (const auto trafficState = dynamic_cast(state.second.get())) + // When there are less than 2 states, flatten out the json. + if (initiator.states.size() == 1) { - state_j["numRequests"] = trafficState->numRequests; - state_j["rwRatio"] = trafficState->rwRatio; - state_j["addressDistribution"] = trafficState->addressDistribution; - state_j["addressIncrement"] = trafficState->addressIncrement; - state_j["minAddress"] = trafficState->minAddress; - state_j["maxAddress"] = trafficState->maxAddress; - state_j["clksPerRequest"] = trafficState->clksPerRequest; - state_j["notify"] = trafficState->notify; + std::visit( + [&initiator_j](auto &&state) + { + using U = std::decay_t; + + if constexpr (std::is_same_v) + { + initiator_j["numRequests"] = state.numRequests; + initiator_j["rwRatio"] = state.rwRatio; + initiator_j["addressDistribution"] = state.addressDistribution; + initiator_j["addressIncrement"] = state.addressIncrement; + initiator_j["minAddress"] = state.minAddress; + initiator_j["maxAddress"] = state.maxAddress; + initiator_j["clksPerRequest"] = state.clksPerRequest; + initiator_j["notify"] = state.notify; + } + else // if constexpr (std::is_same_v) + { + initiator_j["idleClks"] = state.idleClks; + } + }, + initiator.states.at(0)); } - else if (const auto idleState = dynamic_cast(state.second.get())) + else { - state_j["idleClks"] = idleState->idleClks; + json states_j = json::array(); + + for (const auto &state : initiator.states) + { + json state_j; + state_j["id"] = state.first; + + std::visit( + [&state_j](auto &&state) + { + using U = std::decay_t; + + if constexpr (std::is_same_v) + { + state_j["numRequests"] = state.numRequests; + state_j["rwRatio"] = state.rwRatio; + state_j["addressDistribution"] = state.addressDistribution; + state_j["addressIncrement"] = state.addressIncrement; + state_j["minAddress"] = state.minAddress; + state_j["maxAddress"] = state.maxAddress; + state_j["clksPerRequest"] = state.clksPerRequest; + state_j["notify"] = state.notify; + } + else // if constexpr (std::is_same_v) + { + state_j["idleClks"] = state.idleClks; + } + }, + state.second); + + remove_null_values(state_j); + states_j.insert(states_j.end(), state_j); + } + initiator_j["states"] = states_j; + + json transitions_j = json::array(); + + for (const auto &transition : initiator.transitions) + { + json transition_j; + transition_j["from"] = transition.first; + transition_j["to"] = transition.second.to; + transition_j["propability"] = transition.second.propability; + remove_null_values(transition_j); + transitions_j.insert(transitions_j.end(), transition_j); + } + initiator_j["transitions"] = transitions_j; } - - remove_null_values(state_j); - states_j.insert(states_j.end(), state_j); } - initiator_j["states"] = states_j; - - json transitions_j = json::array(); - - for (const auto &transition : generator->transitions) + else if constexpr (std::is_same_v) { - json transition_j; - transition_j["from"] = transition.first; - transition_j["to"] = transition.second.to; - transition_j["propability"] = transition.second.propability; - remove_null_values(transition_j); - transitions_j.insert(transitions_j.end(), transition_j); + initiator_j["type"] = "hammer"; + initiator_j["numRequests"] = initiator.numRequests; + initiator_j["rowIncrement"] = initiator.rowIncrement; } - initiator_j["transitions"] = transitions_j; - } - } - else if (const auto hammer = dynamic_cast(initiator.get())) - { - initiator_j["type"] = "hammer"; - initiator_j["numRequests"] = hammer->numRequests; - initiator_j["rowIncrement"] = hammer->rowIncrement; - } - else if (const auto player = dynamic_cast(initiator.get())) - { - initiator_j["type"] = "player"; - } + else // if constexpr (std::is_same_v) + { + initiator_j["type"] = "player"; + } + }, + initiator); remove_null_values(initiator_j); j.insert(j.end(), initiator_j); @@ -156,51 +176,53 @@ void from_json(const json &j, TraceSetup &c) // Default to Player, when not specified TrafficInitiatorType type = initiator_j.value("type", TrafficInitiatorType::Player); - std::unique_ptr initiator; + std::variant initiator; + if (type == TrafficInitiatorType::Player) { - initiator = std::unique_ptr(new TracePlayer); + initiator = TracePlayer{}; } else if (type == TrafficInitiatorType::Generator) { - TraceGenerator *generator = new TraceGenerator; + TraceGenerator generator; - auto process_state = [](const json &state_j) -> std::pair> + auto process_state = [](const json &state_j) + -> std::pair> { - std::unique_ptr state; + std::variant state; if (state_j.contains("idleClks")) { // Idle state - auto idleState = new TraceGeneratorIdleState; - state_j.at("idleClks").get_to(idleState->idleClks); + TraceGeneratorIdleState idleState; + state_j.at("idleClks").get_to(idleState.idleClks); - state = std::unique_ptr(idleState); + state = std::move(idleState); } else { // Traffic state - auto trafficState = new TraceGeneratorTrafficState; - state_j.at("numRequests").get_to(trafficState->numRequests); - state_j.at("rwRatio").get_to(trafficState->rwRatio); - state_j.at("addressDistribution").get_to(trafficState->addressDistribution); + TraceGeneratorTrafficState trafficState; + state_j.at("numRequests").get_to(trafficState.numRequests); + state_j.at("rwRatio").get_to(trafficState.rwRatio); + state_j.at("addressDistribution").get_to(trafficState.addressDistribution); if (state_j.contains("addressIncrement")) - state_j.at("addressIncrement").get_to(trafficState->addressIncrement); + state_j.at("addressIncrement").get_to(trafficState.addressIncrement); if (state_j.contains("minAddress")) - state_j.at("minAddress").get_to(trafficState->minAddress); + state_j.at("minAddress").get_to(trafficState.minAddress); if (state_j.contains("maxAddress")) - state_j.at("maxAddress").get_to(trafficState->maxAddress); + state_j.at("maxAddress").get_to(trafficState.maxAddress); if (state_j.contains("clksPerRequest")) - state_j.at("clksPerRequest").get_to(trafficState->clksPerRequest); + state_j.at("clksPerRequest").get_to(trafficState.clksPerRequest); if (state_j.contains("notify")) - state_j.at("notify").get_to(trafficState->notify); + state_j.at("notify").get_to(trafficState.notify); - state = std::unique_ptr(trafficState); + state = std::move(trafficState); } // Default to 0 @@ -217,7 +239,7 @@ void from_json(const json &j, TraceSetup &c) for (const auto &state_j : initiator_j.at("states")) { auto state = process_state(state_j); - generator->states[state.first] = std::move(state.second); + generator.states[state.first] = std::move(state.second); } for (const auto &transition_j : initiator_j.at("transitions")) @@ -226,47 +248,52 @@ void from_json(const json &j, TraceSetup &c) unsigned int from = transition_j.at("from"); transition.to = transition_j.at("to"); transition.propability = transition_j.at("propability"); - generator->transitions.emplace(from, transition); + generator.transitions.emplace(from, transition); } } else // Only one state will be created { auto state = process_state(initiator_j); - generator->states[state.first] = std::move(state.second); + generator.states[state.first] = std::move(state.second); } if (initiator_j.contains("seed")) - initiator_j.at("seed").get_to(generator->seed); + initiator_j.at("seed").get_to(generator.seed); if (initiator_j.contains("maxTransactions")) - initiator_j.at("maxTransactions").get_to(generator->maxTransactions); + initiator_j.at("maxTransactions").get_to(generator.maxTransactions); if (initiator_j.contains("idleUntil")) - initiator_j.at("idleUntil").get_to(generator->idleUntil); + initiator_j.at("idleUntil").get_to(generator.idleUntil); - initiator = std::unique_ptr(generator); + initiator = generator; } else if (type == TrafficInitiatorType::Hammer) { - TraceHammer *hammer = new TraceHammer; + TraceHammer hammer; - initiator_j.at("numRequests").get_to(hammer->numRequests); - initiator_j.at("rowIncrement").get_to(hammer->rowIncrement); + initiator_j.at("numRequests").get_to(hammer.numRequests); + initiator_j.at("rowIncrement").get_to(hammer.rowIncrement); - initiator = std::unique_ptr(hammer); + initiator = hammer; } - initiator_j.at("name").get_to(initiator->name); - initiator_j.at("clkMhz").get_to(initiator->clkMhz); + std::visit( + [&initiator_j](auto &&initiator) + { + initiator_j.at("name").get_to(initiator.name); + initiator_j.at("clkMhz").get_to(initiator.clkMhz); - if (initiator_j.contains("maxPendingReadRequests")) - initiator_j.at("maxPendingReadRequests").get_to(initiator->maxPendingReadRequests); + if (initiator_j.contains("maxPendingReadRequests")) + initiator_j.at("maxPendingReadRequests").get_to(initiator.maxPendingReadRequests); - if (initiator_j.contains("maxPendingWriteRequests")) - initiator_j.at("maxPendingWriteRequests").get_to(initiator->maxPendingWriteRequests); + if (initiator_j.contains("maxPendingWriteRequests")) + initiator_j.at("maxPendingWriteRequests").get_to(initiator.maxPendingWriteRequests); - if (initiator_j.contains("addLengthConverter")) - initiator_j.at("addLengthConverter").get_to(initiator->addLengthConverter); + if (initiator_j.contains("addLengthConverter")) + initiator_j.at("addLengthConverter").get_to(initiator.addLengthConverter); + }, + initiator); c.initiators.emplace_back(std::move(initiator)); } diff --git a/DRAMSys/library/src/common/configuration/TraceSetup.h b/DRAMSys/library/src/common/configuration/TraceSetup.h index 54298cf7..d9f0ed9a 100644 --- a/DRAMSys/library/src/common/configuration/TraceSetup.h +++ b/DRAMSys/library/src/common/configuration/TraceSetup.h @@ -38,9 +38,9 @@ #include "util.h" -#include #include #include +#include namespace DRAMSysConfiguration { @@ -117,7 +117,7 @@ struct TraceGenerator : public TrafficInitiator { std::optional seed; std::optional maxTransactions; - std::map> states; + std::map> states; std::multimap transitions; std::optional idleUntil; }; @@ -130,7 +130,7 @@ struct TraceHammer : public TrafficInitiator struct TraceSetup { - std::vector> initiators; + std::vector> initiators; }; void to_json(json &j, const TraceSetup &c); diff --git a/DRAMSys/library/src/common/configuration/tests/simpletest.cpp b/DRAMSys/library/src/common/configuration/tests/simpletest.cpp index 520791f0..c0a9ac95 100644 --- a/DRAMSys/library/src/common/configuration/tests/simpletest.cpp +++ b/DRAMSys/library/src/common/configuration/tests/simpletest.cpp @@ -94,89 +94,88 @@ DRAMSysConfiguration::ThermalConfig getThermalConfig() true}; } -std::unique_ptr getTracePlayer() +DRAMSysConfiguration::TracePlayer getTracePlayer() { - DRAMSysConfiguration::TracePlayer *player = new DRAMSysConfiguration::TracePlayer; + DRAMSysConfiguration::TracePlayer player; + player.clkMhz = 100; + player.name = "mytrace.stl"; - player->clkMhz = 100; - player->name = "mytrace.stl"; - - return std::unique_ptr(player); + return player; } -std::unique_ptr getTraceGeneratorOneState() +DRAMSysConfiguration::TraceGenerator getTraceGeneratorOneState() { - DRAMSysConfiguration::TraceGenerator *gen = new DRAMSysConfiguration::TraceGenerator; + DRAMSysConfiguration::TraceGenerator gen; + gen.clkMhz = 100; + gen.name = "MyTestGen"; - gen->clkMhz = 100; - gen->name = "MyTestGen"; + DRAMSysConfiguration::TraceGeneratorTrafficState state0; + state0.numRequests = 1000; + state0.rwRatio = 0.5; + state0.addressDistribution = DRAMSysConfiguration::AddressDistribution::Random; + state0.addressIncrement = {}; + state0.minAddress = {}; + state0.maxAddress = {}; + state0.clksPerRequest = {}; - auto state0 = new DRAMSysConfiguration::TraceGeneratorTrafficState; - state0->numRequests = 1000; - state0->rwRatio = 0.5; - state0->addressDistribution = DRAMSysConfiguration::AddressDistribution::Random; - state0->addressIncrement = {}; - state0->minAddress = {}; - state0->maxAddress = {}; - state0->clksPerRequest = {}; + gen.states.emplace(0, state0); - gen->states[0] = std::unique_ptr(state0); - - return std::unique_ptr(gen); + return gen; } -std::unique_ptr getTraceGeneratorMultipleStates() +DRAMSysConfiguration::TraceGenerator getTraceGeneratorMultipleStates() { - DRAMSysConfiguration::TraceGenerator *gen = new DRAMSysConfiguration::TraceGenerator; + DRAMSysConfiguration::TraceGenerator gen; - gen->clkMhz = 100; - gen->name = "MyTestGen"; - gen->maxPendingReadRequests = 8; + gen.clkMhz = 100; + gen.name = "MyTestGen"; + gen.maxPendingReadRequests = 8; - auto state0 = new DRAMSysConfiguration::TraceGeneratorTrafficState; - state0->numRequests = 1000; - state0->rwRatio = 0.5; - state0->addressDistribution = DRAMSysConfiguration::AddressDistribution::Sequential; - state0->addressIncrement = 256; - state0->minAddress = {}; - state0->maxAddress = 1024; - state0->clksPerRequest = {}; + DRAMSysConfiguration::TraceGeneratorTrafficState state0; + state0.numRequests = 1000; + state0.rwRatio = 0.5; + state0.addressDistribution = DRAMSysConfiguration::AddressDistribution::Sequential; + state0.addressIncrement = 256; + state0.minAddress = {}; + state0.maxAddress = 1024; + state0.clksPerRequest = {}; - auto state1 = new DRAMSysConfiguration::TraceGeneratorTrafficState; - state1->numRequests = 100; - state1->rwRatio = 0.75; - state1->addressDistribution = DRAMSysConfiguration::AddressDistribution::Sequential; - state1->addressIncrement = 512; - state1->minAddress = 1024; - state1->maxAddress = 2048; - state1->clksPerRequest = {}; + DRAMSysConfiguration::TraceGeneratorTrafficState state1; + state1.numRequests = 100; + state1.rwRatio = 0.75; + state1.addressDistribution = DRAMSysConfiguration::AddressDistribution::Sequential; + state1.addressIncrement = 512; + state1.minAddress = 1024; + state1.maxAddress = 2048; + state1.clksPerRequest = {}; - gen->states[0] = std::unique_ptr(state0); - gen->states[1] = std::unique_ptr(state1); + gen.states.emplace(0, state0); + gen.states.emplace(1, state1); DRAMSysConfiguration::TraceGeneratorStateTransition transistion0{1, 1.0}; - gen->transitions.emplace(0, transistion0); + gen.transitions.emplace(0, transistion0); - return std::unique_ptr(gen); + return gen; } -std::unique_ptr getTraceHammer() +DRAMSysConfiguration::TraceHammer getTraceHammer() { - DRAMSysConfiguration::TraceHammer *hammer = new DRAMSysConfiguration::TraceHammer; + DRAMSysConfiguration::TraceHammer hammer; - hammer->clkMhz = 100; - hammer->name = "MyTestHammer"; + hammer.clkMhz = 100; + hammer.name = "MyTestHammer"; + hammer.numRequests = 4000; + hammer.rowIncrement = 2097152; - hammer->numRequests = 4000; - hammer->rowIncrement = 2097152; - - return std::unique_ptr(hammer); + return hammer; } DRAMSysConfiguration::TraceSetup getTraceSetup() { - std::vector> initiators; + using namespace DRAMSysConfiguration; + + std::vector> initiators; initiators.emplace_back(getTracePlayer()); initiators.emplace_back(getTraceGeneratorOneState()); initiators.emplace_back(getTraceGeneratorMultipleStates()); diff --git a/DRAMSys/simulator/TraceSetup.cpp b/DRAMSys/simulator/TraceSetup.cpp index 3672408d..fd6d8e54 100644 --- a/DRAMSys/simulator/TraceSetup.cpp +++ b/DRAMSys/simulator/TraceSetup.cpp @@ -45,102 +45,107 @@ using namespace sc_core; using namespace tlm; -TraceSetup::TraceSetup(const DRAMSysConfiguration::TraceSetup &traceSetup, - const std::string &pathToResources, +TraceSetup::TraceSetup(const DRAMSysConfiguration::TraceSetup &traceSetup, const std::string &pathToResources, std::vector> &players) { if (traceSetup.initiators.empty()) SC_REPORT_FATAL("TraceSetup", "No traffic initiators specified"); - for (const auto &inititator : traceSetup.initiators) + for (const auto &initiator : traceSetup.initiators) { - double frequencyMHz = inititator->clkMhz; - sc_time playerClk = sc_time(1.0 / frequencyMHz, SC_US); + std::visit( + [&](auto &&initiator) + { + std::string name = initiator.name; + double frequencyMHz = initiator.clkMhz; + sc_time playerClk = sc_time(1.0 / frequencyMHz, SC_US); - std::string name = inititator->name; + unsigned int maxPendingReadRequests = [=]() -> unsigned int + { + if (const auto &maxPendingReadRequests = initiator.maxPendingReadRequests) + return *maxPendingReadRequests; + else + return 0; + }(); - unsigned int maxPendingReadRequests = [=]() -> unsigned int - { - if (const auto &maxPendingReadRequests = inititator->maxPendingReadRequests) - return *maxPendingReadRequests; - else - return 0; - }(); + unsigned int maxPendingWriteRequests = [=]() -> unsigned int + { + if (const auto &maxPendingWriteRequests = initiator.maxPendingWriteRequests) + return *maxPendingWriteRequests; + else + return 0; + }(); - unsigned int maxPendingWriteRequests = [=]() -> unsigned int - { - if (const auto &maxPendingWriteRequests = inititator->maxPendingWriteRequests) - return *maxPendingWriteRequests; - else - return 0; - }(); + bool addLengthConverter = [=]() -> bool + { + if (const auto &addLengthConverter = initiator.addLengthConverter) + return *addLengthConverter; + else + return false; + }(); - bool addLengthConverter = [=]() -> bool - { - if (const auto &addLengthConverter = inititator->addLengthConverter) - return *addLengthConverter; - else - return false; - }(); + using T = std::decay_t; + if constexpr (std::is_same_v) + { + size_t pos = name.rfind('.'); + if (pos == std::string::npos) + throw std::runtime_error("Name of the trace file does not contain a valid extension."); - if (std::dynamic_pointer_cast(inititator)) - { - size_t pos = name.rfind('.'); - if (pos == std::string::npos) - throw std::runtime_error("Name of the trace file does not contain a valid extension."); + // Get the extension and make it lower case + std::string ext = name.substr(pos + 1); + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); - // Get the extension and make it lower case - std::string ext = name.substr(pos + 1); - std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); + std::stringstream stlFileStream; + stlFileStream << pathToResources << "/traces/" << name; + std::string stlFile = stlFileStream.str(); + std::string moduleName = name; - std::stringstream stlFileStream; - stlFileStream << pathToResources << "/traces/" << name; - std::string stlFile = stlFileStream.str(); - std::string moduleName = name; + // replace all '.' to '_' + std::replace(moduleName.begin(), moduleName.end(), '.', '_'); - // replace all '.' to '_' - std::replace(moduleName.begin(), moduleName.end(), '.', '_'); + StlPlayer *player; + if (ext == "stl") + player = new StlPlayer(moduleName.c_str(), stlFile, playerClk, maxPendingReadRequests, + maxPendingWriteRequests, addLengthConverter, this, false); + else if (ext == "rstl") + player = new StlPlayer(moduleName.c_str(), stlFile, playerClk, maxPendingReadRequests, + maxPendingWriteRequests, addLengthConverter, this, true); + else + throw std::runtime_error("Unsupported file extension in " + name); - StlPlayer *player; - if (ext == "stl") - player = new StlPlayer(moduleName.c_str(), stlFile, playerClk, - maxPendingReadRequests, maxPendingWriteRequests, addLengthConverter, this, false); - else if (ext == "rstl") - player = new StlPlayer(moduleName.c_str(), stlFile, playerClk, - maxPendingReadRequests, maxPendingWriteRequests, addLengthConverter, this, true); - else - throw std::runtime_error("Unsupported file extension in " + name); + players.push_back(std::unique_ptr(player)); + totalTransactions += player->getNumberOfLines(); + } + else if constexpr (std::is_same_v) + { + TrafficGenerator *trafficGenerator = new TrafficGenerator(name.c_str(), initiator, this); + players.push_back(std::unique_ptr(trafficGenerator)); - players.push_back(std::unique_ptr(player)); - totalTransactions += player->getNumberOfLines(); - } - else if (auto generator = std::dynamic_pointer_cast(inititator)) - { - TrafficGenerator *trafficGenerator = new TrafficGenerator(name.c_str(), *generator, this); - players.push_back(std::unique_ptr(trafficGenerator)); + totalTransactions += trafficGenerator->getTotalTransactions(); + } + else // if constexpr (std::is_same_v) + { + uint64_t numRequests = initiator.numRequests; + uint64_t rowIncrement = initiator.rowIncrement; - totalTransactions += trafficGenerator->getTotalTransactions(); - } - else if (auto hammer = std::dynamic_pointer_cast(inititator)) - { - uint64_t numRequests = hammer->numRequests; - uint64_t rowIncrement = hammer->rowIncrement; - - players.push_back(std::unique_ptr(new TrafficGeneratorHammer(name.c_str(), *hammer, this))); - totalTransactions += numRequests; - } + players.push_back( + std::unique_ptr(new TrafficGeneratorHammer(name.c_str(), initiator, this))); + totalTransactions += numRequests; + } + }, + initiator); } for (const auto &inititatorConf : traceSetup.initiators) { - if (auto generatorConf = std::dynamic_pointer_cast(inititatorConf)) + if (auto generatorConf = std::get_if(&inititatorConf)) { if (const auto &idleUntil = generatorConf->idleUntil) { const std::string name = generatorConf->name; - auto listenerIt = - std::find_if(players.begin(), players.end(), - [&name](const std::unique_ptr &initiator) { return initiator->name() == name; }); + auto listenerIt = std::find_if(players.begin(), players.end(), + [&name](const std::unique_ptr &initiator) + { return initiator->name() == name; }); // Should be found auto listener = dynamic_cast(listenerIt->get()); @@ -205,7 +210,7 @@ void TraceSetup::loadBar(uint64_t x, uint64_t n, unsigned int w, unsigned int gr if ((n < 100) || ((x != n) && (x % (n / 100 * granularity) != 0))) return; - float ratio = x / (float) n; + float ratio = x / (float)n; unsigned int c = (ratio * w); float rest = (ratio * w) - c; std::cout << std::setw(3) << round(ratio * 100) << "% |"; diff --git a/DRAMSys/simulator/TrafficGenerator.cpp b/DRAMSys/simulator/TrafficGenerator.cpp index 1598a82b..7706840c 100644 --- a/DRAMSys/simulator/TrafficGenerator.cpp +++ b/DRAMSys/simulator/TrafficGenerator.cpp @@ -37,6 +37,7 @@ */ #include "TrafficGenerator.h" +#include "TraceSetup.h" #include @@ -103,7 +104,7 @@ TrafficGenerator::TrafficGenerator(const sc_module_name &name, const DRAMSysConf // Perform checks for all states for (const auto &state : conf.states) { - if (auto trafficState = dynamic_cast(state.second.get())) + if (auto trafficState = std::get_if(&state.second)) { uint64_t minAddress = evaluateMinAddress(*trafficState); uint64_t maxAddress = evaluateMaxAddress(*trafficState); @@ -129,7 +130,8 @@ TrafficGenerator::TrafficGenerator(const sc_module_name &name, const DRAMSysConf } } - if (auto trafficState = dynamic_cast(conf.states.at(currentState).get())) + if (auto trafficState = + std::get_if(&conf.states.at(currentState))) { uint64_t minAddress = evaluateMinAddress(*trafficState); uint64_t maxAddress = evaluateMaxAddress(*trafficState); @@ -182,7 +184,8 @@ void TrafficGenerator::calculateTransitions() if (transitionFound) { - if (auto trafficState = dynamic_cast(conf.states.at(state).get())) + if (auto trafficState = + std::get_if(&conf.states.at(state))) totalTransactions += trafficState->numRequests; if (totalTransactions < maxTransactions) @@ -221,7 +224,7 @@ uint64_t TrafficGenerator::getTotalTransactions() const for (auto state : stateSequence) { - if (auto trafficState = dynamic_cast(conf.states.at(state).get())) + if (auto trafficState = std::get_if(&conf.states.at(state))) totalTransactions += trafficState->numRequests; } @@ -256,13 +259,14 @@ void TrafficGenerator::transitionToNextState() it.second.event.notify(); } - if (auto idleState = dynamic_cast(conf.states.at(currentState).get())) + if (auto idleState = std::get_if(&conf.states.at(currentState))) { currentClksToIdle += idleState->idleClks; transitionToNextState(); return; } - else if (auto trafficState = dynamic_cast(conf.states.at(currentState).get())) + else if (auto trafficState = + std::get_if(&conf.states.at(currentState))) { uint64_t minAddress = evaluateMinAddress(*trafficState); uint64_t maxAddress = evaluateMaxAddress(*trafficState); @@ -285,7 +289,8 @@ void TrafficGenerator::prepareNextPayload() if (startEvent && transactionsSent == 0) wait(*startEvent); - if (auto trafficState = dynamic_cast(conf.states.at(currentState).get())) + if (auto trafficState = + std::get_if(&conf.states.at(currentState))) { if (transactionsSentInCurrentState >= trafficState->numRequests) transitionToNextState(); @@ -303,10 +308,10 @@ void TrafficGenerator::payloadSent() tlm::tlm_command TrafficGenerator::getNextCommand() { // An idle state should never reach this method. - auto state = static_cast(conf.states.at(currentState).get()); + auto &state = std::get(conf.states.at(currentState)); tlm_command command; - if (randomDistribution(randomGenerator) < state->rwRatio) + if (randomDistribution(randomGenerator) < state.rwRatio) command = tlm::TLM_READ_COMMAND; else command = tlm::TLM_WRITE_COMMAND; @@ -324,14 +329,14 @@ uint64_t TrafficGenerator::getNextAddress() using DRAMSysConfiguration::AddressDistribution; // An idle state should never reach this method. - auto state = static_cast(conf.states.at(currentState).get()); + auto &state = std::get(conf.states.at(currentState)); - uint64_t minAddress = evaluateMinAddress(*state); - uint64_t maxAddress = evaluateMaxAddress(*state); + uint64_t minAddress = evaluateMinAddress(state); + uint64_t maxAddress = evaluateMaxAddress(state); - if (state->addressDistribution == AddressDistribution::Sequential) + if (state.addressDistribution == AddressDistribution::Sequential) { - uint64_t addressIncrement = state->addressIncrement.value_or(defaultAddressIncrement); + uint64_t addressIncrement = state.addressIncrement.value_or(defaultAddressIncrement); uint64_t address = currentAddress; currentAddress += addressIncrement; @@ -339,7 +344,7 @@ uint64_t TrafficGenerator::getNextAddress() currentAddress = minAddress; return address; } - else if (state->addressDistribution == AddressDistribution::Random) + else if (state.addressDistribution == AddressDistribution::Random) { return randomAddressDistribution(randomGenerator); }