Skip to content

Commit a714004

Browse files
committed
Constructing batches from non-tensor objects.
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
1 parent f1d1b64 commit a714004

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

dali/python/backend_impl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,7 @@ void ExposeTensorList(py::module &m) {
11361136
"tl"_a,
11371137
"layout"_a = py::none())
11381138
.def(py::init([](py::buffer b, string layout = "", bool is_pinned = false) {
1139-
DomainTimeRange range("TensorListCPU::init from a buffer", kCPUTensorColor);
1139+
DomainTimeRange range("TensorListCPU::init from a buffer", kCPUTensorColor);
11401140
// We need to verify that the input data is C_CONTIGUOUS
11411141
// and of a type that we can work with in the backend
11421142
py::buffer_info info = b.request();

dali/python/nvidia/dali/experimental/dali2/_batch.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from . import _eval_mode
2222
from . import _invocation
2323
import copy
24+
import nvtx
2425

2526

2627
class BatchedSlice:
@@ -154,21 +155,26 @@ def _is_external(self) -> bool:
154155
return self._wraps_external_data
155156

156157
@staticmethod
157-
def broadcast(sample, batch_size: int) -> "Batch":
158+
def broadcast(sample, batch_size: int, device: Optional[Device] = None) -> "Batch":
158159
if isinstance(sample, Batch):
159160
raise ValueError("Cannot broadcast a Batch")
160161
if _is_tensor_type(sample):
161-
return Batch([Tensor(sample)] * batch_size)
162+
return Batch([Tensor(sample, device=device)] * batch_size)
162163
import numpy as np
163-
arr = np.array(batch_size)
164-
if arr.dtype == np.float64:
165-
arr = arr.astype(np.float32)
166-
elif arr.dtype == np.int64:
167-
arr = arr.astype(np.int32)
168-
elif arr.dtype == np.uint64:
169-
arr = arr.astype(np.uint32)
170-
arr = np.stack([arr] * batch_size)
171-
return Batch(_backend.TensorListCPU(arr))
164+
with nvtx.annotate("to numpy and stack", domain="batch"):
165+
arr = np.array(sample)
166+
if arr.dtype == np.float64:
167+
arr = arr.astype(np.float32)
168+
elif arr.dtype == np.int64:
169+
arr = arr.astype(np.int32)
170+
elif arr.dtype == np.uint64:
171+
arr = arr.astype(np.uint32)
172+
arr = np.repeat(arr[np.newaxis], batch_size, axis=0)
173+
174+
with nvtx.annotate("to backend", domain="batch"):
175+
tl = _backend.TensorListCPU(arr)
176+
with nvtx.annotate("create batch", domain="batch"):
177+
return Batch(tl, device=device)
172178

173179
@property
174180
def dtype(self) -> DType:

dali/python/nvidia/dali/experimental/dali2/_op_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _to_batch(x, batch_size, device=None):
131131
raise ValueError(f"Unexpected batch size: {actual_batch_size} != {batch_size}")
132132
return Batch(x, device=device)
133133

134-
return Batch.broadcast(_to_tensor(x, device=device), batch_size)
134+
return Batch.broadcast(x, batch_size, device=device)
135135

136136

137137
_unsupported_args = {"bytes_per_sample_hint", "preserve"}

0 commit comments

Comments
 (0)