Skip to content

Zero tensor in transformer_z_image #12742

@lime-j

Description

@lime-j

Describe the bug

In ‎src/diffusers/models/transformers/transformer_z_image.py, there exists a image_padding_len for padding image to be multiple of SEQ_MULTI_OF. However, when the image shape is already multiple of SEQ_MULTI_OF, this will create a tensor with zero shape. This triggers INVALID_ARGUMENT: Concatenate expects at least one argument. for PyTorch/XLA on TPU. It may also fail on other devices than cuda. I created a PR to fix this.

Reproduction

  import torch_xla
  torch_xla.experimental.eager_mode(True)
  device = torch_xla.device()
  model = ZImageModule.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16).to(device)
  prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."

  # 2. Generate Image
  image = model(
      prompt=prompt,
      height=1024,
      width=1024,
      num_inference_steps=9,  # This actually results in 8 DiT forwards
      guidance_scale=0.0,     # Guidance should be 0 for the Turbo models
  ).images[0]

  image.save("example.png")

Logs

Here is what I got on TPU v4-8:


F1128 18:19:17.110023 2059309 debug_macros.h:21] Non-OK-status: status.status()
Status: INVALID_ARGUMENT: Concatenate expects at least one argument.
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()
        xla::Shape const* ConsumeValue<xla::Shape const*>(absl::lts_20230802::StatusOr<xla::Shape const*>&&)
        torch_xla::ShapeHelper::ShapeOfXlaOp(xla::XlaOp)
        torch_xla::InferOutputShape(absl::lts_20230802::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20230802::Span<xla::XlaOp const>)> const&)
        torch_xla::RepeatOutputShape(torch::lazy::Value const&, absl::lts_20230802::Span<long const>)
        std::_Function_handler<xla::Shape (), torch_xla::Repeat::Repeat(torch::lazy::Value const&, std::vector<long, std::allocator<long> > const&)::{lambda()#1}>::_M_invoke(std::_Any_data const&)
        torch_xla::XlaNode::GetOpShape(std::function<xla::Shape ()> const&) const
        torch_xla::XlaNode::XlaNode(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, unsigned long, torch::lazy::hash_t)
        torch_xla::Repeat::Repeat(torch::lazy::Value const&, std::vector<long, std::allocator<long> > const&)
        std::shared_ptr<torch::lazy::Node> torch_xla::MakeNode<torch_xla::Repeat, torch::lazy::Value, std::vector<long, std::allocator<long> > >(torch::lazy::Value&&, std::vector<long, std::allocator<long> >&&)
        torch_xla::XLANativeFunctions::repeat(at::Tensor const&, c10::ArrayRef<long>)



        at::_ops::repeat::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>)


        at::_ops::repeat::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>)


        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault
        _PyObject_FastCallDictTstate
        _PyObject_Call_Prepend

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        PyObject_Call
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyObject_FastCallDictTstate
        _PyObject_Call_Prepend

        PyObject_Call
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyObject_FastCallDictTstate
        _PyObject_Call_Prepend

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        PyEval_EvalCode



        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain

        __libc_start_main
        _start
*** End stack trace ***

*** Check failure stack trace: ***
    @     0x7ff6e2df191f  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7ff6dbea620d  ConsumeValue<>()
    @     0x7ff6dbea60d4  torch_xla::ShapeHelper::ShapeOfXlaOp()
    @     0x7ff6dbe49b31  torch_xla::InferOutputShape()
    @     0x7ff6dbe79672  torch_xla::RepeatOutputShape()
    @     0x7ff6dbe0c26c  std::_Function_handler<>::_M_invoke()
    @     0x7ff6dbe9c5b9  torch_xla::XlaNode::GetOpShape()
    @     0x7ff6dbe9d093  torch_xla::XlaNode::XlaNode()
    @     0x7ff6dbe124fb  torch_xla::Repeat::Repeat()
    @     0x7ff6dbe18117  torch_xla::MakeNode<>()
    @     0x7ff6dbe04738  torch_xla::XLANativeFunctions::repeat()
    @     0x7ff6dbd63e46  at::(anonymous namespace)::(anonymous namespace)::wrapper_XLA__repeat()
    @     0x7ff6dbdc6f4c  c10::impl::make_boxed_from_unboxed_functor<>::call()
    @     0x7ff871bf73dc  (anonymous namespace)::functionalizeFallback()
    @     0x7ff872b1a8f1  at::_ops::repeat::redispatch()
https://symbolize.stripped_domain/r/?trace=7ff8938969fc,7ff89384251f&map= 
*** SIGABRT received by PID 2059309 (TID 2059309) on cpu 72 from PID 2059309; stack trace: ***
PC: @     0x7ff8938969fc  (unknown)  pthread_kill
    @     0x7ff6329b47e5       1904  (unknown)
    @     0x7ff893842520  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7ff8938969fc,7ff6329b47e4,7ff89384251f&map= 
E1128 18:19:17.457902 2059309 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked.
E1128 18:19:17.457917 2059309 coredump_hook.cc:340] RAW: Skipping coredump since rlimit was 0 at process start.
E1128 18:19:17.457922 2059309 client.cc:270] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E1128 18:19:17.457927 2059309 coredump_hook.cc:396] RAW: Sending fingerprint to remote end.
E1128 18:19:17.457947 2059309 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1128 18:19:17.457952 2059309 coredump_hook.cc:457] RAW: Dumping core locally.
E1128 18:19:17.617125 2059309 process_state.cc:808] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

System Info

I'm on a TPU v4-8.

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • 🤗 Diffusers version: 0.36.0.dev0
  • Platform: Linux-5.19.0-1022-gcp-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.8.0+cu128 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.36.0
  • Transformers version: 4.57.3
  • Accelerate version: 1.12.0
  • PEFT version: 0.18.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.7.0
  • xFormers version: not installed
  • Accelerator: NA
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@yiyixuxu @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions