diff --git a/src/mem/ruby/network/MessageBuffer.cc b/src/mem/ruby/network/MessageBuffer.cc index a891d5a91d..9a6500978e 100644 --- a/src/mem/ruby/network/MessageBuffer.cc +++ b/src/mem/ruby/network/MessageBuffer.cc @@ -65,6 +65,7 @@ MessageBuffer::MessageBuffer(const Params &p) m_last_arrival_time(0), m_strict_fifo(p.ordered), m_randomization(p.randomization), m_allow_zero_latency(p.allow_zero_latency), + m_routing_priority(p.routing_priority), ADD_STAT(m_not_avail_count, statistics::units::Count::get(), "Number of times this buffer did not have N slots available"), ADD_STAT(m_msg_count, statistics::units::Count::get(), diff --git a/src/mem/ruby/network/MessageBuffer.hh b/src/mem/ruby/network/MessageBuffer.hh index 9cabbaf106..279599340a 100644 --- a/src/mem/ruby/network/MessageBuffer.hh +++ b/src/mem/ruby/network/MessageBuffer.hh @@ -158,6 +158,9 @@ class MessageBuffer : public SimObject void setIncomingLink(int link_id) { m_input_link_id = link_id; } void setVnet(int net) { m_vnet_id = net; } + int getIncomingLink() const { return m_input_link_id; } + int getVnet() const { return m_vnet_id; } + Port & getPort(const std::string &, PortID idx=InvalidPortID) override { @@ -187,6 +190,8 @@ class MessageBuffer : public SimObject return functionalAccess(pkt, true, &mask) == 1; } + int routingPriority() const { return m_routing_priority; } + private: void reanalyzeList(std::list &, Tick); @@ -270,6 +275,8 @@ class MessageBuffer : public SimObject const MessageRandomization m_randomization; const bool m_allow_zero_latency; + const int m_routing_priority; + int m_input_link_id; int m_vnet_id; diff --git a/src/mem/ruby/network/MessageBuffer.py b/src/mem/ruby/network/MessageBuffer.py index 80dc872970..b776196f92 100644 --- a/src/mem/ruby/network/MessageBuffer.py +++ b/src/mem/ruby/network/MessageBuffer.py @@ -70,3 +70,6 @@ class MessageBuffer(SimObject): max_dequeue_rate = Param.Unsigned(0, "Maximum number of messages that can \ be dequeued per cycle \ (0 allows dequeueing all ready messages)") + routing_priority = Param.Int(0, "Buffer priority when messages are \ + consumed by the network. Smaller value \ + means higher priority") diff --git a/src/mem/ruby/network/simple/PerfectSwitch.cc b/src/mem/ruby/network/simple/PerfectSwitch.cc index 665fd0faae..74d78e3aae 100644 --- a/src/mem/ruby/network/simple/PerfectSwitch.cc +++ b/src/mem/ruby/network/simple/PerfectSwitch.cc @@ -88,10 +88,36 @@ PerfectSwitch::addInPort(const std::vector& in) in[i]->setConsumer(this); in[i]->setIncomingLink(port); in[i]->setVnet(i); + updatePriorityGroups(i, in[i]); } } } +void +PerfectSwitch::updatePriorityGroups(int vnet, MessageBuffer* in_buf) +{ + while (m_in_prio.size() <= vnet) { + m_in_prio.emplace_back(); + m_in_prio_groups.emplace_back(); + } + + m_in_prio[vnet].push_back(in_buf); + + std::sort(m_in_prio[vnet].begin(), m_in_prio[vnet].end(), + [](const MessageBuffer* i, const MessageBuffer* j) + { return i->routingPriority() < j->routingPriority(); }); + + // reset groups + m_in_prio_groups[vnet].clear(); + int cur_prio = m_in_prio[vnet].front()->routingPriority(); + m_in_prio_groups[vnet].emplace_back(); + for (auto buf : m_in_prio[vnet]) { + if (buf->routingPriority() != cur_prio) + m_in_prio_groups[vnet].emplace_back(); + m_in_prio_groups[vnet].back().push_back(buf); + } +} + void PerfectSwitch::addOutPort(const std::vector& out, const NetDest& routing_table_entry, @@ -126,12 +152,15 @@ PerfectSwitch::inBuffer(int in_port, int vnet) const void PerfectSwitch::operateVnet(int vnet) { - if (m_pending_message_count[vnet] > 0) { + if (m_pending_message_count[vnet] == 0) + return; + + for (auto &in : m_in_prio_groups[vnet]) { // first check the port with the oldest message unsigned start_in_port = 0; Tick lowest_tick = MaxTick; - for (int i = 0; i < m_in.size(); ++i) { - MessageBuffer *buffer = inBuffer(i, vnet); + for (int i = 0; i < in.size(); ++i) { + MessageBuffer *buffer = in[i]; if (buffer) { Tick ready_time = buffer->readyTime(); if (ready_time < lowest_tick){ @@ -141,21 +170,20 @@ PerfectSwitch::operateVnet(int vnet) } } DPRINTF(RubyNetwork, "vnet %d: %d pending msgs. " - "Checking port %d first\n", + "Checking port %d first\n", vnet, m_pending_message_count[vnet], start_in_port); // check all ports starting with the one with the oldest message - for (int i = 0; i < m_in.size(); ++i) { - int in_port = (i + start_in_port) % m_in.size(); - MessageBuffer *buffer = inBuffer(in_port, vnet); + for (int i = 0; i < in.size(); ++i) { + int in_port = (i + start_in_port) % in.size(); + MessageBuffer *buffer = in[in_port]; if (buffer) - operateMessageBuffer(buffer, in_port, vnet); + operateMessageBuffer(buffer, vnet); } } } void -PerfectSwitch::operateMessageBuffer(MessageBuffer *buffer, int incoming, - int vnet) +PerfectSwitch::operateMessageBuffer(MessageBuffer *buffer, int vnet) { MsgPtr msg_ptr; Message *net_msg_ptr = NULL; @@ -166,7 +194,7 @@ PerfectSwitch::operateMessageBuffer(MessageBuffer *buffer, int incoming, Tick current_time = m_switch->clockEdge(); while (buffer->isReady(current_time)) { - DPRINTF(RubyNetwork, "incoming: %d\n", incoming); + DPRINTF(RubyNetwork, "incoming: %d\n", buffer->getIncomingLink()); // Peek at message msg_ptr = buffer->peekMsgPtr(); @@ -237,7 +265,7 @@ PerfectSwitch::operateMessageBuffer(MessageBuffer *buffer, int incoming, // Enqeue msg DPRINTF(RubyNetwork, "Enqueuing net msg from " "inport[%d][%d] to outport [%d][%d].\n", - incoming, vnet, outgoing, vnet); + buffer->getIncomingLink(), vnet, outgoing, vnet); out_port.buffers[vnet]->enqueue(msg_ptr, current_time, out_port.latency); diff --git a/src/mem/ruby/network/simple/PerfectSwitch.hh b/src/mem/ruby/network/simple/PerfectSwitch.hh index 446ae83246..589bca1a05 100644 --- a/src/mem/ruby/network/simple/PerfectSwitch.hh +++ b/src/mem/ruby/network/simple/PerfectSwitch.hh @@ -99,7 +99,7 @@ class PerfectSwitch : public Consumer PerfectSwitch& operator=(const PerfectSwitch& obj); void operateVnet(int vnet); - void operateMessageBuffer(MessageBuffer *b, int incoming, int vnet); + void operateMessageBuffer(MessageBuffer *b, int vnet); const SwitchID m_switch_id; Switch * const m_switch; @@ -115,6 +115,13 @@ class PerfectSwitch : public Consumer }; std::vector m_out; + // input ports ordered by priority; indexed by vnet first + std::vector > m_in_prio; + // input ports grouped by priority; indexed by vnet,prio_lv + std::vector>> m_in_prio_groups; + + void updatePriorityGroups(int vnet, MessageBuffer* buf); + uint32_t m_virtual_networks; int m_wakeups_wo_switch;