3x3 matrix multiplication

This commit is contained in:
2024-01-03 18:50:54 +01:00
parent 6380385bd0
commit 34c8ab84fb
7 changed files with 161 additions and 91 deletions

18
pim-os/Cargo.lock generated
View File

@@ -59,6 +59,16 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "embedded-alloc"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddae17915accbac2cfbc64ea0ae6e3b330e6ea124ba108dada63646fd3c6f815"
dependencies = [
"critical-section",
"linked_list_allocator",
]
[[package]]
name = "half"
version = "2.3.1"
@@ -91,6 +101,12 @@ dependencies = [
"stable_deref_trait",
]
[[package]]
name = "linked_list_allocator"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9afa463f5405ee81cdb9cc2baf37e08ec7e4c8209442b5d72c04cfb2cd6e6286"
[[package]]
name = "lock_api"
version = "0.4.11"
@@ -172,6 +188,8 @@ name = "pim-os"
version = "0.1.0"
dependencies = [
"aarch64-cpu",
"critical-section",
"embedded-alloc",
"half",
"nalgebra",
"num-traits",

View File

@@ -17,6 +17,8 @@ pim-isa = { path = "../pim-isa", default-features = false }
serde-json-core = "0.5.1"
serde = { version = "1.0", default-features = false, features = ["derive"] }
num-traits = { version = "0.2.17", default-features = false }
embedded-alloc = "0.5.1"
critical-section = "1.1.2"
[profile.dev]
panic = "abort"

View File

@@ -1,7 +1,8 @@
MEMORY
{
bootmem : ORIGIN = 0x0, LENGTH = 0x100000
dram : ORIGIN = 0x80000000, LENGTH = 0x100000000
dram : ORIGIN = 0x80000000, LENGTH = 0x20000000
dram_pim : ORIGIN = 0xA0000000, LENGTH = 0x20000000
}
ENTRY(_start)
@@ -18,4 +19,6 @@ SECTIONS
. = ALIGN(8);
. = . + 0x100000; # 1 MiB Stack
LD_STACK_PTR = .;
.pim_data : { KEEP(*(.pim_data)) } > dram_pim
}

View File

@@ -2,13 +2,18 @@
#![no_std]
#![no_main]
extern crate alloc;
use aarch64_cpu::asm::barrier;
use alloc::{boxed::Box, rc::Rc};
use core::{
cell::RefCell,
fmt::Write,
mem::MaybeUninit,
panic::PanicInfo,
sync::atomic::{compiler_fence, Ordering},
};
use embedded_alloc::Heap;
use half::f16;
use nalgebra::Matrix;
use pim::{
@@ -21,39 +26,49 @@ use pim_isa::BankMode;
use uart::Uart0;
mod boot;
mod critical_section;
mod m5ops;
mod pim;
mod uart;
#[global_allocator]
static PIM_ALLOC: Heap = Heap::empty();
const PIM_ARENA_SIZE: usize = 0x2000000;
#[link_section = ".pim_data"]
static mut PIM_ARENA: [MaybeUninit<u8>; PIM_ARENA_SIZE] = [MaybeUninit::uninit(); PIM_ARENA_SIZE];
#[no_mangle]
pub extern "C" fn entry() -> ! {
unsafe {
PIM_ALLOC.init(PIM_ARENA.as_ptr() as usize, PIM_ARENA_SIZE);
}
let mut uart = Uart0;
let mut pim_state = PimState::new(&MATRIX_MUL);
pim_state.set_kernel();
let pim_matrix_arena0 = RefCell::new(PimMatrixArena(
let mut pim_matrix_arena0 = Rc::new(RefCell::new(PimMatrixArena(
[[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3],
));
let pim_matrix_arena1 = RefCell::new(PimMatrixArena(
)));
let mut pim_matrix_arena1 = Rc::new(RefCell::new(PimMatrixArena(
[[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3],
));
let pim_matrix_arena2 = RefCell::new(PimMatrixArena(
)));
let mut pim_matrix_arena2 = Rc::new(RefCell::new(PimMatrixArena(
[[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3],
));
)));
let pim_storage0 = PimStorage {
arena: &pim_matrix_arena0,
index: 0,
row_major: false,
};
let pim_storage1 = PimStorage {
arena: &pim_matrix_arena1,
index: 0,
row_major: false,
};
let pim_storage2 = PimStorage {
arena: &pim_matrix_arena2,
index: 0,
row_major: false,
};
let mut matrix0 = Matrix::from_data(pim_storage0);
@@ -70,19 +85,19 @@ pub extern "C" fn entry() -> ! {
pim_matrix_arena0.borrow_mut().invalidate_flush();
pim_matrix_arena1.borrow_mut().invalidate_flush();
pim_matrix_arena2.borrow_mut().invalidate_flush();
let dummy_array = DummyArray([F16x16::default(); NUMBER_OF_BANKS]);
let mut dummy_array = Box::new(DummyArray([F16x16::default(); NUMBER_OF_BANKS]));
barrier::dsb(barrier::SY);
// execute_matrix_add(&pim_matrix_arena0, &pim_matrix_arena1, &dummy_array);
execute_matrix_multiply(
&mut pim_state,
&pim_matrix_arena0,
&pim_matrix_arena1,
&pim_matrix_arena2,
&dummy_array,
&mut pim_matrix_arena0.borrow_mut(),
&mut pim_matrix_arena1.borrow_mut(),
&mut pim_matrix_arena2.borrow_mut(),
dummy_array.as_mut(),
);
pim_matrix_arena2.borrow_mut().invalidate();
pim_matrix_arena2.borrow().invalidate();
barrier::dsb(barrier::SY);
writeln!(&mut uart, "{matrix2}").unwrap();

View File

@@ -9,7 +9,7 @@ const EVEN_BANK_INDEX: usize = 0;
const ODD_BANK_INDEX: usize = 8;
#[derive(Clone, Debug)]
#[repr(C, align(1024))]
#[repr(C, align(65536))]
pub struct PimMatrixArena<const R: usize, const C: usize>(pub [[[F16x16; NUMBER_OF_BANKS]; R]; C]);
impl<const R: usize, const C: usize> PimRegion for PimMatrixArena<R, C> {
@@ -24,11 +24,10 @@ impl<const R: usize, const C: usize> PimRegion for PimMatrixArena<R, C> {
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct PimStorage<'a, const R: usize, const C: usize> {
pub arena: &'a RefCell<PimMatrixArena<R, C>>,
pub index: usize,
pub row_major: bool,
}
unsafe impl<'a, const R: usize, const C: usize> RawStorage<F16x1, Const<R>, Const<C>>
@@ -46,12 +45,8 @@ unsafe impl<'a, const R: usize, const C: usize> RawStorage<F16x1, Const<R>, Cons
}
fn strides(&self) -> (Self::RStride, Self::CStride) {
if self.row_major {
(Dyn(16 * R * NUMBER_OF_BANKS), Dyn(16 * NUMBER_OF_BANKS))
} else {
(Dyn(16 * NUMBER_OF_BANKS), Dyn(16 * R * NUMBER_OF_BANKS))
}
}
fn is_contiguous(&self) -> bool {
false
@@ -92,6 +87,9 @@ pub trait PimRegion {
self.read_data_bank(EVEN_BANK_INDEX + i * NUMBER_OF_BANKS);
barrier::dsb(barrier::SY);
self.invalidate_bank(EVEN_BANK_INDEX + i * NUMBER_OF_BANKS);
barrier::dsb(barrier::SY);
}
fn execute_instruction_read_dual_bank(&self) {

View File

@@ -1,4 +1,5 @@
use core::cell::RefCell;
use core::fmt::Write;
use pim_isa::{BankMode, File, Instruction, Kernel};
@@ -137,6 +138,10 @@ pub const MATRIX_MUL: Kernel = Kernel([
dst: File::GrfB { index: 0 },
aam: false,
},
// Instruction::JUMP {
// offset: -1,
// count: 2,
// },
Instruction::FILL {
src: File::GrfB { index: 0 },
dst: File::Bank,
@@ -170,45 +175,31 @@ pub const MATRIX_MUL: Kernel = Kernel([
pub fn execute_matrix_multiply<const R: usize, const C: usize>(
pim_state: &mut PimState,
pim_matrix_arena0: &RefCell<PimMatrixArena<R, C>>,
pim_matrix_arena1: &RefCell<PimMatrixArena<R, C>>,
pim_matrix_arena2: &RefCell<PimMatrixArena<R, C>>,
dummy_array: &DummyArray,
pim_matrix_arena0: &mut PimMatrixArena<R, C>,
pim_matrix_arena1: &mut PimMatrixArena<R, C>,
pim_matrix_arena2: &mut PimMatrixArena<R, C>,
dummy_array: &mut DummyArray,
) {
pim_state.set_bank_mode(BankMode::PimAllBank);
let mut index = 0;
while index < (R * C) {
let left_index = index % R;
let right_index = (index / R) * R;
let mut i = 0;
while i < (R * C) {
let left_index = i % R;
let right_index = (i / R) * R;
pim_matrix_arena0
.borrow()
.execute_instruction_read_single_bank(left_index);
pim_matrix_arena0
.borrow()
.execute_instruction_read_single_bank(left_index + R);
pim_matrix_arena0
.borrow()
.execute_instruction_read_single_bank(left_index + R * 2);
pim_matrix_arena0.execute_instruction_read_single_bank(left_index + R * 0);
pim_matrix_arena0.execute_instruction_read_single_bank(left_index + R * 1);
pim_matrix_arena0.execute_instruction_read_single_bank(left_index + R * 2);
pim_matrix_arena1
.borrow()
.execute_instruction_read_single_bank(right_index);
pim_matrix_arena1
.borrow()
.execute_instruction_read_single_bank(right_index + 1);
pim_matrix_arena1
.borrow()
.execute_instruction_read_single_bank(right_index + 2);
pim_matrix_arena1.execute_instruction_read_single_bank(right_index + 0);
pim_matrix_arena1.execute_instruction_read_single_bank(right_index + 1);
pim_matrix_arena1.execute_instruction_read_single_bank(right_index + 2);
pim_matrix_arena2
.borrow_mut()
.execute_instruction_write_single_bank(index);
pim_matrix_arena2.execute_instruction_write_single_bank(i);
dummy_array.execute_instruction_read_single_bank(0);
index += 1;
i += 1;
}
pim_state.set_bank_mode(BankMode::SingleBank);

View File

@@ -131,7 +131,7 @@ impl PimVM {
let pim_unit = &mut self.pim_units[pim_unit_index as usize];
let mut inst = self.kernel.0[pim_unit.pc as usize];
let inst = self.kernel.0[pim_unit.pc as usize];
if pim_unit_index == 0 {
log::debug!(
@@ -140,38 +140,15 @@ impl PimVM {
);
}
pim_unit.pc += 1;
let aam_grf_a_index = (address >> GRF_A_BIT_OFFSET) & 0b111;
let aam_grf_b_index = (address >> GRF_B_BIT_OFFSET) & 0b111;
// The JUMP instruction is zero-cycle and not actually executed
while let Instruction::JUMP { offset, count } = inst {
pim_unit.jump_counter = match pim_unit.jump_counter {
Some(jump_counter) => jump_counter.checked_sub(1),
None => count.checked_sub(1),
};
if pim_unit.jump_counter != None {
let new_pc = pim_unit.pc as i32 + offset as i32;
if new_pc < 0 || new_pc >= 32 {
panic!("Invalid PC {new_pc} after JUMP: {inst:?}");
}
pim_unit.pc = new_pc as _;
if pim_unit_index == 0 {
log::debug!("PimUnit {pim_unit_index} New PC {new_pc}: {inst:?}");
}
}
inst = self.kernel.0[pim_unit.pc as usize];
pim_unit.pc += 1;
}
match inst {
Instruction::NOP => (),
Instruction::EXIT => pim_unit.reset(),
Instruction::EXIT => {
pim_unit.reset();
return;
}
Instruction::JUMP { .. } => unreachable!(),
Instruction::MOV { src, dst } | Instruction::FILL { src, dst } => {
let data = PimVM::load(src, pim_unit, &bank_data);
@@ -267,6 +244,9 @@ impl PimVM {
} => {
if aam {
src1 = if let File::GrfA { index: _ } = src1 {
if pim_unit_index == 0 {
log::debug!("AAM index GrfA {aam_grf_a_index}");
}
File::GrfA {
index: aam_grf_a_index as _,
}
@@ -275,6 +255,9 @@ impl PimVM {
};
src2 = if let File::GrfB { index: _ } = src2 {
if pim_unit_index == 0 {
log::debug!("AAM index GrfB {aam_grf_a_index}");
}
File::GrfB {
index: aam_grf_b_index as _,
}
@@ -326,6 +309,36 @@ impl PimVM {
PimVM::store(dst, pim_unit, &sum);
}
}
pim_unit.pc += 1;
// The JUMP instruction is zero-cycle and not actually executed
while let Instruction::JUMP { offset, count } = self.kernel.0[pim_unit.pc as usize] {
pim_unit.jump_counter = match pim_unit.jump_counter {
Some(jump_counter) => jump_counter.checked_sub(1),
None => count.checked_sub(1),
};
if pim_unit.jump_counter != None {
let new_pc = pim_unit.pc as i32 + offset as i32;
if new_pc < 0 || new_pc >= 32 {
panic!("Invalid PC {new_pc} after JUMP: {inst:?}");
}
pim_unit.pc = new_pc as _;
} else {
pim_unit.pc += 1;
}
if pim_unit_index == 0 {
log::debug!(
"PimUnit {pim_unit_index} JUMP to PC {}: {:?}",
pim_unit.pc,
self.kernel.0[pim_unit.pc as usize]
);
}
}
}
pub fn execute_write(&mut self, bank_index: u32) -> [u8; BURST_LENGTH] {
@@ -336,24 +349,24 @@ impl PimVM {
};
let pim_unit = &mut self.pim_units[pim_unit_index as usize];
let current_pc = pim_unit.pc;
pim_unit.pc += 1;
let inst = &self.kernel.0[current_pc as usize];
let inst = self.kernel.0[pim_unit.pc as usize];
if pim_unit_index == 0 {
log::debug!("PimUnit {pim_unit_index} Execute PC {current_pc}: {inst:?}");
log::debug!(
"PimUnit {pim_unit_index} Execute PC {}: {inst:?}",
pim_unit.pc
);
}
let data = match inst {
Instruction::FILL { src, dst } => {
let data: [f16; FP_UNITS] = match src {
File::GrfA { index } => pim_unit.grf_a[*index as usize],
File::GrfB { index } => pim_unit.grf_b[*index as usize],
File::GrfA { index } => pim_unit.grf_a[index as usize],
File::GrfB { index } => pim_unit.grf_b[index as usize],
_ => panic!("Unsupported src operand: {src:?}"),
};
if *dst != File::Bank {
if dst != File::Bank {
panic!("Unsupported dst operand: {dst:?}")
}
@@ -366,6 +379,36 @@ impl PimVM {
_ => panic!("Unsupported instruction for write: {inst:?}"),
};
pim_unit.pc += 1;
// The JUMP instruction is zero-cycle and not actually executed
while let Instruction::JUMP { offset, count } = self.kernel.0[pim_unit.pc as usize] {
pim_unit.jump_counter = match pim_unit.jump_counter {
Some(jump_counter) => jump_counter.checked_sub(1),
None => count.checked_sub(1),
};
if pim_unit.jump_counter != None {
let new_pc = pim_unit.pc as i32 + offset as i32;
if new_pc < 0 || new_pc >= 32 {
panic!("Invalid PC {new_pc} after JUMP: {inst:?}");
}
pim_unit.pc = new_pc as _;
} else {
pim_unit.pc += 1;
}
if pim_unit_index == 0 {
log::debug!(
"PimUnit {pim_unit_index} JUMP to PC {}: {:?}",
pim_unit.pc,
self.kernel.0[pim_unit.pc as usize]
);
}
}
unsafe { std::mem::transmute(data) }
}