import torch import sys import torch.utils.benchmark as benchmark import numpy as np import json import polars as pl import dataclasses from config import Statistics, Configuration # device = torch.device("cuda:0") device = torch.device("cpu") # ITERATIONS = 1_000_000 ITERATIONS = 1_000 def run_gemv_bench(workload, level): match level: case "X1": ROWS = 16 case "X2": ROWS = 32 case "X3": ROWS = 64 case "X4": ROWS = 128 COLUMNS = 128 matrix = torch.rand( (ROWS, COLUMNS), dtype=torch.float16, device=device, ) input_vector = torch.rand(COLUMNS, dtype=torch.float16, device=device) timer = benchmark.Timer( stmt="gemv(matrix, input_vector)", setup="from workloads import gemv", globals={"input_vector": input_vector, "matrix": matrix}, ) return int(timer.timeit(ITERATIONS).mean * 1e12) def run_gemv_layers_bench(workload, level): LESS_ITERATIONS = int(ITERATIONS / 100) match level: case "X1": DIMENSIONS = 128 case "X2": DIMENSIONS = 256 case "X3": DIMENSIONS = 512 case "X4": DIMENSIONS = 1024 matrix = torch.rand( (DIMENSIONS, DIMENSIONS), dtype=torch.float16, device=device, ) input_vector = torch.rand(DIMENSIONS, dtype=torch.float16, device=device) timer = benchmark.Timer( stmt="gemv_layers(matrix, input_vector)", setup="from workloads import gemv_layers", globals={"input_vector": input_vector, "matrix": matrix}, ) return int(timer.timeit(LESS_ITERATIONS).mean * 1e12) def run_vector_bench(workload, level): match level: case "X1": ROWS = 256 case "X2": ROWS = 512 case "X3": ROWS = 1024 case "X4": ROWS = 2048 vector_a = torch.rand(ROWS, dtype=torch.float16, device=device) vector_b = torch.rand(ROWS, dtype=torch.float16, device=device) timer = benchmark.Timer( stmt=f"{workload}(vector_a, vector_b)", setup=f"from workloads import {workload}", globals={"vector_a": vector_a, "vector_b": vector_b}, ) return int(timer.timeit(ITERATIONS).mean * 1e12) workloads = [ ("vadd", run_vector_bench), ("vmul", run_vector_bench), ("haxpy", run_vector_bench), ("gemv", run_gemv_bench), ("gemv_layers", run_gemv_layers_bench), ] levels = ["X1", "X2", "X3", "X4"] results: list[dict] = [] for workload, workload_callback in workloads: for level in levels: runtime = workload_callback(workload, level) element = {"workload": workload, "level": level, "runtime": runtime} results.append(element) print(element) df = pl.DataFrame(results) df.write_csv("rocm_results.csv")