Skip to content

Commit 3909ef1

Browse files
authored
Merge pull request #20 from bminixhofer/symbolic-dim
Support for symbolic dims
2 parents 1b08682 + 13bc0bf commit 3909ef1

File tree

13 files changed

+472
-114
lines changed

13 files changed

+472
-114
lines changed

Cargo.lock

Lines changed: 220 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ crate-type = ["cdylib", "rlib"]
1010
[dependencies]
1111
wasm-bindgen = "^0.2"
1212
js-sys = "0.3.39"
13-
tract-core = "0.9.2"
14-
tract-onnx = "0.9.2"
15-
tract-hir = "0.9.2"
16-
tract-tensorflow = "0.9.2"
13+
tract-core = "0.11.1"
14+
tract-onnx = "0.11.1"
15+
tract-hir = "0.11.1"
16+
tract-tensorflow = "0.11.1"
1717
console_error_panic_hook = "0.1.1"
1818
ndarray-rand = { version = "0.11.0", optional = true }
1919
serde = { version = "1.0", features = ["derive"], optional = true }

README.md

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,29 @@ tractjs.load("./path/to/your/model").then((model) => {
8080

8181
### My model with dynamic input dimensions doesn't work. Why?
8282

83-
Currently, tract requires fully determined input dimensions to optimize a model. There are two options:
83+
Currently, tract requires has some restrictions on dynamic dimensions. If your model has a dynamic dimension, there's multiple solutions:
8484

85-
1. Turn `optimize` off:
85+
1. Declare a dynamic dimension via an input fact. Input facts are a way to provide additional information about input type and shape that can not be inferred via the model data:
8686

8787
```js
8888
const model = await tractjs.load("path/to/your/model", {
89-
optimize: false,
89+
inputFacts: {
90+
0: ["float32", [1, "s", 224, 224]],
91+
},
9092
});
9193
```
9294

93-
This will however _significantly_ impact performance.
95+
This dimension must then be made concrete on prediction:
96+
97+
```js
98+
model.predict(input, {
99+
"s": 3 // or some other value
100+
})
101+
```
102+
103+
The API supports multiple dynamic dimensions, but currently it will probably only work with one.
94104

95-
2. Set fixed input dimensions via input facts. Input facts are a way to provide additional information about input type and shape that can not be inferred via the model data:
105+
2. Set fixed input dimensions via input facts. This is of course not ideal because subsequently the model can only be passed inputs with this exact shape:
96106

97107
```js
98108
const model = await tractjs.load("path/to/your/model", {
@@ -104,13 +114,17 @@ const model = await tractjs.load("path/to/your/model", {
104114
});
105115
```
106116

107-
Be aware that the model will only work properly with inputs of this exact shape though.
117+
3. Turn `optimize` off. This is the nuclear option. It will turn off all optimizations relying on information about input shape. This will make sure your model work (even with multiple dynamic dimensions) but _significantly_ impact performance:
108118

109-
There is ongoing work in tract to allow dynamically sized inputs.
119+
```js
120+
const model = await tractjs.load("path/to/your/model", {
121+
optimize: false,
122+
});
123+
```
110124

111125
### What about size?
112126

113-
At the time of writing, tractjs is very large for web standards (8.5MB raw, 2.5MB gzipped). This is due to tract being quite large, and due to some overhead from inlining the WASM. But it's not as bad as it sounds. You can load tractjs lazily along your demo, where you will likely have to load significantly large weights too.
127+
At the time of writing, tractjs is very large for web standards (6.2MB raw, 2.1MB gzipped). This is due to tract being quite large, and due to some overhead from inlining the WASM. But it's not as bad as it sounds. You can load tractjs lazily along your demo, where you will likely have to load significantly large weights too.
114128

115129
If you are working on a very size-sensitive application, get in touch and we can work on decreasing the size. There are some more optimizations to be done (e. g. an option not to inline WASM, and removing panics from the build). There is also ongoing work in tract to decrease size.
116130

quality/tests/reference.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ fn custom_input_tf<P1: AsRef<Path>, P2: AsRef<Path>>(
125125
input_file: P1,
126126
output_file: P2,
127127
) -> TractResult<()> {
128-
let mut cursor =
129-
Cursor::new(include_bytes!("../models/data/squeezenet_1_1/model.pb") as &[u8]);
128+
let mut cursor = Cursor::new(include_bytes!("../models/data/squeezenet_1_1/model.pb") as &[u8]);
130129
let model = tract_tensorflow::tensorflow()
131130
.model_for_read(&mut cursor)?
132131
.with_input_fact(
@@ -137,7 +136,10 @@ fn custom_input_tf<P1: AsRef<Path>, P2: AsRef<Path>>(
137136
.into_optimized()?
138137
.into_runnable()?;
139138

140-
let inputs = tvec![random_from_shape((1, 27, 27, 128)), random_from_shape((1, 27, 27, 128))];
139+
let inputs = tvec![
140+
random_from_shape((1, 27, 27, 128)),
141+
random_from_shape((1, 27, 27, 128))
142+
];
141143
serialize_tensors(&inputs, input_file);
142144

143145
let preds = model.run(inputs)?;

src/lib.rs

Lines changed: 110 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
use js_sys::{Array, Error, Object, Uint32Array};
1+
use js_sys::{Array, Error, JsString, Object, Uint32Array};
22
use std::io::Cursor;
3+
use tract_core::internal::ToDim;
34
use tract_hir::prelude::*;
45
use tract_onnx::prelude::*;
56
use tract_tensorflow::prelude::*;
@@ -31,8 +32,9 @@ impl<T: std::fmt::Debug> TractResultExt<T> for TractResult<T> {
3132
}
3233
}
3334

34-
fn fact_from_js(input: JsValue) -> InferenceFact {
35-
let input: Array = input.dyn_into().expect("fact must be an Array.");
35+
fn fact_from_js(input: JsValue) -> (InferenceFact, bool) {
36+
let mut has_symbolic_dim = false;
37+
let input: Array = input.dyn_into().expect("Fact must be an Array.");
3638
let dtype = if let Some(string) = input.get(0).as_string() {
3739
Some(match string.as_str() {
3840
"int8" => i8::datum_type(),
@@ -52,19 +54,77 @@ fn fact_from_js(input: JsValue) -> InferenceFact {
5254
Some(
5355
shape
5456
.iter()
55-
.map(|x| x.as_f64().expect("fact[1] must be an Array of numbers.") as usize)
57+
.map(|x| {
58+
let mut dim: Option<TDim> = None;
59+
60+
if let Some(number) = x.as_f64() {
61+
dim = Some((number as isize).to_dim());
62+
} else if let Some(string) = x.as_string() {
63+
has_symbolic_dim = true;
64+
dim = Some(Symbol::from(string.chars().next().expect("fact[1][i] must consist of exactly one character if string.")).into());
65+
} else if let Ok(object) = x.dyn_into::<Object>() {
66+
// could be made prettier - maybe with serde, but would add an additional dependency
67+
let id = js_sys::Reflect::get(&object, &JsString::from("id")).expect("fact[1][i] must have 'id' key if object.")
68+
.as_string().expect("fact[1][i]['id'] must be a string.");
69+
70+
let intercept = js_sys::Reflect::get(&object, &JsString::from("intercept")).expect("fact[1][i] must have 'intercept' key if object.")
71+
.as_f64().expect("fact[1][i]['intercept'] must be a number.") as isize;
72+
73+
let slope = js_sys::Reflect::get(&object, &JsString::from("slope")).expect("fact[1][i] must have 'slope' key if object.")
74+
.as_f64().expect("fact[1][i]['slope'] must be a number.");
75+
76+
let symbol: TDim = Symbol::from(id.chars().next().expect("fact[1][i] must consist of exactly one character if string.")).into();
77+
has_symbolic_dim = true;
78+
79+
dim = Some((symbol * slope) + intercept)
80+
}
81+
82+
dim.expect("fact[1][i] must be one of: number, string, or object with slope, intercept and id keys.")
83+
})
5684
.collect::<Vec<_>>(),
5785
)
5886
} else {
5987
None
6088
};
6189

62-
match (dtype, shape) {
63-
(Some(dtype), Some(shape)) => InferenceFact::dt_shape(dtype, shape),
64-
(Some(dtype), None) => InferenceFact::dt(dtype),
65-
(None, Some(shape)) => InferenceFact::shape(shape),
66-
(None, None) => panic!("either dtype or shape must be specified."),
90+
(
91+
match (dtype, shape) {
92+
(Some(dtype), Some(shape)) => InferenceFact::dt_shape(dtype, shape),
93+
(Some(dtype), None) => InferenceFact::dt(dtype),
94+
(None, Some(shape)) => InferenceFact::shape(shape),
95+
(None, None) => panic!("either dtype or shape must be specified."),
96+
},
97+
has_symbolic_dim,
98+
)
99+
}
100+
101+
fn symbol_values_from_js(input: JsValue) -> SymbolValues {
102+
let input: Object = input.dyn_into().expect("SymbolValues must be object.");
103+
let entries = Object::entries(&input);
104+
105+
let mut symbol_values = SymbolValues::default();
106+
107+
for i in 0..entries.length() {
108+
let val: Array = entries
109+
.get(i)
110+
.dyn_into()
111+
.expect(".entries[i] must be an array.");
112+
113+
symbol_values = symbol_values.with(
114+
val.get(0)
115+
.as_string()
116+
.expect("entries[i][0] must be a string.")
117+
.chars()
118+
.next()
119+
.expect("entries[i][0] must consist of exactly one character.")
120+
.into(),
121+
val.get(1)
122+
.as_f64()
123+
.expect("entries[i][1] must be a number.") as i64,
124+
)
67125
}
126+
127+
symbol_values
68128
}
69129

70130
#[wasm_bindgen]
@@ -189,14 +249,33 @@ impl From<CoreTensor> for Tensor {
189249

190250
enum Model {
191251
Inference(tract_hir::infer::InferenceSimplePlan<InferenceModel>),
252+
Optimized(TypedModel),
192253
Typed(TypedSimplePlan<TypedModel>),
193254
}
194255

256+
#[wasm_bindgen]
257+
extern "C" {
258+
#[wasm_bindgen(js_namespace = console)]
259+
pub fn log(input: &str);
260+
}
261+
195262
impl Model {
196-
fn run(&self, inputs: TVec<Tensor>) -> TractResult<TVec<Arc<Tensor>>> {
263+
fn run(
264+
&self,
265+
inputs: TVec<Tensor>,
266+
symbol_values: &SymbolValues,
267+
) -> TractResult<TVec<Arc<Tensor>>> {
197268
match self {
198269
Model::Inference(x) => x.run(inputs),
199270
Model::Typed(x) => x.run(inputs),
271+
Model::Optimized(x) => {
272+
let model = x
273+
.concretize_dims(symbol_values)?
274+
.optimize()?
275+
.into_runnable()?;
276+
277+
model.run(inputs)
278+
}
200279
}
201280
}
202281
}
@@ -207,6 +286,12 @@ impl From<tract_hir::infer::InferenceSimplePlan<InferenceModel>> for Model {
207286
}
208287
}
209288

289+
impl From<TypedModel> for Model {
290+
fn from(input: TypedModel) -> Self {
291+
Model::Optimized(input)
292+
}
293+
}
294+
210295
impl From<TypedSimplePlan<TypedModel>> for Model {
211296
fn from(input: TypedSimplePlan<TypedModel>) -> Self {
212297
Model::Typed(input)
@@ -231,6 +316,7 @@ impl CoreModel {
231316
console_error_panic_hook::set_once();
232317

233318
let mut reader = Cursor::new(data);
319+
let mut model_has_symbolic_dim = false;
234320

235321
let mut model = if use_onnx {
236322
onnx().model_for_read(&mut reader)
@@ -243,7 +329,8 @@ impl CoreModel {
243329
.iter()
244330
.zip(Object::values(&input_facts).iter())
245331
{
246-
let fact = fact_from_js(fact);
332+
let (fact, fact_has_symbolic_dim) = fact_from_js(fact);
333+
model_has_symbolic_dim = model_has_symbolic_dim || fact_has_symbolic_dim;
247334

248335
model
249336
.set_input_fact(
@@ -275,22 +362,31 @@ impl CoreModel {
275362
.map_js_error()?;
276363
}
277364

278-
let model: Model = if optimize {
365+
let model: Model = if optimize && !model_has_symbolic_dim {
279366
model
280367
.into_optimized()
281368
.map_js_error()?
282369
.into_runnable()
283370
.map_js_error()?
284371
.into()
372+
} else if optimize {
373+
model.into_optimized().map_js_error()?.into()
285374
} else {
286375
model.into_runnable().map_js_error()?.into()
287376
};
288377

289378
Ok(CoreModel { model })
290379
}
291380

292-
pub fn predict(&self, data: CoreTensorVec) -> Result<CoreTensorVec, JsValue> {
293-
let outputs = self.model.run(data.into_tvec()).map_js_error()?;
381+
pub fn predict(
382+
&self,
383+
data: CoreTensorVec,
384+
symbol_values: JsValue,
385+
) -> Result<CoreTensorVec, JsValue> {
386+
let outputs = self
387+
.model
388+
.run(data.into_tvec(), &symbol_values_from_js(symbol_values))
389+
.map_js_error()?;
294390
Ok(CoreTensorVec::from_slice(&outputs))
295391
}
296392
}

0 commit comments

Comments
 (0)