use super::matrix::{F16x1, F16x16}; use aarch64_cpu::asm::barrier; use core::panic; use core::{arch::asm, cell::RefCell}; use half::f16; use nalgebra::{Const, Dyn, RawStorage, RawStorageMut, SMatrix, Storage}; // const NUMBER_OF_BANKS: usize = 32; const EVEN_BANK_INDEX: usize = 0; const ODD_BANK_INDEX: usize = 8; #[derive(Clone, Debug)] #[repr(C, align(1024))] pub struct PimMatrixArena(pub [[F16x16; R]; C]); #[derive(Debug)] pub struct PimStorage<'a, const R: usize, const C: usize> { pub arena: &'a RefCell>, pub index: usize, } unsafe impl<'a, const R: usize, const C: usize> RawStorage, Const> for PimStorage<'a, R, C> { type RStride = Dyn; type CStride = Dyn; fn ptr(&self) -> *const F16x1 { unsafe { (&self.arena.borrow().0[0][0] as *const F16x16 as *const F16x1).offset(self.index as _) } } fn shape(&self) -> (Const, Const) { (Const::, Const::) } fn strides(&self) -> (Self::RStride, Self::CStride) { (Dyn(16), Dyn(16 * R)) } fn is_contiguous(&self) -> bool { false } unsafe fn as_slice_unchecked(&self) -> &[F16x1] { panic!("PimStorage is not contiguous!"); } } unsafe impl<'a, const R: usize, const C: usize> RawStorageMut, Const> for PimStorage<'a, R, C> { fn ptr_mut(&mut self) -> *mut F16x1 { unsafe { (&mut self.arena.borrow_mut().0[0][0] as *mut F16x16 as *mut F16x1) .offset(self.index as _) } } unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [F16x1] { panic!("PimStorage is not contiguous!"); } } // #[repr(C, align(1024))] // #[derive(Clone, Debug, Default)] // pub struct PimMatrix(pub SMatrix); // impl PimRegion for PimMatrix { // const NUMBER_OF_BANKS: usize = 64; // fn bank_ptr(&self, bank_index: usize) -> *const f16 { // return &self.0[bank_index].0 as _; // } // fn bank_ptr_mut(&mut self, bank_index: usize) -> *mut f16 { // return &mut self.0[bank_index].0 as _; // } // } pub trait PimRegion { const NUMBER_OF_BANKS: usize; fn bank_ptr(&self, bank_index: usize) -> *const f16; fn bank_ptr_mut(&mut self, bank_index: usize) -> *mut f16; fn execute_instruction_read_single_bank(&self) { if !cfg!(feature = "cacheless") { self.invalidate_bank(EVEN_BANK_INDEX); barrier::dsb(barrier::SY); } // Read from first bank self.read_data_bank(EVEN_BANK_INDEX); barrier::dsb(barrier::SY); } fn execute_instruction_read_dual_bank(&self) { if !cfg!(feature = "cacheless") { self.invalidate_bank(EVEN_BANK_INDEX); self.invalidate_bank(ODD_BANK_INDEX); barrier::dsb(barrier::SY); } // Read from first and second bank self.read_data_bank(EVEN_BANK_INDEX); self.read_data_bank(ODD_BANK_INDEX); barrier::dsb(barrier::SY); } fn read_data_bank(&self, bank_index: usize) { let bank = self.bank_ptr(bank_index); unsafe { core::ptr::read_volatile(bank); } } fn execute_instruction_write_single_bank(&mut self) { if !cfg!(feature = "cacheless") { self.preload_zero(); barrier::dsb(barrier::SY); } // Write to first bank self.write_data_bank(EVEN_BANK_INDEX); if !cfg!(feature = "cacheless") { self.invalidate_flush_bank(EVEN_BANK_INDEX); } barrier::dsb(barrier::SY); } fn execute_instruction_write_dual_bank(&mut self) { if !cfg!(feature = "cacheless") { self.preload_zero(); barrier::dsb(barrier::SY); } // Write to first and second bank self.write_data_bank(EVEN_BANK_INDEX); self.write_data_bank(ODD_BANK_INDEX); if !cfg!(feature = "cacheless") { self.invalidate_flush_bank(EVEN_BANK_INDEX); self.invalidate_flush_bank(ODD_BANK_INDEX); } barrier::dsb(barrier::SY); } fn write_data_bank(&mut self, bank_index: usize) { let bank = self.bank_ptr_mut(bank_index); unsafe { core::ptr::write_volatile(bank, Default::default()); } } fn invalidate(&self) { (0..Self::NUMBER_OF_BANKS).for_each(|idx| self.invalidate_bank(idx)); } fn invalidate_bank(&self, bank_index: usize) { let bank = self.bank_ptr(bank_index); unsafe { asm!("dc ivac, {val}", val = in(reg) bank); } } fn invalidate_flush(&self) { (0..Self::NUMBER_OF_BANKS).for_each(|idx| self.invalidate_flush_bank(idx)); } fn invalidate_flush_bank(&self, bank_index: usize) { let bank = self.bank_ptr(bank_index); unsafe { asm!("dc civac, {val}", val = in(reg) bank); } } fn preload_zero(&self) { (0..Self::NUMBER_OF_BANKS).for_each(|idx| self.preload_zero_bank(idx)); } fn preload_zero_bank(&self, bank_index: usize) { let bank = self.bank_ptr(bank_index); unsafe { // Preload first bank asm!("dc zva, {val}", val = in(reg) bank); } } } #[derive(Clone, Debug)] #[repr(C, align(65536))] pub struct ComputeArray(pub [T; N]); impl ComputeArray { pub fn invalidate_flush(&self) { self.0 .iter() .for_each(|bank_array| bank_array.invalidate_flush()); } pub fn invalidate(&self) { self.0.iter().for_each(|bank_array| bank_array.invalidate()); } } impl Default for ComputeArray { fn default() -> Self { Self(core::array::from_fn(|_| Default::default())) } }