import torch import sys import torch.utils.benchmark as benchmark import numpy as np import json import polars as pl import dataclasses import workloads as wl from config import Statistics, Configuration device = torch.device("cuda:0") ITERATIONS = 100_000 def run_gemv_bench(workload, level): match level: case "X1": ROWS = 1024 * 1 case "X2": ROWS = 1024 * 2 case "X3": ROWS = 1024 * 4 case "X4": ROWS = 1024 * 8 match level: case "X1": COLUMNS = 1024 * 4 case "X2": COLUMNS = 1024 * 4 case "X3": COLUMNS = 1024 * 8 case "X4": COLUMNS = 1024 * 8 matrix = torch.rand( (ROWS, COLUMNS), dtype=torch.float16, device=device, ) input_vector = torch.rand(COLUMNS, dtype=torch.float16, device=device) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(ITERATIONS): torch.matmul(matrix, input_vector) end.record() torch.cuda.synchronize() runtime = int(start.elapsed_time(end) * 1e9 / ITERATIONS) return runtime def run_gemv_layers_bench(workload, level): match level: case "X1": DIMENSIONS = 256 case "X2": DIMENSIONS = 512 case "X3": DIMENSIONS = 1024 case "X4": DIMENSIONS = 2048 matrix = torch.rand( (DIMENSIONS, DIMENSIONS), dtype=torch.float16, device=device, ) input_vector = torch.rand(DIMENSIONS, dtype=torch.float16, device=device) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(ITERATIONS): for _ in range(5): input_vector = torch.matmul(matrix, input_vector) input_vector.relu() end.record() torch.cuda.synchronize() runtime = int(start.elapsed_time(end) * 1e9 / ITERATIONS) return runtime def run_vector_bench(workload, level): match level: case "X1": ROWS = 2097152 * 1 case "X2": ROWS = 2097152 * 2 case "X3": ROWS = 2097152 * 4 case "X4": ROWS = 2097152 * 8 vector_a = torch.rand(ROWS, dtype=torch.float16, device=device) vector_b = torch.rand(ROWS, dtype=torch.float16, device=device) func = getattr(wl, workload) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(ITERATIONS): match workload: case "vadd": torch.add(vector_a, vector_b) case "vmul": torch.mul(vector_a, vector_b) case "haxpy": torch.add(vector_a, vector_b, alpha=2) end.record() torch.cuda.synchronize() runtime = int(start.elapsed_time(end) * 1e9 / ITERATIONS) return runtime workloads = [ ("vadd", run_vector_bench), ("vmul", run_vector_bench), ("haxpy", run_vector_bench), ("gemv", run_gemv_bench), ("gemv_layers", run_gemv_layers_bench), ] config = Configuration(**json.loads(sys.argv[1])) match config.workload: case "vadd" | "vmul" | "haxpy": runtime = run_vector_bench(config.workload, config.level) case "gemv": runtime = run_gemv_bench(config.workload, config.level) case "gemv_layers": runtime = run_gemv_layers_bench(config.workload, config.level) print(json.dumps(dataclasses.asdict(Statistics(runtime))))