diff --git a/pim-isa/src/lib.rs b/pim-isa/src/lib.rs index d3a1a69..9aacf2c 100644 --- a/pim-isa/src/lib.rs +++ b/pim-isa/src/lib.rs @@ -22,39 +22,44 @@ pub enum Instruction { src0: File, src1: File, dst: File, + aam: bool, }, MUL { src0: File, src1: File, dst: File, + aam: bool, }, MAC { src0: File, src1: File, src2: File, dst: File, + aam: bool, }, MAD { src0: File, src1: File, src2: File, dst: File, + aam: bool, }, } impl Instruction { pub fn supported_source(&self, src: File) -> bool { - true + todo!() } pub fn supported_destination(&self, src: File) -> bool { - true + todo!() } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum File { - Grf { index: u8 }, + GrfA { index: u8 }, + GrfB { index: u8 }, SrfM { index: u8 }, SrfA { index: u8 }, Bank, diff --git a/pim-os/src/boot.rs b/pim-os/src/boot.rs new file mode 100644 index 0000000..3935089 --- /dev/null +++ b/pim-os/src/boot.rs @@ -0,0 +1,8 @@ +use core::arch::global_asm; + +global_asm!(include_str!("start.s")); + +extern "C" { + pub fn set_page_table_cache(); + pub fn set_page_table_non_cache(); +} diff --git a/pim-os/src/pim/kernel.rs b/pim-os/src/pim/kernel.rs index ed8b413..708c15e 100644 --- a/pim-os/src/pim/kernel.rs +++ b/pim-os/src/pim/kernel.rs @@ -3,20 +3,21 @@ use pim_isa::{File, Instruction, Kernel}; pub const TEST_KERNEL: Kernel = Kernel([ Instruction::MOV { src: File::Bank, - dst: File::Grf { index: 0 }, + dst: File::GrfA { index: 0 }, }, Instruction::MOV { src: File::Bank, - dst: File::Grf { index: 1 }, + dst: File::GrfA { index: 1 }, }, Instruction::MAC { - src0: File::Grf { index: 0 }, - src1: File::Bank, - src2: File::Grf { index: 1 }, - dst: File::Grf { index: 1 }, + src0: File::Bank, + src1: File::GrfA { index: 0 }, + src2: File::GrfA { index: 1 }, + dst: File::GrfA { index: 1 }, + aam: false }, Instruction::FILL { - src: File::Grf { index: 1 }, + src: File::GrfA { index: 1 }, dst: File::Bank, }, Instruction::EXIT, diff --git a/pim-vm/src/lib.rs b/pim-vm/src/lib.rs index 4b3efe5..28bb9d7 100644 --- a/pim-vm/src/lib.rs +++ b/pim-vm/src/lib.rs @@ -16,7 +16,7 @@ mod ffi { fn reset(&mut self); fn apply_config(&mut self, config: &str); fn bank_mode(&self) -> BankMode; - fn execute_read(&mut self, bank_index: u32, data: &[u8]); + fn execute_read(&mut self, bank_index: u32, address: u32, bank_data: &[u8]); fn execute_write(&mut self, bank_index: u32) -> [u8; 32]; fn init_logger(); @@ -27,9 +27,10 @@ fn init_logger() { env_logger::init(); } +const COLUMN_BIT_OFFSET: usize = 10; const BURST_LENGTH: usize = 32; -const GRF_NUM_REGISTERS: usize = 16; +const GRF_NUM_REGISTERS: usize = 8; const SRF_A_NUM_REGISTERS: usize = 8; const SRF_M_NUM_REGISTERS: usize = 8; @@ -38,7 +39,8 @@ type GrfRegister = [f16; FP_UNITS]; #[derive(Clone, Debug)] struct PimUnit { - grf: [GrfRegister; GRF_NUM_REGISTERS], + grf_a: [GrfRegister; GRF_NUM_REGISTERS], + grf_b: [GrfRegister; GRF_NUM_REGISTERS], srf_a: [f16; SRF_A_NUM_REGISTERS], srf_m: [f16; SRF_A_NUM_REGISTERS], pc: u8, @@ -48,7 +50,8 @@ struct PimUnit { impl Default for PimUnit { fn default() -> Self { Self { - grf: [[f16::ZERO; FP_UNITS]; GRF_NUM_REGISTERS], + grf_a: [[f16::ZERO; FP_UNITS]; GRF_NUM_REGISTERS], + grf_b: [[f16::ZERO; FP_UNITS]; GRF_NUM_REGISTERS], srf_a: [f16::ZERO; SRF_A_NUM_REGISTERS], srf_m: [f16::ZERO; SRF_M_NUM_REGISTERS], pc: 0, @@ -103,7 +106,7 @@ fn new_pim_vm(num_pim_units: u32) -> Box { struct BankData([f16; FP_UNITS]); impl PimVM { - pub fn execute_read(&mut self, bank_index: u32, bank_data: &[u8]) { + pub fn execute_read(&mut self, bank_index: u32, address: u32, bank_data: &[u8]) { assert_eq!(bank_data.len(), BURST_LENGTH); let pim_unit = &mut self.pim_units[bank_index as usize]; @@ -111,7 +114,10 @@ impl PimVM { let current_pc = pim_unit.pc; pim_unit.pc += 1; - let inst = &self.pim_config.kernel.0[current_pc as usize]; + let inst = self.pim_config.kernel.0[current_pc as usize]; + + let aam_grf_a_index = (address >> COLUMN_BIT_OFFSET) & 0b111; + let aam_grf_b_index = (address >> COLUMN_BIT_OFFSET + 3) & 0b111; log::debug!("PimUnit {bank_index} Execute PC {current_pc}: {inst:?}"); @@ -128,7 +134,7 @@ impl PimVM { }; if pim_unit.jump_counter != None { - let new_pc = current_pc as i32 + *offset as i32; + let new_pc = current_pc as i32 + offset as i32; if new_pc < 0 || new_pc >= 32 { panic!("Invalid PC {new_pc} after JUMP: {inst:?}"); @@ -139,12 +145,35 @@ impl PimVM { } } Instruction::MOV { src, dst } | Instruction::FILL { src, dst } => { - let data = PimVM::load(*src, pim_unit, &bank_data); - PimVM::store(*dst, pim_unit, &data); + let data = PimVM::load(src, pim_unit, &bank_data); + PimVM::store(dst, pim_unit, &data); } - Instruction::ADD { src0, src1, dst } => { - let data0 = PimVM::load(*src0, pim_unit, &bank_data); - let data1 = PimVM::load(*src1, pim_unit, &bank_data); + Instruction::ADD { + src0, + mut src1, + mut dst, + aam, + } => { + if aam { + src1 = if let File::GrfA { index: _ } = src1 { + File::GrfA { + index: aam_grf_a_index as _, + } + } else { + panic!("Invalid operand in address-aligned-mode"); + }; + + dst = if let File::GrfB { index: _ } = dst { + File::GrfB { + index: aam_grf_b_index as _, + } + } else { + panic!("Invalid operand in address-aligned-mode"); + }; + } + + let data0 = PimVM::load(src0, pim_unit, &bank_data); + let data1 = PimVM::load(src1, pim_unit, &bank_data); let sum: [f16; FP_UNITS] = data0 .into_iter() @@ -154,11 +183,34 @@ impl PimVM { .try_into() .unwrap(); - PimVM::store(*dst, pim_unit, &sum); + PimVM::store(dst, pim_unit, &sum); } - Instruction::MUL { src0, src1, dst } => { - let data0 = PimVM::load(*src0, pim_unit, &bank_data); - let data1 = PimVM::load(*src1, pim_unit, &bank_data); + Instruction::MUL { + src0, + mut src1, + mut dst, + aam, + } => { + if aam { + src1 = if let File::GrfA { index: _ } = src1 { + File::GrfA { + index: aam_grf_a_index as _, + } + } else { + panic!("Invalid operand in address-aligned-mode"); + }; + + dst = if let File::GrfB { index: _ } = dst { + File::GrfB { + index: aam_grf_b_index as _, + } + } else { + panic!("Invalid operand in address-aligned-mode"); + }; + } + + let data0 = PimVM::load(src0, pim_unit, &bank_data); + let data1 = PimVM::load(src1, pim_unit, &bank_data); let product: [f16; FP_UNITS] = data0 .into_iter() @@ -168,23 +220,53 @@ impl PimVM { .try_into() .unwrap(); - PimVM::store(*dst, pim_unit, &product); + PimVM::store(dst, pim_unit, &product); } Instruction::MAC { src0, - src1, - src2, - dst, + mut src1, + mut src2, + mut dst, + aam, } | Instruction::MAD { src0, - src1, - src2, - dst, + mut src1, + mut src2, + mut dst, + aam, } => { - let data0 = PimVM::load(*src0, pim_unit, &bank_data); - let data1 = PimVM::load(*src1, pim_unit, &bank_data); - let data2 = PimVM::load(*src2, pim_unit, &bank_data); + if aam { + src1 = if let File::GrfA { index: _ } = src1 { + File::GrfA { + index: aam_grf_a_index as _, + } + } else { + panic!("Invalid operand in address-aligned-mode"); + }; + + src2 = if let File::GrfB { index: _ } = src2 { + File::GrfB { + index: aam_grf_b_index as _, + } + } else { + panic!("Invalid operand in address-aligned-mode"); + }; + + dst = if let File::GrfB { index: _ } = dst { + File::GrfB { + index: aam_grf_b_index as _, + } + } else { + panic!("Invalid operand in address-aligned-mode"); + }; + } + + assert_eq!(src2, dst); + + let data0 = PimVM::load(src0, pim_unit, &bank_data); + let data1 = PimVM::load(src1, pim_unit, &bank_data); + let data2 = PimVM::load(src2, pim_unit, &bank_data); let product: [f16; FP_UNITS] = data0 .into_iter() @@ -203,7 +285,7 @@ impl PimVM { .unwrap(); log::debug!("{data0:#?}\n{data1:#?}\n{data2:#?}\n{product:#?}\n{sum:#?}"); - PimVM::store(*dst, pim_unit, &sum); + PimVM::store(dst, pim_unit, &sum); } } } @@ -221,7 +303,8 @@ impl PimVM { let data = match inst { Instruction::FILL { src, dst } => { let data: [f16; FP_UNITS] = match src { - File::Grf { index } => pim_unit.grf[*index as usize], + File::GrfA { index } => pim_unit.grf_a[*index as usize], + File::GrfB { index } => pim_unit.grf_b[*index as usize], _ => panic!("Unsupported src operand: {src:?}"), }; @@ -239,7 +322,8 @@ impl PimVM { fn load(src: File, pim_unit: &PimUnit, bank_data: &[u8]) -> [f16; FP_UNITS] { match src { - File::Grf { index } => pim_unit.grf[index as usize], + File::GrfA { index } => pim_unit.grf_a[index as usize], + File::GrfB { index } => pim_unit.grf_b[index as usize], File::SrfM { index } => [pim_unit.srf_m[index as usize]; FP_UNITS], File::SrfA { index } => [pim_unit.srf_a[index as usize]; FP_UNITS], File::Bank => unsafe { std::ptr::read(bank_data.as_ptr() as *const BankData).0 }, @@ -248,7 +332,8 @@ impl PimVM { fn store(dst: File, pim_unit: &mut PimUnit, data: &[f16; FP_UNITS]) { match dst { - File::Grf { index } => pim_unit.grf[index as usize] = data.clone(), + File::GrfA { index } => pim_unit.grf_a[index as usize] = data.clone(), + File::GrfB { index } => pim_unit.grf_b[index as usize] = data.clone(), File::SrfM { index } => pim_unit.srf_m[index as usize] = data[0], File::SrfA { index } => pim_unit.srf_a[index as usize] = data[0], File::Bank => panic!("Unsupported dst operand: {dst:?}"),