Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFLITE NMS kernel Inconsistent Outputs and Out of Memory issues #553

Open
pkgoogle opened this issue Jan 22, 2025 · 4 comments
Open

TFLITE NMS kernel Inconsistent Outputs and Out of Memory issues #553

pkgoogle opened this issue Jan 22, 2025 · 4 comments
Assignees
Labels
type:bug Bug type:memory An issue with memory, memory performance, or memory leaks type:precision/accuracy For issues where the precision/accuracy appear incorrect

Comments

@pkgoogle
Copy link

Original Issue: tensorflow/tensorflow#84033
Original Author: @sdp009

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

binary

TensorFlow version

tf 2.16.2, tf 2.18, tf 2.19.0-dev2024122

Custom code

Yes

OS platform and distribution

Linux Ubuntu 22

Mobile device

Android

Python version

3.10.12

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

The TFLITE NMS kernel output is not same as Tensorflow NMS output. Although the TFLITE NMS is a dynamic output shape layer, it is appending 0's in the "selected_indices" output till "max_output_size", defeating the purpose of dynamic output.

TFLITE NMS output must identically match with TF NMS output.

For large "max_output_size", the TFLITE NMS results in super slow computation and many times it goes Out-of-memory on Android devices. The subsequent Gather ops, after NMS suffers heavily due to appended 0's in the TFLITE NMS output.

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/non_max_suppression.cc#L190C44-L190C59

Requesting to fix this behavior and ensure both TF and TFLITE NMS output are exactly same.

image

Standalone code to reproduce the issue

import tensorflow as tf
import numpy as np

# Test inputs from : https://github.com/onnx/onnx/blob/main/docs/Operators.md#NonMaxSuppression : nonmaxsuppression_limit_output_size

boxes = np.array(
    [
        [0.0, 0.0, 1.0, 1.0],
        [0.0, 0.1, 1.0, 1.1],
        [0.0, -0.1, 1.0, 0.9],
        [0.0, 10.0, 1.0, 11.0],
        [0.0, 10.1, 1.0, 11.1],
        [0.0, 100.0, 1.0, 101.0],
    ]
).astype(np.float32)
scores = np.array([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]).astype(np.float32)

import tensorflow as tf

max_output_size = tf.constant(tf.int32.max, dtype=tf.int32)
iou_threshold = 0.5

selected_indices = tf.image.non_max_suppression(
    boxes, scores, max_output_size, iou_threshold
)
print(selected_indices)    # returns expected output : tf.Tensor([3 0 5], shape=(3,), dtype=int32)

@tf.function(input_signature=[
    tf.TensorSpec(shape=[6, 4], dtype=tf.float32),
    tf.TensorSpec(shape=[6], dtype=tf.float32),
])
def nms_function(boxes, scores):
    return tf.image.non_max_suppression(boxes, scores, max_output_size=tf.constant(tf.int32.max, dtype=tf.int32), iou_threshold=0.5)

concrete_function = nms_function.get_concrete_function()
print(concrete_function(boxes, scores))    # returns expected output : tf.Tensor([3 0 5], shape=(3,), dtype=int32)


converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function])
tflite_model = converter.convert()

with open('test_nms.tflite', 'wb') as f:
    f.write(tflite_model)

interpreter = tf.lite.Interpreter(model_path='test_nms.tflite')
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.set_tensor(input_details[0]['index'], boxes)
interpreter.set_tensor(input_details[1]['index'], scores)

interpreter.invoke()

selected_indices = interpreter.get_tensor(output_details[0]['index'])
print(selected_indices)    # returns incorrect output appended with 0's : [3 0 5 ... 0 0 0]
print(selected_indices.shape)    # incorrect output of shape : (2147483647,)

# above causes OOM error

Relevant log output

TF output = tf.Tensor([3 0 5], shape=(3,), dtype=int32)

TF Lite output = [3 0 5 ... 0 0 0]  ; shape = (2147483647,)
@pkgoogle pkgoogle self-assigned this Jan 22, 2025
@sdp009
Copy link

sdp009 commented Jan 23, 2025

@pkgoogle , is it possible to provide a tentative ETA for this.
Can we expect fix in upcoming TF release ?

@pkgoogle
Copy link
Author

pkgoogle commented Jan 24, 2025

Hi @sdp009, I tried with AI-Edge-Torch for conversion but it is running into conversion issues with NMS. As neither workflow works, this is definitely a bug.

My attempt:

import torch
import torchvision
import ai_edge_torch


# Test inputs (same as TensorFlow example)
boxes = torch.tensor(
    [
        [0.0, 0.0, 1.0, 1.0],
        [0.0, 0.1, 1.0, 1.1],
        [0.0, -0.1, 1.0, 0.9],
        [0.0, 10.0, 1.0, 11.0],
        [0.0, 10.1, 1.0, 11.1],
        [0.0, 100.0, 1.0, 101.0],
    ]
).float()

scores = torch.tensor([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]).float()
idxs = torch.tensor([0, 1, 2, 3, 4, 5]).int()

iou_threshold = 0.5

# PyTorch NMS (Batched NMS for multiclass)
selected_indices = torchvision.ops.batched_nms(boxes, scores, idxs, iou_threshold)
print("PyTorch NMS indices:", selected_indices)

# PyTorch NMS model
class NMSModel(torch.nn.Module):
    def __init__(self, iou_threshold):
        super().__init__()
        self.iou_threshold = iou_threshold

    def forward(self, boxes, scores, idxs):
        selected_indices = torchvision.ops.batched_nms(boxes, scores, idxs, self.iou_threshold)
        return selected_indices

# Instantiate the model
model = NMSModel(iou_threshold)

edge_model = ai_edge_torch.convert(model.eval(), (boxes, scores, idxs))
edge_model.export("nms_model.tflite")

print("Edge NMS indices:", edge_model(boxes, scores, idxs))

my error:

edge_model = ai_edge_torch.convert(model.eval(), (boxes, scores, idxs))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/converter.py", line 261, in convert
    return Converter().convert(
           ^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/converter.py", line 172, in convert
    return conversion.convert_signatures(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/conversion.py", line 146, in convert_signatures
    tflite_model = lowertools.exported_programs_to_tflite(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/lowertools/_shim.py", line 71, in exported_programs_to_tflite
    bundles: list[utils.MlirBundle] = [
                                      ^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/lowertools/_shim.py", line 72, in <listcomp>
    utils.exported_program_to_mlir(exported, sig.flat_args)
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/lowertools/odml_torch_utils.py", line 235, in exported_program_to_mlir
    return odml_torch.export.exported_program_to_mlir(exported_program)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/odml_torch/export.py", line 344, in exported_program_to_mlir
    interpreter.run(*temp_func.arguments, enable_io_processing=False)
  File "/usr/local/google/home/pisethk/envs/aet_head/lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/odml_torch/export.py", line 102, in run_node
    res = super().run_node(node)
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/odml_torch/export.py", line 124, in call_function
    raise RuntimeError(f"Lowering not found: {target}")
RuntimeError: Lowering not found: torchvision.nms.default

@pkgoogle pkgoogle added type:bug Bug type:precision/accuracy For issues where the precision/accuracy appear incorrect type:memory An issue with memory, memory performance, or memory leaks labels Jan 24, 2025
@sdp009
Copy link

sdp009 commented Jan 24, 2025

Hi @pkgoogle,

Although there is a conversion failure when translating torchvision.ops.batched_nms to the TFLite NMS op, I would suggest creating a separate Git issue for this at a lower priority.

Currently, I am requesting a fix in the TFLITE Kernel implementation, where the output is not consistent with TF/PyTorch.

The NMS TFLITE Kernel is unnecessarily appending zeros to the output until max_output_size. For large max_output_size, this causes OOM errors on edge devices.

Please refer to the following link for the relevant code: non_max_suppression.cc#L190

Using the above TF code from issue description, please generate a sample TFLITE model and expedite the fix for the NMS Kernel implementation. NMS is a crucial operation for many Vision models, and having it fixed will significantly enhance many use-cases.

Thank you.

@pkgoogle
Copy link
Author

Hi @sdp009, thanks for the code pointer, that's good to know, I'm trying to understand the Gather segfault rationale that the code is meant to fix so that we can hopefully find a solution that fixes both... if you feel you have a good root solution, please feel free to recommend a PR if I don't see it first. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:bug Bug type:memory An issue with memory, memory performance, or memory leaks type:precision/accuracy For issues where the precision/accuracy appear incorrect
Projects
None yet
Development

No branches or pull requests

2 participants