Files
DRAMSys/extensions/apps/traceAnalyzer/scripts/vcdExport.py

386 lines
13 KiB
Python

#!/usr/bin/env python3
# Copyright (c) 2021, RPTU Kaiserslautern-Landau
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER
# OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Authors:
# Derek Christ
import sqlite3
import io
import sys
import enum
import math
import datetime
from abc import ABC, abstractmethod
from memUtil import *
from tqdm import tqdm
from vcd import VCDWriter
TIME_STEP = 1_000_000
class Signal(ABC):
def __init__(self, name):
self.name = name
@abstractmethod
def getNeutralValue(self):
pass
@abstractmethod
def getSignalType(self):
pass
class NumericSignal(Signal):
def getNeutralValue(self):
return "z"
def getSignalType(self):
return "integer"
class StringSignal(Signal):
def getNeutralValue(self):
return ""
def getSignalType(self):
return "string"
class Event():
def __init__(self, signal, value):
self.signal = signal
self.value = value
class Transaction():
def __init__(self, rank, bankgroup, bank, command):
self.rank = rank
self.bankgroup = bankgroup
self.bank = bank
self.command = command
class Granularity(enum.Enum):
Bankwise = 0
TwoBankwise = 1
Groupwise = 2
Rankwise = 3
class TimeWindow():
def __init__(self, windowSize, lastTimestamp):
self.currentTime = 0
self.windowSize = windowSize
self.lastTimestamp = lastTimestamp
def __iter__(self):
return self
def __next__(self):
currentRange = (self.currentTime, self.currentTime + self.windowSize)
if self.currentTime <= self.lastTimestamp:
self.currentTime += self.windowSize
return currentRange
else:
raise StopIteration
def numberOfIterations(self):
return int(self.lastTimestamp / self.windowSize)
def getGranularity(phase):
if phase == "PRESB" or phase == "REFSB" or phase == "RFMSB":
return Granularity.Groupwise
elif phase == "REFP2B" or phase == "RFMP2B":
return Granularity.TwoBankwise
elif phase == "PREAB" or phase == "PREA" or phase == "REFAB" or phase == "REFA" or phase == "RFMAB" \
or phase == "PDNA" or phase == "PDNP" or phase == "SREF":
return Granularity.Rankwise
else:
return Granularity.Bankwise
def getAmountOfCommandBusSpans(phase):
if phase == "PDNA" or phase == "PDNAB" or phase == "PDNP" or phase == "PDNPB" or phase == "SREF" or phase == "SREFB":
return 2
else:
return 1
def getUnitOfTime(connection):
_, unit = getClock(connection)
return unit.lower()
def getLastTimestamp(connection):
cursor = connection.cursor()
cursor.execute("SELECT DataStrobeEnd FROM Phases ORDER BY DataStrobeEnd DESC LIMIT 1")
return cursor.fetchone()[0]
def getRanksBankgroupsBanks(connection):
ranks = getNumberOfRanks(connection)
bankgroups = int(getNumberOfBankGroups(connection) / ranks)
banks = int(getNumberOfBanks(connection) / (bankgroups * ranks))
return (ranks, bankgroups, banks)
def getBankName(rank, bankgroup, bank):
return "RA" + str(rank) + "_BG" + str(bankgroup) + "_BA" + str(bank)
def getBankNames(ranks, bankgroups, banks):
names = []
for rank in range(ranks):
for bankgroup in range(bankgroups):
for bank in range(banks):
names.append(getBankName(rank, bankgroup, bank))
return names
def getOccurringSignals(connection):
setOfSignals = set()
setOfSignals.add(StringSignal("REQ"))
setOfSignals.add(StringSignal("RESP"))
(ranks, bankgroups, banks) = getRanksBankgroupsBanks(connection)
for name in getBankNames(ranks, bankgroups, banks):
setOfSignals.add(StringSignal(name))
setOfSignals.add(StringSignal("Command_Bus"))
setOfSignals.add(StringSignal("Data_Bus"))
return setOfSignals
def getDataBusEvents(connection, eventDict, windowRange):
beginWindow, endWindow = windowRange
cursor = connection.cursor()
cursor.execute("SELECT Transactions.ID, DataStrobeBegin, DataStrobeEnd, Command FROM Phases INNER JOIN Transactions ON Transactions.ID=Phases.Transact " +
"WHERE DataStrobeBegin BETWEEN " + str(beginWindow) + " AND " + str(endWindow) +
" AND DataStrobeEnd BETWEEN " + str(beginWindow) + " AND " + str(endWindow))
for transactionId, begin, end, command in cursor.fetchall():
if eventDict.get(begin) == None:
eventDict[begin] = []
if eventDict.get(end) == None:
eventDict[end] = []
eventDict[begin].append(Event("Data_Bus", command + " " + str(transactionId)))
eventDict[end].append(Event("Data_Bus", ""))
def getCommandBusEvents(connection, eventDict, transactionDict, windowRange):
beginWindow, endWindow = windowRange
cursor = connection.cursor()
cursor.execute("SELECT PhaseName, PhaseBegin, PhaseEnd, Transact FROM Phases " +
"WHERE PhaseBegin BETWEEN " + str(beginWindow) + " AND " + str(endWindow) +
" AND PhaseEnd BETWEEN " + str(beginWindow) + " AND " + str(endWindow))
for phase, phaseBegin, phaseEnd, transactionId in cursor.fetchall():
if phase == "REQ" or phase == "RESP":
continue
timespans = []
commandLengthTime = getCommandLengthForPhase(connection, "RD") * getClock(connection)[0]
if getAmountOfCommandBusSpans(phase) == 1:
timespans.append((phaseBegin, phaseBegin + commandLengthTime))
else:
timespans.append((phaseBegin, phaseBegin + commandLengthTime))
timespans.append((phaseEnd - commandLengthTime, phaseEnd))
for begin, end in timespans:
if eventDict.get(begin) == None:
eventDict[begin] = []
if eventDict.get(end) == None:
eventDict[end] = []
eventDict[begin].append(Event("Command_Bus", phase + " " + str(transactionId)))
eventDict[end].append(Event("Command_Bus", ""))
currentTransaction = transactionDict[transactionId]
rank = currentTransaction.rank
bankgroup = currentTransaction.bankgroup
bank = currentTransaction.bank
(ranks, bankgroups, banks) = getRanksBankgroupsBanks(connection)
currentBanks = []
if getGranularity(phase) == Granularity.Rankwise:
rank = currentTransaction.rank
for _bankgroup in range(bankgroups):
for _bank in range(banks):
currentBanks.append((rank, _bankgroup, _bank))
elif getGranularity(phase) == Granularity.Groupwise:
for _bankgroup in range(bankgroups):
currentBanks.append((rank, _bankgroup, bank))
elif getGranularity(phase) == Granularity.TwoBankwise:
currentBanks.append((rank, bankgroup, bank))
per2BankOffset = getPer2BankOffset(connection)
bankgroupOffset = per2BankOffset // banks
bankOffset = per2BankOffset % banks
currentBanks.append((rank, bankgroup + bankgroupOffset, bank + bankOffset))
else:
currentBanks.append((rank, bankgroup, bank))
for _rank, _bankgroup, _bank in currentBanks:
currentBankName = getBankName(_rank, _bankgroup, _bank)
eventDict[begin].append(Event(currentBankName, phase + " " + str(transactionId)))
eventDict[end].append(Event(currentBankName, ""))
def getTransactionRange(connection, transactionRange, windowRange):
beginWindow, endWindow = windowRange
cursor = connection.cursor()
cursor.execute("SELECT Transact FROM Phases" +
" WHERE PhaseBegin BETWEEN " + str(beginWindow) + " AND " + str(endWindow) +
" AND PhaseEnd BETWEEN " + str(beginWindow) + " AND " + str(endWindow))
minTransaction, maxTransaction = float('inf'), 0
for transactionId in cursor.fetchall():
maxTransaction = max(maxTransaction, transactionId[0])
minTransaction = min(minTransaction, transactionId[0])
if minTransaction == float('inf'):
minTransaction = 0
transactionRange.append(minTransaction)
transactionRange.append(maxTransaction)
def getReqAndRespPhases(connection, eventDict, transactionDict, windowRange):
beginWindow, endWindow = windowRange
cursor = connection.cursor()
cursor.execute("SELECT PhaseName, PhaseBegin, PhaseEnd, Transact " +
"FROM Phases WHERE PhaseBegin BETWEEN " + str(beginWindow) + " AND " + str(endWindow) +
" AND PhaseEnd BETWEEN " + str(beginWindow) + " AND " + str(endWindow))
for phase, begin, end, transactionId in cursor.fetchall():
if phase != "REQ" and phase != "RESP":
continue
if eventDict.get(begin) == None:
eventDict[begin] = []
if eventDict.get(end) == None:
eventDict[end] = []
currentTransaction = transactionDict[transactionId]
command = currentTransaction.command
eventDict[begin].append(Event(phase, command + " " + str(transactionId)))
eventDict[end].append(Event(phase, ""))
def getTransactions(connection, transactionDict, transactionRange):
minTransaction, maxTransaction = transactionRange
cursor = connection.cursor()
cursor.execute("SELECT Transactions.ID, Rank, Bankgroup, Bank, Command FROM Transactions INNER JOIN Phases ON Transactions.ID=Phases.Transact" +
" WHERE Transactions.ID BETWEEN " + str(minTransaction) + " AND " + str(maxTransaction))
for transactionId, rank, bankgroup, bank, command in cursor.fetchall():
(ranks, bankgroups, banks) = getRanksBankgroupsBanks(connection)
rank = rank % ranks
bankgroup = bankgroup % bankgroups
bank = bank % banks
currentTransaction = Transaction(rank, bankgroup, bank, command)
transactionDict[transactionId] = currentTransaction
def dumpVcd(pathToTrace):
connection = sqlite3.connect(pathToTrace)
signalList = getOccurringSignals(connection)
window = TimeWindow(TIME_STEP, getLastTimestamp(connection))
with io.StringIO() as f:
currentDate = datetime.date.today().strftime("%B %d, %Y")
unit = getUnitOfTime(connection)
with VCDWriter(f, timescale="1" + unit, date=currentDate) as writer:
variableDict = {}
for signal in signalList:
neutralValue = signal.getNeutralValue()
signalType = signal.getSignalType()
variableDict[signal.name] = writer.register_var("DRAMSys", signal.name, signalType, init=neutralValue)
for windowRange in tqdm(window, total=window.numberOfIterations(), desc="VCD export"):
eventDict = {}
transactionDict = {}
transactionRange = []
getTransactionRange(connection, transactionRange, windowRange)
getTransactions(connection, transactionDict, transactionRange)
getReqAndRespPhases(connection, eventDict, transactionDict, windowRange)
getCommandBusEvents(connection, eventDict, transactionDict, windowRange)
getDataBusEvents(connection, eventDict, windowRange)
# Sort the eventDict so that VCDWriter can work with it.
eventDict = sorted(eventDict.items(), key=lambda x: x[0])
for timestamp, eventList in eventDict:
for event in eventList:
value_to_change = variableDict.get(event.signal)
if value_to_change != None:
writer.change(value_to_change, timestamp, event.value)
f.seek(0)
return f.read()
if __name__ == "__main__":
if len(sys.argv) == 2:
dump = dumpVcd(sys.argv[1])
print(dump)
elif len(sys.argv) == 3:
dump = dumpVcd(sys.argv[1])
with open(sys.argv[2], 'x') as outputFile:
outputFile.write(dump)
else:
print("Usage: ", sys.argv[0], "<trace_file> [output_file_name]")