diff --git a/pim-os/src/main.rs b/pim-os/src/main.rs index e368f88..53034f2 100644 --- a/pim-os/src/main.rs +++ b/pim-os/src/main.rs @@ -18,7 +18,7 @@ use half::f16; use nalgebra::Matrix; use pim::{ array::{DummyArray, PimMatrixArena, PimStorage, NUMBER_OF_BANKS}, - kernel::{execute_matrix_multiply, MATRIX_MUL}, + kernel::{execute_matrix_multiply_rowwise, MATRIX_MUL}, state::PimState, vector::{F16x1, F16x16}, }; @@ -49,13 +49,13 @@ pub extern "C" fn entry() -> ! { pim_state.set_kernel(); let pim_matrix_arena0 = Rc::new(RefCell::new(PimMatrixArena( - [[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3], + [[[F16x16::default(); NUMBER_OF_BANKS]; 8]; 8], ))); let pim_matrix_arena1 = Rc::new(RefCell::new(PimMatrixArena( - [[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3], + [[[F16x16::default(); NUMBER_OF_BANKS]; 8]; 8], ))); let pim_matrix_arena2 = Rc::new(RefCell::new(PimMatrixArena( - [[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3], + [[[F16x16::default(); NUMBER_OF_BANKS]; 8]; 8], ))); let pim_storage0 = PimStorage { arena: &pim_matrix_arena0, @@ -92,7 +92,7 @@ pub extern "C" fn entry() -> ! { let mut dummy_array = Box::new(DummyArray([F16x16::default(); NUMBER_OF_BANKS])); barrier::dsb(barrier::SY); - execute_matrix_multiply( + execute_matrix_multiply_rowwise( &mut pim_state, &mut pim_matrix_arena0.borrow_mut(), &mut pim_matrix_arena1.borrow_mut(), diff --git a/pim-os/src/pim/kernel.rs b/pim-os/src/pim/kernel.rs index 170312e..166d84c 100644 --- a/pim-os/src/pim/kernel.rs +++ b/pim-os/src/pim/kernel.rs @@ -115,35 +115,69 @@ pub const MATRIX_MUL: Kernel = Kernel([ src: File::Bank, dst: File::GrfA { index: 2 }, }, + Instruction::MOV { + src: File::Bank, + dst: File::GrfA { index: 3 }, + }, + Instruction::MOV { + src: File::Bank, + dst: File::GrfA { index: 4 }, + }, + Instruction::MOV { + src: File::Bank, + dst: File::GrfA { index: 5 }, + }, + Instruction::MOV { + src: File::Bank, + dst: File::GrfA { index: 6 }, + }, + Instruction::MOV { + src: File::Bank, + dst: File::GrfA { index: 7 }, + }, Instruction::MAC { src0: File::Bank, src1: File::GrfA { index: 0 }, src2: File::GrfB { index: 0 }, dst: File::GrfB { index: 0 }, - aam: false, + aam: true, }, - Instruction::MAC { - src0: File::Bank, - src1: File::GrfA { index: 1 }, - src2: File::GrfB { index: 0 }, - dst: File::GrfB { index: 0 }, - aam: false, + Instruction::JUMP { + offset: -1, + count: 63, }, - Instruction::MAC { - src0: File::Bank, - src1: File::GrfA { index: 2 }, - src2: File::GrfB { index: 0 }, - dst: File::GrfB { index: 0 }, - aam: false, - }, - // Instruction::JUMP { - // offset: -1, - // count: 2, - // }, Instruction::FILL { src: File::GrfB { index: 0 }, dst: File::Bank, }, + Instruction::FILL { + src: File::GrfB { index: 1 }, + dst: File::Bank, + }, + Instruction::FILL { + src: File::GrfB { index: 2 }, + dst: File::Bank, + }, + Instruction::FILL { + src: File::GrfB { index: 3 }, + dst: File::Bank, + }, + Instruction::FILL { + src: File::GrfB { index: 4 }, + dst: File::Bank, + }, + Instruction::FILL { + src: File::GrfB { index: 5 }, + dst: File::Bank, + }, + Instruction::FILL { + src: File::GrfB { index: 6 }, + dst: File::Bank, + }, + Instruction::FILL { + src: File::GrfB { index: 7 }, + dst: File::Bank, + }, Instruction::EXIT, Instruction::NOP, Instruction::NOP, @@ -158,20 +192,9 @@ pub const MATRIX_MUL: Kernel = Kernel([ Instruction::NOP, Instruction::NOP, Instruction::NOP, - Instruction::NOP, - Instruction::NOP, - Instruction::NOP, - Instruction::NOP, - Instruction::NOP, - Instruction::NOP, - Instruction::NOP, - Instruction::NOP, - Instruction::NOP, - Instruction::NOP, - Instruction::NOP, ]); -pub fn execute_matrix_multiply( +pub fn execute_matrix_multiply_elementwise( pim_state: &mut PimState, pim_matrix_arena0: &mut PimMatrixArena, pim_matrix_arena1: &mut PimMatrixArena, @@ -181,15 +204,15 @@ pub fn execute_matrix_multiply( pim_state.set_bank_mode(BankMode::PimAllBank); for i in 0..(R * C) { - let left_index = i % R; - let right_index = (i / R) * R; + let start_column = i % R; + let start_row = (i / R) * R; - for k in 0..R { - pim_matrix_arena0.execute_instruction_read_single_bank(left_index + R * k); + for j in 0..C { + pim_matrix_arena0.execute_instruction_read_single_bank(start_column + R * j); } - for k in 0..C { - pim_matrix_arena1.execute_instruction_read_single_bank(right_index + k); + for j in 0..R { + pim_matrix_arena1.execute_instruction_read_single_bank(start_row + j); } pim_matrix_arena2.execute_instruction_write_single_bank(i); @@ -199,3 +222,33 @@ pub fn execute_matrix_multiply( pim_state.set_bank_mode(BankMode::SingleBank); } + +pub fn execute_matrix_multiply_rowwise( + pim_state: &mut PimState, + pim_matrix_arena0: &mut PimMatrixArena, + pim_matrix_arena1: &mut PimMatrixArena, + pim_matrix_arena2: &mut PimMatrixArena, + dummy_array: &mut DummyArray, +) { + pim_state.set_bank_mode(BankMode::PimAllBank); + + for row in 0..R { + for i in 0..C { + pim_matrix_arena0.execute_instruction_read_single_bank(row + R * i); + } + + for column in 0..C { + for i in 0..R { + pim_matrix_arena1.execute_instruction_read_single_bank(column * R + i); + } + } + + for column in 0..C { + pim_matrix_arena2.execute_instruction_write_single_bank(column * R + row); + } + + dummy_array.execute_instruction_read_single_bank(0); + } + + pim_state.set_bank_mode(BankMode::SingleBank); +} diff --git a/pim-vm/src/lib.rs b/pim-vm/src/lib.rs index ad46baf..7f4c226 100644 --- a/pim-vm/src/lib.rs +++ b/pim-vm/src/lib.rs @@ -133,16 +133,16 @@ impl PimVM { let inst = self.kernel.0[pim_unit.pc as usize]; + let aam_grf_a_index = (address >> GRF_A_BIT_OFFSET) & 0b111; + let aam_grf_b_index = (address >> GRF_B_BIT_OFFSET) & 0b111; + if pim_unit_index == 0 { log::debug!( - "PimUnit {pim_unit_index} Execute PC {}: {inst:?}", + "PimUnit {pim_unit_index} at {address:#x} (B{aam_grf_b_index}, A{aam_grf_a_index}) Execute PC {}: {inst:?}", pim_unit.pc ); } - let aam_grf_a_index = (address >> GRF_A_BIT_OFFSET) & 0b111; - let aam_grf_b_index = (address >> GRF_B_BIT_OFFSET) & 0b111; - match inst { Instruction::NOP => (), Instruction::EXIT => { @@ -244,9 +244,9 @@ impl PimVM { } => { if aam { src1 = if let File::GrfA { index: _ } = src1 { - if pim_unit_index == 0 { - log::debug!("AAM index GrfA {aam_grf_a_index}"); - } + // if pim_unit_index == 0 { + // log::debug!("AAM index GrfA {aam_grf_a_index}"); + // } File::GrfA { index: aam_grf_a_index as _, } @@ -255,9 +255,9 @@ impl PimVM { }; src2 = if let File::GrfB { index: _ } = src2 { - if pim_unit_index == 0 { - log::debug!("AAM index GrfB {aam_grf_a_index}"); - } + // if pim_unit_index == 0 { + // log::debug!("AAM index GrfB {aam_grf_a_index}"); + // } File::GrfB { index: aam_grf_b_index as _, } @@ -296,16 +296,16 @@ impl PimVM { .try_into() .unwrap(); - if pim_unit_index == 0 { - log::debug!( - "\n{:?}\n{:?}\n{:?}\n{:?}\n{:?}", - data0[0], - data1[0], - data2[0], - product[0], - sum[0] - ); - } + // if pim_unit_index == 0 { + // log::debug!( + // "\n{:?}\n{:?}\n{:?}\n{:?}\n{:?}", + // data0[0], + // data1[0], + // data2[0], + // product[0], + // sum[0] + // ); + // } PimVM::store(dst, pim_unit, &sum); } } @@ -370,9 +370,9 @@ impl PimVM { panic!("Unsupported dst operand: {dst:?}") } - if pim_unit_index == 0 { - log::debug!("Store {data:?}"); - } + // if pim_unit_index == 0 { + // log::debug!("Store {data:?}"); + // } data }