Prepare dataframe format for Latex plots

This commit is contained in:
2024-02-28 16:26:37 +01:00
committed by Derek Christ
parent 3d9533c10c
commit 353488837c
2 changed files with 33 additions and 23 deletions

View File

@@ -3,33 +3,46 @@ import seaborn as sns
import pandas as pd
import numpy as np
df = pd.read_csv("pim_results.csv")
from pathlib import Path
workloads = df["workload"].unique()
df = pd.read_csv("pim_results.csv")
sns.set_theme()
def calc_speedup(x):
return x.iat[0] / x.iat[1]
for workload in df["workload"].unique():
workload_filter = df["workload"] == workload
workload_sets = [["vadd", "vmul", "haxpy"], ["gemv", "gemv_layers"]]
filtered_df = df[workload_filter]
preprocessed_df = filtered_df.groupby(["workload", "level", "frequency"], as_index=False).agg({"ticks": calc_speedup}).rename(columns={"ticks":"speedup"})
for workload_set in workload_sets:
workload_filter = df["workload"].isin(workload_set)
# print(preprocessed_df)
# preprocessed_df.to_csv("plot.csv", index=False)
for frequency in df["frequency"].unique():
frequency_filter = df["frequency"] == frequency
g = sns.catplot(
data=preprocessed_df, kind="bar",
x="level", y="speedup", hue="frequency",
palette="dark", alpha=.6, height=6
)
filtered_df = df[workload_filter & frequency_filter]
print(filtered_df)
preprocessed_df = filtered_df.groupby(["workload", "level", "frequency"], as_index=False).agg({"ticks": calc_speedup}).rename(columns={"ticks":"speedup"})
g.despine(left=True)
g.set_axis_labels("", "Speedup")
g.set(title=workload)
g.legend.set_title("")
print(preprocessed_df)
# preprocessed_df.to_csv("plot.csv", index=False)
g = sns.catplot(
data=preprocessed_df, kind="bar",
x="level", y="speedup", hue="workload",
palette="dark", alpha=.6, height=6
)
g.despine(left=True)
g.set_axis_labels("", "Speedup")
g.set(title=frequency)
g.legend.set_title("")
for workload in workload_set:
export_df = preprocessed_df[preprocessed_df["workload"] == workload]
filename = f"{workload}_{frequency}.csv"
directory = Path("plots_out")
export_df.to_csv(directory / filename, index=False)
plt.show()

View File

@@ -49,7 +49,7 @@ workloads = [
"vmul",
"haxpy",
"gemv",
# "gemv_layers",
"gemv_layers",
]
systems = [
@@ -63,9 +63,6 @@ for frequency in ["3GHz", "100GHz"]:
for level in ["X1", "X2", "X3", "X4"]:
for system in systems:
for workload in workloads:
if workload == "gemv_layers" and level != "X4":
continue
executable = workload
if system == "HBM":