Compare commits

...

10 Commits

13 changed files with 403 additions and 30 deletions

View File

@@ -43,7 +43,7 @@ scons build/ARM/gem5.opt
from gem5.isas import ISA
from gem5.utils.requires import requires
from gem5.resources.resource import Resource
from gem5.resources.resource import BinaryResource
from gem5.components.memory import SingleChannelDDR3_1600
from gem5.components.processors.cpu_types import CPUTypes
from gem5.components.boards.simple_board import SimpleBoard
@@ -84,7 +84,7 @@ board.set_se_binary_workload(
# Any resource specified in this file will be automatically retrieved.
# At the time of writing, this file is a WIP and does not contain all
# resources. Jira ticket: https://gem5.atlassian.net/browse/GEM5-1096
Resource("arm-hello64-static")
BinaryResource("physical")
)
# Lastly we run the simulation.

15
configs/pim_config.py Normal file
View File

@@ -0,0 +1,15 @@
from dataclasses import dataclass
from pathlib import Path
@dataclass(frozen=True)
class Configuration:
name: str
workload: str
executable: Path
level: str
system: str
frequency: str = "3GHz"
@dataclass(frozen=True)
class Statistics:
ticks: int

View File

@@ -1,3 +1,8 @@
import m5
import json
import dataclasses
import sys
from gem5.isas import ISA
from m5.objects import (
ArmDefaultRelease,
@@ -8,9 +13,13 @@ from gem5.resources.resource import BinaryResource
from gem5.simulate.simulator import Simulator
from m5.objects import VExpress_GEM5_Foundation
from gem5.components.boards.arm_baremetal_board import ArmBareMetalBoard
from gem5.components.memory import DRAMSysDDR3_1600
from gem5.components.memory import DRAMSysHBM2
from gem5.components.processors.cpu_types import CPUTypes
from gem5.components.processors.simple_processor import SimpleProcessor
from gem5.simulate.exit_event import ExitEvent
from dataclasses import dataclass
from pim_config import Configuration, Statistics
requires(isa_required=ISA.ARM)
@@ -19,18 +28,19 @@ from gem5.components.cachehierarchies.classic.private_l1_private_l2_cache_hierar
)
from gem5.components.cachehierarchies.classic.no_cache import NoCache
configuration = Configuration(**json.loads(sys.argv[1]))
cache_hierarchy = PrivateL1PrivateL2CacheHierarchy(
l1d_size="16kB", l1i_size="16kB", l2_size="256kB"
)
# cache_hierarchy = NoCache()
memory = DRAMSysDDR3_1600(recordable=True)
memory = DRAMSysHBM2(recordable=False)
processor = SimpleProcessor(cpu_type=CPUTypes.O3, num_cores=1, isa=ISA.ARM)
release = ArmDefaultRelease()
platform = VExpress_GEM5_Foundation()
board = ArmBareMetalBoard(
clk_freq="3GHz",
clk_freq=configuration.frequency,
processor=processor,
memory=memory,
cache_hierarchy=cache_hierarchy,
@@ -38,15 +48,52 @@ board = ArmBareMetalBoard(
platform=platform,
)
board.m5ops_base = 0x10010000
# HBM2 requires line size of 32 Bytes
board.cache_line_size = 32
for core in processor.get_cores():
core.core.fetchBufferSize = 32
workload = CustomWorkload(
"set_baremetal_workload",
{
"kernel": BinaryResource("aarch64"),
"kernel": BinaryResource(configuration.executable),
},
)
board.set_workload(workload)
simulator = Simulator(board=board)
@dataclass
class WorkloadTime:
start: int
end: int
workload_time = WorkloadTime(0, 0)
def exit_event():
print(f"Workload begin @{m5.curTick()}")
workload_time.start = m5.curTick()
m5.stats.reset()
yield False
print(f"Workload end @{m5.curTick()}")
workload_time.end = m5.curTick()
m5.stats.dump()
yield False
print(f"Exit simulation @{m5.curTick()}...")
yield True
simulator = Simulator(
board=board, on_exit_event={ExitEvent.EXIT: exit_event()}
)
simulator.run()
print(f"Workload took {workload_time.end - workload_time.start}")
statistics = Statistics(workload_time.end - workload_time.start)
print(json.dumps(dataclasses.asdict(statistics)))

View File

@@ -49,7 +49,8 @@ subprocess.run(
f"-B{build_current}",
"-DCMAKE_BUILD_TYPE=Release",
f"-DSCONS_SOURCE_DIR:STRING={scons_root}",
"-DDRAMSYS_BUILD_CLI=OFF"
"-DDRAMSYS_BUILD_CLI=OFF",
"-DDRAMSYS_SHARED_PIM_UNITS=ON"
],
check=True
)
@@ -62,6 +63,9 @@ subprocess.run(
env.Append(LIBS="DRAMSys_libdramsys")
env.Append(LIBPATH=Dir("./DRAMSys/src/libdramsys").abspath)
env.Append(LIBS=["libpim_vm", "libpim-vm-cxx"])
env.Append(LIBPATH=Dir("./DRAMSys").abspath)
env.Append(LIBS="DRAMSys_Configuration")
env.Append(LIBPATH=Dir("./DRAMSys/src/configuration").abspath)

45
latex_table.py Normal file
View File

@@ -0,0 +1,45 @@
import matplotlib.pyplot as plt
import seaborn as sns
import polars as pl
import numpy as np
from pathlib import Path
workload_order = {val: idx for idx, val in enumerate(["vadd", "vmul", "haxpy", "gemv", "gemv_layers"])}
workload_mapping = {
"vadd": "VADD",
"vmul": "VMUL",
"haxpy": "HAXPY",
"gemv": "GEMV",
"gemv_layers": "DNN",
}
out_directory = Path("tables_out")
df = pl.read_csv("pim_results.csv")
df = df.select(["workload", "level", "system", "frequency", "ticks"])
for name, data in df.group_by(["frequency"], maintain_order=True):
data = data.pivot(index=["workload", "level"], columns=["system"], values=["ticks"])
data = data.sort(pl.col("workload").replace(workload_order))
data = data.with_columns(pl.col("workload").replace(workload_mapping))
data = data.rename({"HBM": "hbm", "PIM-HBM": "pim"})
print(data)
data.write_csv(out_directory / f"simulations_{name[0]}.csv")
vega_df = pl.read_csv("vega_results.csv")
vega_df = vega_df.with_columns(system=pl.lit("vega"))
tesla_df = pl.read_csv("tesla_results.csv")
tesla_df = tesla_df.with_columns(system=pl.lit("tesla"))
torch_df = pl.concat([vega_df, tesla_df])
torch_df = torch_df.pivot(index=["workload", "level"], columns=["system"], values=["runtime"])
torch_df = torch_df.sort(pl.col("workload").replace(workload_order))
torch_df = torch_df.with_columns(pl.col("workload").replace(workload_mapping))
print(torch_df)
torch_df.write_csv(out_directory / "torch.csv")

64
pim_plots.py Normal file
View File

@@ -0,0 +1,64 @@
import polars as pl
import seaborn as sns
import matplotlib.pyplot as plt
from datetime import datetime
from pathlib import Path
out_directory = Path("pim_plots_out")
df = pl.read_csv("pim_results.csv")
workload_sets = {
"vector": ["vadd", "vmul", "haxpy"],
"matrix": ["gemv", "dnn"],
}
workload_mapping = {
"gemv_layers": "dnn",
}
system_mapping = {
"HBM": "hbm",
"PIM-HBM": "pim"
}
def calc_speedup(tick_list):
return tick_list[0] / tick_list[1]
df = df.with_columns(pl.col("workload").replace(workload_mapping))
df = df.with_columns(pl.col("system").replace(system_mapping))
df = df.group_by(
["workload", "level", "frequency"], maintain_order=True
).agg(pl.col("ticks").map_elements(calc_speedup).alias("speedup"))
for name, data in df.group_by(
"frequency",
pl.when(pl.col("workload").is_in(workload_sets["vector"]))
.then(pl.lit("vector"))
.when(pl.col("workload").is_in(workload_sets["matrix"]))
.then(pl.lit("matrix")),
):
plot = sns.catplot(
data=data.to_pandas(),
kind="bar",
x="level",
y="speedup",
hue="workload",
palette="dark",
alpha=0.6,
height=6,
)
plot.set_axis_labels("Level", "Speedup")
plot.set(title=name[0] + name[1])
plot.fig.subplots_adjust(top=0.95)
data = data.pivot(index=["level"], columns=["workload"], values=["speedup"])
print(data)
data.write_csv(out_directory / f"{name[1]}_{name[0]}.csv")
plt.show()

104
simulation_script.py Normal file
View File

@@ -0,0 +1,104 @@
import subprocess
import dataclasses
import json
import pandas as pd
from tqdm import tqdm
from dataclasses import dataclass
from threading import Thread
from multiprocessing.pool import ThreadPool
from pathlib import Path
from configs.pim_config import Configuration, Statistics
gem5 = Path("build/ARM/gem5.opt")
out_dir_base = Path("pim_out")
pim_simulation = Path("configs/pim_simulation.py")
@dataclass
class WorkItem:
configuration: Configuration
statistics: Statistics | None = None
def run_gem5_process(work_item: WorkItem):
serialized_configuration = json.dumps(
dataclasses.asdict(work_item.configuration)
)
out_dir = out_dir_base / work_item.configuration.name
out = subprocess.run(
[
gem5,
"-d" + out_dir.as_posix(),
pim_simulation,
serialized_configuration,
],
capture_output=True,
)
output = out.stdout.splitlines()[-1]
work_item.statistics = Statistics(**json.loads(output))
workload_base_directory = Path("kernels")
workload_sub_directory = Path("aarch64-unknown-none/release")
workloads = [
"vadd",
"vmul",
"haxpy",
"gemv",
"gemv_layers",
]
systems = [
"HBM",
"PIM-HBM",
]
configurations: list[Configuration] = []
for frequency in ["3GHz", "100GHz"]:
# for frequency in ["100GHz"]:
for level in ["X1", "X2", "X3", "X4"]:
# for level in ["X3"]:
for system in systems:
for workload in workloads:
executable = workload
if system == "HBM":
executable = f"classic_{workload}"
executable = (
workload_base_directory
/ level
/ workload_sub_directory
/ executable
)
configurations.append(
Configuration(
f"{workload}_{level}_{system}_{frequency}",
workload,
executable.as_posix(),
level,
system,
frequency,
)
)
work_items = [WorkItem(configuration) for configuration in configurations]
with ThreadPool() as pool:
for _ in tqdm(pool.imap_unordered(run_gem5_process, work_items), total=len(work_items)):
pass
results: list[dict] = []
for work_item in work_items:
result = dataclasses.asdict(work_item.configuration) | dataclasses.asdict(work_item.statistics)
results.append(result)
dataframe = pd.DataFrame(results)
dataframe.to_csv("pim_results.csv", index=False)

View File

@@ -27,6 +27,8 @@
*/
#include "dramsys.hh"
#include "DRAMSys/common/Deserialize.h"
#include "DRAMSys/common/Serialize.h"
namespace gem5
{

View File

@@ -89,24 +89,6 @@ tlm::tlm_sync_enum DRAMSysWrapper::nb_transport_fw(
// Subtract base address offset
payload.set_address(payload.get_address() - range.start());
if (payload.get_address() < 0x4000 && payload.is_write() && phase == tlm::BEGIN_REQ)
{
char *msg = reinterpret_cast<char*>(payload.get_data_ptr());
for (std::size_t i = 0; i < payload.get_data_length(); i++)
{
if (msg[i] != '\0')
{
message.push_back(msg[i]);
}
else
{
std::cout << message << std::endl;
message.clear();
break;
}
}
}
return iSocket->nb_transport_fw(payload, phase, fwDelay);
}

View File

@@ -84,7 +84,6 @@ class DRAMSysWrapper : public sc_core::sc_module
tlm_utils::simple_target_socket<DRAMSysWrapper> tSocket;
std::shared_ptr<::DRAMSys::DRAMSys> dramsys;
std::string message;
AddrRange range;
};

View File

@@ -183,6 +183,6 @@ class DRAMSysHBM2(DRAMSysMem):
configuration=(
DEFAULT_DRAMSYS_DIRECTORY / "configs/hbm2-example.json"
).as_posix(),
size="1GB",
size="2GB",
recordable=recordable,
)

59
torch_plots.py Normal file
View File

@@ -0,0 +1,59 @@
import matplotlib.pyplot as plt
import seaborn as sns
import polars as pl
from pathlib import Path
out_directory = Path("torch_plots_out")
system_mapping = {
"HBM": "hbm",
"PIM-HBM": "pim"
}
gem_df = pl.read_csv("pim_results.csv")
gem_df = gem_df.with_columns(pl.col("system").replace(system_mapping))
gem_df = gem_df.with_columns(
pl.concat_str(["system", "frequency"], separator="_").alias("system")
)
gem_df = gem_df.select(["workload", "level", "system", "ticks"])
vega_df = pl.read_csv("vega_results.csv")
vega_df = vega_df.rename({"runtime": "ticks"})
vega_df = vega_df.with_columns(pl.lit("vega").alias("system"))
tesla_df = pl.read_csv("tesla_results.csv")
tesla_df = tesla_df.rename({"runtime": "ticks"})
tesla_df = tesla_df.with_columns(pl.lit("tesla").alias("system"))
df = pl.concat([gem_df, vega_df, tesla_df], how="diagonal")
workload_sets = [["vadd", "vmul", "haxpy"], ["gemv", "dnn"]]
workload_mapping = {
"gemv_layers": "dnn",
}
df = df.with_columns(pl.col("workload").replace(workload_mapping))
# for workload_set in workload_sets:
# temp_df = df.filter(pl.col("workload").is_in(workload_set))
g = sns.catplot(
data=df.to_pandas(),
kind="bar",
x="level",
y="ticks",
hue="system",
col="workload",
palette="dark",
alpha=0.6,
height=6,
)
for name, data in df.group_by("system"):
data = data.pivot(index=["level"], columns=["workload"], values=["ticks"])
data.write_csv(out_directory / f"{name}.csv")
print(data)
plt.show()

52
wallclock_time_plots.py Normal file
View File

@@ -0,0 +1,52 @@
import re
import polars as pl
import seaborn as sns
import matplotlib.pyplot as plt
from datetime import timedelta
from pathlib import Path
stats_dir = Path("pim_out")
runtime_dict: dict[str, list[any]] = {}
for element in stats_dir.iterdir():
print(element.name)
matches = re.search(r'(\w+)_(\w+)_(\w*-*\w*)_(\w+)', element.name)
workload, level, system, freq = matches.group(1), matches.group(2), matches.group(3), matches.group(4)
with open(element / "stats.txt") as f:
regex = re.compile(r'hostSeconds\ +(\d+.\d+).*')
for line in f:
result = regex.search(line)
if result is not None:
# implicitly only get last match in file...
runtime = result.group(1)
runtime_dict.setdefault("workload", []).append(workload)
runtime_dict.setdefault("level", []).append(level)
runtime_dict.setdefault("system", []).append(system)
runtime_dict.setdefault("freq", []).append(freq)
runtime_dict.setdefault("runtime", []).append(float(runtime))
df = pl.DataFrame(runtime_dict)
df = df.filter((pl.col("freq") == "100GHz") & (pl.col("level") == "X3"))
df = df.drop("freq")
print(df)
plot = sns.catplot(
data=df.to_pandas(),
kind="bar",
x="system",
y="runtime",
hue="workload",
palette="dark",
alpha=0.6,
height=6,
)
plot.set_axis_labels("PIM vs. Non-PIM", "Runtime [s]")
plot.set(title="Wallclock Time")
plot.fig.subplots_adjust(top=0.95)
plt.show()