Skip to content

Conversation

JoeyTeng
Copy link

@JoeyTeng JoeyTeng commented Jun 2, 2023

This solves #331 #108 #47 #67 and is involved in #302

Example Colab (adopt majority of the code from Brax teams's Brax Training)

Known issue: the plane may not display in full correctly, due to unclipped vertices behind/at the camera plane. This will be solved once a proper clipping algorithm is implemented in JaxRenderer (will be delivered soon).
The plane rendering issue is now solved by implementing a rasterisation based on homogeneous interpolation.

Currently this only make changes in the v2 API, with minimal changes. Feel free to tell me if you think it is needed to back port to v1 API as well. I will try to optimise the performance further later.

Update:

  1. with 8a111d9, I jitted the rendering for each frame when renders a batch of states.
  2. 828848a refactor batch rendering a bit but with no significant improvement in performance (if there is any), see simple benchmark here
  3. with 0.3.0: Performance Improvement JoeyTeng/jaxrenderer#2 (release 0.3.0), the performance is improved to about 10x. Rendering one frame of the Ant environment with 960x540 resolution and 1x SSAA is now about 500ms using T4 (0.5 fps), and ~70-100ms using A100 (10-14 fps). Previous CPU implementation is about 200ms per frame (5fps).
  4. with Lower minimum Python version to 3.8; Improve typing annotations JoeyTeng/jaxrenderer#3 (release 0.3.1), the minimum Python version is lowered to Python 3.8 which is the same as brax.

@JoeyTeng JoeyTeng marked this pull request as ready for review June 3, 2023 15:44
@JoeyTeng
Copy link
Author

Friendly ping @erikfrey . With the release of jaxrenderer 0.3.0, I believe this should be a suitable replacement for current CPU renderer pytinyrenderer as 1) it is a pure JAX implementation; and 2) it performs better on high-end GPUs like A100 (>2x speedup). Let me know what you think :)

@erikfrey
Copy link
Contributor

Hi Joey - thanks for this tremendous effort! Unfortunately we're bound by certain restrictions that make it very difficult for us to depend on external packages unless they go through an arduous vetting process behind the scenes.

I'm curious how your approach performs on CPU?

You've probably already found that using XLA on its own limits your performance here, as XLA wasn't designed for rasterization-like operations and it does not know how to use the underlying GPU primitives to actually do this performantly.

If you're interested, there is a way to register custom XLA ops, so that you could call into CUDA's rasterization functions. Tensorflow graphics does such a thing:

https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/rendering/opengl/rasterizer_op.cc

You could then make such an op accessible to JAX. This would be very fast, but also quite tricky to get all the plumbing to work. Just thought I'd put it out there in case you're curious to hack on such a thing.

Either way though, it would be hard for us to accept this PR due to the policies that control which external packages we can rely on. Sorry about that! But I'll leave the issue open for a while - I'd love to hear if you continue hacking on this.

@JoeyTeng
Copy link
Author

Thank you so much for your suggestions! I am also thinking about customised op as that would dramatically improve the performance.

For the current performance, using high-end GPUs the rendering will actually be faster than CPU; current implementation is not optimised for TPU so the performance is really bad on TPUs, both for execution time and memory usage (up to 98% padding...). I will benchmark and improve the performance over a batch of small images (e.g. 84x84, as this is used in RL for Atori environment) to see if my implementation could benefit from rendering batches of environments in parallel (although we could definitely benefit if we are using a TPU Pod and pmap/xmap all environment simulation + rendering over devices).

@sai-prasanna
Copy link

How about brax-contrib package or something like that where this can be added?

@btaba
Copy link
Collaborator

btaba commented Jan 31, 2025

We have since released demos with Madrona-MJX, I strongly recommend folks try that out for large-batch rendering on GPU. https://github.com/shacklettbp/madrona_mjx

@btaba
Copy link
Collaborator

btaba commented Apr 23, 2025

Closing the PR since we recommend trying out https://github.com/shacklettbp/madrona_mjx as referenced here https://playground.mujoco.org/

@btaba btaba closed this Apr 23, 2025
@btaba
Copy link
Collaborator

btaba commented Apr 23, 2025

Never mind re-opening the PR as it is a great pedagogical example! But we don't expect to merge this in

@btaba btaba reopened this Apr 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants