From 71c766736a58d7487bc26504cde9199cc8dbcab4 Mon Sep 17 00:00:00 2001 From: Derek Christ Date: Sun, 3 Dec 2023 15:48:31 +0100 Subject: [PATCH] Implement all missing instructions --- pim-isa/Cargo.toml | 4 --- pim-isa/src/lib.rs | 40 +++++++++++++++++++++++++---- pim-os/.cargo/config | 6 ----- pim-os/Cargo.toml | 3 +++ pim-os/build.rs | 15 +++++++++++ pim-os/src/main.rs | 16 +++++++++--- pim-os/src/pim/array.rs | 2 +- pim-os/src/pim/kernel.rs | 19 +++++++++----- pim-vm/Cargo.lock | 10 ++++---- pim-vm/src/lib.rs | 54 ++++++++++++++++++++++++++++++++++++++-- 10 files changed, 137 insertions(+), 32 deletions(-) delete mode 100644 pim-os/.cargo/config create mode 100644 pim-os/build.rs diff --git a/pim-isa/Cargo.toml b/pim-isa/Cargo.toml index c04da4a..86d4648 100644 --- a/pim-isa/Cargo.toml +++ b/pim-isa/Cargo.toml @@ -3,9 +3,5 @@ name = "pim-isa" version = "0.1.0" edition = "2021" -[features] -default = ["std"] -std = [] - [dependencies] serde = { version = "1.0", default-features = false, features = ["derive"] } diff --git a/pim-isa/src/lib.rs b/pim-isa/src/lib.rs index b81f8d1..d3a1a69 100644 --- a/pim-isa/src/lib.rs +++ b/pim-isa/src/lib.rs @@ -1,4 +1,4 @@ -#![cfg_attr(not(feature = "std"), no_std)] +#![no_std] use serde::{Deserialize, Serialize}; @@ -6,10 +6,40 @@ use serde::{Deserialize, Serialize}; pub enum Instruction { NOP, EXIT, - JUMP { offset: i16, count: u16 }, - MOV { src: File, dst: File }, - FILL { src: File, dst: File }, - ADD { src0: File, src1: File, dst: File }, + JUMP { + offset: i16, + count: u16, + }, + MOV { + src: File, + dst: File, + }, + FILL { + src: File, + dst: File, + }, + ADD { + src0: File, + src1: File, + dst: File, + }, + MUL { + src0: File, + src1: File, + dst: File, + }, + MAC { + src0: File, + src1: File, + src2: File, + dst: File, + }, + MAD { + src0: File, + src1: File, + src2: File, + dst: File, + }, } impl Instruction { diff --git a/pim-os/.cargo/config b/pim-os/.cargo/config deleted file mode 100644 index 84d4758..0000000 --- a/pim-os/.cargo/config +++ /dev/null @@ -1,6 +0,0 @@ -[build] -target = "aarch64-unknown-none" - -rustflags = [ - "-C", "link-arg=-Taarch64-gem5.ld", -] diff --git a/pim-os/Cargo.toml b/pim-os/Cargo.toml index eec0e97..8c32bee 100644 --- a/pim-os/Cargo.toml +++ b/pim-os/Cargo.toml @@ -1,7 +1,10 @@ +cargo-features = ["per-package-target"] + [package] name = "pim-os" version = "0.1.0" edition = "2021" +forced-target = "aarch64-unknown-none" [dependencies] aarch64-cpu = "9.4.0" diff --git a/pim-os/build.rs b/pim-os/build.rs new file mode 100644 index 0000000..e95e8d6 --- /dev/null +++ b/pim-os/build.rs @@ -0,0 +1,15 @@ +use std::env; +use std::fs; +use std::path::PathBuf; + +const LINKER_SCRIPT: &str = "aarch64-gem5.ld"; + +fn main() { + // Put `aarch64-gem5.ld` in our output directory and ensure it's + // on the linker search path. + let out = &PathBuf::from(env::var_os("OUT_DIR").unwrap()); + fs::copy(LINKER_SCRIPT, out.join(LINKER_SCRIPT)).unwrap(); + println!("cargo:rustc-link-search={}", out.display()); + println!("cargo:rerun-if-changed={LINKER_SCRIPT}"); + println!("cargo:rustc-link-arg=-T{LINKER_SCRIPT}"); +} diff --git a/pim-os/src/main.rs b/pim-os/src/main.rs index 23810bf..3ab8d8d 100644 --- a/pim-os/src/main.rs +++ b/pim-os/src/main.rs @@ -28,7 +28,7 @@ pub extern "C" fn entry() -> ! { let mut compute_array: ComputeArray<3> = ComputeArray([ BankArray([f16::from_f32(0.1); 512]), BankArray([f16::from_f32(0.2); 512]), - BankArray([f16::ZERO; 512]), + BankArray([f16::from_f32(0.3); 512]), ]); let dummy_array = BankArray::default(); let mut uart = Uart0; @@ -40,12 +40,22 @@ pub extern "C" fn entry() -> ! { ) .unwrap(); + writeln!( + &mut uart, + "BankArray0: [{:?}, ...]\nBankArray1: [{:?}, ...]\nBankArray2: [{:?}, ...]", + compute_array.0[0].0[0], compute_array.0[1].0[0], compute_array.0[2].0[0] + ) + .unwrap(); + + writeln!(&mut uart, "MAC: BankArray2 += BankArray0 * BankArray1",).unwrap(); + // Invalidate and flush array just in case compute_array.invalidate_flush(); dummy_array.invalidate_flush(); pim_state.set_bank_mode(BankMode::PimAllBank); compute_array.0[0].execute_instruction_read(); + compute_array.0[2].execute_instruction_read(); compute_array.0[1].execute_instruction_read(); compute_array.0[2].execute_instruction_write(); dummy_array.execute_instruction_read(); @@ -55,8 +65,8 @@ pub extern "C" fn entry() -> ! { writeln!( &mut uart, - "BankArray 0: [{:?}, ...]\nBankArray 1: [{:?}, ...]\nBankArray 2: [{:?}, ...]", - compute_array.0[0].0[0], compute_array.0[1].0[0], compute_array.0[2].0[0] + "BankArray2: [{:?}, ...]", + compute_array.0[2].0[0] ) .unwrap(); diff --git a/pim-os/src/pim/array.rs b/pim-os/src/pim/array.rs index d870779..3784082 100644 --- a/pim-os/src/pim/array.rs +++ b/pim-os/src/pim/array.rs @@ -51,7 +51,7 @@ impl BankArray { pub fn invalidate_single_bank(&self, idx: usize) { unsafe { - // Invalidate and flush first bank + // Invalidate first bank asm!("dc ivac, {val}", val = in(reg) &self.0[idx]); asm!("dsb sy"); } diff --git a/pim-os/src/pim/kernel.rs b/pim-os/src/pim/kernel.rs index 83158a8..ed8b413 100644 --- a/pim-os/src/pim/kernel.rs +++ b/pim-os/src/pim/kernel.rs @@ -5,12 +5,20 @@ pub const TEST_KERNEL: Kernel = Kernel([ src: File::Bank, dst: File::Grf { index: 0 }, }, - Instruction::ADD { - src0: File::Bank, - src1: File::Grf { index: 0 }, - dst: File::Grf { index: 0 }, + Instruction::MOV { + src: File::Bank, + dst: File::Grf { index: 1 }, + }, + Instruction::MAC { + src0: File::Grf { index: 0 }, + src1: File::Bank, + src2: File::Grf { index: 1 }, + dst: File::Grf { index: 1 }, + }, + Instruction::FILL { + src: File::Grf { index: 1 }, + dst: File::Bank, }, - Instruction::FILL { src: File::Grf { index: 0 }, dst: File::Bank }, Instruction::EXIT, Instruction::NOP, Instruction::NOP, @@ -39,5 +47,4 @@ pub const TEST_KERNEL: Kernel = Kernel([ Instruction::NOP, Instruction::NOP, Instruction::NOP, - Instruction::NOP, ]); diff --git a/pim-vm/Cargo.lock b/pim-vm/Cargo.lock index aae5113..989fb38 100644 --- a/pim-vm/Cargo.lock +++ b/pim-vm/Cargo.lock @@ -146,9 +146,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "log" @@ -230,15 +230,15 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "rustix" -version = "0.38.25" +version = "0.38.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc99bc2d4f1fed22595588a013687477aedf3cdcfb26558c559edb67b4d9b22e" +checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a" dependencies = [ "bitflags", "errno", "libc", "linux-raw-sys", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] diff --git a/pim-vm/src/lib.rs b/pim-vm/src/lib.rs index bcc975a..4b3efe5 100644 --- a/pim-vm/src/lib.rs +++ b/pim-vm/src/lib.rs @@ -124,17 +124,18 @@ impl PimVM { Instruction::JUMP { offset, count } => { pim_unit.jump_counter = match pim_unit.jump_counter { Some(jump_counter) => jump_counter.checked_sub(1), - None => Some(*count), + None => count.checked_sub(1), }; if pim_unit.jump_counter != None { - let new_pc = pim_unit.pc as i32 + *offset as i32; + let new_pc = current_pc as i32 + *offset as i32; if new_pc < 0 || new_pc >= 32 { panic!("Invalid PC {new_pc} after JUMP: {inst:?}"); } pim_unit.pc = new_pc as _; + log::debug!("PimUnit {bank_index} New PC {new_pc}: {inst:?}"); } } Instruction::MOV { src, dst } | Instruction::FILL { src, dst } => { @@ -155,6 +156,55 @@ impl PimVM { PimVM::store(*dst, pim_unit, &sum); } + Instruction::MUL { src0, src1, dst } => { + let data0 = PimVM::load(*src0, pim_unit, &bank_data); + let data1 = PimVM::load(*src1, pim_unit, &bank_data); + + let product: [f16; FP_UNITS] = data0 + .into_iter() + .zip(data1) + .map(|(src0, src1)| src0 * src1) + .collect::>() + .try_into() + .unwrap(); + + PimVM::store(*dst, pim_unit, &product); + } + Instruction::MAC { + src0, + src1, + src2, + dst, + } + | Instruction::MAD { + src0, + src1, + src2, + dst, + } => { + let data0 = PimVM::load(*src0, pim_unit, &bank_data); + let data1 = PimVM::load(*src1, pim_unit, &bank_data); + let data2 = PimVM::load(*src2, pim_unit, &bank_data); + + let product: [f16; FP_UNITS] = data0 + .into_iter() + .zip(data1) + .map(|(src0, src1)| src0 * src1) + .collect::>() + .try_into() + .unwrap(); + + let sum: [f16; FP_UNITS] = product + .into_iter() + .zip(data2) + .map(|(product, src2)| product + src2) + .collect::>() + .try_into() + .unwrap(); + + log::debug!("{data0:#?}\n{data1:#?}\n{data2:#?}\n{product:#?}\n{sum:#?}"); + PimVM::store(*dst, pim_unit, &sum); + } } }