Skip to content

Conversation

arthurdjn
Copy link

@arthurdjn arthurdjn commented Jun 3, 2025

Fixes #8471

Description

This PR supports numpy scalars (e.g. in the form of np.array(1) ) in the decollate_batch function (fix issue #8471).

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Summary by CodeRabbit

  • New Features
    • Improved batch decollation to seamlessly handle scalar inputs (0‑dimensional values), producing correct outputs without relying on tensor-specific behavior.
  • Bug Fixes
    • Resolved issues where scalar values in batches could lead to unexpected behavior during decollation.
    • Ensures consistent results whether detaching or not, reducing edge-case errors for scalar data.

@arthurdjn arthurdjn force-pushed the support-decollate-batch-numpy-scalars branch 5 times, most recently from 187c141 to c438fe0 Compare June 3, 2025 13:03
@arthurdjn arthurdjn marked this pull request as ready for review June 3, 2025 14:20
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
):
return batch
if isinstance(batch, np.ndarray) and batch.ndim == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr! Do you think it might be beneficial to convert the array into a tensor? This way, the data could be handled more consistently.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, I think it does not matter for my use cases. As long as the function handles numpy scalars in the form of an array it is good for me!

I will add this change and convert it as a tensor there (L629) if you prefer :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick fix!
May I ask the reason for only convert to tensor when batch.ndim == 0 here?

Copy link
Author

@arthurdjn arthurdjn Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed a different behavior when using the decollate_batch function on torch tensors vs numpy arrays (see discussion #8472) so I don't want to convert numpy arrays to torch tensors as it will introduce some breaking changes

This PR only address the issue #8471 as I think it was not expected and should be supported (?).

@arthurdjn arthurdjn force-pushed the support-decollate-batch-numpy-scalars branch from 451c207 to 49d4954 Compare June 4, 2025 09:00
@arthurdjn arthurdjn requested a review from KumoLiu June 4, 2025 09:01
@ericspod
Copy link
Member

ericspod commented Jun 10, 2025

Could we consider a more complete solution? The issue it seems is that 0-d arrays are iterable but can't be iterated over. We already check for non-iterable things in decollate_batch here. Can we modeify this to correctly pick up when the batch is a 0-d array and just return it in that case? Or return its contents?

@arthurdjn
Copy link
Author

arthurdjn commented Jun 10, 2025

Thanks for the feedback. The initial PR was:

if isinstance(batch, (float, int, str, bytes)) or (
    type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
):
    return batch
if isinstance(batch, np.ndarray) and batch.ndim == 0:
    return batch.item() if detach else batch
# rest ...

Is this something that you find more complete?

Note

I refactored the PR to convert from numpy array to torch tensor as suggested by @KumoLiu.

@ericspod
Copy link
Member

Is this something that you find more complete?

What I had in mind was more of the following change:

...
    if batch is None or isinstance(batch, (float, int, str, bytes)):
        return batch
    if getattr(batch, "ndim", -1) == 0:  # assumes only Numpy objects and Pytorch tensors have ndim
        return batch.item() if detach else batch
    if isinstance(batch, torch.Tensor):
        if detach:
            batch = batch.detach()
# REMOVE
#      if batch.ndim == 0:
#          return batch.item() if detach else batch
...

@arthurdjn
Copy link
Author

Thanks! I will update the PR to include these changes.

Copy link
Contributor

coderabbitai bot commented Sep 5, 2025

Walkthrough

Updated decollate_batch in monai/data/utils.py to handle 0‑dimensional inputs earlier. If the input batch has ndim == 0 and supports item(), the function returns batch.item() when detach is True; otherwise it returns the batch unchanged. Existing behavior for non-tensor inputs and higher-dimensional tensor inputs remains unchanged. No changes to public APIs or exports.

Changes

Cohort / File(s) Summary
decollate scalar pre-check
monai/data/utils.py
Added an early check in decollate_batch for 0‑dimensional (scalar-like) inputs that returns batch.item() when detach is True (else returns the original batch). Removed the previous 0‑d handling inside the Tensor-specific branch so scalar handling occurs before generic unbinding; retain existing detach behavior.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant decollate_batch
    opt Old flow (prior to change)
        Caller->>decollate_batch: call with batch
        decollate_batch->>decollate_batch: is Tensor?
        decollate_batch->>decollate_batch: if Tensor and ndim==0 -> return item()
        decollate_batch->>Caller: return result
    end
    opt New flow (after change)
        Caller->>decollate_batch: call with batch
        decollate_batch->>decollate_batch: pre-check ndim==0 and has item()
        decollate_batch->>decollate_batch: if true and detach -> return item()
        decollate_batch->>decollate_batch: else proceed with existing branches (Tensor/non-Tensor)
        decollate_batch->>Caller: return result
    end
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

I'm a rabbit in the code-glade, hops so spry,
I found a tiny scalar and gave it sky-high.
A pre-check nibble, tidy and bright,
Now scalars hop out, snug and light. 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The title “fix: support decollate for numpy scalars” clearly and concisely describes the primary change in this pull request by highlighting the added support for numpy scalar inputs in the decollate_batch function. It uses a conventional commit prefix and avoids unnecessary detail or file listings, making it easy for teammates to understand the main purpose at a glance. The phrasing directly matches the core functionality being modified and contains no vague or generic terms.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed The pull request description includes the required “Fixes” reference, a clear “### Description” section summarizing the change, and a “### Types of changes” checkbox list mirroring the repository template; the author has marked all applicable items and provided sufficient context about the non-breaking nature of the update and the successful test runs.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
monai/data/utils.py (3)

628-635: Prefer generic 0-D detection early; drop special-case conversion.

Handle any object with ndim==0 up-front, then the torch branch can skip its own scalar check. This reduces duplication and avoids unnecessary conversion for NumPy scalars when detach=False.

Apply this diff (and remove the torch scalar early-return below):

@@
-    if isinstance(batch, (float, int, str, bytes)) or (
-        type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
-    ):
-        return batch
-    if isinstance(batch, np.ndarray) and batch.ndim == 0:
-        batch = torch.from_numpy(batch)
+    if isinstance(batch, (float, int, str, bytes)) or (
+        type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
+    ):
+        return batch
+    # generic 0-D objects (NumPy/Torch/others exposing ndim): treat as scalars
+    if getattr(batch, "ndim", -1) == 0:
+        return batch.item() if detach else batch
@@
-    if isinstance(batch, torch.Tensor):
-        if detach:
-            batch = batch.detach()
-        if batch.ndim == 0:
-            return batch.item() if detach else batch
+    if isinstance(batch, torch.Tensor):
+        if detach:
+            batch = batch.detach()
+        # ndim==0 handled above

628-630: Graceful fallback for unsupported NumPy dtypes.

torch.from_numpy will fail for some dtypes (e.g., datetime64, object). Consider catching and falling back to .item() or returning the ndarray unchanged.

-    if isinstance(batch, np.ndarray) and batch.ndim == 0:
-        batch = torch.from_numpy(batch)
+    if isinstance(batch, np.ndarray) and batch.ndim == 0:
+        try:
+            batch = torch.from_numpy(batch)
+        except (TypeError, ValueError):
+            return batch.item() if detach else batch

614-621: Doc/test touch-up for 0-D NumPy arrays.

Please note 0-D NumPy array behavior in the decollate_batch docstring and add unit tests for:

  • np.array(1) with detach=True/False
  • Nested structures containing 0-D arrays
  • Edge dtype (e.g., bool, float32)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 0968da2 and c1339ec.

📒 Files selected for processing (1)
  • monai/data/utils.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/data/utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
🔇 Additional comments (1)
monai/data/utils.py (1)

628-635: Good fix: avoids 0-D ndarray iteration error and aligns with tensor path.

Converting 0-D NumPy arrays to torch first prevents the TypeError raised by iterating 0-D arrays and lets the existing scalar-tensor handling (.item() when detach=True) kick in. Looks correct.

@arthurdjn arthurdjn force-pushed the support-decollate-batch-numpy-scalars branch 2 times, most recently from ec3e5d9 to f9dba63 Compare September 5, 2025 10:28
@ericspod
Copy link
Member

Sorry for the long delay again. If this is all that's needed in the end we can merge this shortly. I think lines 605-6 are not redundant since the added if condition will catch this case?

@ericspod
Copy link
Member

Also please have a look at the DCO message for doing a remedial commit.

@arthurdjn arthurdjn force-pushed the support-decollate-batch-numpy-scalars branch from 0e43b28 to 96e950c Compare September 29, 2025 09:16
@arthurdjn arthurdjn force-pushed the support-decollate-batch-numpy-scalars branch from 96e950c to 4e17ad5 Compare September 29, 2025 09:18
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
monai/data/utils.py (5)

600-601: Early scalar handling looks good; tiny robustness nit.

The generic ndim==0 fast-path correctly covers NumPy and Torch 0‑d cases. Consider guarding with callable to be extra safe:

- if getattr(batch, "ndim", -1) == 0 and hasattr(batch, "item"):
+ if getattr(batch, "ndim", -1) == 0 and callable(getattr(batch, "item", None)):

605-606: Remove now‑redundant 0‑d Tensor branch.

The new early check (Lines 600‑601) already handles 0‑d tensors, so this block is dead code. Simplify:

 if isinstance(batch, torch.Tensor):
     if detach:
         batch = batch.detach()
-    if batch.ndim == 0:
-        return batch.item() if detach else batch
     out_list = torch.unbind(batch, dim=0)

527-535: Mark 0‑d NumPy arrays as non‑iterable to avoid zip TypeError when detach=False.

When decollating mappings/lists containing a 0‑d NumPy array and a sized sequence with pad=True, zip(*deco.values()) can attempt to iterate the 0‑d array and raise TypeError: iteration over a 0-d array. Mirror the new scalar logic here:

-    for k, v in _deco.items() if isinstance(_deco, Mapping) else enumerate(_deco):
-        if not isinstance(v, Iterable) or isinstance(v, (str, bytes)) or (isinstance(v, torch.Tensor) and v.ndim == 0):
+    for k, v in _deco.items() if isinstance(_deco, Mapping) else enumerate(_deco):
+        is_scalarish = getattr(v, "ndim", -1) == 0  # covers torch/NumPy 0-d
+        if not isinstance(v, Iterable) or isinstance(v, (str, bytes)) or is_scalarish:
             non_iterable.append(k)
         elif isinstance(v, Sized):
             batch_size = max(batch_size, len(v))

This keeps behavior consistent for detach=False and avoids surprises in mixed containers.


588-590: Docstring clarity: include NumPy 0‑d arrays.

Mention that 0‑d NumPy arrays are also returned as Python scalars when detach=True:

-        detach: whether to detach the tensors. Scalars tensors will be detached into number types
-            instead of torch tensors.
+        detach: whether to detach tensors. 0‑d tensors/NumPy arrays will be returned as Python numbers
+            when True; otherwise kept as tensor/ndarray.

538-593: Add minimal tests for 0‑d cases (NumPy/Torch, detach True/False, in containers).

Recommend adding unit tests to lock behavior:

  • bare inputs:
    • decollate_batch(np.array(1), detach=True/False)
    • decollate_batch(torch.tensor(1), detach=True/False)
  • inside Mapping with pad both True/False:
    • {"x": np.array(1), "y": [10, 20, 30]}
    • {"x": torch.tensor(1), "y": [10, 20, 30]}
  • list of scalars:
    • [np.array(1), np.array(2)]
    • [torch.tensor(1), torch.tensor(2)]

I can draft these if helpful.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 0e43b28 and ab61dff.

📒 Files selected for processing (1)
  • monai/data/utils.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/data/utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (18)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (windows-latest)

@arthurdjn arthurdjn force-pushed the support-decollate-batch-numpy-scalars branch 2 times, most recently from 3b30673 to a8f8915 Compare September 29, 2025 09:48
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
monai/data/utils.py (4)

633-635: Remove now‑redundant 0‑D Tensor check.

The new early return already handles 0‑D Tensors; this block is dead code and can go.

-        if batch.ndim == 0:
-            return batch.item() if detach else batch

616-618: Docstring: clarify 0‑D NumPy/Tensor behavior with detach.

Update to reflect the new unified handling.

-        detach: whether to detach the tensors. Scalars tensors will be detached into number types
-            instead of torch tensors.
+        detach: whether to detach the tensors. For 0‑D inputs (NumPy arrays or PyTorch tensors),
+            returns their Python scalar via ``item()`` when True; otherwise returns the original object.

As per coding guidelines.


566-621: Add unit tests for 0‑D NumPy/Tensor cases.

Please add coverage to lock this in.

Example pytest snippets:

import numpy as np
import torch
from monai.data.utils import decollate_batch

def test_decollate_batch_numpy_0d_detach_true():
    a = np.array(7, dtype=np.int32)
    assert decollate_batch(a, detach=True) == 7

def test_decollate_batch_numpy_0d_detach_false():
    a = np.array(7.0, dtype=np.float32)
    out = decollate_batch(a, detach=False)
    assert isinstance(out, np.ndarray) and out.ndim == 0 and out.item() == 7.0

def test_decollate_batch_nested_with_numpy_0d():
    data = {"v": np.array(3)}
    out = decollate_batch(data, detach=True)
    assert out == {"v": 3}

def test_decollate_batch_torch_0d_detach_true():
    t = torch.tensor(5.0, requires_grad=True)
    assert decollate_batch(t, detach=True) == 5.0

def test_decollate_batch_torch_0d_detach_false():
    t = torch.tensor(5)
    out = decollate_batch(t, detach=False)
    assert isinstance(out, torch.Tensor) and out.ndim == 0 and out.item() == 5

Happy to open a follow‑up PR with these. As per coding guidelines.


628-629: Nit: ensure item is callable.

Guard against exotic objects where item exists but isn’t callable.

-    if getattr(batch, "ndim", -1) == 0 and hasattr(batch, "item"):
+    if getattr(batch, "ndim", -1) == 0 and callable(getattr(batch, "item", None)):
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between ab61dff and 3b30673.

📒 Files selected for processing (1)
  • monai/data/utils.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/data/utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: packaging
🔇 Additional comments (1)
monai/data/utils.py (1)

628-629: Solid central fix for 0‑D arrays/tensors.

Early ndim == 0 + item() handling neatly covers NumPy scalars and PyTorch scalars and avoids the iterator pitfall. Nice.

Signed-off-by: Arthur Dujardin <arthurdujardin.dev@gmail.com>
@arthurdjn arthurdjn force-pushed the support-decollate-batch-numpy-scalars branch from a5a4c45 to 1e0a554 Compare September 29, 2025 09:51
Signed-off-by: Arthur Dujardin <arthurdujardin.dev@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

decollate batch different behavior with numpy and torch for scalars
3 participants