Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions constensor-core/src/cpu_storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use pool::{BufferPool, PooledBuffer};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};

use crate::device::Dev;
use crate::storage::Storage;
use crate::Shape;
use crate::{
storage::{BackendDevice, BackendStorage},
Expand All @@ -26,8 +27,11 @@ pub struct CpuStorage<T: DType>(pub(crate) Vec<T>);

impl<T: DType> BackendStorage<T> for CpuStorage<T> {
fn to_cpu_storage(&self) -> Result<Cow<CpuStorage<T>>> {
// Note: copying all data here.
Ok(Cow::Owned(self.clone()))
Ok(Cow::Borrowed(self))
}
fn cast<U: DType>(&self) -> Result<Storage<U>> {
let new = self.0.iter().map(|x| U::from_f64(x.to_f64()));
Ok(Storage::Cpu(CpuStorage(new.collect())))
}
}

Expand Down
115 changes: 109 additions & 6 deletions constensor-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::sync::{
};
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
collections::{HashMap, HashSet, VecDeque},
fs,
hash::{DefaultHasher, Hash, Hasher},
marker::PhantomData,
Expand All @@ -24,7 +24,7 @@ use std::{
use crate::{
cpu_storage::CpuStorage,
device::Dev,
storage::{BackendDevice, BackendStorage},
storage::{BackendDevice, BackendStorage, Storage},
CompiledGraph, DType, GraphNode, Op, Result, Shape,
};

Expand All @@ -38,11 +38,14 @@ unsafe impl Send for CudaRng {}
pub struct CudaDevice {
context: Arc<cudarc::driver::CudaContext>,
stream: Arc<cudarc::driver::CudaStream>,
modules: Arc<RwLock<Vec<Arc<CudaModule>>>>,
modules: Arc<RwLock<HashMap<String, Arc<CudaModule>>>>,
module_cache_order: Arc<Mutex<VecDeque<String>>>,
streams: Arc<Vec<Arc<CudaStream>>>,
stream_index: Arc<AtomicUsize>,
}

const MAX_CACHED_KERNELS: usize = 128;

impl CudaDevice {
pub(crate) fn new(ordinal: usize) -> Result<Self> {
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
Expand All @@ -57,7 +60,8 @@ impl CudaDevice {
Ok(Self {
context,
stream,
modules: Arc::new(RwLock::new(vec![])),
modules: Arc::new(RwLock::new(HashMap::new())),
module_cache_order: Arc::new(Mutex::new(VecDeque::new())),
streams,
stream_index,
})
Expand All @@ -74,9 +78,29 @@ impl CudaDevice {
}

pub(crate) fn load_func(&self, function_name: &str, ptx: Ptx) -> Result<CudaFunction> {
// If we've already loaded this kernel, skip reloading
{
let modules_read = self.modules.read().unwrap();
if let Some(module) = modules_read.get(function_name) {
return module.load_function(function_name).w();
}
}

// Otherwise compile and load
let module = self.context.load_module(ptx).w()?;
let func = module.load_function(function_name).w()?;
self.modules.write().unwrap().push(module);
// Insert into cache and cap size
{
let mut modules_write = self.modules.write().unwrap();
let mut order = self.module_cache_order.lock().unwrap();
modules_write.insert(function_name.to_string(), module.clone());
order.push_back(function_name.to_string());
if order.len() > MAX_CACHED_KERNELS {
if let Some(old) = order.pop_front() {
modules_write.remove(&old);
}
}
}
Ok(func)
}
}
Expand All @@ -100,6 +124,77 @@ impl<T: DType> BackendStorage<T> for CudaStorage<T> {
let data = self.device.stream().memcpy_dtov(&self.slice).w()?;
Ok(Cow::Owned(CpuStorage(data)))
}
fn cast<U: DType>(&self) -> Result<Storage<U>> {
let function_name = format!("cast_{}_to_{}", T::NAME, U::NAME);

let template_kernel = format!(
r#"
typedef unsigned char uint8_t;
typedef unsigned int uint32_t;
typedef long long int int64_t;
{}
{}

template <typename T, typename U>
__device__ void {function_name}_kernel(T *in, U *out, const size_t numel) {{
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel;
i += blockDim.x * gridDim.x) {{
out[i] = static_cast<U>(in[i]);
}}
}}

extern "C" __global__ void {function_name}({} *in, {} *out, const size_t numel) {{
{function_name}_kernel(in, out, numel);
}}

"#,
T::C_DEP.unwrap_or(""),
U::C_DEP.unwrap_or(""),
T::C_NAME,
U::C_NAME,
);

// Always recompile PTX to avoid using stale cached files
let ptx = compile_ptx(template_kernel.clone())?;

let ptx_str = ptx.to_src();
if let Some(home) = dirs::home_dir() {
let path = format!(
"{}/.cache/constensor/ptx/{function_name}.ptx",
home.display()
);
let path = Path::new(&path);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(path, ptx_str)?;
}

let stream = self.device.select_stream();
let n_elems = self.slice.len();

let out = unsafe { stream.alloc::<U>(n_elems) }.w()?;

let func = self.device.load_func(&function_name, ptx)?;

let cfg = LaunchConfig::for_num_elems(n_elems as u32);

let mut builder = stream.launch_builder(&func);
builder.arg(&self.slice);
builder.arg(&out);
builder.arg(&n_elems);
unsafe { builder.launch(cfg).w()? };

// Record an event once this kernel completes
let event = self.device.context.new_event(None).w()?;
event.record(&stream).w()?;

Ok(Storage::Cuda(CudaStorage {
slice: out,
device: self.device.clone(),
event,
}))
}
}

pub enum CudaCompiledKernel<T: DType> {
Expand Down Expand Up @@ -250,14 +345,14 @@ fn cuda_include_dir() -> Option<PathBuf> {
fn compile_ptx(template_kernel: String) -> Result<Ptx> {
cudarc::nvrtc::compile_ptx_with_opts(
template_kernel,
// Compile PTX without hardcoding an architecture so it can JIT to the current device
CompileOptions {
use_fast_math: Some(true),
include_paths: vec![cuda_include_dir()
.unwrap()
.join("include")
.display()
.to_string()],
arch: Some("sm_90"),
..Default::default()
},
)
Expand Down Expand Up @@ -304,6 +399,14 @@ impl CudaDevice {
header.hash(&mut hasher);
let function_name = format!("jit_kernel_{}_{}", hasher.finish(), T::NAME);

// If we've already compiled this kernel, skip PTX compilation
if let Some(module) = self.modules.read().unwrap().get(&function_name) {
let func = module.load_function(&function_name).w()?;
let n_elems: usize = shape.iter().product();
let data = unsafe { self.stream.alloc::<T>(n_elems) }.w()?;
return Ok((func, data));
}

let template_kernel = format!(
r#"
typedef unsigned char uint8_t;
Expand Down
9 changes: 9 additions & 0 deletions constensor-core/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ impl<T: DType> Storage<T> {
Self::Cuda(cuda) => cuda.to_cpu_storage(),
}
}

pub(crate) fn cast<U: DType>(&self) -> Result<Storage<U>> {
match self {
Self::Cpu(cpu) => cpu.cast::<U>(),
#[cfg(feature = "cuda")]
Self::Cuda(cuda) => cuda.cast::<U>(),
}
}
}

pub trait BackendStorage<T: DType> {
fn to_cpu_storage(&self) -> Result<Cow<CpuStorage<T>>>;
fn cast<U: DType>(&self) -> Result<Storage<U>>;
}

pub trait BackendDevice {
Expand Down
9 changes: 9 additions & 0 deletions constensor-core/src/tensor/concretetensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ tensor_api!(Cpu);
#[cfg(feature = "cuda")]
tensor_api!(Cuda<0>);

impl<S: Shape, T: DType, D: Dev> Tensor<S, T, D> {
/// Cast this tensor to a different dtype `U` on the CPU.
pub fn cast<U: DType>(&self) -> Result<Tensor<S, U, D>> {
// retrieve data from storage as owned Vec<T>
let storage = self.storage.cast::<U>()?;
Ok(from_storage::<S, U, D>(Arc::new(storage)))
}
}

/*macro_rules! binary_op {
($trait:ident, $fn:ident) => {
impl<S: Shape, D: DType> $trait for Tensor<S, D> {
Expand Down
64 changes: 64 additions & 0 deletions constensor-core/tests/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#[cfg(feature = "cuda")]
use constensor_core::Cuda;
use constensor_core::{CompiledGraph, Cpu, Graph, GraphTensor, R1, R2, R3};

macro_rules! test_for_device_cast {
($dev:ty, $name:ident) => {
mod $name {
use super::*;

// Test casting a 1D tensor from f32 to f64
#[test]
fn cast_f32_to_f64_1d() {
let mut graph = Graph::empty();
let _x = GraphTensor::<R1<4>, f32, $dev>::fill(&mut graph, 1.5);
let compiled: CompiledGraph<R1<4>, f32, $dev> = graph.compile().unwrap();
let tensor = compiled.run().unwrap();
let casted = tensor.cast::<f64>().unwrap();
let data = casted.data().unwrap().into_owned();
assert_eq!(data, vec![1.5_f64; 4]);
}

// Test casting a 2D tensor from f64 to f32
#[test]
fn cast_f64_to_f32_2d() {
let mut graph = Graph::empty();
let _x = GraphTensor::<R2<2, 3>, f64, $dev>::fill(&mut graph, 2.75);
let compiled: CompiledGraph<R2<2, 3>, f64, $dev> = graph.compile().unwrap();
let tensor = compiled.run().unwrap();
let casted = tensor.cast::<f32>().unwrap();
let data = casted.data().unwrap().into_owned();
assert_eq!(data, vec![vec![2.75_f32; 3]; 2]);
}

// Test casting a 3D tensor from i32 to f32
#[test]
fn cast_i32_to_f32_3d() {
let mut graph = Graph::empty();
let _x = GraphTensor::<R3<1, 2, 3>, i32, $dev>::fill(&mut graph, 7);
let compiled: CompiledGraph<R3<1, 2, 3>, i32, $dev> = graph.compile().unwrap();
let tensor = compiled.run().unwrap();
let casted = tensor.cast::<f32>().unwrap();
let data = casted.data().unwrap().into_owned();
let expected = vec![vec![vec![7.0_f32; 3]; 2]; 1];
assert_eq!(data, expected);
}

// Test casting from f32 to i32 truncates toward zero
#[test]
fn cast_f32_to_i32_truncate() {
let mut graph = Graph::empty();
let _x = GraphTensor::<R1<3>, f32, $dev>::fill(&mut graph, 1.9);
let compiled: CompiledGraph<R1<3>, f32, $dev> = graph.compile().unwrap();
let tensor = compiled.run().unwrap();
let casted = tensor.cast::<i32>().unwrap();
let data = casted.data().unwrap().into_owned();
assert_eq!(data, vec![1_i32; 3]);
}
}
};
}

test_for_device_cast!(Cpu, cpu_tests_cast);
#[cfg(feature = "cuda")]
test_for_device_cast!(Cuda<0>, cuda_tests_cast);
Loading