Skip to content

Commit 4786140

Browse files
authored
Add an Allocatable trait to the runner interface (#179)
This PR adds an `Allocatable` trait to the runner interface. This generally makes it cleaner to write generic functions over Carton types without needing to reference implementation details in trait bounds (e.g. `where InlineAllocator: TypedAlloc<T, Output = InlineTensorStorage>`). Note: this PR does touch files inside `do_not_modify`, but it does so in a way that does not affect the wire protocol.
1 parent 60afa98 commit 4786140

File tree

7 files changed

+73
-52
lines changed

7 files changed

+73
-52
lines changed

source/carton-runner-interface/src/do_not_modify/alloc.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,17 @@ pub trait AsPtr<T> {
3131
fn as_mut_ptr(&mut self) -> *mut T;
3232
}
3333

34-
pub trait TypedAlloc<T> {
35-
type Output: AsPtr<T>;
34+
pub trait Allocator {
35+
type Output;
36+
}
3637

38+
pub trait TypedAlloc<T>: Allocator
39+
where
40+
Self::Output: AsPtr<T>,
41+
{
3742
fn alloc(&self, numel: usize) -> Self::Output;
3843
}
44+
45+
pub trait AllocatableBy<A: Allocator>: Sized {
46+
fn alloc(allocator: &A, numel: usize) -> A::Output;
47+
}

source/carton-runner-interface/src/do_not_modify/alloc_inline.rs

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use once_cell::sync::Lazy;
2323
use serde::{Deserialize, Serialize};
2424

2525
use super::{
26-
alloc::{AsPtr, NumericTensorType, TypedAlloc},
26+
alloc::{AllocatableBy, Allocator, AsPtr, NumericTensorType, TypedAlloc},
2727
alloc_pool::{PoolAllocator, PoolItem},
2828
storage::TensorStorage,
2929
};
@@ -79,12 +79,14 @@ impl<T> AsPtr<T> for InlineTensorStorage {
7979
}
8080
}
8181

82+
impl Allocator for InlineAllocator {
83+
type Output = InlineTensorStorage;
84+
}
85+
8286
for_each_numeric_carton_type! {
8387
$(
8488
/// We're using a macro here instead of a generic impl because rust gives misleading error messages otherwise.
8589
impl TypedAlloc<$RustType> for InlineAllocator {
86-
type Output = InlineTensorStorage;
87-
8890
fn alloc(&self, numel: usize) -> Self::Output {
8991
// We need to convert to size_bytes since we always use a Vec<u8>
9092
let size_bytes = numel * std::mem::size_of::<$RustType>();
@@ -101,8 +103,6 @@ for_each_numeric_carton_type! {
101103
}
102104

103105
impl TypedAlloc<String> for InlineAllocator {
104-
type Output = InlineTensorStorage;
105-
106106
fn alloc(&self, numel: usize) -> Self::Output {
107107
let out = if !self.use_pool {
108108
vec![String::default(); numel].into()
@@ -115,10 +115,8 @@ impl TypedAlloc<String> for InlineAllocator {
115115
}
116116

117117
// Copy the data
118-
impl<T: NumericTensorType + Default + Copy> From<ndarray::ArrayViewD<'_, T>>
119-
for TensorStorage<T, InlineTensorStorage>
120-
where
121-
InlineAllocator: TypedAlloc<T, Output = InlineTensorStorage>,
118+
impl<T: NumericTensorType + Default + Copy + AllocatableBy<InlineAllocator>>
119+
From<ndarray::ArrayViewD<'_, T>> for TensorStorage<T, InlineTensorStorage>
122120
{
123121
fn from(view: ndarray::ArrayViewD<'_, T>) -> Self {
124122
// Alloc a tensor
@@ -152,17 +150,14 @@ impl From<ndarray::ArrayViewD<'_, String>> for TensorStorage<String, InlineTenso
152150

153151
// Allocates a contiguous tensor with a shape and type
154152
#[cfg(feature = "benchmark")]
155-
pub fn alloc_tensor_no_pool<T: Default + Clone>(
153+
pub fn alloc_tensor_no_pool<T: Default + Clone + AllocatableBy<InlineAllocator>>(
156154
shape: Vec<u64>,
157-
) -> TensorStorage<T, InlineTensorStorage>
158-
where
159-
InlineAllocator: TypedAlloc<T, Output = InlineTensorStorage>,
160-
{
155+
) -> TensorStorage<T, InlineTensorStorage> {
161156
static POOL_ALLOCATOR: Lazy<InlineAllocator> = Lazy::new(|| InlineAllocator::without_pool());
162157

163158
let numel = shape.iter().product::<u64>().max(1) as usize;
164159

165-
let data = <InlineAllocator as TypedAlloc<T>>::alloc(&POOL_ALLOCATOR, numel);
160+
let data = T::alloc(&POOL_ALLOCATOR, numel);
166161

167162
TensorStorage {
168163
data,
@@ -172,15 +167,14 @@ where
172167
}
173168
}
174169

175-
pub fn alloc_tensor<T: Default + Clone>(shape: Vec<u64>) -> TensorStorage<T, InlineTensorStorage>
176-
where
177-
InlineAllocator: TypedAlloc<T, Output = InlineTensorStorage>,
178-
{
170+
pub fn alloc_tensor<T: Default + Clone + AllocatableBy<InlineAllocator>>(
171+
shape: Vec<u64>,
172+
) -> TensorStorage<T, InlineTensorStorage> {
179173
static POOL_ALLOCATOR: Lazy<InlineAllocator> = Lazy::new(|| InlineAllocator::new());
180174

181175
let numel = shape.iter().product::<u64>().max(1) as usize;
182176

183-
let data = <InlineAllocator as TypedAlloc<T>>::alloc(&POOL_ALLOCATOR, numel);
177+
let data = T::alloc(&POOL_ALLOCATOR, numel);
184178

185179
TensorStorage {
186180
data,
@@ -190,11 +184,17 @@ where
190184
}
191185
}
192186

193-
impl<T: Default + Clone> TensorStorage<T, InlineTensorStorage>
187+
impl<T: Default + Clone + AllocatableBy<InlineAllocator>> TensorStorage<T, InlineTensorStorage> {
188+
pub fn new(shape: Vec<u64>) -> TensorStorage<T, InlineTensorStorage> {
189+
alloc_tensor(shape)
190+
}
191+
}
192+
193+
impl<T> AllocatableBy<InlineAllocator> for T
194194
where
195195
InlineAllocator: TypedAlloc<T, Output = InlineTensorStorage>,
196196
{
197-
pub fn new(shape: Vec<u64>) -> TensorStorage<T, InlineTensorStorage> {
198-
alloc_tensor(shape)
197+
fn alloc(allocator: &InlineAllocator, numel: usize) -> InlineTensorStorage {
198+
<InlineAllocator as TypedAlloc<T>>::alloc(allocator, numel)
199199
}
200200
}

source/carton-runner-interface/src/do_not_modify/alloc_shm.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use dashmap::DashMap;
2525
use once_cell::sync::Lazy;
2626

2727
use super::{
28-
alloc::{AsPtr, NumericTensorType, TypedAlloc},
28+
alloc::{AsPtr, NumericTensorType, TypedAlloc, Allocator, AllocatableBy},
2929
alloc_pool::{AllocItem, PoolAllocator, PoolItem},
3030
storage::TensorStorage,
3131
};
@@ -260,12 +260,14 @@ impl SHMAllocator {
260260
}
261261
}
262262

263+
impl Allocator for SHMAllocator {
264+
type Output = SHMTensorStorage;
265+
}
266+
263267
for_each_numeric_carton_type! {
264268
$(
265269
/// We're using a macro here instead of a generic impl because rust gives misleading error messages otherwise.
266270
impl TypedAlloc<$RustType> for SHMAllocator {
267-
type Output = SHMTensorStorage;
268-
269271
fn alloc(&self, numel: usize) -> Self::Output {
270272
// We need to convert to size_bytes
271273
let size_bytes = numel * std::mem::size_of::<$RustType>();
@@ -282,8 +284,6 @@ for_each_numeric_carton_type! {
282284
}
283285

284286
impl TypedAlloc<String> for SHMAllocator {
285-
type Output = SHMTensorStorage;
286-
287287
fn alloc(&self, numel: usize) -> Self::Output {
288288
let out = if !self.use_pool {
289289
vec![String::default(); numel].into()
@@ -384,17 +384,15 @@ impl From<ndarray::ArrayViewD<'_, String>> for TensorStorage<String, SHMTensorSt
384384

385385
// Allocates a contiguous tensor with a shape and type
386386
#[cfg(feature = "benchmark")]
387-
pub fn alloc_tensor_no_pool<T: Default + Clone>(
387+
pub fn alloc_tensor_no_pool<T: Default + Clone + AllocatableBy<SHMAllocator>>(
388388
shape: Vec<u64>,
389389
) -> TensorStorage<T, SHMTensorStorage>
390-
where
391-
SHMAllocator: TypedAlloc<T, Output = SHMTensorStorage>,
392390
{
393391
static POOL_ALLOCATOR: Lazy<SHMAllocator> = Lazy::new(|| SHMAllocator::without_pool());
394392

395393
let numel = shape.iter().product::<u64>().max(1) as usize;
396394

397-
let data = <SHMAllocator as TypedAlloc<T>>::alloc(&POOL_ALLOCATOR, numel);
395+
let data = T::alloc(&POOL_ALLOCATOR, numel);
398396

399397
TensorStorage {
400398
data,
@@ -404,15 +402,13 @@ where
404402
}
405403
}
406404

407-
pub fn alloc_tensor<T: Default + Clone>(shape: Vec<u64>) -> TensorStorage<T, SHMTensorStorage>
408-
where
409-
SHMAllocator: TypedAlloc<T, Output = SHMTensorStorage>,
405+
pub fn alloc_tensor<T: Default + Clone + AllocatableBy<SHMAllocator>>(shape: Vec<u64>) -> TensorStorage<T, SHMTensorStorage>
410406
{
411407
static POOL_ALLOCATOR: Lazy<SHMAllocator> = Lazy::new(|| SHMAllocator::new());
412408

413409
let numel = shape.iter().product::<u64>().max(1) as usize;
414410

415-
let data = <SHMAllocator as TypedAlloc<T>>::alloc(&POOL_ALLOCATOR, numel);
411+
let data = T::alloc(&POOL_ALLOCATOR, numel);
416412

417413
TensorStorage {
418414
data,
@@ -421,3 +417,13 @@ where
421417
pd: PhantomData,
422418
}
423419
}
420+
421+
422+
impl<T> AllocatableBy<SHMAllocator> for T
423+
where
424+
SHMAllocator: TypedAlloc<T, Output = SHMTensorStorage>,
425+
{
426+
fn alloc(allocator: &SHMAllocator, numel: usize) -> SHMTensorStorage {
427+
<SHMAllocator as TypedAlloc<T>>::alloc(allocator, numel)
428+
}
429+
}

source/carton-runner-interface/src/do_not_modify/types.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
pub use carton_macros::for_each_carton_type;
15+
pub use carton_macros::{for_each_carton_type, for_each_numeric_carton_type};
1616
use serde::{Deserialize, Serialize};
1717
use std::collections::HashMap;
1818

19-
use super::{alloc_inline::InlineTensorStorage, comms::Comms};
19+
use super::{
20+
alloc::AllocatableBy,
21+
alloc_inline::{InlineAllocator, InlineTensorStorage},
22+
comms::Comms,
23+
};
2024

2125
#[derive(Debug, Serialize, Deserialize)]
2226
pub(crate) struct RPCRequest {
@@ -247,6 +251,9 @@ for_each_carton_type! {
247251

248252
pub type TensorStorage<T> = super::storage::TensorStorage<T, InlineTensorStorage>;
249253

254+
pub trait Allocatable: AllocatableBy<InlineAllocator> {}
255+
impl<T> Allocatable for T where T: AllocatableBy<InlineAllocator> {}
256+
250257
for_each_carton_type! {
251258
$(
252259
impl From<TensorStorage<$RustType>> for Tensor {

source/carton-runner-interface/src/runner.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,8 @@ use std::{collections::HashMap, sync::Arc};
1717
use crate::{
1818
client::Client,
1919
do_not_modify::comms::OwnedComms,
20-
do_not_modify::{
21-
alloc::TypedAlloc,
22-
alloc_inline::{InlineAllocator, InlineTensorStorage},
23-
types::{Device, RPCRequestData, RPCResponseData, SealHandle, Tensor},
24-
},
25-
types::{Handle, RunnerOpt, TensorStorage},
20+
do_not_modify::types::{Device, RPCRequestData, RPCResponseData, SealHandle, Tensor},
21+
types::{Allocatable, Handle, RunnerOpt, TensorStorage},
2622
};
2723

2824
use futures::Stream;
@@ -299,9 +295,11 @@ impl Runner {
299295
}
300296
}
301297

302-
pub fn alloc_tensor<T: Clone + Default>(&self, shape: Vec<u64>) -> Result<Tensor, String>
298+
pub fn alloc_tensor<T: Clone + Default + Allocatable>(
299+
&self,
300+
shape: Vec<u64>,
301+
) -> Result<Tensor, String>
303302
where
304-
InlineAllocator: TypedAlloc<T, Output = InlineTensorStorage>,
305303
Tensor: From<TensorStorage<T>>,
306304
{
307305
Ok(TensorStorage::new(shape).into())

source/carton-runner-wasm/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ publish = false
77
exclude = ["tests/test_model", "carton-wasm-interface"]
88

99
[dependencies]
10-
carton = { path = "../carton" }
1110
carton-runner-interface = { path = "../carton-runner-interface" }
1211
color-eyre = "0.6.2"
1312
lunchbox = { version = "0.1", default-features = false }
@@ -28,4 +27,5 @@ semver = "1.0.20"
2827
[dev-dependencies]
2928
escargot = "0.5.8"
3029
paste = "1.0.14"
31-
tempfile = "3.8.0"
30+
tempfile = "3.8.0"
31+
carton = { path = "../carton" }

source/carton-runner-wasm/src/types.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use color_eyre::eyre::{ensure, eyre};
22
use color_eyre::{Report, Result};
33

4-
use carton::types::for_each_numeric_carton_type;
5-
use carton_runner_interface::types::{Tensor as CartonTensor, TensorStorage as CartonStorage};
4+
use carton_runner_interface::types::{
5+
for_each_numeric_carton_type, Tensor as CartonTensor, TensorStorage as CartonStorage,
6+
};
67

78
use crate::component::{Dtype, Tensor as WasmTensor, TensorNumeric, TensorString};
89

0 commit comments

Comments
 (0)