156 lines
3.9 KiB
Python
156 lines
3.9 KiB
Python
import torch
|
|
import sys
|
|
import torch.utils.benchmark as benchmark
|
|
import numpy as np
|
|
import json
|
|
import polars as pl
|
|
import dataclasses
|
|
|
|
import workloads as wl
|
|
|
|
from config import Statistics, Configuration
|
|
|
|
device = torch.device("cuda:0")
|
|
|
|
ITERATIONS = 10_000
|
|
|
|
|
|
def run_gemv_bench(workload, level):
|
|
match level:
|
|
case "X1":
|
|
ROWS = 1024 * 1
|
|
case "X2":
|
|
ROWS = 1024 * 2
|
|
case "X3":
|
|
ROWS = 1024 * 4
|
|
case "X4":
|
|
ROWS = 1024 * 8
|
|
|
|
match level:
|
|
case "X1":
|
|
COLUMNS = 1024 * 4
|
|
case "X2":
|
|
COLUMNS = 1024 * 4
|
|
case "X3":
|
|
COLUMNS = 1024 * 8
|
|
case "X4":
|
|
COLUMNS = 1024 * 8
|
|
|
|
matrix = torch.rand(
|
|
(ROWS, COLUMNS),
|
|
dtype=torch.float16,
|
|
device=device,
|
|
)
|
|
input_vector = torch.rand(COLUMNS, dtype=torch.float16, device=device)
|
|
|
|
def bench_callback(matrix, input_vector):
|
|
torch.matmul(matrix, input_vector)
|
|
|
|
timer = benchmark.Timer(
|
|
"bench_callback(matrix, input_vector)",
|
|
globals={
|
|
"bench_callback": bench_callback,
|
|
"matrix": matrix,
|
|
"input_vector": input_vector,
|
|
},
|
|
)
|
|
runtime = int(timer.timeit(ITERATIONS).mean * 1e12)
|
|
|
|
return runtime
|
|
|
|
|
|
def run_gemv_layers_bench(workload, level):
|
|
match level:
|
|
case "X1":
|
|
DIMENSIONS = 256
|
|
case "X2":
|
|
DIMENSIONS = 512
|
|
case "X3":
|
|
DIMENSIONS = 1024
|
|
case "X4":
|
|
DIMENSIONS = 2048
|
|
|
|
matrix = torch.rand(
|
|
(DIMENSIONS, DIMENSIONS),
|
|
dtype=torch.float16,
|
|
device=device,
|
|
)
|
|
input_vector = torch.rand(DIMENSIONS, dtype=torch.float16, device=device)
|
|
|
|
def bench_callback(matrix, input_vector):
|
|
for _ in range(5):
|
|
input_vector = torch.matmul(matrix, input_vector)
|
|
input_vector.relu()
|
|
|
|
timer = benchmark.Timer(
|
|
"bench_callback(matrix, input_vector)",
|
|
globals={
|
|
"bench_callback": bench_callback,
|
|
"matrix": matrix,
|
|
"input_vector": input_vector,
|
|
},
|
|
)
|
|
runtime = int(timer.timeit(ITERATIONS).mean * 1e12)
|
|
|
|
return runtime
|
|
|
|
|
|
def run_vector_bench(workload, level):
|
|
match level:
|
|
case "X1":
|
|
ROWS = 2097152 * 1
|
|
case "X2":
|
|
ROWS = 2097152 * 2
|
|
case "X3":
|
|
ROWS = 2097152 * 4
|
|
case "X4":
|
|
ROWS = 2097152 * 8
|
|
|
|
vector_a = torch.rand(ROWS, dtype=torch.float16, device=device)
|
|
vector_b = torch.rand(ROWS, dtype=torch.float16, device=device)
|
|
|
|
func = getattr(wl, workload)
|
|
|
|
match workload:
|
|
case "vadd":
|
|
bench_callback = lambda vector_a, vector_b: torch.add(vector_a, vector_b)
|
|
case "vmul":
|
|
bench_callback = lambda vector_a, vector_b: torch.mul(vector_a, vector_b)
|
|
case "haxpy":
|
|
bench_callback = lambda vector_a, vector_b: torch.add(
|
|
vector_a, vector_b, alpha=2
|
|
)
|
|
|
|
timer = benchmark.Timer(
|
|
"bench_callback(vector_a, vector_b)",
|
|
globals={
|
|
"bench_callback": bench_callback,
|
|
"vector_a": vector_a,
|
|
"vector_b": vector_b,
|
|
},
|
|
)
|
|
runtime = int(timer.timeit(ITERATIONS).mean * 1e12)
|
|
|
|
return runtime
|
|
|
|
|
|
workloads = [
|
|
("vadd", run_vector_bench),
|
|
("vmul", run_vector_bench),
|
|
("haxpy", run_vector_bench),
|
|
("gemv", run_gemv_bench),
|
|
("gemv_layers", run_gemv_layers_bench),
|
|
]
|
|
|
|
config = Configuration(**json.loads(sys.argv[1]))
|
|
|
|
match config.workload:
|
|
case "vadd" | "vmul" | "haxpy":
|
|
runtime = run_vector_bench(config.workload, config.level)
|
|
case "gemv":
|
|
runtime = run_gemv_bench(config.workload, config.level)
|
|
case "gemv_layers":
|
|
runtime = run_gemv_layers_bench(config.workload, config.level)
|
|
|
|
print(json.dumps(dataclasses.asdict(Statistics(runtime))))
|