use half::f16; use pim_isa::{BankMode, File, Instruction, Kernel, PimConfig}; #[cxx::bridge(namespace = "pim_vm")] mod ffi { pub enum BankMode { SingleBank, AllBank, PimAllBank, } extern "Rust" { type PimVM; fn new_pim_vm(num_pim_units: u32) -> Box; fn reset(&mut self); fn apply_config(&mut self, config: &str); fn bank_mode(&self) -> BankMode; fn execute_read(&mut self, bank_index: u32, address: u32, bank_data: &[u8]); fn execute_write(&mut self, bank_index: u32) -> [u8; 32]; fn init_logger(); } } fn init_logger() { env_logger::init(); } const GRF_A_BIT_OFFSET: usize = 10; const GRF_B_BIT_OFFSET: usize = 13; const BURST_LENGTH: usize = 32; const GRF_NUM_REGISTERS: usize = 8; const SRF_A_NUM_REGISTERS: usize = 8; const SRF_M_NUM_REGISTERS: usize = 8; const FP_UNITS: usize = 16; type GrfRegister = [f16; FP_UNITS]; #[derive(Clone, Debug)] struct PimUnit { grf_a: [GrfRegister; GRF_NUM_REGISTERS], grf_b: [GrfRegister; GRF_NUM_REGISTERS], srf_a: [f16; SRF_A_NUM_REGISTERS], srf_m: [f16; SRF_A_NUM_REGISTERS], pc: u8, jump_counter: Option, } impl Default for PimUnit { fn default() -> Self { Self { grf_a: [[f16::ZERO; FP_UNITS]; GRF_NUM_REGISTERS], grf_b: [[f16::ZERO; FP_UNITS]; GRF_NUM_REGISTERS], srf_a: [f16::ZERO; SRF_A_NUM_REGISTERS], srf_m: [f16::ZERO; SRF_M_NUM_REGISTERS], pc: 0, jump_counter: None, } } } #[derive(Debug)] struct PimVM { pim_units: Vec, pim_config: pim_isa::PimConfig, } impl PimVM { fn reset(&mut self) { for unit in self.pim_units.iter_mut() { unit.pc = 0; unit.jump_counter = None; } } fn apply_config(&mut self, config_str: &str) { log::debug!("Config string:\n{config_str}"); self.pim_config = serde_json::from_str::(config_str).unwrap(); self.reset(); log::debug!("Apply pim config:\n{:?}", self.pim_config); } fn bank_mode(&self) -> ffi::BankMode { match self.pim_config.bank_mode { BankMode::SingleBank => ffi::BankMode::SingleBank, BankMode::AllBank => ffi::BankMode::AllBank, BankMode::PimAllBank => ffi::BankMode::PimAllBank, } } } fn new_pim_vm(num_pim_units: u32) -> Box { Box::new(PimVM { pim_units: vec![PimUnit::default(); num_pim_units as _], pim_config: PimConfig { bank_mode: BankMode::SingleBank, kernel: Kernel::NOP, }, }) } #[repr(C)] struct BankData([f16; FP_UNITS]); impl PimVM { pub fn execute_read(&mut self, bank_index: u32, address: u32, bank_data: &[u8]) { assert_eq!(bank_data.len(), BURST_LENGTH); let pim_unit = &mut self.pim_units[bank_index as usize]; let mut inst = self.pim_config.kernel.0[pim_unit.pc as usize]; pim_unit.pc += 1; let aam_grf_a_index = (address >> GRF_A_BIT_OFFSET) & 0b111; let aam_grf_b_index = (address >> GRF_B_BIT_OFFSET) & 0b111; log::debug!("PimUnit {bank_index} Execute PC {}: {inst:?}", pim_unit.pc); // The JUMP instruction is zero-cycle and not actually executed while let Instruction::JUMP { offset, count } = inst { pim_unit.jump_counter = match pim_unit.jump_counter { Some(jump_counter) => jump_counter.checked_sub(1), None => count.checked_sub(1), }; if pim_unit.jump_counter != None { let new_pc = pim_unit.pc as i32 + offset as i32; if new_pc < 0 || new_pc >= 32 { panic!("Invalid PC {new_pc} after JUMP: {inst:?}"); } pim_unit.pc = new_pc as _; log::debug!("PimUnit {bank_index} New PC {new_pc}: {inst:?}"); } inst = self.pim_config.kernel.0[pim_unit.pc as usize]; pim_unit.pc += 1; } match inst { Instruction::NOP => (), Instruction::EXIT => { pim_unit.jump_counter = None; pim_unit.pc = 0; } Instruction::JUMP { .. } => unreachable!(), Instruction::MOV { src, dst } | Instruction::FILL { src, dst } => { let data = PimVM::load(src, pim_unit, &bank_data); PimVM::store(dst, pim_unit, &data); } Instruction::ADD { src0, mut src1, mut dst, aam, } => { if aam { src1 = if let File::GrfA { index: _ } = src1 { File::GrfA { index: aam_grf_a_index as _, } } else { panic!("Invalid operand in address-aligned-mode"); }; dst = if let File::GrfB { index: _ } = dst { File::GrfB { index: aam_grf_b_index as _, } } else { panic!("Invalid operand in address-aligned-mode"); }; } let data0 = PimVM::load(src0, pim_unit, &bank_data); let data1 = PimVM::load(src1, pim_unit, &bank_data); let sum: [f16; FP_UNITS] = data0 .into_iter() .zip(data1) .map(|(src0, src1)| src0 + src1) .collect::>() .try_into() .unwrap(); PimVM::store(dst, pim_unit, &sum); } Instruction::MUL { src0, mut src1, mut dst, aam, } => { if aam { src1 = if let File::GrfA { index: _ } = src1 { File::GrfA { index: aam_grf_a_index as _, } } else { panic!("Invalid operand in address-aligned-mode"); }; dst = if let File::GrfB { index: _ } = dst { File::GrfB { index: aam_grf_b_index as _, } } else { panic!("Invalid operand in address-aligned-mode"); }; } let data0 = PimVM::load(src0, pim_unit, &bank_data); let data1 = PimVM::load(src1, pim_unit, &bank_data); let product: [f16; FP_UNITS] = data0 .into_iter() .zip(data1) .map(|(src0, src1)| src0 * src1) .collect::>() .try_into() .unwrap(); PimVM::store(dst, pim_unit, &product); } Instruction::MAC { src0, mut src1, mut src2, mut dst, aam, } | Instruction::MAD { src0, mut src1, mut src2, mut dst, aam, } => { if aam { src1 = if let File::GrfA { index: _ } = src1 { File::GrfA { index: aam_grf_a_index as _, } } else { panic!("Invalid operand in address-aligned-mode"); }; src2 = if let File::GrfB { index: _ } = src2 { File::GrfB { index: aam_grf_b_index as _, } } else { panic!("Invalid operand in address-aligned-mode"); }; dst = if let File::GrfB { index: _ } = dst { File::GrfB { index: aam_grf_b_index as _, } } else { panic!("Invalid operand in address-aligned-mode"); }; } assert_eq!(src2, dst); let data0 = PimVM::load(src0, pim_unit, &bank_data); let data1 = PimVM::load(src1, pim_unit, &bank_data); let data2 = PimVM::load(src2, pim_unit, &bank_data); let product: [f16; FP_UNITS] = data0 .into_iter() .zip(data1) .map(|(src0, src1)| src0 * src1) .collect::>() .try_into() .unwrap(); let sum: [f16; FP_UNITS] = product .into_iter() .zip(data2) .map(|(product, src2)| product + src2) .collect::>() .try_into() .unwrap(); log::debug!("{data0:?}\n{data1:?}\n{data2:?}\n{product:?}\n{sum:?}"); PimVM::store(dst, pim_unit, &sum); } } } pub fn execute_write(&mut self, bank_index: u32) -> [u8; BURST_LENGTH] { let pim_unit = &mut self.pim_units[bank_index as usize]; let current_pc = pim_unit.pc; pim_unit.pc += 1; let inst = &self.pim_config.kernel.0[current_pc as usize]; log::debug!("PimUnit {bank_index} Execute PC {current_pc}: {inst:?}"); let data = match inst { Instruction::FILL { src, dst } => { let data: [f16; FP_UNITS] = match src { File::GrfA { index } => pim_unit.grf_a[*index as usize], File::GrfB { index } => pim_unit.grf_b[*index as usize], _ => panic!("Unsupported src operand: {src:?}"), }; if *dst != File::Bank { panic!("Unsupported dst operand: {dst:?}") } data } _ => panic!("Unsupported instruction for write: {inst:?}"), }; unsafe { std::mem::transmute(data) } } fn load(src: File, pim_unit: &PimUnit, bank_data: &[u8]) -> [f16; FP_UNITS] { match src { File::GrfA { index } => pim_unit.grf_a[index as usize], File::GrfB { index } => pim_unit.grf_b[index as usize], File::SrfM { index } => [pim_unit.srf_m[index as usize]; FP_UNITS], File::SrfA { index } => [pim_unit.srf_a[index as usize]; FP_UNITS], File::Bank => unsafe { std::ptr::read(bank_data.as_ptr() as *const BankData).0 }, } } fn store(dst: File, pim_unit: &mut PimUnit, data: &[f16; FP_UNITS]) { match dst { File::GrfA { index } => pim_unit.grf_a[index as usize] = data.clone(), File::GrfB { index } => pim_unit.grf_b[index as usize] = data.clone(), File::SrfM { index } => pim_unit.srf_m[index as usize] = data[0], File::SrfA { index } => pim_unit.srf_a[index as usize] = data[0], File::Bank => panic!("Unsupported dst operand: {dst:?}"), } } }