diff --git a/pim-os/src/bin/samsung_matrix_vector_multiply.rs b/pim-os/src/bin/samsung_matrix_vector_multiply.rs index bb7f240..480c635 100644 --- a/pim-os/src/bin/samsung_matrix_vector_multiply.rs +++ b/pim-os/src/bin/samsung_matrix_vector_multiply.rs @@ -26,17 +26,10 @@ const X16_COLUMNS: usize = COLUMNS / 16; pub extern "C" fn main() { pim::state::set_kernel(&samsung_matrix_vector_mul::KERNEL); - let matrix = Box::new(pim::continuous_array::Matrix::( - SMatrix::from_fn(|r, c| { - if c > 0 { - return F16x16::zero(); - } - - let mut entry = F16x16::zero(); - entry.0.iter_mut().take(r).for_each(|val| *val = F16x1::one()); - - entry - }))); + let mut matrix = SMatrix::<_, ROWS, COLUMNS>::zeros(); + matrix.fill_lower_triangle(F16x1::one(), 0); + + let pim_matrix = Box::new(pim::continuous_array::Matrix::from(matrix)); let input_vector = SVector::<_, X16_COLUMNS>::from_element(F16x16::one()); let interleaved_input_vector = Box::new(interleaved_array::Vector::from(input_vector)); @@ -53,7 +46,7 @@ pub extern "C" fn main() { pim::state::set_bank_mode(BankMode::PimAllBank); samsung_matrix_vector_mul::execute( - matrix.as_ref(), + pim_matrix.as_ref(), interleaved_input_vector.as_ref(), output_partial_sum_vector.as_mut(), dummy.as_ref(), diff --git a/pim-os/src/lib.rs b/pim-os/src/lib.rs index f0537aa..5ee2f57 100644 --- a/pim-os/src/lib.rs +++ b/pim-os/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(iter_array_chunks)] #![no_std] use core::sync::atomic::{compiler_fence, Ordering}; diff --git a/pim-os/src/pim/continuous_array.rs b/pim-os/src/pim/continuous_array.rs index 3451acf..2497696 100644 --- a/pim-os/src/pim/continuous_array.rs +++ b/pim-os/src/pim/continuous_array.rs @@ -1,4 +1,4 @@ -use super::vector::F16x16; +use super::vector::{F16x1, F16x16}; use core::fmt::Display; use nalgebra::SMatrix; @@ -11,3 +11,18 @@ impl Display for Matrix { self.0.fmt(f) } } + +impl From> + for Matrix +{ + fn from(matrix: SMatrix) -> Self { + Self(SMatrix::from_row_iterator( + matrix + .transpose() + .iter() + .map(|e| *e) + .array_chunks::<16>() + .map(|chunk| F16x16(chunk)), + )) + } +}