21 lines
488 B
Python
21 lines
488 B
Python
import torch
|
|
|
|
def gemv(matrix, input_vector):
|
|
return torch.matmul(matrix, input_vector)
|
|
|
|
def gemv_layers(matrix, input_vector):
|
|
for _ in range(5):
|
|
input_vector = torch.matmul(matrix, input_vector)
|
|
input_vector.relu()
|
|
|
|
return input_vector
|
|
|
|
def vadd(vector_a, vector_b):
|
|
return torch.add(vector_a, vector_b)
|
|
|
|
def vmul(vector_a, vector_b):
|
|
return torch.mul(vector_a, vector_b)
|
|
|
|
def haxpy(vector_a, vector_b):
|
|
return torch.add(vector_a, vector_b, alpha=2)
|