-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TFLITE NMS kernel Inconsistent Outputs and Out of Memory issues #553
Comments
@pkgoogle , is it possible to provide a tentative ETA for this. |
Hi @sdp009, I tried with AI-Edge-Torch for conversion but it is running into conversion issues with NMS. As neither workflow works, this is definitely a bug. My attempt: import torch
import torchvision
import ai_edge_torch
# Test inputs (same as TensorFlow example)
boxes = torch.tensor(
[
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.1, 1.0, 1.1],
[0.0, -0.1, 1.0, 0.9],
[0.0, 10.0, 1.0, 11.0],
[0.0, 10.1, 1.0, 11.1],
[0.0, 100.0, 1.0, 101.0],
]
).float()
scores = torch.tensor([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]).float()
idxs = torch.tensor([0, 1, 2, 3, 4, 5]).int()
iou_threshold = 0.5
# PyTorch NMS (Batched NMS for multiclass)
selected_indices = torchvision.ops.batched_nms(boxes, scores, idxs, iou_threshold)
print("PyTorch NMS indices:", selected_indices)
# PyTorch NMS model
class NMSModel(torch.nn.Module):
def __init__(self, iou_threshold):
super().__init__()
self.iou_threshold = iou_threshold
def forward(self, boxes, scores, idxs):
selected_indices = torchvision.ops.batched_nms(boxes, scores, idxs, self.iou_threshold)
return selected_indices
# Instantiate the model
model = NMSModel(iou_threshold)
edge_model = ai_edge_torch.convert(model.eval(), (boxes, scores, idxs))
edge_model.export("nms_model.tflite")
print("Edge NMS indices:", edge_model(boxes, scores, idxs)) my error: edge_model = ai_edge_torch.convert(model.eval(), (boxes, scores, idxs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/converter.py", line 261, in convert
return Converter().convert(
^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/converter.py", line 172, in convert
return conversion.convert_signatures(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/conversion.py", line 146, in convert_signatures
tflite_model = lowertools.exported_programs_to_tflite(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/lowertools/_shim.py", line 71, in exported_programs_to_tflite
bundles: list[utils.MlirBundle] = [
^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/lowertools/_shim.py", line 72, in <listcomp>
utils.exported_program_to_mlir(exported, sig.flat_args)
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/lowertools/odml_torch_utils.py", line 235, in exported_program_to_mlir
return odml_torch.export.exported_program_to_mlir(exported_program)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/odml_torch/export.py", line 344, in exported_program_to_mlir
interpreter.run(*temp_func.arguments, enable_io_processing=False)
File "/usr/local/google/home/pisethk/envs/aet_head/lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/odml_torch/export.py", line 102, in run_node
res = super().run_node(node)
^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/odml_torch/export.py", line 124, in call_function
raise RuntimeError(f"Lowering not found: {target}")
RuntimeError: Lowering not found: torchvision.nms.default |
Hi @pkgoogle, Although there is a conversion failure when translating torchvision.ops.batched_nms to the TFLite NMS op, I would suggest creating a separate Git issue for this at a lower priority. Currently, I am requesting a fix in the TFLITE Kernel implementation, where the output is not consistent with TF/PyTorch. The NMS TFLITE Kernel is unnecessarily appending zeros to the output until max_output_size. For large max_output_size, this causes OOM errors on edge devices. Please refer to the following link for the relevant code: non_max_suppression.cc#L190 Using the above TF code from issue description, please generate a sample TFLITE model and expedite the fix for the NMS Kernel implementation. NMS is a crucial operation for many Vision models, and having it fixed will significantly enhance many use-cases. Thank you. |
Hi @sdp009, thanks for the code pointer, that's good to know, I'm trying to understand the Gather segfault rationale that the code is meant to fix so that we can hopefully find a solution that fixes both... if you feel you have a good root solution, please feel free to recommend a PR if I don't see it first. Thanks. |
Original Issue: tensorflow/tensorflow#84033
Original Author: @sdp009
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
binary
TensorFlow version
tf 2.16.2, tf 2.18, tf 2.19.0-dev2024122
Custom code
Yes
OS platform and distribution
Linux Ubuntu 22
Mobile device
Android
Python version
3.10.12
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
The TFLITE NMS kernel output is not same as Tensorflow NMS output. Although the TFLITE NMS is a dynamic output shape layer, it is appending 0's in the "selected_indices" output till "max_output_size", defeating the purpose of dynamic output.
TFLITE NMS output must identically match with TF NMS output.
For large "max_output_size", the TFLITE NMS results in super slow computation and many times it goes Out-of-memory on Android devices. The subsequent Gather ops, after NMS suffers heavily due to appended 0's in the TFLITE NMS output.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/non_max_suppression.cc#L190C44-L190C59
Requesting to fix this behavior and ensure both TF and TFLITE NMS output are exactly same.
Standalone code to reproduce the issue
Relevant log output
TF output = tf.Tensor([3 0 5], shape=(3,), dtype=int32) TF Lite output = [3 0 5 ... 0 0 0] ; shape = (2147483647,)
The text was updated successfully, but these errors were encountered: