Bigger dimensions

This commit is contained in:
2024-03-06 12:09:54 +01:00
parent fbbe34212f
commit 3808bcd478
4 changed files with 186 additions and 120 deletions

149
benches.py Normal file
View File

@@ -0,0 +1,149 @@
import torch
import sys
import torch.utils.benchmark as benchmark
import numpy as np
import json
import polars as pl
import dataclasses
import workloads as wl
from config import Statistics, Configuration
device = torch.device("cuda:0")
ITERATIONS = 100_000
def run_gemv_bench(workload, level):
match level:
case "X1":
ROWS = 1024 * 1
case "X2":
ROWS = 1024 * 2
case "X3":
ROWS = 1024 * 4
case "X4":
ROWS = 1024 * 8
match level:
case "X1":
COLUMNS = 1024 * 4
case "X2":
COLUMNS = 1024 * 4
case "X3":
COLUMNS = 1024 * 8
case "X4":
COLUMNS = 1024 * 8
matrix = torch.rand(
(ROWS, COLUMNS),
dtype=torch.float16,
device=device,
)
input_vector = torch.rand(COLUMNS, dtype=torch.float16, device=device)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(ITERATIONS):
torch.matmul(matrix, input_vector)
end.record()
torch.cuda.synchronize()
runtime = int(start.elapsed_time(end) * 1e9 / ITERATIONS)
return runtime
def run_gemv_layers_bench(workload, level):
match level:
case "X1":
DIMENSIONS = 256
case "X2":
DIMENSIONS = 512
case "X3":
DIMENSIONS = 1024
case "X4":
DIMENSIONS = 2048
matrix = torch.rand(
(DIMENSIONS, DIMENSIONS),
dtype=torch.float16,
device=device,
)
input_vector = torch.rand(DIMENSIONS, dtype=torch.float16, device=device)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(ITERATIONS):
for _ in range(5):
input_vector = torch.matmul(matrix, input_vector)
input_vector.relu()
end.record()
torch.cuda.synchronize()
runtime = int(start.elapsed_time(end) * 1e9 / ITERATIONS)
return runtime
def run_vector_bench(workload, level):
match level:
case "X1":
ROWS = 2097152 * 1
case "X2":
ROWS = 2097152 * 2
case "X3":
ROWS = 2097152 * 4
case "X4":
ROWS = 2097152 * 8
vector_a = torch.rand(ROWS, dtype=torch.float16, device=device)
vector_b = torch.rand(ROWS, dtype=torch.float16, device=device)
func = getattr(wl, workload)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(ITERATIONS):
match workload:
case "vadd":
torch.add(vector_a, vector_b)
case "vmul":
torch.mul(vector_a, vector_b)
case "haxpy":
torch.add(vector_a, vector_b, alpha=2)
end.record()
torch.cuda.synchronize()
runtime = int(start.elapsed_time(end) * 1e9 / ITERATIONS)
return runtime
workloads = [
("vadd", run_vector_bench),
("vmul", run_vector_bench),
("haxpy", run_vector_bench),
("gemv", run_gemv_bench),
("gemv_layers", run_gemv_layers_bench),
]
config = Configuration(**json.loads(sys.argv[1]))
match config.workload:
case "vadd" | "vmul" | "haxpy":
runtime = run_vector_bench(config.workload, config.level)
case "gemv":
runtime = run_gemv_bench(config.workload, config.level)
case "gemv_layers":
runtime = run_gemv_layers_bench(config.workload, config.level)
print(json.dumps(dataclasses.asdict(Statistics(runtime))))

View File

@@ -2,6 +2,7 @@ from dataclasses import dataclass
@dataclass
class Configuration:
workload: str
level: str
@dataclass

View File

@@ -1,120 +0,0 @@
import torch
import sys
import torch.utils.benchmark as benchmark
import numpy as np
import json
import polars as pl
import dataclasses
from config import Statistics, Configuration
# device = torch.device("cuda:0")
device = torch.device("cpu")
# ITERATIONS = 1_000_000
ITERATIONS = 1_000
def run_gemv_bench(workload, level):
match level:
case "X1":
ROWS = 16
case "X2":
ROWS = 32
case "X3":
ROWS = 64
case "X4":
ROWS = 128
COLUMNS = 128
matrix = torch.rand(
(ROWS, COLUMNS),
dtype=torch.float16,
device=device,
)
input_vector = torch.rand(COLUMNS, dtype=torch.float16, device=device)
timer = benchmark.Timer(
stmt="gemv(matrix, input_vector)",
setup="from workloads import gemv",
globals={"input_vector": input_vector, "matrix": matrix},
)
return int(timer.timeit(ITERATIONS).mean * 1e12)
def run_gemv_layers_bench(workload, level):
LESS_ITERATIONS = int(ITERATIONS / 100)
match level:
case "X1":
DIMENSIONS = 128
case "X2":
DIMENSIONS = 256
case "X3":
DIMENSIONS = 512
case "X4":
DIMENSIONS = 1024
matrix = torch.rand(
(DIMENSIONS, DIMENSIONS),
dtype=torch.float16,
device=device,
)
input_vector = torch.rand(DIMENSIONS, dtype=torch.float16, device=device)
timer = benchmark.Timer(
stmt="gemv_layers(matrix, input_vector)",
setup="from workloads import gemv_layers",
globals={"input_vector": input_vector, "matrix": matrix},
)
return int(timer.timeit(LESS_ITERATIONS).mean * 1e12)
def run_vector_bench(workload, level):
match level:
case "X1":
ROWS = 256
case "X2":
ROWS = 512
case "X3":
ROWS = 1024
case "X4":
ROWS = 2048
vector_a = torch.rand(ROWS, dtype=torch.float16, device=device)
vector_b = torch.rand(ROWS, dtype=torch.float16, device=device)
timer = benchmark.Timer(
stmt=f"{workload}(vector_a, vector_b)",
setup=f"from workloads import {workload}",
globals={"vector_a": vector_a, "vector_b": vector_b},
)
return int(timer.timeit(ITERATIONS).mean * 1e12)
workloads = [
("vadd", run_vector_bench),
("vmul", run_vector_bench),
("haxpy", run_vector_bench),
("gemv", run_gemv_bench),
("gemv_layers", run_gemv_layers_bench),
]
levels = ["X1", "X2", "X3", "X4"]
results: list[dict] = []
for workload, workload_callback in workloads:
for level in levels:
runtime = workload_callback(workload, level)
element = {"workload": workload, "level": level, "runtime": runtime}
results.append(element)
print(element)
df = pl.DataFrame(results)
df.write_csv("rocm_results.csv")

36
run_microbenchmarks.py Normal file
View File

@@ -0,0 +1,36 @@
import dataclasses
import json
import polars as pl
import subprocess
from config import Configuration, Statistics
workloads = [
"vadd",
"vmul",
"haxpy",
"gemv",
"gemv_layers"
]
levels = ["X1", "X2", "X3", "X4"]
results: list[dict] = []
for workload in workloads:
for level in levels:
config = Configuration(workload, level)
serialized_config = json.dumps(dataclasses.asdict(config))
out = subprocess.run(
["python3", "benches.py", serialized_config], capture_output=True
)
statistics = Statistics(**json.loads(out.stdout))
result = {"workload": workload, "level": level, "runtime": statistics.runtime}
results.append(result)
print(result)
df = pl.DataFrame(results)
df.write_csv("rocm_results.csv")