You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
PR tensorflow#22437: Added frontend attribute handling to explicit_stream_annotation_async_wrapper
Imported from GitHub PR openxla/xla#22437
This is a small change that ensures the frontend attributes are correctly passed to both the `async-start` and `async-done` created pairs. This also clears the scheduling attributes that are directly on the call operation and inner ops.
The specific goal of this change is to have stable support combining the scheduling group ids with stream annotation in JAX.
```python
with set_xla_metadata(_scheduling_group_id=1):
result = compute_on("gpu_stream:1")(jitted_func)(...)
```
Currently, the issue stems from the `set_xla_metadata` context manager, which will apply the frontend attribute to all operations, including the ones within our `jitted_func`. When the same scheduling annotations is found in two `HloComputation`s, an error is raised in `LegalizeSchedulingAnnotations`. This is intended to avoid hitting this check, and cleaning up the annotations on the wrapped streamed computation.
Copybara import of the project:
--
994c2eee3c946102270587681f5c17b994cbb6a9 by chaser <chaser@nvidia.com>:
Added frontend attributed handling
--
9db58b2b988dc2288d42126271223f924aac19f9 by chaser <chaser@nvidia.com>:
Added clearing of scheduling annotations
--
a83e32a34ba5d64a29c7f01b03536f27decd8125 by chaser <chaser@nvidia.com>:
Added HloInstruction.erase_frontend_attribute
Merging this change closestensorflow#22437
PiperOrigin-RevId: 731960979
0 commit comments