diff --git a/constensor-core/src/cpu_storage/mod.rs b/constensor-core/src/cpu_storage/mod.rs index 68111b8..d94e9ce 100644 --- a/constensor-core/src/cpu_storage/mod.rs +++ b/constensor-core/src/cpu_storage/mod.rs @@ -8,6 +8,7 @@ use pool::{BufferPool, PooledBuffer}; use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; use crate::device::Dev; +use crate::storage::Storage; use crate::Shape; use crate::{ storage::{BackendDevice, BackendStorage}, @@ -26,8 +27,11 @@ pub struct CpuStorage(pub(crate) Vec); impl BackendStorage for CpuStorage { fn to_cpu_storage(&self) -> Result>> { - // Note: copying all data here. - Ok(Cow::Owned(self.clone())) + Ok(Cow::Borrowed(self)) + } + fn cast(&self) -> Result> { + let new = self.0.iter().map(|x| U::from_f64(x.to_f64())); + Ok(Storage::Cpu(CpuStorage(new.collect()))) } } diff --git a/constensor-core/src/cuda_backend/mod.rs b/constensor-core/src/cuda_backend/mod.rs index bf62c8d..cae4205 100644 --- a/constensor-core/src/cuda_backend/mod.rs +++ b/constensor-core/src/cuda_backend/mod.rs @@ -13,7 +13,7 @@ use std::sync::{ }; use std::{ borrow::Cow, - collections::{HashMap, HashSet}, + collections::{HashMap, HashSet, VecDeque}, fs, hash::{DefaultHasher, Hash, Hasher}, marker::PhantomData, @@ -24,7 +24,7 @@ use std::{ use crate::{ cpu_storage::CpuStorage, device::Dev, - storage::{BackendDevice, BackendStorage}, + storage::{BackendDevice, BackendStorage, Storage}, CompiledGraph, DType, GraphNode, Op, Result, Shape, }; @@ -38,11 +38,14 @@ unsafe impl Send for CudaRng {} pub struct CudaDevice { context: Arc, stream: Arc, - modules: Arc>>>, + modules: Arc>>>, + module_cache_order: Arc>>, streams: Arc>>, stream_index: Arc, } +const MAX_CACHED_KERNELS: usize = 128; + impl CudaDevice { pub(crate) fn new(ordinal: usize) -> Result { let context = cudarc::driver::CudaContext::new(ordinal).w()?; @@ -57,7 +60,8 @@ impl CudaDevice { Ok(Self { context, stream, - modules: Arc::new(RwLock::new(vec![])), + modules: Arc::new(RwLock::new(HashMap::new())), + module_cache_order: Arc::new(Mutex::new(VecDeque::new())), streams, stream_index, }) @@ -74,9 +78,29 @@ impl CudaDevice { } pub(crate) fn load_func(&self, function_name: &str, ptx: Ptx) -> Result { + // If we've already loaded this kernel, skip reloading + { + let modules_read = self.modules.read().unwrap(); + if let Some(module) = modules_read.get(function_name) { + return module.load_function(function_name).w(); + } + } + + // Otherwise compile and load let module = self.context.load_module(ptx).w()?; let func = module.load_function(function_name).w()?; - self.modules.write().unwrap().push(module); + // Insert into cache and cap size + { + let mut modules_write = self.modules.write().unwrap(); + let mut order = self.module_cache_order.lock().unwrap(); + modules_write.insert(function_name.to_string(), module.clone()); + order.push_back(function_name.to_string()); + if order.len() > MAX_CACHED_KERNELS { + if let Some(old) = order.pop_front() { + modules_write.remove(&old); + } + } + } Ok(func) } } @@ -100,6 +124,77 @@ impl BackendStorage for CudaStorage { let data = self.device.stream().memcpy_dtov(&self.slice).w()?; Ok(Cow::Owned(CpuStorage(data))) } + fn cast(&self) -> Result> { + let function_name = format!("cast_{}_to_{}", T::NAME, U::NAME); + + let template_kernel = format!( + r#" + typedef unsigned char uint8_t; + typedef unsigned int uint32_t; + typedef long long int int64_t; + {} + {} + + template + __device__ void {function_name}_kernel(T *in, U *out, const size_t numel) {{ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; + i += blockDim.x * gridDim.x) {{ + out[i] = static_cast(in[i]); + }} + }} + + extern "C" __global__ void {function_name}({} *in, {} *out, const size_t numel) {{ + {function_name}_kernel(in, out, numel); + }} + + "#, + T::C_DEP.unwrap_or(""), + U::C_DEP.unwrap_or(""), + T::C_NAME, + U::C_NAME, + ); + + // Always recompile PTX to avoid using stale cached files + let ptx = compile_ptx(template_kernel.clone())?; + + let ptx_str = ptx.to_src(); + if let Some(home) = dirs::home_dir() { + let path = format!( + "{}/.cache/constensor/ptx/{function_name}.ptx", + home.display() + ); + let path = Path::new(&path); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + fs::write(path, ptx_str)?; + } + + let stream = self.device.select_stream(); + let n_elems = self.slice.len(); + + let out = unsafe { stream.alloc::(n_elems) }.w()?; + + let func = self.device.load_func(&function_name, ptx)?; + + let cfg = LaunchConfig::for_num_elems(n_elems as u32); + + let mut builder = stream.launch_builder(&func); + builder.arg(&self.slice); + builder.arg(&out); + builder.arg(&n_elems); + unsafe { builder.launch(cfg).w()? }; + + // Record an event once this kernel completes + let event = self.device.context.new_event(None).w()?; + event.record(&stream).w()?; + + Ok(Storage::Cuda(CudaStorage { + slice: out, + device: self.device.clone(), + event, + })) + } } pub enum CudaCompiledKernel { @@ -250,6 +345,7 @@ fn cuda_include_dir() -> Option { fn compile_ptx(template_kernel: String) -> Result { cudarc::nvrtc::compile_ptx_with_opts( template_kernel, + // Compile PTX without hardcoding an architecture so it can JIT to the current device CompileOptions { use_fast_math: Some(true), include_paths: vec![cuda_include_dir() @@ -257,7 +353,6 @@ fn compile_ptx(template_kernel: String) -> Result { .join("include") .display() .to_string()], - arch: Some("sm_90"), ..Default::default() }, ) @@ -304,6 +399,14 @@ impl CudaDevice { header.hash(&mut hasher); let function_name = format!("jit_kernel_{}_{}", hasher.finish(), T::NAME); + // If we've already compiled this kernel, skip PTX compilation + if let Some(module) = self.modules.read().unwrap().get(&function_name) { + let func = module.load_function(&function_name).w()?; + let n_elems: usize = shape.iter().product(); + let data = unsafe { self.stream.alloc::(n_elems) }.w()?; + return Ok((func, data)); + } + let template_kernel = format!( r#" typedef unsigned char uint8_t; diff --git a/constensor-core/src/storage.rs b/constensor-core/src/storage.rs index 8394207..d274f07 100644 --- a/constensor-core/src/storage.rs +++ b/constensor-core/src/storage.rs @@ -18,10 +18,19 @@ impl Storage { Self::Cuda(cuda) => cuda.to_cpu_storage(), } } + + pub(crate) fn cast(&self) -> Result> { + match self { + Self::Cpu(cpu) => cpu.cast::(), + #[cfg(feature = "cuda")] + Self::Cuda(cuda) => cuda.cast::(), + } + } } pub trait BackendStorage { fn to_cpu_storage(&self) -> Result>>; + fn cast(&self) -> Result>; } pub trait BackendDevice { diff --git a/constensor-core/src/tensor/concretetensor.rs b/constensor-core/src/tensor/concretetensor.rs index 5d92107..3b07361 100644 --- a/constensor-core/src/tensor/concretetensor.rs +++ b/constensor-core/src/tensor/concretetensor.rs @@ -85,6 +85,15 @@ tensor_api!(Cpu); #[cfg(feature = "cuda")] tensor_api!(Cuda<0>); +impl Tensor { + /// Cast this tensor to a different dtype `U` on the CPU. + pub fn cast(&self) -> Result> { + // retrieve data from storage as owned Vec + let storage = self.storage.cast::()?; + Ok(from_storage::(Arc::new(storage))) + } +} + /*macro_rules! binary_op { ($trait:ident, $fn:ident) => { impl $trait for Tensor { diff --git a/constensor-core/tests/cast.rs b/constensor-core/tests/cast.rs new file mode 100644 index 0000000..b17a3fd --- /dev/null +++ b/constensor-core/tests/cast.rs @@ -0,0 +1,64 @@ +#[cfg(feature = "cuda")] +use constensor_core::Cuda; +use constensor_core::{CompiledGraph, Cpu, Graph, GraphTensor, R1, R2, R3}; + +macro_rules! test_for_device_cast { + ($dev:ty, $name:ident) => { + mod $name { + use super::*; + + // Test casting a 1D tensor from f32 to f64 + #[test] + fn cast_f32_to_f64_1d() { + let mut graph = Graph::empty(); + let _x = GraphTensor::, f32, $dev>::fill(&mut graph, 1.5); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); + let casted = tensor.cast::().unwrap(); + let data = casted.data().unwrap().into_owned(); + assert_eq!(data, vec![1.5_f64; 4]); + } + + // Test casting a 2D tensor from f64 to f32 + #[test] + fn cast_f64_to_f32_2d() { + let mut graph = Graph::empty(); + let _x = GraphTensor::, f64, $dev>::fill(&mut graph, 2.75); + let compiled: CompiledGraph, f64, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); + let casted = tensor.cast::().unwrap(); + let data = casted.data().unwrap().into_owned(); + assert_eq!(data, vec![vec![2.75_f32; 3]; 2]); + } + + // Test casting a 3D tensor from i32 to f32 + #[test] + fn cast_i32_to_f32_3d() { + let mut graph = Graph::empty(); + let _x = GraphTensor::, i32, $dev>::fill(&mut graph, 7); + let compiled: CompiledGraph, i32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); + let casted = tensor.cast::().unwrap(); + let data = casted.data().unwrap().into_owned(); + let expected = vec![vec![vec![7.0_f32; 3]; 2]; 1]; + assert_eq!(data, expected); + } + + // Test casting from f32 to i32 truncates toward zero + #[test] + fn cast_f32_to_i32_truncate() { + let mut graph = Graph::empty(); + let _x = GraphTensor::, f32, $dev>::fill(&mut graph, 1.9); + let compiled: CompiledGraph, f32, $dev> = graph.compile().unwrap(); + let tensor = compiled.run().unwrap(); + let casted = tensor.cast::().unwrap(); + let data = casted.data().unwrap().into_owned(); + assert_eq!(data, vec![1_i32; 3]); + } + } + }; +} + +test_for_device_cast!(Cpu, cpu_tests_cast); +#[cfg(feature = "cuda")] +test_for_device_cast!(Cuda<0>, cuda_tests_cast);