From 34c8ab84fb2f880afd280d592c065dc1f1a02390 Mon Sep 17 00:00:00 2001 From: Derek Christ Date: Wed, 3 Jan 2024 18:50:54 +0100 Subject: [PATCH] 3x3 matrix multiplication --- pim-os/Cargo.lock | 18 ++++++ pim-os/Cargo.toml | 2 + pim-os/aarch64-gem5.ld | 5 +- pim-os/src/main.rs | 45 ++++++++++----- pim-os/src/pim/array.rs | 14 ++--- pim-os/src/pim/kernel.rs | 51 +++++++---------- pim-vm/src/lib.rs | 117 ++++++++++++++++++++++++++------------- 7 files changed, 161 insertions(+), 91 deletions(-) diff --git a/pim-os/Cargo.lock b/pim-os/Cargo.lock index 29a0be0..edd1100 100644 --- a/pim-os/Cargo.lock +++ b/pim-os/Cargo.lock @@ -59,6 +59,16 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "embedded-alloc" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddae17915accbac2cfbc64ea0ae6e3b330e6ea124ba108dada63646fd3c6f815" +dependencies = [ + "critical-section", + "linked_list_allocator", +] + [[package]] name = "half" version = "2.3.1" @@ -91,6 +101,12 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "linked_list_allocator" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afa463f5405ee81cdb9cc2baf37e08ec7e4c8209442b5d72c04cfb2cd6e6286" + [[package]] name = "lock_api" version = "0.4.11" @@ -172,6 +188,8 @@ name = "pim-os" version = "0.1.0" dependencies = [ "aarch64-cpu", + "critical-section", + "embedded-alloc", "half", "nalgebra", "num-traits", diff --git a/pim-os/Cargo.toml b/pim-os/Cargo.toml index cdd060a..1849a98 100644 --- a/pim-os/Cargo.toml +++ b/pim-os/Cargo.toml @@ -17,6 +17,8 @@ pim-isa = { path = "../pim-isa", default-features = false } serde-json-core = "0.5.1" serde = { version = "1.0", default-features = false, features = ["derive"] } num-traits = { version = "0.2.17", default-features = false } +embedded-alloc = "0.5.1" +critical-section = "1.1.2" [profile.dev] panic = "abort" diff --git a/pim-os/aarch64-gem5.ld b/pim-os/aarch64-gem5.ld index 0190bf7..ed3874a 100644 --- a/pim-os/aarch64-gem5.ld +++ b/pim-os/aarch64-gem5.ld @@ -1,7 +1,8 @@ MEMORY { bootmem : ORIGIN = 0x0, LENGTH = 0x100000 - dram : ORIGIN = 0x80000000, LENGTH = 0x100000000 + dram : ORIGIN = 0x80000000, LENGTH = 0x20000000 + dram_pim : ORIGIN = 0xA0000000, LENGTH = 0x20000000 } ENTRY(_start) @@ -18,4 +19,6 @@ SECTIONS . = ALIGN(8); . = . + 0x100000; # 1 MiB Stack LD_STACK_PTR = .; + + .pim_data : { KEEP(*(.pim_data)) } > dram_pim } diff --git a/pim-os/src/main.rs b/pim-os/src/main.rs index efc93ee..dc738ab 100644 --- a/pim-os/src/main.rs +++ b/pim-os/src/main.rs @@ -2,13 +2,18 @@ #![no_std] #![no_main] +extern crate alloc; + use aarch64_cpu::asm::barrier; +use alloc::{boxed::Box, rc::Rc}; use core::{ cell::RefCell, fmt::Write, + mem::MaybeUninit, panic::PanicInfo, sync::atomic::{compiler_fence, Ordering}, }; +use embedded_alloc::Heap; use half::f16; use nalgebra::Matrix; use pim::{ @@ -21,39 +26,49 @@ use pim_isa::BankMode; use uart::Uart0; mod boot; +mod critical_section; mod m5ops; mod pim; mod uart; +#[global_allocator] +static PIM_ALLOC: Heap = Heap::empty(); + +const PIM_ARENA_SIZE: usize = 0x2000000; + +#[link_section = ".pim_data"] +static mut PIM_ARENA: [MaybeUninit; PIM_ARENA_SIZE] = [MaybeUninit::uninit(); PIM_ARENA_SIZE]; + #[no_mangle] pub extern "C" fn entry() -> ! { + unsafe { + PIM_ALLOC.init(PIM_ARENA.as_ptr() as usize, PIM_ARENA_SIZE); + } + let mut uart = Uart0; let mut pim_state = PimState::new(&MATRIX_MUL); pim_state.set_kernel(); - let pim_matrix_arena0 = RefCell::new(PimMatrixArena( + let mut pim_matrix_arena0 = Rc::new(RefCell::new(PimMatrixArena( [[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3], - )); - let pim_matrix_arena1 = RefCell::new(PimMatrixArena( + ))); + let mut pim_matrix_arena1 = Rc::new(RefCell::new(PimMatrixArena( [[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3], - )); - let pim_matrix_arena2 = RefCell::new(PimMatrixArena( + ))); + let mut pim_matrix_arena2 = Rc::new(RefCell::new(PimMatrixArena( [[[F16x16::default(); NUMBER_OF_BANKS]; 3]; 3], - )); + ))); let pim_storage0 = PimStorage { arena: &pim_matrix_arena0, index: 0, - row_major: false, }; let pim_storage1 = PimStorage { arena: &pim_matrix_arena1, index: 0, - row_major: false, }; let pim_storage2 = PimStorage { arena: &pim_matrix_arena2, index: 0, - row_major: false, }; let mut matrix0 = Matrix::from_data(pim_storage0); @@ -70,19 +85,19 @@ pub extern "C" fn entry() -> ! { pim_matrix_arena0.borrow_mut().invalidate_flush(); pim_matrix_arena1.borrow_mut().invalidate_flush(); pim_matrix_arena2.borrow_mut().invalidate_flush(); - let dummy_array = DummyArray([F16x16::default(); NUMBER_OF_BANKS]); + let mut dummy_array = Box::new(DummyArray([F16x16::default(); NUMBER_OF_BANKS])); barrier::dsb(barrier::SY); // execute_matrix_add(&pim_matrix_arena0, &pim_matrix_arena1, &dummy_array); execute_matrix_multiply( &mut pim_state, - &pim_matrix_arena0, - &pim_matrix_arena1, - &pim_matrix_arena2, - &dummy_array, + &mut pim_matrix_arena0.borrow_mut(), + &mut pim_matrix_arena1.borrow_mut(), + &mut pim_matrix_arena2.borrow_mut(), + dummy_array.as_mut(), ); - pim_matrix_arena2.borrow_mut().invalidate(); + pim_matrix_arena2.borrow().invalidate(); barrier::dsb(barrier::SY); writeln!(&mut uart, "{matrix2}").unwrap(); diff --git a/pim-os/src/pim/array.rs b/pim-os/src/pim/array.rs index 647d3ac..8ede96f 100644 --- a/pim-os/src/pim/array.rs +++ b/pim-os/src/pim/array.rs @@ -9,7 +9,7 @@ const EVEN_BANK_INDEX: usize = 0; const ODD_BANK_INDEX: usize = 8; #[derive(Clone, Debug)] -#[repr(C, align(1024))] +#[repr(C, align(65536))] pub struct PimMatrixArena(pub [[[F16x16; NUMBER_OF_BANKS]; R]; C]); impl PimRegion for PimMatrixArena { @@ -24,11 +24,10 @@ impl PimRegion for PimMatrixArena { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct PimStorage<'a, const R: usize, const C: usize> { pub arena: &'a RefCell>, pub index: usize, - pub row_major: bool, } unsafe impl<'a, const R: usize, const C: usize> RawStorage, Const> @@ -46,11 +45,7 @@ unsafe impl<'a, const R: usize, const C: usize> RawStorage, Cons } fn strides(&self) -> (Self::RStride, Self::CStride) { - if self.row_major { - (Dyn(16 * R * NUMBER_OF_BANKS), Dyn(16 * NUMBER_OF_BANKS)) - } else { - (Dyn(16 * NUMBER_OF_BANKS), Dyn(16 * R * NUMBER_OF_BANKS)) - } + (Dyn(16 * NUMBER_OF_BANKS), Dyn(16 * R * NUMBER_OF_BANKS)) } fn is_contiguous(&self) -> bool { @@ -92,6 +87,9 @@ pub trait PimRegion { self.read_data_bank(EVEN_BANK_INDEX + i * NUMBER_OF_BANKS); barrier::dsb(barrier::SY); + + self.invalidate_bank(EVEN_BANK_INDEX + i * NUMBER_OF_BANKS); + barrier::dsb(barrier::SY); } fn execute_instruction_read_dual_bank(&self) { diff --git a/pim-os/src/pim/kernel.rs b/pim-os/src/pim/kernel.rs index 163f10e..fac35d6 100644 --- a/pim-os/src/pim/kernel.rs +++ b/pim-os/src/pim/kernel.rs @@ -1,4 +1,5 @@ use core::cell::RefCell; +use core::fmt::Write; use pim_isa::{BankMode, File, Instruction, Kernel}; @@ -137,6 +138,10 @@ pub const MATRIX_MUL: Kernel = Kernel([ dst: File::GrfB { index: 0 }, aam: false, }, + // Instruction::JUMP { + // offset: -1, + // count: 2, + // }, Instruction::FILL { src: File::GrfB { index: 0 }, dst: File::Bank, @@ -170,45 +175,31 @@ pub const MATRIX_MUL: Kernel = Kernel([ pub fn execute_matrix_multiply( pim_state: &mut PimState, - pim_matrix_arena0: &RefCell>, - pim_matrix_arena1: &RefCell>, - pim_matrix_arena2: &RefCell>, - dummy_array: &DummyArray, + pim_matrix_arena0: &mut PimMatrixArena, + pim_matrix_arena1: &mut PimMatrixArena, + pim_matrix_arena2: &mut PimMatrixArena, + dummy_array: &mut DummyArray, ) { pim_state.set_bank_mode(BankMode::PimAllBank); - let mut index = 0; - while index < (R * C) { - let left_index = index % R; - let right_index = (index / R) * R; + let mut i = 0; + while i < (R * C) { + let left_index = i % R; + let right_index = (i / R) * R; - pim_matrix_arena0 - .borrow() - .execute_instruction_read_single_bank(left_index); - pim_matrix_arena0 - .borrow() - .execute_instruction_read_single_bank(left_index + R); - pim_matrix_arena0 - .borrow() - .execute_instruction_read_single_bank(left_index + R * 2); + pim_matrix_arena0.execute_instruction_read_single_bank(left_index + R * 0); + pim_matrix_arena0.execute_instruction_read_single_bank(left_index + R * 1); + pim_matrix_arena0.execute_instruction_read_single_bank(left_index + R * 2); - pim_matrix_arena1 - .borrow() - .execute_instruction_read_single_bank(right_index); - pim_matrix_arena1 - .borrow() - .execute_instruction_read_single_bank(right_index + 1); - pim_matrix_arena1 - .borrow() - .execute_instruction_read_single_bank(right_index + 2); + pim_matrix_arena1.execute_instruction_read_single_bank(right_index + 0); + pim_matrix_arena1.execute_instruction_read_single_bank(right_index + 1); + pim_matrix_arena1.execute_instruction_read_single_bank(right_index + 2); - pim_matrix_arena2 - .borrow_mut() - .execute_instruction_write_single_bank(index); + pim_matrix_arena2.execute_instruction_write_single_bank(i); dummy_array.execute_instruction_read_single_bank(0); - index += 1; + i += 1; } pim_state.set_bank_mode(BankMode::SingleBank); diff --git a/pim-vm/src/lib.rs b/pim-vm/src/lib.rs index b44bfe8..ad46baf 100644 --- a/pim-vm/src/lib.rs +++ b/pim-vm/src/lib.rs @@ -131,7 +131,7 @@ impl PimVM { let pim_unit = &mut self.pim_units[pim_unit_index as usize]; - let mut inst = self.kernel.0[pim_unit.pc as usize]; + let inst = self.kernel.0[pim_unit.pc as usize]; if pim_unit_index == 0 { log::debug!( @@ -140,38 +140,15 @@ impl PimVM { ); } - pim_unit.pc += 1; - let aam_grf_a_index = (address >> GRF_A_BIT_OFFSET) & 0b111; let aam_grf_b_index = (address >> GRF_B_BIT_OFFSET) & 0b111; - // The JUMP instruction is zero-cycle and not actually executed - while let Instruction::JUMP { offset, count } = inst { - pim_unit.jump_counter = match pim_unit.jump_counter { - Some(jump_counter) => jump_counter.checked_sub(1), - None => count.checked_sub(1), - }; - - if pim_unit.jump_counter != None { - let new_pc = pim_unit.pc as i32 + offset as i32; - - if new_pc < 0 || new_pc >= 32 { - panic!("Invalid PC {new_pc} after JUMP: {inst:?}"); - } - - pim_unit.pc = new_pc as _; - if pim_unit_index == 0 { - log::debug!("PimUnit {pim_unit_index} New PC {new_pc}: {inst:?}"); - } - } - - inst = self.kernel.0[pim_unit.pc as usize]; - pim_unit.pc += 1; - } - match inst { Instruction::NOP => (), - Instruction::EXIT => pim_unit.reset(), + Instruction::EXIT => { + pim_unit.reset(); + return; + } Instruction::JUMP { .. } => unreachable!(), Instruction::MOV { src, dst } | Instruction::FILL { src, dst } => { let data = PimVM::load(src, pim_unit, &bank_data); @@ -267,6 +244,9 @@ impl PimVM { } => { if aam { src1 = if let File::GrfA { index: _ } = src1 { + if pim_unit_index == 0 { + log::debug!("AAM index GrfA {aam_grf_a_index}"); + } File::GrfA { index: aam_grf_a_index as _, } @@ -275,6 +255,9 @@ impl PimVM { }; src2 = if let File::GrfB { index: _ } = src2 { + if pim_unit_index == 0 { + log::debug!("AAM index GrfB {aam_grf_a_index}"); + } File::GrfB { index: aam_grf_b_index as _, } @@ -326,6 +309,36 @@ impl PimVM { PimVM::store(dst, pim_unit, &sum); } } + + pim_unit.pc += 1; + + // The JUMP instruction is zero-cycle and not actually executed + while let Instruction::JUMP { offset, count } = self.kernel.0[pim_unit.pc as usize] { + pim_unit.jump_counter = match pim_unit.jump_counter { + Some(jump_counter) => jump_counter.checked_sub(1), + None => count.checked_sub(1), + }; + + if pim_unit.jump_counter != None { + let new_pc = pim_unit.pc as i32 + offset as i32; + + if new_pc < 0 || new_pc >= 32 { + panic!("Invalid PC {new_pc} after JUMP: {inst:?}"); + } + + pim_unit.pc = new_pc as _; + } else { + pim_unit.pc += 1; + } + + if pim_unit_index == 0 { + log::debug!( + "PimUnit {pim_unit_index} JUMP to PC {}: {:?}", + pim_unit.pc, + self.kernel.0[pim_unit.pc as usize] + ); + } + } } pub fn execute_write(&mut self, bank_index: u32) -> [u8; BURST_LENGTH] { @@ -336,24 +349,24 @@ impl PimVM { }; let pim_unit = &mut self.pim_units[pim_unit_index as usize]; - - let current_pc = pim_unit.pc; - pim_unit.pc += 1; - - let inst = &self.kernel.0[current_pc as usize]; + let inst = self.kernel.0[pim_unit.pc as usize]; if pim_unit_index == 0 { - log::debug!("PimUnit {pim_unit_index} Execute PC {current_pc}: {inst:?}"); + log::debug!( + "PimUnit {pim_unit_index} Execute PC {}: {inst:?}", + pim_unit.pc + ); } + let data = match inst { Instruction::FILL { src, dst } => { let data: [f16; FP_UNITS] = match src { - File::GrfA { index } => pim_unit.grf_a[*index as usize], - File::GrfB { index } => pim_unit.grf_b[*index as usize], + File::GrfA { index } => pim_unit.grf_a[index as usize], + File::GrfB { index } => pim_unit.grf_b[index as usize], _ => panic!("Unsupported src operand: {src:?}"), }; - if *dst != File::Bank { + if dst != File::Bank { panic!("Unsupported dst operand: {dst:?}") } @@ -366,6 +379,36 @@ impl PimVM { _ => panic!("Unsupported instruction for write: {inst:?}"), }; + pim_unit.pc += 1; + + // The JUMP instruction is zero-cycle and not actually executed + while let Instruction::JUMP { offset, count } = self.kernel.0[pim_unit.pc as usize] { + pim_unit.jump_counter = match pim_unit.jump_counter { + Some(jump_counter) => jump_counter.checked_sub(1), + None => count.checked_sub(1), + }; + + if pim_unit.jump_counter != None { + let new_pc = pim_unit.pc as i32 + offset as i32; + + if new_pc < 0 || new_pc >= 32 { + panic!("Invalid PC {new_pc} after JUMP: {inst:?}"); + } + + pim_unit.pc = new_pc as _; + } else { + pim_unit.pc += 1; + } + + if pim_unit_index == 0 { + log::debug!( + "PimUnit {pim_unit_index} JUMP to PC {}: {:?}", + pim_unit.pc, + self.kernel.0[pim_unit.pc as usize] + ); + } + } + unsafe { std::mem::transmute(data) } }