|
21 | 21 | from . import _eval_mode |
22 | 22 | from . import _invocation |
23 | 23 | import copy |
| 24 | +import nvtx |
24 | 25 |
|
25 | 26 |
|
26 | 27 | class BatchedSlice: |
@@ -154,21 +155,26 @@ def _is_external(self) -> bool: |
154 | 155 | return self._wraps_external_data |
155 | 156 |
|
156 | 157 | @staticmethod |
157 | | - def broadcast(sample, batch_size: int) -> "Batch": |
| 158 | + def broadcast(sample, batch_size: int, device: Optional[Device] = None) -> "Batch": |
158 | 159 | if isinstance(sample, Batch): |
159 | 160 | raise ValueError("Cannot broadcast a Batch") |
160 | 161 | if _is_tensor_type(sample): |
161 | | - return Batch([Tensor(sample)] * batch_size) |
| 162 | + return Batch([Tensor(sample, device=device)] * batch_size) |
162 | 163 | import numpy as np |
163 | | - arr = np.array(batch_size) |
164 | | - if arr.dtype == np.float64: |
165 | | - arr = arr.astype(np.float32) |
166 | | - elif arr.dtype == np.int64: |
167 | | - arr = arr.astype(np.int32) |
168 | | - elif arr.dtype == np.uint64: |
169 | | - arr = arr.astype(np.uint32) |
170 | | - arr = np.stack([arr] * batch_size) |
171 | | - return Batch(_backend.TensorListCPU(arr)) |
| 164 | + with nvtx.annotate("to numpy and stack", domain="batch"): |
| 165 | + arr = np.array(sample) |
| 166 | + if arr.dtype == np.float64: |
| 167 | + arr = arr.astype(np.float32) |
| 168 | + elif arr.dtype == np.int64: |
| 169 | + arr = arr.astype(np.int32) |
| 170 | + elif arr.dtype == np.uint64: |
| 171 | + arr = arr.astype(np.uint32) |
| 172 | + arr = np.repeat(arr[np.newaxis], batch_size, axis=0) |
| 173 | + |
| 174 | + with nvtx.annotate("to backend", domain="batch"): |
| 175 | + tl = _backend.TensorListCPU(arr) |
| 176 | + with nvtx.annotate("create batch", domain="batch"): |
| 177 | + return Batch(tl, device=device) |
172 | 178 |
|
173 | 179 | @property |
174 | 180 | def dtype(self) -> DType: |
|
0 commit comments