Files
pytorch-pim/workloads.py
2024-03-04 23:04:16 +01:00

21 lines
488 B
Python

import torch
def gemv(matrix, input_vector):
return torch.matmul(matrix, input_vector)
def gemv_layers(matrix, input_vector):
for _ in range(5):
input_vector = torch.matmul(matrix, input_vector)
input_vector.relu()
return input_vector
def vadd(vector_a, vector_b):
return torch.add(vector_a, vector_b)
def vmul(vector_a, vector_b):
return torch.mul(vector_a, vector_b)
def haxpy(vector_a, vector_b):
return torch.add(vector_a, vector_b, alpha=2)