Bigger dimensions
This commit is contained in:
149
benches.py
Normal file
149
benches.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import torch
|
||||
import sys
|
||||
import torch.utils.benchmark as benchmark
|
||||
import numpy as np
|
||||
import json
|
||||
import polars as pl
|
||||
import dataclasses
|
||||
|
||||
import workloads as wl
|
||||
|
||||
from config import Statistics, Configuration
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
ITERATIONS = 100_000
|
||||
|
||||
|
||||
def run_gemv_bench(workload, level):
|
||||
match level:
|
||||
case "X1":
|
||||
ROWS = 1024 * 1
|
||||
case "X2":
|
||||
ROWS = 1024 * 2
|
||||
case "X3":
|
||||
ROWS = 1024 * 4
|
||||
case "X4":
|
||||
ROWS = 1024 * 8
|
||||
|
||||
match level:
|
||||
case "X1":
|
||||
COLUMNS = 1024 * 4
|
||||
case "X2":
|
||||
COLUMNS = 1024 * 4
|
||||
case "X3":
|
||||
COLUMNS = 1024 * 8
|
||||
case "X4":
|
||||
COLUMNS = 1024 * 8
|
||||
|
||||
matrix = torch.rand(
|
||||
(ROWS, COLUMNS),
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
input_vector = torch.rand(COLUMNS, dtype=torch.float16, device=device)
|
||||
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
for _ in range(ITERATIONS):
|
||||
torch.matmul(matrix, input_vector)
|
||||
end.record()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
runtime = int(start.elapsed_time(end) * 1e9 / ITERATIONS)
|
||||
return runtime
|
||||
|
||||
|
||||
def run_gemv_layers_bench(workload, level):
|
||||
match level:
|
||||
case "X1":
|
||||
DIMENSIONS = 256
|
||||
case "X2":
|
||||
DIMENSIONS = 512
|
||||
case "X3":
|
||||
DIMENSIONS = 1024
|
||||
case "X4":
|
||||
DIMENSIONS = 2048
|
||||
|
||||
matrix = torch.rand(
|
||||
(DIMENSIONS, DIMENSIONS),
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
input_vector = torch.rand(DIMENSIONS, dtype=torch.float16, device=device)
|
||||
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
|
||||
for _ in range(ITERATIONS):
|
||||
for _ in range(5):
|
||||
input_vector = torch.matmul(matrix, input_vector)
|
||||
input_vector.relu()
|
||||
|
||||
end.record()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
runtime = int(start.elapsed_time(end) * 1e9 / ITERATIONS)
|
||||
return runtime
|
||||
|
||||
|
||||
def run_vector_bench(workload, level):
|
||||
match level:
|
||||
case "X1":
|
||||
ROWS = 2097152 * 1
|
||||
case "X2":
|
||||
ROWS = 2097152 * 2
|
||||
case "X3":
|
||||
ROWS = 2097152 * 4
|
||||
case "X4":
|
||||
ROWS = 2097152 * 8
|
||||
|
||||
vector_a = torch.rand(ROWS, dtype=torch.float16, device=device)
|
||||
vector_b = torch.rand(ROWS, dtype=torch.float16, device=device)
|
||||
|
||||
func = getattr(wl, workload)
|
||||
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
for _ in range(ITERATIONS):
|
||||
match workload:
|
||||
case "vadd":
|
||||
torch.add(vector_a, vector_b)
|
||||
case "vmul":
|
||||
torch.mul(vector_a, vector_b)
|
||||
case "haxpy":
|
||||
torch.add(vector_a, vector_b, alpha=2)
|
||||
end.record()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
runtime = int(start.elapsed_time(end) * 1e9 / ITERATIONS)
|
||||
return runtime
|
||||
|
||||
workloads = [
|
||||
("vadd", run_vector_bench),
|
||||
("vmul", run_vector_bench),
|
||||
("haxpy", run_vector_bench),
|
||||
("gemv", run_gemv_bench),
|
||||
("gemv_layers", run_gemv_layers_bench),
|
||||
]
|
||||
|
||||
config = Configuration(**json.loads(sys.argv[1]))
|
||||
|
||||
match config.workload:
|
||||
case "vadd" | "vmul" | "haxpy":
|
||||
runtime = run_vector_bench(config.workload, config.level)
|
||||
case "gemv":
|
||||
runtime = run_gemv_bench(config.workload, config.level)
|
||||
case "gemv_layers":
|
||||
runtime = run_gemv_layers_bench(config.workload, config.level)
|
||||
|
||||
print(json.dumps(dataclasses.asdict(Statistics(runtime))))
|
||||
Reference in New Issue
Block a user