Files
pytorch-pim/benches.py

156 lines
3.9 KiB
Python

import torch
import sys
import torch.utils.benchmark as benchmark
import numpy as np
import json
import polars as pl
import dataclasses
import workloads as wl
from config import Statistics, Configuration
device = torch.device("cuda:0")
ITERATIONS = 10_000
def run_gemv_bench(workload, level):
match level:
case "X1":
ROWS = 1024 * 1
case "X2":
ROWS = 1024 * 2
case "X3":
ROWS = 1024 * 4
case "X4":
ROWS = 1024 * 8
match level:
case "X1":
COLUMNS = 1024 * 4
case "X2":
COLUMNS = 1024 * 4
case "X3":
COLUMNS = 1024 * 8
case "X4":
COLUMNS = 1024 * 8
matrix = torch.rand(
(ROWS, COLUMNS),
dtype=torch.float16,
device=device,
)
input_vector = torch.rand(COLUMNS, dtype=torch.float16, device=device)
def bench_callback(matrix, input_vector):
torch.matmul(matrix, input_vector)
timer = benchmark.Timer(
"bench_callback(matrix, input_vector)",
globals={
"bench_callback": bench_callback,
"matrix": matrix,
"input_vector": input_vector,
},
)
runtime = int(timer.timeit(ITERATIONS).mean * 1e12)
return runtime
def run_gemv_layers_bench(workload, level):
match level:
case "X1":
DIMENSIONS = 256
case "X2":
DIMENSIONS = 512
case "X3":
DIMENSIONS = 1024
case "X4":
DIMENSIONS = 2048
matrix = torch.rand(
(DIMENSIONS, DIMENSIONS),
dtype=torch.float16,
device=device,
)
input_vector = torch.rand(DIMENSIONS, dtype=torch.float16, device=device)
def bench_callback(matrix, input_vector):
for _ in range(5):
input_vector = torch.matmul(matrix, input_vector)
input_vector.relu()
timer = benchmark.Timer(
"bench_callback(matrix, input_vector)",
globals={
"bench_callback": bench_callback,
"matrix": matrix,
"input_vector": input_vector,
},
)
runtime = int(timer.timeit(ITERATIONS).mean * 1e12)
return runtime
def run_vector_bench(workload, level):
match level:
case "X1":
ROWS = 2097152 * 1
case "X2":
ROWS = 2097152 * 2
case "X3":
ROWS = 2097152 * 4
case "X4":
ROWS = 2097152 * 8
vector_a = torch.rand(ROWS, dtype=torch.float16, device=device)
vector_b = torch.rand(ROWS, dtype=torch.float16, device=device)
func = getattr(wl, workload)
match workload:
case "vadd":
bench_callback = lambda vector_a, vector_b: torch.add(vector_a, vector_b)
case "vmul":
bench_callback = lambda vector_a, vector_b: torch.mul(vector_a, vector_b)
case "haxpy":
bench_callback = lambda vector_a, vector_b: torch.add(
vector_a, vector_b, alpha=2
)
timer = benchmark.Timer(
"bench_callback(vector_a, vector_b)",
globals={
"bench_callback": bench_callback,
"vector_a": vector_a,
"vector_b": vector_b,
},
)
runtime = int(timer.timeit(ITERATIONS).mean * 1e12)
return runtime
workloads = [
("vadd", run_vector_bench),
("vmul", run_vector_bench),
("haxpy", run_vector_bench),
("gemv", run_gemv_bench),
("gemv_layers", run_gemv_layers_bench),
]
config = Configuration(**json.loads(sys.argv[1]))
match config.workload:
case "vadd" | "vmul" | "haxpy":
runtime = run_vector_bench(config.workload, config.level)
case "gemv":
runtime = run_gemv_bench(config.workload, config.level)
case "gemv_layers":
runtime = run_gemv_layers_bench(config.workload, config.level)
print(json.dumps(dataclasses.asdict(Statistics(runtime))))