-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
In src/diffusers/models/transformers/transformer_z_image.py, there exists a image_padding_len for padding image to be multiple of SEQ_MULTI_OF. However, when the image shape is already multiple of SEQ_MULTI_OF, this will create a tensor with zero shape. This triggers INVALID_ARGUMENT: Concatenate expects at least one argument. for PyTorch/XLA on TPU. It may also fail on other devices than cuda. I created a PR to fix this.
Reproduction
import torch_xla
torch_xla.experimental.eager_mode(True)
device = torch_xla.device()
model = ZImageModule.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16).to(device)
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
# 2. Generate Image
image = model(
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=9, # This actually results in 8 DiT forwards
guidance_scale=0.0, # Guidance should be 0 for the Turbo models
).images[0]
image.save("example.png")
Logs
Here is what I got on TPU v4-8:
F1128 18:19:17.110023 2059309 debug_macros.h:21] Non-OK-status: status.status()
Status: INVALID_ARGUMENT: Concatenate expects at least one argument.
*** Begin stack trace ***
tsl::CurrentStackTrace[abi:cxx11]()
xla::Shape const* ConsumeValue<xla::Shape const*>(absl::lts_20230802::StatusOr<xla::Shape const*>&&)
torch_xla::ShapeHelper::ShapeOfXlaOp(xla::XlaOp)
torch_xla::InferOutputShape(absl::lts_20230802::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20230802::Span<xla::XlaOp const>)> const&)
torch_xla::RepeatOutputShape(torch::lazy::Value const&, absl::lts_20230802::Span<long const>)
std::_Function_handler<xla::Shape (), torch_xla::Repeat::Repeat(torch::lazy::Value const&, std::vector<long, std::allocator<long> > const&)::{lambda()#1}>::_M_invoke(std::_Any_data const&)
torch_xla::XlaNode::GetOpShape(std::function<xla::Shape ()> const&) const
torch_xla::XlaNode::XlaNode(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, unsigned long, torch::lazy::hash_t)
torch_xla::Repeat::Repeat(torch::lazy::Value const&, std::vector<long, std::allocator<long> > const&)
std::shared_ptr<torch::lazy::Node> torch_xla::MakeNode<torch_xla::Repeat, torch::lazy::Value, std::vector<long, std::allocator<long> > >(torch::lazy::Value&&, std::vector<long, std::allocator<long> >&&)
torch_xla::XLANativeFunctions::repeat(at::Tensor const&, c10::ArrayRef<long>)
at::_ops::repeat::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>)
at::_ops::repeat::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>)
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
PyObject_Call
_PyEval_EvalFrameDefault
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
PyEval_EvalCode
_PyRun_SimpleFileObject
_PyRun_AnyFileObject
Py_RunMain
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***
*** Check failure stack trace: ***
@ 0x7ff6e2df191f absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x7ff6dbea620d ConsumeValue<>()
@ 0x7ff6dbea60d4 torch_xla::ShapeHelper::ShapeOfXlaOp()
@ 0x7ff6dbe49b31 torch_xla::InferOutputShape()
@ 0x7ff6dbe79672 torch_xla::RepeatOutputShape()
@ 0x7ff6dbe0c26c std::_Function_handler<>::_M_invoke()
@ 0x7ff6dbe9c5b9 torch_xla::XlaNode::GetOpShape()
@ 0x7ff6dbe9d093 torch_xla::XlaNode::XlaNode()
@ 0x7ff6dbe124fb torch_xla::Repeat::Repeat()
@ 0x7ff6dbe18117 torch_xla::MakeNode<>()
@ 0x7ff6dbe04738 torch_xla::XLANativeFunctions::repeat()
@ 0x7ff6dbd63e46 at::(anonymous namespace)::(anonymous namespace)::wrapper_XLA__repeat()
@ 0x7ff6dbdc6f4c c10::impl::make_boxed_from_unboxed_functor<>::call()
@ 0x7ff871bf73dc (anonymous namespace)::functionalizeFallback()
@ 0x7ff872b1a8f1 at::_ops::repeat::redispatch()
https://symbolize.stripped_domain/r/?trace=7ff8938969fc,7ff89384251f&map=
*** SIGABRT received by PID 2059309 (TID 2059309) on cpu 72 from PID 2059309; stack trace: ***
PC: @ 0x7ff8938969fc (unknown) pthread_kill
@ 0x7ff6329b47e5 1904 (unknown)
@ 0x7ff893842520 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7ff8938969fc,7ff6329b47e4,7ff89384251f&map=
E1128 18:19:17.457902 2059309 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked.
E1128 18:19:17.457917 2059309 coredump_hook.cc:340] RAW: Skipping coredump since rlimit was 0 at process start.
E1128 18:19:17.457922 2059309 client.cc:270] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E1128 18:19:17.457927 2059309 coredump_hook.cc:396] RAW: Sending fingerprint to remote end.
E1128 18:19:17.457947 2059309 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1128 18:19:17.457952 2059309 coredump_hook.cc:457] RAW: Dumping core locally.
E1128 18:19:17.617125 2059309 process_state.cc:808] RAW: Raising signal 6 with default behavior
Aborted (core dumped)System Info
I'm on a TPU v4-8.
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
- 🤗 Diffusers version: 0.36.0.dev0
- Platform: Linux-5.19.0-1022-gcp-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.8.0+cu128 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.36.0
- Transformers version: 4.57.3
- Accelerate version: 1.12.0
- PEFT version: 0.18.0
- Bitsandbytes version: not installed
- Safetensors version: 0.7.0
- xFormers version: not installed
- Accelerator: NA
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working