Crude implementation of matrices using nalgebra
This commit is contained in:
@@ -1,25 +1,91 @@
|
||||
use super::matrix::{F16x1, F16x16};
|
||||
use aarch64_cpu::asm::barrier;
|
||||
use core::arch::asm;
|
||||
use core::panic;
|
||||
use core::{arch::asm, cell::RefCell};
|
||||
use half::f16;
|
||||
use nalgebra::{Const, Dyn, RawStorage, RawStorageMut, SMatrix, Storage};
|
||||
|
||||
const NUMBER_OF_BANKS: usize = 32;
|
||||
const ELEMENTS_PER_CACHE_LINE: usize = 16;
|
||||
const ELEMENTS_PER_BANK_ARRAY: usize = NUMBER_OF_BANKS * ELEMENTS_PER_CACHE_LINE;
|
||||
// const NUMBER_OF_BANKS: usize = 32;
|
||||
const EVEN_BANK_INDEX: usize = 0;
|
||||
const ODD_BANK_INDEX: usize = 8;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[repr(C, align(1024))]
|
||||
pub struct BankArray(pub [f16; ELEMENTS_PER_BANK_ARRAY]);
|
||||
pub struct PimMatrixArena<const R: usize, const C: usize>(pub [[F16x16; R]; C]);
|
||||
|
||||
impl Default for BankArray {
|
||||
fn default() -> Self {
|
||||
Self([f16::ZERO; ELEMENTS_PER_BANK_ARRAY])
|
||||
#[derive(Debug)]
|
||||
pub struct PimStorage<'a, const R: usize, const C: usize> {
|
||||
pub arena: &'a RefCell<PimMatrixArena<R, C>>,
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
unsafe impl<'a, const R: usize, const C: usize> RawStorage<F16x1, Const<R>, Const<C>>
|
||||
for PimStorage<'a, R, C>
|
||||
{
|
||||
type RStride = Dyn;
|
||||
type CStride = Dyn;
|
||||
|
||||
fn ptr(&self) -> *const F16x1 {
|
||||
unsafe {
|
||||
(&self.arena.borrow().0[0][0] as *const F16x16 as *const F16x1).offset(self.index as _)
|
||||
}
|
||||
}
|
||||
|
||||
fn shape(&self) -> (Const<R>, Const<C>) {
|
||||
(Const::<R>, Const::<C>)
|
||||
}
|
||||
|
||||
fn strides(&self) -> (Self::RStride, Self::CStride) {
|
||||
(Dyn(16), Dyn(16 * R))
|
||||
}
|
||||
|
||||
fn is_contiguous(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
unsafe fn as_slice_unchecked(&self) -> &[F16x1] {
|
||||
panic!("PimStorage is not contiguous!");
|
||||
}
|
||||
}
|
||||
|
||||
impl BankArray {
|
||||
pub fn execute_instruction_read_single_bank(&self) {
|
||||
unsafe impl<'a, const R: usize, const C: usize> RawStorageMut<F16x1, Const<R>, Const<C>>
|
||||
for PimStorage<'a, R, C>
|
||||
{
|
||||
fn ptr_mut(&mut self) -> *mut F16x1 {
|
||||
unsafe {
|
||||
(&mut self.arena.borrow_mut().0[0][0] as *mut F16x16 as *mut F16x1)
|
||||
.offset(self.index as _)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn as_mut_slice_unchecked(&mut self) -> &mut [F16x1] {
|
||||
panic!("PimStorage is not contiguous!");
|
||||
}
|
||||
}
|
||||
|
||||
// #[repr(C, align(1024))]
|
||||
// #[derive(Clone, Debug, Default)]
|
||||
// pub struct PimMatrix(pub SMatrix<f166, 8, 8>);
|
||||
|
||||
// impl PimRegion for PimMatrix {
|
||||
// const NUMBER_OF_BANKS: usize = 64;
|
||||
|
||||
// fn bank_ptr(&self, bank_index: usize) -> *const f16 {
|
||||
// return &self.0[bank_index].0 as _;
|
||||
// }
|
||||
|
||||
// fn bank_ptr_mut(&mut self, bank_index: usize) -> *mut f16 {
|
||||
// return &mut self.0[bank_index].0 as _;
|
||||
// }
|
||||
// }
|
||||
|
||||
pub trait PimRegion {
|
||||
const NUMBER_OF_BANKS: usize;
|
||||
|
||||
fn bank_ptr(&self, bank_index: usize) -> *const f16;
|
||||
fn bank_ptr_mut(&mut self, bank_index: usize) -> *mut f16;
|
||||
|
||||
fn execute_instruction_read_single_bank(&self) {
|
||||
if !cfg!(feature = "cacheless") {
|
||||
self.invalidate_bank(EVEN_BANK_INDEX);
|
||||
|
||||
@@ -32,7 +98,7 @@ impl BankArray {
|
||||
barrier::dsb(barrier::SY);
|
||||
}
|
||||
|
||||
pub fn execute_instruction_read_dual_bank(&self) {
|
||||
fn execute_instruction_read_dual_bank(&self) {
|
||||
if !cfg!(feature = "cacheless") {
|
||||
self.invalidate_bank(EVEN_BANK_INDEX);
|
||||
self.invalidate_bank(ODD_BANK_INDEX);
|
||||
@@ -48,13 +114,13 @@ impl BankArray {
|
||||
}
|
||||
|
||||
fn read_data_bank(&self, bank_index: usize) {
|
||||
let bank = &self.0[bank_index * ELEMENTS_PER_CACHE_LINE];
|
||||
let bank = self.bank_ptr(bank_index);
|
||||
unsafe {
|
||||
core::ptr::read_volatile(bank);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute_instruction_write_single_bank(&mut self) {
|
||||
fn execute_instruction_write_single_bank(&mut self) {
|
||||
if !cfg!(feature = "cacheless") {
|
||||
self.preload_zero();
|
||||
barrier::dsb(barrier::SY);
|
||||
@@ -70,7 +136,7 @@ impl BankArray {
|
||||
barrier::dsb(barrier::SY);
|
||||
}
|
||||
|
||||
pub fn execute_instruction_write_dual_bank(&mut self) {
|
||||
fn execute_instruction_write_dual_bank(&mut self) {
|
||||
if !cfg!(feature = "cacheless") {
|
||||
self.preload_zero();
|
||||
barrier::dsb(barrier::SY);
|
||||
@@ -89,40 +155,40 @@ impl BankArray {
|
||||
}
|
||||
|
||||
fn write_data_bank(&mut self, bank_index: usize) {
|
||||
let bank = &mut self.0[bank_index * ELEMENTS_PER_CACHE_LINE];
|
||||
let bank = self.bank_ptr_mut(bank_index);
|
||||
unsafe {
|
||||
core::ptr::write_volatile(bank, f16::ZERO);
|
||||
core::ptr::write_volatile(bank, Default::default());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn invalidate(&self) {
|
||||
(0..NUMBER_OF_BANKS).for_each(|idx| self.invalidate_bank(idx));
|
||||
fn invalidate(&self) {
|
||||
(0..Self::NUMBER_OF_BANKS).for_each(|idx| self.invalidate_bank(idx));
|
||||
}
|
||||
|
||||
fn invalidate_bank(&self, bank_index: usize) {
|
||||
let bank = &self.0[bank_index * ELEMENTS_PER_CACHE_LINE];
|
||||
let bank = self.bank_ptr(bank_index);
|
||||
unsafe {
|
||||
asm!("dc ivac, {val}", val = in(reg) bank);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn invalidate_flush(&self) {
|
||||
(0..NUMBER_OF_BANKS).for_each(|idx| self.invalidate_flush_bank(idx));
|
||||
fn invalidate_flush(&self) {
|
||||
(0..Self::NUMBER_OF_BANKS).for_each(|idx| self.invalidate_flush_bank(idx));
|
||||
}
|
||||
|
||||
fn invalidate_flush_bank(&self, bank_index: usize) {
|
||||
let bank = &self.0[bank_index * ELEMENTS_PER_CACHE_LINE];
|
||||
let bank = self.bank_ptr(bank_index);
|
||||
unsafe {
|
||||
asm!("dc civac, {val}", val = in(reg) bank);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preload_zero(&self) {
|
||||
(0..NUMBER_OF_BANKS).for_each(|idx| self.preload_zero_bank(idx));
|
||||
fn preload_zero(&self) {
|
||||
(0..Self::NUMBER_OF_BANKS).for_each(|idx| self.preload_zero_bank(idx));
|
||||
}
|
||||
|
||||
fn preload_zero_bank(&self, bank_index: usize) {
|
||||
let bank = &self.0[bank_index * ELEMENTS_PER_CACHE_LINE];
|
||||
let bank = self.bank_ptr(bank_index);
|
||||
unsafe {
|
||||
// Preload first bank
|
||||
asm!("dc zva, {val}", val = in(reg) bank);
|
||||
@@ -132,9 +198,9 @@ impl BankArray {
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[repr(C, align(65536))]
|
||||
pub struct ComputeArray<const N: usize>(pub [BankArray; N]);
|
||||
pub struct ComputeArray<T: PimRegion, const N: usize>(pub [T; N]);
|
||||
|
||||
impl<const N: usize> ComputeArray<N> {
|
||||
impl<T: PimRegion, const N: usize> ComputeArray<T, N> {
|
||||
pub fn invalidate_flush(&self) {
|
||||
self.0
|
||||
.iter()
|
||||
@@ -146,7 +212,7 @@ impl<const N: usize> ComputeArray<N> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Default for ComputeArray<N> {
|
||||
impl<T: PimRegion + Default, const N: usize> Default for ComputeArray<T, N> {
|
||||
fn default() -> Self {
|
||||
Self(core::array::from_fn(|_| Default::default()))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user