From 8a2c675d7194416d222eb7a44baddca1d5cb6420 Mon Sep 17 00:00:00 2001 From: Derek Christ Date: Tue, 19 Dec 2023 15:37:32 +0100 Subject: [PATCH] Crude implementation of matrices using nalgebra --- pim-isa/src/lib.rs | 10 ---- pim-os/Cargo.toml | 7 ++- pim-os/src/main.rs | 108 +++++++++++++++++++++------------- pim-os/src/pim.rs | 1 + pim-os/src/pim/array.rs | 122 ++++++++++++++++++++++++++++++--------- pim-os/src/pim/kernel.rs | 6 +- pim-os/src/pim/matrix.rs | 117 +++++++++++++++++++++++++++++++++++++ 7 files changed, 287 insertions(+), 84 deletions(-) create mode 100644 pim-os/src/pim/matrix.rs diff --git a/pim-isa/src/lib.rs b/pim-isa/src/lib.rs index 9aacf2c..5daac81 100644 --- a/pim-isa/src/lib.rs +++ b/pim-isa/src/lib.rs @@ -46,16 +46,6 @@ pub enum Instruction { }, } -impl Instruction { - pub fn supported_source(&self, src: File) -> bool { - todo!() - } - - pub fn supported_destination(&self, src: File) -> bool { - todo!() - } -} - #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum File { GrfA { index: u8 }, diff --git a/pim-os/Cargo.toml b/pim-os/Cargo.toml index c3de38a..cdd060a 100644 --- a/pim-os/Cargo.toml +++ b/pim-os/Cargo.toml @@ -12,10 +12,11 @@ cacheless = [] [dependencies] aarch64-cpu = "9.4.0" half = { version = "2.3.1", default-features = false } -serde = { version = "1.0", default-features = false, features = ["derive"] } -serde-json-core = "0.5.1" - +nalgebra = { version = "0.32.3", default-features = false } pim-isa = { path = "../pim-isa", default-features = false } +serde-json-core = "0.5.1" +serde = { version = "1.0", default-features = false, features = ["derive"] } +num-traits = { version = "0.2.17", default-features = false } [profile.dev] panic = "abort" diff --git a/pim-os/src/main.rs b/pim-os/src/main.rs index 931a515..0ba0e4f 100644 --- a/pim-os/src/main.rs +++ b/pim-os/src/main.rs @@ -1,20 +1,27 @@ +#![feature(generic_const_exprs)] #![no_std] #![no_main] use aarch64_cpu::asm::barrier; use core::{ + cell::RefCell, fmt::Write, panic::PanicInfo, sync::atomic::{compiler_fence, Ordering}, }; use half::f16; +use nalgebra::{Const, Matrix, Matrix2, SMatrixView}; use pim::{ - array::{BankArray, ComputeArray}, + array::{PimMatrixArena, PimStorage}, + // array::PimMatrix, + // array::{BankArray, ComputeArray}, kernel::TEST_KERNEL, + matrix::{F16x1, F16x16}, state::PimState, }; use pim_isa::BankMode; use uart::Uart0; + mod boot; mod m5ops; mod pim; @@ -22,53 +29,74 @@ mod uart; #[no_mangle] pub extern "C" fn entry() -> ! { - let mut pim_state = PimState::new(&TEST_KERNEL); - let mut compute_array: ComputeArray<3> = ComputeArray([ - BankArray([f16::from_f32(0.1); 512]), - BankArray([f16::from_f32(0.2); 512]), - BankArray([f16::from_f32(0.3); 512]), - ]); - let dummy_array = BankArray::default(); let mut uart = Uart0; + let mut pim_state = PimState::new(&TEST_KERNEL); - writeln!( - &mut uart, - "PIM array is at {:x?}", - core::ptr::addr_of!(compute_array) - ) - .unwrap(); + let mut arena = RefCell::new(PimMatrixArena([[F16x16::default(); 8]; 8])); + let pim_storage0 = PimStorage { + arena: &arena, + index: 0, + }; + let pim_storage1 = PimStorage { + arena: &arena, + index: 1, + }; + let pim_storage2 = PimStorage { + arena: &arena, + index: 2, + }; + let mut matrix0 = Matrix::from_data(pim_storage0); + let mut matrix1 = Matrix::from_data(pim_storage1); + matrix0.fill_lower_triangle(F16x1(f16::ONE), 0); + matrix1.fill_upper_triangle(F16x1(f16::from_f32(2.0)), 0); + writeln!(&mut uart, "{}", matrix0).unwrap(); + writeln!(&mut uart, "{}", matrix1).unwrap(); - writeln!( - &mut uart, - "BankArray0: [{:?}, ...]\nBankArray1: [{:?}, ...]\nBankArray2: [{:?}, ...]", - compute_array.0[0].0[0], compute_array.0[1].0[0], compute_array.0[2].0[0] - ) - .unwrap(); + // let mut compute_array: ComputeArray<3> = ComputeArray([ + // BankArray([F16x16([f16::from_f32(0.1); 16]); 32]), + // BankArray([f16::from_f32(0.2); 512]), + // BankArray([f16::from_f32(0.3); 512]), + // ]); + // let dummy_array = BankArray::default(); - writeln!(&mut uart, "MAC: BankArray2 += BankArray0 * BankArray1",).unwrap(); + // writeln!( + // &mut uart, + // "PIM array is at {:x?}", + // core::ptr::addr_of!(compute_array) + // ) + // .unwrap(); - // Invalidate and flush array just in case - compute_array.invalidate_flush(); - dummy_array.invalidate_flush(); - barrier::dsb(barrier::SY); + // writeln!( + // &mut uart, + // "BankArray0: [{:?}, ...]\nBankArray1: [{:?}, ...]\nBankArray2: [{:?}, ...]", + // compute_array.0[0].0[0], compute_array.0[1].0[0], compute_array.0[2].0[0] + // ) + // .unwrap(); - pim_state.set_bank_mode(BankMode::PimAllBank); - compute_array.0[1].execute_instruction_read_dual_bank(); - compute_array.0[2].execute_instruction_read_dual_bank(); - compute_array.0[0].execute_instruction_read_dual_bank(); - compute_array.0[2].execute_instruction_write_dual_bank(); - dummy_array.execute_instruction_read_single_bank(); - pim_state.set_bank_mode(BankMode::SingleBank); + // writeln!(&mut uart, "MAC: BankArray2 += BankArray0 * BankArray1",).unwrap(); - compute_array.invalidate(); - barrier::dsb(barrier::SY); + // // Invalidate and flush array just in case + // compute_array.invalidate_flush(); + // dummy_array.invalidate_flush(); + // barrier::dsb(barrier::SY); - writeln!( - &mut uart, - "BankArray2: [{:?}, ...]", - compute_array.0[2].0[0] - ) - .unwrap(); + // pim_state.set_bank_mode(BankMode::PimAllBank); + // compute_array.0[1].execute_instruction_read_dual_bank(); + // compute_array.0[2].execute_instruction_read_dual_bank(); + // compute_array.0[0].execute_instruction_read_dual_bank(); + // compute_array.0[2].execute_instruction_write_dual_bank(); + // dummy_array.execute_instruction_read_single_bank(); + // pim_state.set_bank_mode(BankMode::SingleBank); + + // compute_array.invalidate(); + // barrier::dsb(barrier::SY); + + // writeln!( + // &mut uart, + // "BankArray2: [{:?}, ...]", + // compute_array.0[2].0[0] + // ) + // .unwrap(); // writeln!(&mut uart, "ComputeArray:\n{:#?}", compute_array.0[2]).unwrap(); diff --git a/pim-os/src/pim.rs b/pim-os/src/pim.rs index a030fb8..5d5d47a 100644 --- a/pim-os/src/pim.rs +++ b/pim-os/src/pim.rs @@ -1,4 +1,5 @@ pub mod array; pub mod config; pub mod kernel; +pub mod matrix; pub mod state; diff --git a/pim-os/src/pim/array.rs b/pim-os/src/pim/array.rs index d517032..afd3052 100644 --- a/pim-os/src/pim/array.rs +++ b/pim-os/src/pim/array.rs @@ -1,25 +1,91 @@ +use super::matrix::{F16x1, F16x16}; use aarch64_cpu::asm::barrier; -use core::arch::asm; +use core::panic; +use core::{arch::asm, cell::RefCell}; use half::f16; +use nalgebra::{Const, Dyn, RawStorage, RawStorageMut, SMatrix, Storage}; -const NUMBER_OF_BANKS: usize = 32; -const ELEMENTS_PER_CACHE_LINE: usize = 16; -const ELEMENTS_PER_BANK_ARRAY: usize = NUMBER_OF_BANKS * ELEMENTS_PER_CACHE_LINE; +// const NUMBER_OF_BANKS: usize = 32; const EVEN_BANK_INDEX: usize = 0; const ODD_BANK_INDEX: usize = 8; #[derive(Clone, Debug)] #[repr(C, align(1024))] -pub struct BankArray(pub [f16; ELEMENTS_PER_BANK_ARRAY]); +pub struct PimMatrixArena(pub [[F16x16; R]; C]); -impl Default for BankArray { - fn default() -> Self { - Self([f16::ZERO; ELEMENTS_PER_BANK_ARRAY]) +#[derive(Debug)] +pub struct PimStorage<'a, const R: usize, const C: usize> { + pub arena: &'a RefCell>, + pub index: usize, +} + +unsafe impl<'a, const R: usize, const C: usize> RawStorage, Const> + for PimStorage<'a, R, C> +{ + type RStride = Dyn; + type CStride = Dyn; + + fn ptr(&self) -> *const F16x1 { + unsafe { + (&self.arena.borrow().0[0][0] as *const F16x16 as *const F16x1).offset(self.index as _) + } + } + + fn shape(&self) -> (Const, Const) { + (Const::, Const::) + } + + fn strides(&self) -> (Self::RStride, Self::CStride) { + (Dyn(16), Dyn(16 * R)) + } + + fn is_contiguous(&self) -> bool { + false + } + + unsafe fn as_slice_unchecked(&self) -> &[F16x1] { + panic!("PimStorage is not contiguous!"); } } -impl BankArray { - pub fn execute_instruction_read_single_bank(&self) { +unsafe impl<'a, const R: usize, const C: usize> RawStorageMut, Const> + for PimStorage<'a, R, C> +{ + fn ptr_mut(&mut self) -> *mut F16x1 { + unsafe { + (&mut self.arena.borrow_mut().0[0][0] as *mut F16x16 as *mut F16x1) + .offset(self.index as _) + } + } + + unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [F16x1] { + panic!("PimStorage is not contiguous!"); + } +} + +// #[repr(C, align(1024))] +// #[derive(Clone, Debug, Default)] +// pub struct PimMatrix(pub SMatrix); + +// impl PimRegion for PimMatrix { +// const NUMBER_OF_BANKS: usize = 64; + +// fn bank_ptr(&self, bank_index: usize) -> *const f16 { +// return &self.0[bank_index].0 as _; +// } + +// fn bank_ptr_mut(&mut self, bank_index: usize) -> *mut f16 { +// return &mut self.0[bank_index].0 as _; +// } +// } + +pub trait PimRegion { + const NUMBER_OF_BANKS: usize; + + fn bank_ptr(&self, bank_index: usize) -> *const f16; + fn bank_ptr_mut(&mut self, bank_index: usize) -> *mut f16; + + fn execute_instruction_read_single_bank(&self) { if !cfg!(feature = "cacheless") { self.invalidate_bank(EVEN_BANK_INDEX); @@ -32,7 +98,7 @@ impl BankArray { barrier::dsb(barrier::SY); } - pub fn execute_instruction_read_dual_bank(&self) { + fn execute_instruction_read_dual_bank(&self) { if !cfg!(feature = "cacheless") { self.invalidate_bank(EVEN_BANK_INDEX); self.invalidate_bank(ODD_BANK_INDEX); @@ -48,13 +114,13 @@ impl BankArray { } fn read_data_bank(&self, bank_index: usize) { - let bank = &self.0[bank_index * ELEMENTS_PER_CACHE_LINE]; + let bank = self.bank_ptr(bank_index); unsafe { core::ptr::read_volatile(bank); } } - pub fn execute_instruction_write_single_bank(&mut self) { + fn execute_instruction_write_single_bank(&mut self) { if !cfg!(feature = "cacheless") { self.preload_zero(); barrier::dsb(barrier::SY); @@ -70,7 +136,7 @@ impl BankArray { barrier::dsb(barrier::SY); } - pub fn execute_instruction_write_dual_bank(&mut self) { + fn execute_instruction_write_dual_bank(&mut self) { if !cfg!(feature = "cacheless") { self.preload_zero(); barrier::dsb(barrier::SY); @@ -89,40 +155,40 @@ impl BankArray { } fn write_data_bank(&mut self, bank_index: usize) { - let bank = &mut self.0[bank_index * ELEMENTS_PER_CACHE_LINE]; + let bank = self.bank_ptr_mut(bank_index); unsafe { - core::ptr::write_volatile(bank, f16::ZERO); + core::ptr::write_volatile(bank, Default::default()); } } - pub fn invalidate(&self) { - (0..NUMBER_OF_BANKS).for_each(|idx| self.invalidate_bank(idx)); + fn invalidate(&self) { + (0..Self::NUMBER_OF_BANKS).for_each(|idx| self.invalidate_bank(idx)); } fn invalidate_bank(&self, bank_index: usize) { - let bank = &self.0[bank_index * ELEMENTS_PER_CACHE_LINE]; + let bank = self.bank_ptr(bank_index); unsafe { asm!("dc ivac, {val}", val = in(reg) bank); } } - pub fn invalidate_flush(&self) { - (0..NUMBER_OF_BANKS).for_each(|idx| self.invalidate_flush_bank(idx)); + fn invalidate_flush(&self) { + (0..Self::NUMBER_OF_BANKS).for_each(|idx| self.invalidate_flush_bank(idx)); } fn invalidate_flush_bank(&self, bank_index: usize) { - let bank = &self.0[bank_index * ELEMENTS_PER_CACHE_LINE]; + let bank = self.bank_ptr(bank_index); unsafe { asm!("dc civac, {val}", val = in(reg) bank); } } - pub fn preload_zero(&self) { - (0..NUMBER_OF_BANKS).for_each(|idx| self.preload_zero_bank(idx)); + fn preload_zero(&self) { + (0..Self::NUMBER_OF_BANKS).for_each(|idx| self.preload_zero_bank(idx)); } fn preload_zero_bank(&self, bank_index: usize) { - let bank = &self.0[bank_index * ELEMENTS_PER_CACHE_LINE]; + let bank = self.bank_ptr(bank_index); unsafe { // Preload first bank asm!("dc zva, {val}", val = in(reg) bank); @@ -132,9 +198,9 @@ impl BankArray { #[derive(Clone, Debug)] #[repr(C, align(65536))] -pub struct ComputeArray(pub [BankArray; N]); +pub struct ComputeArray(pub [T; N]); -impl ComputeArray { +impl ComputeArray { pub fn invalidate_flush(&self) { self.0 .iter() @@ -146,7 +212,7 @@ impl ComputeArray { } } -impl Default for ComputeArray { +impl Default for ComputeArray { fn default() -> Self { Self(core::array::from_fn(|_| Default::default())) } diff --git a/pim-os/src/pim/kernel.rs b/pim-os/src/pim/kernel.rs index 277550e..e11b0e9 100644 --- a/pim-os/src/pim/kernel.rs +++ b/pim-os/src/pim/kernel.rs @@ -20,13 +20,13 @@ pub const TEST_KERNEL: Kernel = Kernel([ Instruction::MAC { src0: File::Bank, src1: File::GrfA { index: 0 }, - src2: File::GrfA { index: 1 }, - dst: File::GrfA { index: 1 }, + src2: File::GrfB { index: 0 }, + dst: File::GrfB { index: 0 }, aam: false, }, Instruction::MAC { src0: File::Bank, - src1: File::GrfB { index: 0 }, + src1: File::GrfA { index: 1 }, src2: File::GrfB { index: 1 }, dst: File::GrfB { index: 1 }, aam: false, diff --git a/pim-os/src/pim/matrix.rs b/pim-os/src/pim/matrix.rs new file mode 100644 index 0000000..36db7b5 --- /dev/null +++ b/pim-os/src/pim/matrix.rs @@ -0,0 +1,117 @@ +use half::f16; + +const FLOATING_POINT_UNITS: usize = 16; + +#[repr(C)] +#[derive(Default, Clone, Copy, PartialEq)] +pub struct F16x1(pub f16); + +impl core::fmt::Debug for F16x1 { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + Ok(self.0.fmt(f)?) + } +} + +impl core::fmt::Display for F16x1 { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + Ok(self.0.fmt(f)?) + } +} + +impl num_traits::identities::Zero for F16x1 { + fn zero() -> Self { + Self(f16::ZERO) + } + + fn is_zero(&self) -> bool { + self.0 == f16::ZERO + } +} + +impl num_traits::identities::One for F16x1 { + fn one() -> Self { + Self(f16::ONE) + } +} + +impl core::ops::Add for F16x1 { + type Output = Self; + + fn add(self, rhs: F16x1) -> Self::Output { + Self(self.0 + rhs.0) + } +} + +impl core::ops::AddAssign for F16x1 { + fn add_assign(&mut self, rhs: F16x1) { + self.0 += rhs.0; + } +} + +impl core::ops::Mul for F16x1 { + type Output = Self; + + fn mul(self, rhs: F16x1) -> Self::Output { + Self(self.0 * rhs.0) + } +} + +impl core::ops::MulAssign for F16x1 { + fn mul_assign(&mut self, rhs: F16x1) { + self.0 *= rhs.0; + } +} + +#[repr(C)] +#[derive(Default, Clone, Copy, Debug, PartialEq)] +pub struct F16x16(pub [F16x1; FLOATING_POINT_UNITS]); + +impl num_traits::identities::Zero for F16x16 { + fn zero() -> Self { + Self([F16x1::zero(); FLOATING_POINT_UNITS]) + } + + fn is_zero(&self) -> bool { + self.0 == [F16x1::zero(); FLOATING_POINT_UNITS] + } +} + +impl num_traits::identities::One for F16x16 { + fn one() -> Self { + Self([F16x1::one(); FLOATING_POINT_UNITS]) + } +} + +impl core::ops::Add for F16x16 { + type Output = Self; + + fn add(self, rhs: F16x16) -> Self::Output { + Self(core::array::from_fn(|i| self.0[i] + rhs.0[i])) + } +} + +impl core::ops::AddAssign for F16x16 { + fn add_assign(&mut self, rhs: F16x16) { + self.0 + .iter_mut() + .zip(&rhs.0) + .for_each(|(left, right)| *left += *right); + } +} + +impl core::ops::Mul for F16x16 { + type Output = Self; + + fn mul(self, rhs: F16x16) -> Self::Output { + Self(core::array::from_fn(|i| self.0[i] * rhs.0[i])) + } +} + +impl core::ops::MulAssign for F16x16 { + fn mul_assign(&mut self, rhs: F16x16) { + self.0 + .iter_mut() + .zip(&rhs.0) + .for_each(|(left, right)| *left *= *right); + } +}