3x3 matrix multiplication

This commit is contained in:
2024-01-03 18:50:54 +01:00
parent 6380385bd0
commit 34c8ab84fb
7 changed files with 161 additions and 91 deletions

18
pim-os/Cargo.lock generated
View File

@@ -59,6 +59,16 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "embedded-alloc"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddae17915accbac2cfbc64ea0ae6e3b330e6ea124ba108dada63646fd3c6f815"
dependencies = [
"critical-section",
"linked_list_allocator",
]
[[package]]
name = "half"
version = "2.3.1"
@@ -91,6 +101,12 @@ dependencies = [
"stable_deref_trait",
]
[[package]]
name = "linked_list_allocator"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9afa463f5405ee81cdb9cc2baf37e08ec7e4c8209442b5d72c04cfb2cd6e6286"
[[package]]
name = "lock_api"
version = "0.4.11"
@@ -172,6 +188,8 @@ name = "pim-os"
version = "0.1.0"
dependencies = [
"aarch64-cpu",
"critical-section",
"embedded-alloc",
"half",
"nalgebra",
"num-traits",

View File

@@ -17,6 +17,8 @@ pim-isa = { path = "../pim-isa", default-features = false }
serde-json-core = "0.5.1"
serde = { version = "1.0", default-features = false, features = ["derive"] }
num-traits = { version = "0.2.17", default-features = false }
embedded-alloc = "0.5.1"
critical-section = "1.1.2"
[profile.dev]
panic = "abort"

View File

@@ -1,7 +1,8 @@
MEMORY
{
bootmem : ORIGIN = 0x0, LENGTH = 0x100000
dram : ORIGIN = 0x80000000, LENGTH = 0x100000000
dram : ORIGIN = 0x80000000, LENGTH = 0x20000000
dram_pim : ORIGIN = 0xA0000000, LENGTH = 0x20000000
}
ENTRY(_start)
@@ -18,4 +19,6 @@ SECTIONS
. = ALIGN(8);
. = . + 0x100000; # 1 MiB Stack
LD_STACK_PTR = .;
.pim_data : { KEEP(*(.pim_data)) } > dram_pim
}

View File

@@ -2,13 +2,18 @@
#![no_std]
#![no_main]
extern crate alloc;
use aarch64_cpu::asm::barrier;
use alloc::{boxed::Box, rc::Rc};
use core::{
cell::RefCell,
fmt::Write,
mem::MaybeUninit,
panic::PanicInfo,
sync::atomic::{compiler_fence, Ordering},
};
use embedded_alloc::Heap;
use half::f16;
use nalgebra::Matrix;
use pim::{
@@ -21,39 +26,49 @@ use pim_isa::BankMode;
use uart::Uart0;
mod boot;
mod critical_section;
mod m5ops;
mod pim;
mod uart;
#[global_allocator]
static PIM_ALLOC: Heap = Heap::empty();
const PIM_ARENA_SIZE: usize = 0x2000000;
#[link_section = ".pim_data"]
static mut PIM_ARENA: [MaybeUninit<u8>; PIM_ARENA_SIZE] = [MaybeUninit::uninit(); PIM_ARENA_SIZE];
#[no_mangle]
pub extern "C" fn entry() -> ! {
unsafe {
PIM_ALLOC.init(PIM_ARENA.as_ptr() as usize, PIM_ARENA_SIZE);
}
let mut uart = Uart0;
let mut pim_state = PimState::new(&MATRIX_MUL);
pim_state.set_kernel();
let pim_matrix_arena0 = RefCell::new(PimMatrixArena(
let mut pim_matrix_arena0 = Rc::new(RefCell::new(PimMatrixArena(
[[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3],
));
let pim_matrix_arena1 = RefCell::new(PimMatrixArena(
)));
let mut pim_matrix_arena1 = Rc::new(RefCell::new(PimMatrixArena(
[[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3],
));
let pim_matrix_arena2 = RefCell::new(PimMatrixArena(
)));
let mut pim_matrix_arena2 = Rc::new(RefCell::new(PimMatrixArena(
[[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3],
));
)));
let pim_storage0 = PimStorage {
arena: &pim_matrix_arena0,
index: 0,
row_major: false,
};
let pim_storage1 = PimStorage {
arena: &pim_matrix_arena1,
index: 0,
row_major: false,
};
let pim_storage2 = PimStorage {
arena: &pim_matrix_arena2,
index: 0,
row_major: false,
};
let mut matrix0 = Matrix::from_data(pim_storage0);
@@ -70,19 +85,19 @@ pub extern "C" fn entry() -> ! {
pim_matrix_arena0.borrow_mut().invalidate_flush();
pim_matrix_arena1.borrow_mut().invalidate_flush();
pim_matrix_arena2.borrow_mut().invalidate_flush();
let dummy_array = DummyArray([F16x16::default(); NUMBER_OF_BANKS]);
let mut dummy_array = Box::new(DummyArray([F16x16::default(); NUMBER_OF_BANKS]));
barrier::dsb(barrier::SY);
// execute_matrix_add(&pim_matrix_arena0, &pim_matrix_arena1, &dummy_array);
execute_matrix_multiply(
&mut pim_state,
&pim_matrix_arena0,
&pim_matrix_arena1,
&pim_matrix_arena2,
&dummy_array,
&mut pim_matrix_arena0.borrow_mut(),
&mut pim_matrix_arena1.borrow_mut(),
&mut pim_matrix_arena2.borrow_mut(),
dummy_array.as_mut(),
);
pim_matrix_arena2.borrow_mut().invalidate();
pim_matrix_arena2.borrow().invalidate();
barrier::dsb(barrier::SY);
writeln!(&mut uart, "{matrix2}").unwrap();

View File

@@ -9,7 +9,7 @@ const EVEN_BANK_INDEX: usize = 0;
const ODD_BANK_INDEX: usize = 8;
#[derive(Clone, Debug)]
#[repr(C, align(1024))]
#[repr(C, align(65536))]
pub struct PimMatrixArena<const R: usize, const C: usize>(pub [[[F16x16; NUMBER_OF_BANKS]; R]; C]);
impl<const R: usize, const C: usize> PimRegion for PimMatrixArena<R, C> {
@@ -24,11 +24,10 @@ impl<const R: usize, const C: usize> PimRegion for PimMatrixArena<R, C> {
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct PimStorage<'a, const R: usize, const C: usize> {
pub arena: &'a RefCell<PimMatrixArena<R, C>>,
pub index: usize,
pub row_major: bool,
}
unsafe impl<'a, const R: usize, const C: usize> RawStorage<F16x1, Const<R>, Const<C>>
@@ -46,11 +45,7 @@ unsafe impl<'a, const R: usize, const C: usize> RawStorage<F16x1, Const<R>, Cons
}
fn strides(&self) -> (Self::RStride, Self::CStride) {
if self.row_major {
(Dyn(16 * R * NUMBER_OF_BANKS), Dyn(16 * NUMBER_OF_BANKS))
} else {
(Dyn(16 * NUMBER_OF_BANKS), Dyn(16 * R * NUMBER_OF_BANKS))
}
(Dyn(16 * NUMBER_OF_BANKS), Dyn(16 * R * NUMBER_OF_BANKS))
}
fn is_contiguous(&self) -> bool {
@@ -92,6 +87,9 @@ pub trait PimRegion {
self.read_data_bank(EVEN_BANK_INDEX + i * NUMBER_OF_BANKS);
barrier::dsb(barrier::SY);
self.invalidate_bank(EVEN_BANK_INDEX + i * NUMBER_OF_BANKS);
barrier::dsb(barrier::SY);
}
fn execute_instruction_read_dual_bank(&self) {

View File

@@ -1,4 +1,5 @@
use core::cell::RefCell;
use core::fmt::Write;
use pim_isa::{BankMode, File, Instruction, Kernel};
@@ -137,6 +138,10 @@ pub const MATRIX_MUL: Kernel = Kernel([
dst: File::GrfB { index: 0 },
aam: false,
},
// Instruction::JUMP {
// offset: -1,
// count: 2,
// },
Instruction::FILL {
src: File::GrfB { index: 0 },
dst: File::Bank,
@@ -170,45 +175,31 @@ pub const MATRIX_MUL: Kernel = Kernel([
pub fn execute_matrix_multiply<const R: usize, const C: usize>(
pim_state: &mut PimState,
pim_matrix_arena0: &RefCell<PimMatrixArena<R, C>>,
pim_matrix_arena1: &RefCell<PimMatrixArena<R, C>>,
pim_matrix_arena2: &RefCell<PimMatrixArena<R, C>>,
dummy_array: &DummyArray,
pim_matrix_arena0: &mut PimMatrixArena<R, C>,
pim_matrix_arena1: &mut PimMatrixArena<R, C>,
pim_matrix_arena2: &mut PimMatrixArena<R, C>,
dummy_array: &mut DummyArray,
) {
pim_state.set_bank_mode(BankMode::PimAllBank);
let mut index = 0;
while index < (R * C) {
let left_index = index % R;
let right_index = (index / R) * R;
let mut i = 0;
while i < (R * C) {
let left_index = i % R;
let right_index = (i / R) * R;
pim_matrix_arena0
.borrow()
.execute_instruction_read_single_bank(left_index);
pim_matrix_arena0
.borrow()
.execute_instruction_read_single_bank(left_index + R);
pim_matrix_arena0
.borrow()
.execute_instruction_read_single_bank(left_index + R * 2);
pim_matrix_arena0.execute_instruction_read_single_bank(left_index + R * 0);
pim_matrix_arena0.execute_instruction_read_single_bank(left_index + R * 1);
pim_matrix_arena0.execute_instruction_read_single_bank(left_index + R * 2);
pim_matrix_arena1
.borrow()
.execute_instruction_read_single_bank(right_index);
pim_matrix_arena1
.borrow()
.execute_instruction_read_single_bank(right_index + 1);
pim_matrix_arena1
.borrow()
.execute_instruction_read_single_bank(right_index + 2);
pim_matrix_arena1.execute_instruction_read_single_bank(right_index + 0);
pim_matrix_arena1.execute_instruction_read_single_bank(right_index + 1);
pim_matrix_arena1.execute_instruction_read_single_bank(right_index + 2);
pim_matrix_arena2
.borrow_mut()
.execute_instruction_write_single_bank(index);
pim_matrix_arena2.execute_instruction_write_single_bank(i);
dummy_array.execute_instruction_read_single_bank(0);
index += 1;
i += 1;
}
pim_state.set_bank_mode(BankMode::SingleBank);