Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables compatibility between diffusers CPU offloading and xFuser paralleism #147

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

BBuf
Copy link

@BBuf BBuf commented Dec 20, 2024

The previous incompatibility was caused by diffusers not being aware of the local rank in distributed environments, which made it always assume it was rank 0. This led to the model.to(device) call at line 1174 in pipeline_utils.py constantly copying the DiT model from other ranks to rank 0, causing memory OOM issues.

The bug was fixed by passing the device corresponding to the local_rank to pipeline.enable_sequential_cpu_offload. As a result, diffusers' CPU offloading and xFuser parallelization can now be used together.

Copy link
Contributor

@feifeibear feifeibear left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@eppane
Copy link

eppane commented Dec 20, 2024

Running:

torchrun --nproc_per_node=4 sample_video.py --video-size 624 832 --video-length 129 --infer-steps 50 --prompt "a cat walks along the sidewalk of a city. The camera follows the cat at knee level. The city has many people and cars moving around, with advertisement billboards in the background" --flow-reverse --ulysses-degree 4 --ring-degree 1 --use-cpu-offload --save-path ./results

produces the following video:

sample_video_pr_147.mp4

@feifeibear any suggestions what is going on? Did you get correct output @BBuf? Thank you! 🙏

@BBuf
Copy link
Author

BBuf commented Dec 20, 2024

Running:

torchrun --nproc_per_node=4 sample_video.py --video-size 624 832 --video-length 129 --infer-steps 50 --prompt "a cat walks along the sidewalk of a city. The camera follows the cat at knee level. The city has many people and cars moving around, with advertisement billboards in the background" --flow-reverse --ulysses-degree 4 --ring-degree 1 --use-cpu-offload --save-path ./results

produces the following video:
sample_video_pr_147.mp4

@feifeibear any suggestions what is going on? Did you get correct output @BBuf? Thank you! 🙏

You can try turning off Offload. If the problem still persists, it indicates an issue with the model itself - it cannot properly generate videos at the 624x832 resolution.

@eppane
Copy link

eppane commented Dec 20, 2024

@BBuf thanks for the quick reply, I see. I was just following the "Supported Parallel Configurations" listed in here which indicates that 832 624 or 624 832 | 129 | 4x1,2x2,1x4 | 4 would work.

Without cpu offloading, produces OOM error:

2024-12-20 12:29:52.879 | INFO     | hyvideo.inference:predict:580 - Input (height, width, video_length) = (624, 832, 129)
2024-12-20 12:29:52.954 | DEBUG    | hyvideo.inference:predict:640 - 
                        height: 624
                         width: 832
                  video_length: 129
                        prompt: ['a cat walks along the sidewalk of a city. The camera follows the cat at knee level. The city has many people and cars moving around, with advertisement billboards in the background']
                    neg_prompt: ['Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion']
                          seed: None
                   infer_steps: 50
         num_videos_per_prompt: 1
                guidance_scale: 1.0
                      n_tokens: 66924
                    flow_shift: 7.0
       embedded_guidance_scale: 6.0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [10:22<00:00, 12.46s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [10:26<00:00, 12.53s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [10:23<00:00, 12.48s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [10:23<00:00, 12.47s/it]
[rank1]: Traceback (most recent call last):
[rank1]:   File "/workspaces/HunyuanVideo/sample_video.py", line 58, in <module>
[rank1]:     main()
[rank1]:   File "/workspaces/HunyuanVideo/sample_video.py", line 32, in main
[rank1]:     outputs = hunyuan_video_sampler.predict(
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspaces/HunyuanVideo/hyvideo/inference.py", line 646, in predict
[rank1]:     samples = self.pipeline(
[rank1]:               ^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspaces/HunyuanVideo/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py", line 1076, in __call__
[rank1]:     image = self.vae.decode(
[rank1]:             ^^^^^^^^^^^^^^^^
[rank1]:   File "/workspaces/diffusers/src/diffusers/utils/accelerate_utils.py", line 46, in wrapper
[rank1]:     return method(self, *args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspaces/HunyuanVideo/hyvideo/vae/autoencoder_kl_causal_3d.py", line 336, in decode
[rank1]:     decoded = self._decode(z).sample
[rank1]:               ^^^^^^^^^^^^^^^
[rank1]:   File "/workspaces/HunyuanVideo/hyvideo/vae/autoencoder_kl_causal_3d.py", line 301, in _decode
[rank1]:     return self.temporal_tiled_decode(z, return_dict=return_dict)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspaces/HunyuanVideo/hyvideo/vae/autoencoder_kl_causal_3d.py", line 512, in temporal_tiled_decode
[rank1]:     decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspaces/HunyuanVideo/hyvideo/vae/autoencoder_kl_causal_3d.py", line 443, in spatial_tiled_decode
[rank1]:     decoded = self.decoder(tile)
[rank1]:               ^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspaces/HunyuanVideo/hyvideo/vae/vae.py", line 281, in forward
[rank1]:     sample = up_block(sample, latent_embeds)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspaces/HunyuanVideo/hyvideo/vae/unet_causal_3d_blocks.py", line 762, in forward
[rank1]:     hidden_states = upsampler(hidden_states)
[rank1]:                     ^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspaces/HunyuanVideo/hyvideo/vae/unet_causal_3d_blocks.py", line 178, in forward
[rank1]:     hidden_states = self.conv(hidden_states)
[rank1]:                     ^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspaces/HunyuanVideo/hyvideo/vae/unet_causal_3d_blocks.py", line 74, in forward
[rank1]:     return self.conv(x)
[rank1]:            ^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 725, in forward
[rank1]:     return self._conv_forward(input, self.weight, self.bias)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/venv/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 720, in _conv_forward
[rank1]:     return F.conv3d(
[rank1]:            ^^^^^^^^^
[rank1]: torch.OutOfMemoryError: HIP out of memory. Tried to allocate 4.16 GiB. GPU 1 has a total capacity of 63.98 GiB of which 3.34 GiB is free. Of the allocated memory 55.96 GiB is allocated by PyTorch, and 2.68 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_HIP_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

@eppane
Copy link

eppane commented Dec 20, 2024

@BBuf what parameters did you try (assuming you got sensible output video)?

@BBuf
Copy link
Author

BBuf commented Dec 20, 2024

@BBuf what parameters did you try (assuming you got sensible output video)?

@eppane

I try follow command in A800 node:

torchrun --nproc_per_node=8 sample_video.py     --video-size 720 1280     --video-length 129     --infer-steps 30     --prompt "A cute rabbit family eating dinner in their burrow."   --use-cpu-offload  --flow-reverse --save-path ./results --ring-degree 4 --ulysses-degree 2 --seed 42

And the result is normal:

图片

@eppane
Copy link

eppane commented Dec 23, 2024

If the problem still persists, it indicates an issue with the model itself - it cannot properly generate videos at the 624x832 resolution.

Same thing with --video-size 720 1280. @BBuf have you experimented with other configurations, such as 4 GPUs and 1x4?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants