import torch import sys import torch.utils.benchmark as benchmark import numpy as np import json import polars as pl import dataclasses from config import Statistics, Configuration device = torch.device("cuda:0") ITERATIONS = 10_000 def run_gemv_bench(workload, level): match level: case "X1": ROWS = 1024 * 1 case "X2": ROWS = 1024 * 2 case "X3": ROWS = 1024 * 4 case "X4": ROWS = 1024 * 8 match level: case "X1": COLUMNS = 1024 * 4 case "X2": COLUMNS = 1024 * 4 case "X3": COLUMNS = 1024 * 8 case "X4": COLUMNS = 1024 * 8 matrix = torch.rand( (ROWS, COLUMNS), dtype=torch.float16, device=device, ) input_vector = torch.rand(COLUMNS, dtype=torch.float16, device=device) def bench_callback(matrix, input_vector): torch.matmul(matrix, input_vector) timer = benchmark.Timer( "bench_callback(matrix, input_vector)", globals={ "bench_callback": bench_callback, "matrix": matrix, "input_vector": input_vector, }, ) runtime = int(timer.timeit(ITERATIONS).mean * 1e12) return runtime def run_gemv_layers_bench(workload, level): match level: case "X1": DIMENSIONS = 256 case "X2": DIMENSIONS = 512 case "X3": DIMENSIONS = 1024 case "X4": DIMENSIONS = 2048 matrix = torch.rand( (DIMENSIONS, DIMENSIONS), dtype=torch.float16, device=device, ) input_vector = torch.rand(DIMENSIONS, dtype=torch.float16, device=device) def bench_callback(matrix, input_vector): for _ in range(5): input_vector = torch.matmul(matrix, input_vector) input_vector.relu() timer = benchmark.Timer( "bench_callback(matrix, input_vector)", globals={ "bench_callback": bench_callback, "matrix": matrix, "input_vector": input_vector, }, ) runtime = int(timer.timeit(ITERATIONS).mean * 1e12) return runtime def run_vector_bench(workload, level): match level: case "X1": ROWS = 2097152 * 1 case "X2": ROWS = 2097152 * 2 case "X3": ROWS = 2097152 * 4 case "X4": ROWS = 2097152 * 8 vector_a = torch.rand(ROWS, dtype=torch.float16, device=device) vector_b = torch.rand(ROWS, dtype=torch.float16, device=device) func = getattr(wl, workload) match workload: case "vadd": bench_callback = lambda vector_a, vector_b: torch.add(vector_a, vector_b) case "vmul": bench_callback = lambda vector_a, vector_b: torch.mul(vector_a, vector_b) case "haxpy": bench_callback = lambda vector_a, vector_b: torch.add( vector_a, vector_b, alpha=2 ) timer = benchmark.Timer( "bench_callback(vector_a, vector_b)", globals={ "bench_callback": bench_callback, "vector_a": vector_a, "vector_b": vector_b, }, ) runtime = int(timer.timeit(ITERATIONS).mean * 1e12) return runtime workloads = [ ("vadd", run_vector_bench), ("vmul", run_vector_bench), ("haxpy", run_vector_bench), ("gemv", run_gemv_bench), ("gemv_layers", run_gemv_layers_bench), ] config = Configuration(**json.loads(sys.argv[1])) match config.workload: case "vadd" | "vmul" | "haxpy": runtime = run_vector_bench(config.workload, config.level) case "gemv": runtime = run_gemv_bench(config.workload, config.level) case "gemv_layers": runtime = run_gemv_layers_bench(config.workload, config.level) print(json.dumps(dataclasses.asdict(Statistics(runtime))))