Files
pim-rs/pim-vm/src/lib.rs
2023-12-12 17:15:29 +01:00

348 lines
11 KiB
Rust

use half::f16;
use pim_isa::{BankMode, File, Instruction, Kernel, PimConfig};
#[cxx::bridge(namespace = "pim_vm")]
mod ffi {
pub enum BankMode {
SingleBank,
AllBank,
PimAllBank,
}
extern "Rust" {
type PimVM;
fn new_pim_vm(num_pim_units: u32) -> Box<PimVM>;
fn reset(&mut self);
fn apply_config(&mut self, config: &str);
fn bank_mode(&self) -> BankMode;
fn execute_read(&mut self, bank_index: u32, address: u32, bank_data: &[u8]);
fn execute_write(&mut self, bank_index: u32) -> [u8; 32];
fn init_logger();
}
}
fn init_logger() {
env_logger::init();
}
const GRF_A_BIT_OFFSET: usize = 10;
const GRF_B_BIT_OFFSET: usize = 13;
const BURST_LENGTH: usize = 32;
const GRF_NUM_REGISTERS: usize = 8;
const SRF_A_NUM_REGISTERS: usize = 8;
const SRF_M_NUM_REGISTERS: usize = 8;
const FP_UNITS: usize = 16;
type GrfRegister = [f16; FP_UNITS];
#[derive(Clone, Debug)]
struct PimUnit {
grf_a: [GrfRegister; GRF_NUM_REGISTERS],
grf_b: [GrfRegister; GRF_NUM_REGISTERS],
srf_a: [f16; SRF_A_NUM_REGISTERS],
srf_m: [f16; SRF_A_NUM_REGISTERS],
pc: u8,
jump_counter: Option<u16>,
}
impl Default for PimUnit {
fn default() -> Self {
Self {
grf_a: [[f16::ZERO; FP_UNITS]; GRF_NUM_REGISTERS],
grf_b: [[f16::ZERO; FP_UNITS]; GRF_NUM_REGISTERS],
srf_a: [f16::ZERO; SRF_A_NUM_REGISTERS],
srf_m: [f16::ZERO; SRF_M_NUM_REGISTERS],
pc: 0,
jump_counter: None,
}
}
}
#[derive(Debug)]
struct PimVM {
pim_units: Vec<PimUnit>,
pim_config: pim_isa::PimConfig,
}
impl PimVM {
fn reset(&mut self) {
for unit in self.pim_units.iter_mut() {
unit.pc = 0;
unit.jump_counter = None;
}
}
fn apply_config(&mut self, config_str: &str) {
log::debug!("Config string:\n{config_str}");
self.pim_config = serde_json::from_str::<pim_isa::PimConfig>(config_str).unwrap();
self.reset();
log::debug!("Apply pim config:\n{:?}", self.pim_config);
}
fn bank_mode(&self) -> ffi::BankMode {
match self.pim_config.bank_mode {
BankMode::SingleBank => ffi::BankMode::SingleBank,
BankMode::AllBank => ffi::BankMode::AllBank,
BankMode::PimAllBank => ffi::BankMode::PimAllBank,
}
}
}
fn new_pim_vm(num_pim_units: u32) -> Box<PimVM> {
Box::new(PimVM {
pim_units: vec![PimUnit::default(); num_pim_units as _],
pim_config: PimConfig {
bank_mode: BankMode::SingleBank,
kernel: Kernel::NOP,
},
})
}
#[repr(C)]
struct BankData([f16; FP_UNITS]);
impl PimVM {
pub fn execute_read(&mut self, bank_index: u32, address: u32, bank_data: &[u8]) {
assert_eq!(bank_data.len(), BURST_LENGTH);
let pim_unit = &mut self.pim_units[bank_index as usize];
let mut inst = self.pim_config.kernel.0[pim_unit.pc as usize];
pim_unit.pc += 1;
let aam_grf_a_index = (address >> GRF_A_BIT_OFFSET) & 0b111;
let aam_grf_b_index = (address >> GRF_B_BIT_OFFSET) & 0b111;
log::debug!("PimUnit {bank_index} Execute PC {}: {inst:?}", pim_unit.pc);
// The JUMP instruction is zero-cycle and not actually executed
while let Instruction::JUMP { offset, count } = inst {
pim_unit.jump_counter = match pim_unit.jump_counter {
Some(jump_counter) => jump_counter.checked_sub(1),
None => count.checked_sub(1),
};
if pim_unit.jump_counter != None {
let new_pc = pim_unit.pc as i32 + offset as i32;
if new_pc < 0 || new_pc >= 32 {
panic!("Invalid PC {new_pc} after JUMP: {inst:?}");
}
pim_unit.pc = new_pc as _;
log::debug!("PimUnit {bank_index} New PC {new_pc}: {inst:?}");
}
inst = self.pim_config.kernel.0[pim_unit.pc as usize];
pim_unit.pc += 1;
}
match inst {
Instruction::NOP => (),
Instruction::EXIT => {
pim_unit.jump_counter = None;
pim_unit.pc = 0;
}
Instruction::JUMP { .. } => unreachable!(),
Instruction::MOV { src, dst } | Instruction::FILL { src, dst } => {
let data = PimVM::load(src, pim_unit, &bank_data);
PimVM::store(dst, pim_unit, &data);
}
Instruction::ADD {
src0,
mut src1,
mut dst,
aam,
} => {
if aam {
src1 = if let File::GrfA { index: _ } = src1 {
File::GrfA {
index: aam_grf_a_index as _,
}
} else {
panic!("Invalid operand in address-aligned-mode");
};
dst = if let File::GrfB { index: _ } = dst {
File::GrfB {
index: aam_grf_b_index as _,
}
} else {
panic!("Invalid operand in address-aligned-mode");
};
}
let data0 = PimVM::load(src0, pim_unit, &bank_data);
let data1 = PimVM::load(src1, pim_unit, &bank_data);
let sum: [f16; FP_UNITS] = data0
.into_iter()
.zip(data1)
.map(|(src0, src1)| src0 + src1)
.collect::<Vec<_>>()
.try_into()
.unwrap();
PimVM::store(dst, pim_unit, &sum);
}
Instruction::MUL {
src0,
mut src1,
mut dst,
aam,
} => {
if aam {
src1 = if let File::GrfA { index: _ } = src1 {
File::GrfA {
index: aam_grf_a_index as _,
}
} else {
panic!("Invalid operand in address-aligned-mode");
};
dst = if let File::GrfB { index: _ } = dst {
File::GrfB {
index: aam_grf_b_index as _,
}
} else {
panic!("Invalid operand in address-aligned-mode");
};
}
let data0 = PimVM::load(src0, pim_unit, &bank_data);
let data1 = PimVM::load(src1, pim_unit, &bank_data);
let product: [f16; FP_UNITS] = data0
.into_iter()
.zip(data1)
.map(|(src0, src1)| src0 * src1)
.collect::<Vec<_>>()
.try_into()
.unwrap();
PimVM::store(dst, pim_unit, &product);
}
Instruction::MAC {
src0,
mut src1,
mut src2,
mut dst,
aam,
}
| Instruction::MAD {
src0,
mut src1,
mut src2,
mut dst,
aam,
} => {
if aam {
src1 = if let File::GrfA { index: _ } = src1 {
File::GrfA {
index: aam_grf_a_index as _,
}
} else {
panic!("Invalid operand in address-aligned-mode");
};
src2 = if let File::GrfB { index: _ } = src2 {
File::GrfB {
index: aam_grf_b_index as _,
}
} else {
panic!("Invalid operand in address-aligned-mode");
};
dst = if let File::GrfB { index: _ } = dst {
File::GrfB {
index: aam_grf_b_index as _,
}
} else {
panic!("Invalid operand in address-aligned-mode");
};
}
assert_eq!(src2, dst);
let data0 = PimVM::load(src0, pim_unit, &bank_data);
let data1 = PimVM::load(src1, pim_unit, &bank_data);
let data2 = PimVM::load(src2, pim_unit, &bank_data);
let product: [f16; FP_UNITS] = data0
.into_iter()
.zip(data1)
.map(|(src0, src1)| src0 * src1)
.collect::<Vec<_>>()
.try_into()
.unwrap();
let sum: [f16; FP_UNITS] = product
.into_iter()
.zip(data2)
.map(|(product, src2)| product + src2)
.collect::<Vec<_>>()
.try_into()
.unwrap();
log::debug!("{data0:?}\n{data1:?}\n{data2:?}\n{product:?}\n{sum:?}");
PimVM::store(dst, pim_unit, &sum);
}
}
}
pub fn execute_write(&mut self, bank_index: u32) -> [u8; BURST_LENGTH] {
let pim_unit = &mut self.pim_units[bank_index as usize];
let current_pc = pim_unit.pc;
pim_unit.pc += 1;
let inst = &self.pim_config.kernel.0[current_pc as usize];
log::debug!("PimUnit {bank_index} Execute PC {current_pc}: {inst:?}");
let data = match inst {
Instruction::FILL { src, dst } => {
let data: [f16; FP_UNITS] = match src {
File::GrfA { index } => pim_unit.grf_a[*index as usize],
File::GrfB { index } => pim_unit.grf_b[*index as usize],
_ => panic!("Unsupported src operand: {src:?}"),
};
if *dst != File::Bank {
panic!("Unsupported dst operand: {dst:?}")
}
data
}
_ => panic!("Unsupported instruction for write: {inst:?}"),
};
unsafe { std::mem::transmute(data) }
}
fn load(src: File, pim_unit: &PimUnit, bank_data: &[u8]) -> [f16; FP_UNITS] {
match src {
File::GrfA { index } => pim_unit.grf_a[index as usize],
File::GrfB { index } => pim_unit.grf_b[index as usize],
File::SrfM { index } => [pim_unit.srf_m[index as usize]; FP_UNITS],
File::SrfA { index } => [pim_unit.srf_a[index as usize]; FP_UNITS],
File::Bank => unsafe { std::ptr::read(bank_data.as_ptr() as *const BankData).0 },
}
}
fn store(dst: File, pim_unit: &mut PimUnit, data: &[f16; FP_UNITS]) {
match dst {
File::GrfA { index } => pim_unit.grf_a[index as usize] = data.clone(),
File::GrfB { index } => pim_unit.grf_b[index as usize] = data.clone(),
File::SrfM { index } => pim_unit.srf_m[index as usize] = data[0],
File::SrfA { index } => pim_unit.srf_a[index as usize] = data[0],
File::Bank => panic!("Unsupported dst operand: {dst:?}"),
}
}
}