3x3 matrix multiplication

This commit is contained in:
2024-01-03 18:50:54 +01:00
parent 6380385bd0
commit 34c8ab84fb
7 changed files with 161 additions and 91 deletions

View File

@@ -2,13 +2,18 @@
#![no_std]
#![no_main]
extern crate alloc;
use aarch64_cpu::asm::barrier;
use alloc::{boxed::Box, rc::Rc};
use core::{
cell::RefCell,
fmt::Write,
mem::MaybeUninit,
panic::PanicInfo,
sync::atomic::{compiler_fence, Ordering},
};
use embedded_alloc::Heap;
use half::f16;
use nalgebra::Matrix;
use pim::{
@@ -21,39 +26,49 @@ use pim_isa::BankMode;
use uart::Uart0;
mod boot;
mod critical_section;
mod m5ops;
mod pim;
mod uart;
#[global_allocator]
static PIM_ALLOC: Heap = Heap::empty();
const PIM_ARENA_SIZE: usize = 0x2000000;
#[link_section = ".pim_data"]
static mut PIM_ARENA: [MaybeUninit<u8>; PIM_ARENA_SIZE] = [MaybeUninit::uninit(); PIM_ARENA_SIZE];
#[no_mangle]
pub extern "C" fn entry() -> ! {
unsafe {
PIM_ALLOC.init(PIM_ARENA.as_ptr() as usize, PIM_ARENA_SIZE);
}
let mut uart = Uart0;
let mut pim_state = PimState::new(&MATRIX_MUL);
pim_state.set_kernel();
let pim_matrix_arena0 = RefCell::new(PimMatrixArena(
let mut pim_matrix_arena0 = Rc::new(RefCell::new(PimMatrixArena(
[[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3],
));
let pim_matrix_arena1 = RefCell::new(PimMatrixArena(
)));
let mut pim_matrix_arena1 = Rc::new(RefCell::new(PimMatrixArena(
[[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3],
));
let pim_matrix_arena2 = RefCell::new(PimMatrixArena(
)));
let mut pim_matrix_arena2 = Rc::new(RefCell::new(PimMatrixArena(
[[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3],
));
)));
let pim_storage0 = PimStorage {
arena: &pim_matrix_arena0,
index: 0,
row_major: false,
};
let pim_storage1 = PimStorage {
arena: &pim_matrix_arena1,
index: 0,
row_major: false,
};
let pim_storage2 = PimStorage {
arena: &pim_matrix_arena2,
index: 0,
row_major: false,
};
let mut matrix0 = Matrix::from_data(pim_storage0);
@@ -70,19 +85,19 @@ pub extern "C" fn entry() -> ! {
pim_matrix_arena0.borrow_mut().invalidate_flush();
pim_matrix_arena1.borrow_mut().invalidate_flush();
pim_matrix_arena2.borrow_mut().invalidate_flush();
let dummy_array = DummyArray([F16x16::default(); NUMBER_OF_BANKS]);
let mut dummy_array = Box::new(DummyArray([F16x16::default(); NUMBER_OF_BANKS]));
barrier::dsb(barrier::SY);
// execute_matrix_add(&pim_matrix_arena0, &pim_matrix_arena1, &dummy_array);
execute_matrix_multiply(
&mut pim_state,
&pim_matrix_arena0,
&pim_matrix_arena1,
&pim_matrix_arena2,
&dummy_array,
&mut pim_matrix_arena0.borrow_mut(),
&mut pim_matrix_arena1.borrow_mut(),
&mut pim_matrix_arena2.borrow_mut(),
dummy_array.as_mut(),
);
pim_matrix_arena2.borrow_mut().invalidate();
pim_matrix_arena2.borrow().invalidate();
barrier::dsb(barrier::SY);
writeln!(&mut uart, "{matrix2}").unwrap();