diff --git a/pim-os/src/bin/samsung_matrix_vector_multiply.rs b/pim-os/src/bin/samsung_matrix_vector_multiply.rs index 480c635..c8ddb41 100644 --- a/pim-os/src/bin/samsung_matrix_vector_multiply.rs +++ b/pim-os/src/bin/samsung_matrix_vector_multiply.rs @@ -18,8 +18,9 @@ use pim_os::{ uart::Uart0, }; -const ROWS: usize = 16; +const ROWS: usize = 32; const COLUMNS: usize = 128; +const X16_ROWS: usize = ROWS / 16; const X16_COLUMNS: usize = COLUMNS / 16; #[no_mangle] @@ -29,7 +30,7 @@ pub extern "C" fn main() { let mut matrix = SMatrix::<_, ROWS, COLUMNS>::zeros(); matrix.fill_lower_triangle(F16x1::one(), 0); - let pim_matrix = Box::new(pim::continuous_array::Matrix::from(matrix)); + let pim_matrix = Box::new(pim::continuous_array::Matrix::::from(matrix)); let input_vector = SVector::<_, X16_COLUMNS>::from_element(F16x16::one()); let interleaved_input_vector = Box::new(interleaved_array::Vector::from(input_vector)); diff --git a/pim-os/src/pim/continuous_array.rs b/pim-os/src/pim/continuous_array.rs index 2497696..39207b2 100644 --- a/pim-os/src/pim/continuous_array.rs +++ b/pim-os/src/pim/continuous_array.rs @@ -4,25 +4,31 @@ use nalgebra::SMatrix; #[repr(C, align(65536))] #[derive(Debug)] -pub struct Matrix(pub SMatrix); +pub struct Matrix(pub [SMatrix; X16R]); -impl Display for Matrix { +impl Display for Matrix { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.fmt(f) + for block in self.0.iter() { + block.fmt(f)? + } + Ok(()) } } -impl From> - for Matrix +impl + From> for Matrix { fn from(matrix: SMatrix) -> Self { - Self(SMatrix::from_row_iterator( - matrix - .transpose() - .iter() - .map(|e| *e) - .array_chunks::<16>() - .map(|chunk| F16x16(chunk)), - )) + Self(core::array::from_fn(|i| { + SMatrix::from_row_iterator( + matrix + .fixed_rows::<16>(i * 16) + .transpose() + .iter() + .map(|e| *e) + .array_chunks::<16>() + .map(|chunk| F16x16(chunk)), + ) + })) } } diff --git a/pim-os/src/pim/interleaved_array.rs b/pim-os/src/pim/interleaved_array.rs index aea4b25..99e7ec0 100644 --- a/pim-os/src/pim/interleaved_array.rs +++ b/pim-os/src/pim/interleaved_array.rs @@ -3,19 +3,19 @@ use nalgebra::SVector; #[repr(C, align(512))] #[derive(Debug)] -pub struct Vector(pub [[F16x16; NUMBER_OF_BANKS]; R]); +pub struct Vector(pub [[F16x16; NUMBER_OF_BANKS]; X16R]); -impl Default for Vector { +impl Default for Vector { fn default() -> Self { - Self([[F16x16::default(); NUMBER_OF_BANKS]; R]) + Self([[F16x16::default(); NUMBER_OF_BANKS]; X16R]) } } -impl From> for Vector { - fn from(input_vector: SVector) -> Self { +impl From> for Vector { + fn from(input_vector: SVector) -> Self { let mut interleaved_vector = Self::default(); - for block_index in 0..R { + for block_index in 0..X16R { let element = input_vector[block_index]; for k in 0..NUMBER_OF_BANKS { interleaved_vector.0[block_index][k] = element; diff --git a/pim-os/src/pim/kernel/samsung_matrix_vector_mul.rs b/pim-os/src/pim/kernel/samsung_matrix_vector_mul.rs index 1a7a0e6..6f23c5b 100644 --- a/pim-os/src/pim/kernel/samsung_matrix_vector_mul.rs +++ b/pim-os/src/pim/kernel/samsung_matrix_vector_mul.rs @@ -46,12 +46,16 @@ pub const KERNEL: Kernel = Kernel([ }, Instruction::JUMP { offset: -1, - count: 7, + count: 15, }, Instruction::FILL { src: File::GrfB { index: 0 }, dst: File::Bank, }, + Instruction::FILL { + src: File::GrfB { index: 1 }, + dst: File::Bank, + }, Instruction::EXIT, Instruction::NOP, Instruction::NOP, @@ -72,12 +76,11 @@ pub const KERNEL: Kernel = Kernel([ Instruction::NOP, Instruction::NOP, Instruction::NOP, - Instruction::NOP, ]); -pub fn execute( - matrix: &Matrix, - input_vector: &interleaved_array::Vector, +pub fn execute( + matrix: &Matrix, + input_vector: &interleaved_array::Vector, output_partial_sum_vector: &mut SVector, dummy: &impl PimOperand, ) { @@ -85,11 +88,18 @@ pub fn execute( block.execute_read(); } - for column_block in matrix.0.fixed_rows::<1>(0).iter() { - column_block.execute_read(); + for row_block in matrix.0.iter() { + for column_block in row_block.fixed_rows::<1>(0).iter() { + column_block.execute_read(); + } } - output_partial_sum_vector.execute_write(); + for chunk in output_partial_sum_vector + .fixed_rows_with_step_mut::(0, 16) + .iter_mut() + { + chunk.execute_write(); + } dummy.execute_read(); } diff --git a/pim-os/src/pim/operation.rs b/pim-os/src/pim/operation.rs index 728fb2e..146a480 100644 --- a/pim-os/src/pim/operation.rs +++ b/pim-os/src/pim/operation.rs @@ -1,13 +1,17 @@ +use aarch64_cpu::asm::barrier; + pub trait PimOperand { fn ptr(&self) -> *const u8; fn ptr_mut(&mut self) -> *mut u8; fn execute_read(&self) { unsafe { core::ptr::read_volatile(self.ptr()) }; + barrier::dsb(barrier::SY); } fn execute_write(&mut self) { unsafe { core::ptr::write_volatile(self.ptr_mut(), Default::default()) }; + barrier::dsb(barrier::SY); } } diff --git a/pim-vm/src/lib.rs b/pim-vm/src/lib.rs index 273cdf0..367718d 100644 --- a/pim-vm/src/lib.rs +++ b/pim-vm/src/lib.rs @@ -16,7 +16,14 @@ mod ffi { fn reset(&mut self); fn apply_config(&mut self, config: &str); fn bank_mode(&self) -> BankMode; - fn execute_read(&mut self, bank_index: u32, address: u32, bank_data: &[u8]); + fn execute_read( + &mut self, + bank_index: u32, + address: u32, + row: u32, + column: u32, + bank_data: &[u8], + ); fn execute_write(&mut self, bank_index: u32) -> [u8; 32]; fn init_logger(); @@ -27,8 +34,10 @@ fn init_logger() { env_logger::init(); } -const GRF_A_BIT_OFFSET: usize = 10; -const GRF_B_BIT_OFFSET: usize = 13; +const GRF_A_BIT_OFFSET: usize = 2; +const GRF_B_BIT_OFFSET: usize = 5; +const COLUMN_BITS : usize = 7; + const BURST_LENGTH: usize = 32; const GRF_NUM_REGISTERS: usize = 8; @@ -120,7 +129,14 @@ fn new_pim_vm(num_banks: u32) -> Box { struct BankData([f16; FP_UNITS]); impl PimVM { - pub fn execute_read(&mut self, bank_index: u32, address: u32, bank_data: &[u8]) { + pub fn execute_read( + &mut self, + bank_index: u32, + address: u32, + row: u32, + column: u32, + bank_data: &[u8], + ) { assert_eq!(bank_data.len(), BURST_LENGTH); let pim_unit_index = if cfg!(feature = "shared_pim_units") { @@ -133,8 +149,9 @@ impl PimVM { let inst = self.kernel.0[pim_unit.pc as usize]; - let aam_grf_a_index = (address >> GRF_A_BIT_OFFSET) & 0b111; - let aam_grf_b_index = (address >> GRF_B_BIT_OFFSET) & 0b111; + let row_column_bits = (row << COLUMN_BITS) | column; + let aam_grf_a_index = (row_column_bits >> GRF_A_BIT_OFFSET) & 0b111; + let aam_grf_b_index = (row_column_bits >> GRF_B_BIT_OFFSET) & 0b111; if pim_unit_index == 0 { log::debug!(