From 748b613c94f8dbd693f429eb2086b8ad37c630dc Mon Sep 17 00:00:00 2001 From: Gabriel Busnot Date: Fri, 21 Jan 2022 15:12:56 +0100 Subject: [PATCH] mem-ruby: Fix switch storage in SimpleNetwork In SimpleNetwork, switches were assigned an index depending on their position in params().routers. But switches are also referenced by their router_id parameter in other locations of the ruby network system (e.g., src and dst node parameter in links). If the router_id does not match the position in SimpleNetwork::m_switches, the network initialization might fail or implement a different topology from what the user intended. This patch fixes this issue by storing switches in a map instead of a vector. Change-Id: I398f950ad404efbf9516ea9bbced598970a2bc24 Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/55723 Reviewed-by: Jason Lowe-Power Maintainer: Jason Lowe-Power Tested-by: kokoro --- src/mem/ruby/network/simple/SimpleNetwork.cc | 26 ++++++++++---------- src/mem/ruby/network/simple/SimpleNetwork.hh | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/mem/ruby/network/simple/SimpleNetwork.cc b/src/mem/ruby/network/simple/SimpleNetwork.cc index 8e10081c28..ce7bf2c6c1 100644 --- a/src/mem/ruby/network/simple/SimpleNetwork.cc +++ b/src/mem/ruby/network/simple/SimpleNetwork.cc @@ -66,9 +66,10 @@ SimpleNetwork::SimpleNetwork(const Params &p) // record the routers for (std::vector::const_iterator i = p.routers.begin(); i != p.routers.end(); ++i) { - Switch* s = safe_cast(*i); - m_switches.push_back(s); + auto* s = safe_cast(*i); s->init_net_ptr(this); + auto id = static_cast(s->params().router_id); + m_switches[id] = s; } const std::vector &physical_vnets_channels = @@ -108,7 +109,6 @@ SimpleNetwork::makeExtOutLink(SwitchID src, NodeID global_dest, { NodeID local_dest = getLocalNodeID(global_dest); assert(local_dest < m_nodes); - assert(src < m_switches.size()); assert(m_switches[src] != NULL); SimpleExtLink *simple_link = safe_cast(link); @@ -179,9 +179,9 @@ SimpleNetwork::regStats() ; // Now state what the formula is. - for (int i = 0; i < m_switches.size(); i++) { + for (auto& [id, sw]: m_switches) { *(networkStats.m_msg_counts[(unsigned int) type]) += - sum(m_switches[i]->getMsgCount(type)); + sum(sw->getMsgCount(type)); } *(networkStats.m_msg_bytes[(unsigned int) type]) = @@ -193,8 +193,8 @@ SimpleNetwork::regStats() void SimpleNetwork::collateStats() { - for (int i = 0; i < m_switches.size(); i++) { - m_switches[i]->collateStats(); + for (auto& [id, sw]: m_switches) { + sw->collateStats(); } } @@ -212,8 +212,8 @@ SimpleNetwork::print(std::ostream& out) const bool SimpleNetwork::functionalRead(Packet *pkt) { - for (unsigned int i = 0; i < m_switches.size(); i++) { - if (m_switches[i]->functionalRead(pkt)) + for (auto& [id, sw]: m_switches) { + if (sw->functionalRead(pkt)) return true; } for (unsigned int i = 0; i < m_int_link_buffers.size(); ++i) { @@ -228,8 +228,8 @@ bool SimpleNetwork::functionalRead(Packet *pkt, WriteMask &mask) { bool read = false; - for (unsigned int i = 0; i < m_switches.size(); i++) { - if (m_switches[i]->functionalRead(pkt, mask)) + for (auto& [id, sw]: m_switches) { + if (sw->functionalRead(pkt, mask)) read = true; } for (unsigned int i = 0; i < m_int_link_buffers.size(); ++i) { @@ -244,8 +244,8 @@ SimpleNetwork::functionalWrite(Packet *pkt) { uint32_t num_functional_writes = 0; - for (unsigned int i = 0; i < m_switches.size(); i++) { - num_functional_writes += m_switches[i]->functionalWrite(pkt); + for (auto& [id, sw]: m_switches) { + num_functional_writes += sw->functionalWrite(pkt); } for (unsigned int i = 0; i < m_int_link_buffers.size(); ++i) { diff --git a/src/mem/ruby/network/simple/SimpleNetwork.hh b/src/mem/ruby/network/simple/SimpleNetwork.hh index e3364922e2..b90ee3332f 100644 --- a/src/mem/ruby/network/simple/SimpleNetwork.hh +++ b/src/mem/ruby/network/simple/SimpleNetwork.hh @@ -102,7 +102,7 @@ class SimpleNetwork : public Network SimpleNetwork(const SimpleNetwork& obj); SimpleNetwork& operator=(const SimpleNetwork& obj); - std::vector m_switches; + std::unordered_map m_switches; std::vector m_int_link_buffers; const int m_buffer_size; const int m_endpoint_bandwidth;