diff --git a/DRAMSys/simulator/src/simulation/Arbiter.h b/DRAMSys/simulator/src/simulation/Arbiter.h index d41273cd..6ddd56f1 100644 --- a/DRAMSys/simulator/src/simulation/Arbiter.h +++ b/DRAMSys/simulator/src/simulation/Arbiter.h @@ -101,10 +101,18 @@ private: std::vector tlmRecorders; + //used to map the transaction from devices to the arbiter's target socket ID. + std::map routeMap; + // Initiated by dram side // This function is called when an arbiter's initiator socket receives a transaction from a memory controller tlm_sync_enum nb_transport_bw(int channelId, tlm_generic_payload &payload, tlm_phase &phase, sc_time &bwDelay) { + // Check channel ID + assert((unsigned int)channelId < iSocket.size()); + if ((unsigned int)channelId != DramExtension::getExtension(payload).getChannel().ID()) { + SC_REPORT_FATAL("Arbiter", "Payload extension was corrupted"); + } sc_time recTime = bwDelay + sc_time_stamp(); sc_time notDelay = bwDelay; @@ -122,12 +130,17 @@ private: // This function is called when an arbiter's target socket receives a transaction from a device tlm_sync_enum nb_transport_fw(int id, tlm_generic_payload& payload, tlm_phase& phase, sc_time& fwDelay) { + assert ((unsigned int)id < tSocket.size()); if (phase == BEGIN_REQ) { + // Map the payload with socket id. + routeMap[&payload] = id; // In the begin request phase the socket ID is appended to the payload. // It will extracted from the payload and used later. appendDramExtension(id, payload); payload.acquire(); } else if (phase == END_RESP) { + // Erase before the payload is released. + routeMap.erase(&payload); payload.release(); } @@ -147,7 +160,9 @@ private: unsigned int initiatorSocket = DramExtension::getExtension(payload).getThread().ID()-1; unsigned int channelId = DramExtension::getExtension(payload).getChannel().ID(); - // TODO: here check if the channel and the initiatorSocket ID are valid. If not, the payload extension was corrupted. + // Check the valid range of initiatorSocket ID and channel Id + assert(initiatorSocket < Configuration::getInstance().NumberOfTracePlayers); + assert(channelId < Configuration::getInstance().NumberOfMemChannels); // Phases initiated by the intiator side from arbiter's point of view (devices performing memory requests to the arbiter) if (phase == BEGIN_REQ) { @@ -177,6 +192,10 @@ private: // Phases initiated by the target side from arbiter's point of view (memory side) else if (phase == END_REQ) { channelIsFree[channelId] = true; + // Validate the initiatorSocket ID + if ((int)initiatorSocket != routeMap[&payload]) { + SC_REPORT_FATAL("Arbiter", "Payload extension was corrupted"); + } // The arbiter receives a transaction which phase is END_REQ from memory controller and forwards it to the requester device. sendToInitiator(initiatorSocket, payload, phase, SC_ZERO_TIME); @@ -190,6 +209,10 @@ private: channelIsFree[channelId] = false; } } else if (phase == BEGIN_RESP) { + // Validate the initiatorSocket ID + if ((int)initiatorSocket != routeMap[&payload]) { + SC_REPORT_FATAL("Arbiter", "Payload extension was corrupted"); + } // The arbiter receives a transaction in BEGIN_RESP phase (that came from the memory side) and forwards it to the requester device if (receivedResponses[initiatorSocket].empty()) sendToInitiator(initiatorSocket, payload, phase, SC_ZERO_TIME); @@ -216,12 +239,35 @@ private: void appendDramExtension(int socketId, tlm_generic_payload& payload) { - // TODO: check if channel valid before appending. - // TODO: check if all parts of the decodedAddress are inside the valid range (devices should not perform invalid requests to the arbiter, right?). unsigned int burstlength = payload.get_streaming_width(); DecodedAddress decodedAddress = xmlAddressDecoder::getInstance().decodeAddress(payload.get_address()); - DramExtension* extension = new DramExtension(Thread(socketId+1), Channel(decodedAddress.channel), Bank(decodedAddress.bank), BankGroup(decodedAddress.bankgroup), Row(decodedAddress.row), Column(decodedAddress.column),burstlength); - payload.set_auto_extension(extension); + // Check the valid range of decodedAddress + if (addressIsValid(decodedAddress)) { + DramExtension* extension = new DramExtension(Thread(socketId+1), Channel(decodedAddress.channel), Bank(decodedAddress.bank), BankGroup(decodedAddress.bankgroup), Row(decodedAddress.row), Column(decodedAddress.column),burstlength); + payload.set_auto_extension(extension); + } else { + SC_REPORT_FATAL("Arbiter", "Decoded Address are not inside the valid range"); + } + } + + bool addressIsValid(DecodedAddress& decodedAddress) + { + if (decodedAddress.channel >= xmlAddressDecoder::getInstance().amount["channel"]) { + return false; + } + if (decodedAddress.bank >= xmlAddressDecoder::getInstance().amount["bank"]) { + return false; + } + if (decodedAddress.bankgroup > xmlAddressDecoder::getInstance().amount["bankgroup"]) { + return false; + } + if (decodedAddress.column >= xmlAddressDecoder::getInstance().amount["column"]) { + return false; + } + if (decodedAddress.row >= xmlAddressDecoder::getInstance().amount["row"]) { + return false; + } + return true; } void printDebugMessage(std::string message)