-
Notifications
You must be signed in to change notification settings - Fork 1.3k
fix: support decollate for numpy scalars #8470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
fix: support decollate for numpy scalars #8470
Conversation
187c141
to
c438fe0
Compare
monai/data/utils.py
Outdated
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable) | ||
): | ||
return batch | ||
if isinstance(batch, np.ndarray) and batch.ndim == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pr! Do you think it might be beneficial to convert the array into a tensor? This way, the data could be handled more consistently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could, I think it does not matter for my use cases. As long as the function handles numpy scalars in the form of an array it is good for me!
I will add this change and convert it as a tensor there (L629) if you prefer :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick fix!
May I ask the reason for only convert to tensor when batch.ndim == 0
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed a different behavior when using the decollate_batch
function on torch tensors vs numpy arrays (see discussion #8472) so I don't want to convert numpy arrays to torch tensors as it will introduce some breaking changes
This PR only address the issue #8471 as I think it was not expected and should be supported (?).
451c207
to
49d4954
Compare
Could we consider a more complete solution? The issue it seems is that 0-d arrays are iterable but can't be iterated over. We already check for non-iterable things in |
Thanks for the feedback. The initial PR was: if isinstance(batch, (float, int, str, bytes)) or (
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
):
return batch
if isinstance(batch, np.ndarray) and batch.ndim == 0:
return batch.item() if detach else batch
# rest ... Is this something that you find more complete? Note I refactored the PR to convert from numpy array to torch tensor as suggested by @KumoLiu. |
What I had in mind was more of the following change: ...
if batch is None or isinstance(batch, (float, int, str, bytes)):
return batch
if getattr(batch, "ndim", -1) == 0: # assumes only Numpy objects and Pytorch tensors have ndim
return batch.item() if detach else batch
if isinstance(batch, torch.Tensor):
if detach:
batch = batch.detach()
# REMOVE
# if batch.ndim == 0:
# return batch.item() if detach else batch
... |
Thanks! I will update the PR to include these changes. |
WalkthroughUpdated decollate_batch in Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant decollate_batch
opt Old flow (prior to change)
Caller->>decollate_batch: call with batch
decollate_batch->>decollate_batch: is Tensor?
decollate_batch->>decollate_batch: if Tensor and ndim==0 -> return item()
decollate_batch->>Caller: return result
end
opt New flow (after change)
Caller->>decollate_batch: call with batch
decollate_batch->>decollate_batch: pre-check ndim==0 and has item()
decollate_batch->>decollate_batch: if true and detach -> return item()
decollate_batch->>decollate_batch: else proceed with existing branches (Tensor/non-Tensor)
decollate_batch->>Caller: return result
end
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
monai/data/utils.py (3)
628-635
: Prefer generic 0-D detection early; drop special-case conversion.Handle any object with ndim==0 up-front, then the torch branch can skip its own scalar check. This reduces duplication and avoids unnecessary conversion for NumPy scalars when detach=False.
Apply this diff (and remove the torch scalar early-return below):
@@ - if isinstance(batch, (float, int, str, bytes)) or ( - type(batch).__module__ == "numpy" and not isinstance(batch, Iterable) - ): - return batch - if isinstance(batch, np.ndarray) and batch.ndim == 0: - batch = torch.from_numpy(batch) + if isinstance(batch, (float, int, str, bytes)) or ( + type(batch).__module__ == "numpy" and not isinstance(batch, Iterable) + ): + return batch + # generic 0-D objects (NumPy/Torch/others exposing ndim): treat as scalars + if getattr(batch, "ndim", -1) == 0: + return batch.item() if detach else batch @@ - if isinstance(batch, torch.Tensor): - if detach: - batch = batch.detach() - if batch.ndim == 0: - return batch.item() if detach else batch + if isinstance(batch, torch.Tensor): + if detach: + batch = batch.detach() + # ndim==0 handled above
628-630
: Graceful fallback for unsupported NumPy dtypes.torch.from_numpy will fail for some dtypes (e.g., datetime64, object). Consider catching and falling back to .item() or returning the ndarray unchanged.
- if isinstance(batch, np.ndarray) and batch.ndim == 0: - batch = torch.from_numpy(batch) + if isinstance(batch, np.ndarray) and batch.ndim == 0: + try: + batch = torch.from_numpy(batch) + except (TypeError, ValueError): + return batch.item() if detach else batch
614-621
: Doc/test touch-up for 0-D NumPy arrays.Please note 0-D NumPy array behavior in the decollate_batch docstring and add unit tests for:
- np.array(1) with detach=True/False
- Nested structures containing 0-D arrays
- Edge dtype (e.g., bool, float32)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/data/utils.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/data/utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
🔇 Additional comments (1)
monai/data/utils.py (1)
628-635
: Good fix: avoids 0-D ndarray iteration error and aligns with tensor path.Converting 0-D NumPy arrays to torch first prevents the TypeError raised by iterating 0-D arrays and lets the existing scalar-tensor handling (.item() when detach=True) kick in. Looks correct.
ec3e5d9
to
f9dba63
Compare
Sorry for the long delay again. If this is all that's needed in the end we can merge this shortly. I think lines 605-6 are not redundant since the added if condition will catch this case? |
Also please have a look at the DCO message for doing a remedial commit. |
0e43b28
to
96e950c
Compare
96e950c
to
4e17ad5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
monai/data/utils.py (5)
600-601
: Early scalar handling looks good; tiny robustness nit.The generic ndim==0 fast-path correctly covers NumPy and Torch 0‑d cases. Consider guarding with callable to be extra safe:
- if getattr(batch, "ndim", -1) == 0 and hasattr(batch, "item"): + if getattr(batch, "ndim", -1) == 0 and callable(getattr(batch, "item", None)):
605-606
: Remove now‑redundant 0‑d Tensor branch.The new early check (Lines 600‑601) already handles 0‑d tensors, so this block is dead code. Simplify:
if isinstance(batch, torch.Tensor): if detach: batch = batch.detach() - if batch.ndim == 0: - return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0)
527-535
: Mark 0‑d NumPy arrays as non‑iterable to avoid zip TypeError when detach=False.When decollating mappings/lists containing a 0‑d NumPy array and a sized sequence with
pad=True
,zip(*deco.values())
can attempt to iterate the 0‑d array and raiseTypeError: iteration over a 0-d array
. Mirror the new scalar logic here:- for k, v in _deco.items() if isinstance(_deco, Mapping) else enumerate(_deco): - if not isinstance(v, Iterable) or isinstance(v, (str, bytes)) or (isinstance(v, torch.Tensor) and v.ndim == 0): + for k, v in _deco.items() if isinstance(_deco, Mapping) else enumerate(_deco): + is_scalarish = getattr(v, "ndim", -1) == 0 # covers torch/NumPy 0-d + if not isinstance(v, Iterable) or isinstance(v, (str, bytes)) or is_scalarish: non_iterable.append(k) elif isinstance(v, Sized): batch_size = max(batch_size, len(v))This keeps behavior consistent for
detach=False
and avoids surprises in mixed containers.
588-590
: Docstring clarity: include NumPy 0‑d arrays.Mention that 0‑d NumPy arrays are also returned as Python scalars when
detach=True
:- detach: whether to detach the tensors. Scalars tensors will be detached into number types - instead of torch tensors. + detach: whether to detach tensors. 0‑d tensors/NumPy arrays will be returned as Python numbers + when True; otherwise kept as tensor/ndarray.
538-593
: Add minimal tests for 0‑d cases (NumPy/Torch, detach True/False, in containers).Recommend adding unit tests to lock behavior:
- bare inputs:
decollate_batch(np.array(1), detach=True/False)
decollate_batch(torch.tensor(1), detach=True/False)
- inside Mapping with pad both True/False:
{"x": np.array(1), "y": [10, 20, 30]}
{"x": torch.tensor(1), "y": [10, 20, 30]}
- list of scalars:
[np.array(1), np.array(2)]
[torch.tensor(1), torch.tensor(2)]
I can draft these if helpful.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (1)
monai/data/utils.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/data/utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (18)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
3b30673
to
a8f8915
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
monai/data/utils.py (4)
633-635
: Remove now‑redundant 0‑D Tensor check.The new early return already handles 0‑D Tensors; this block is dead code and can go.
- if batch.ndim == 0: - return batch.item() if detach else batch
616-618
: Docstring: clarify 0‑D NumPy/Tensor behavior withdetach
.Update to reflect the new unified handling.
- detach: whether to detach the tensors. Scalars tensors will be detached into number types - instead of torch tensors. + detach: whether to detach the tensors. For 0‑D inputs (NumPy arrays or PyTorch tensors), + returns their Python scalar via ``item()`` when True; otherwise returns the original object.As per coding guidelines.
566-621
: Add unit tests for 0‑D NumPy/Tensor cases.Please add coverage to lock this in.
Example pytest snippets:
import numpy as np import torch from monai.data.utils import decollate_batch def test_decollate_batch_numpy_0d_detach_true(): a = np.array(7, dtype=np.int32) assert decollate_batch(a, detach=True) == 7 def test_decollate_batch_numpy_0d_detach_false(): a = np.array(7.0, dtype=np.float32) out = decollate_batch(a, detach=False) assert isinstance(out, np.ndarray) and out.ndim == 0 and out.item() == 7.0 def test_decollate_batch_nested_with_numpy_0d(): data = {"v": np.array(3)} out = decollate_batch(data, detach=True) assert out == {"v": 3} def test_decollate_batch_torch_0d_detach_true(): t = torch.tensor(5.0, requires_grad=True) assert decollate_batch(t, detach=True) == 5.0 def test_decollate_batch_torch_0d_detach_false(): t = torch.tensor(5) out = decollate_batch(t, detach=False) assert isinstance(out, torch.Tensor) and out.ndim == 0 and out.item() == 5Happy to open a follow‑up PR with these. As per coding guidelines.
628-629
: Nit: ensureitem
is callable.Guard against exotic objects where
item
exists but isn’t callable.- if getattr(batch, "ndim", -1) == 0 and hasattr(batch, "item"): + if getattr(batch, "ndim", -1) == 0 and callable(getattr(batch, "item", None)):
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (1)
monai/data/utils.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/data/utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
🔇 Additional comments (1)
monai/data/utils.py (1)
628-629
: Solid central fix for 0‑D arrays/tensors.Early
ndim == 0
+item()
handling neatly covers NumPy scalars and PyTorch scalars and avoids the iterator pitfall. Nice.
Signed-off-by: Arthur Dujardin <arthurdujardin.dev@gmail.com>
a5a4c45
to
1e0a554
Compare
Signed-off-by: Arthur Dujardin <arthurdujardin.dev@gmail.com>
Fixes #8471
Description
This PR supports numpy scalars (e.g. in the form of
np.array(1)
) in thedecollate_batch
function (fix issue #8471).Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.Summary by CodeRabbit