Skip to content

Refactor(handler part 17): extract MLX helper mixins#625

Merged
ChuxiJ merged 3 commits intoace-step:mainfrom
1larity:feat/handler-decomp-part-17
Feb 18, 2026
Merged

Refactor(handler part 17): extract MLX helper mixins#625
ChuxiJ merged 3 commits intoace-step:mainfrom
1larity:feat/handler-decomp-part-17

Conversation

@1larity
Copy link
Contributor

@1larity 1larity commented Feb 17, 2026

Summary

This PR delivers handler decomposition part 17 by extracting remaining MLX-specific helper logic from �cestep/handler.py into focused mixins while preserving runtime behaviour.

What Changed

  • Extracted MLX DiT initialization into MlxDitInitMixin.
  • Extracted MLX VAE initialization into MlxVaeInitMixin.
  • Extracted native MLX VAE decode helpers into MlxVaeDecodeNativeMixin.
  • Extracted native MLX VAE encode helpers into MlxVaeEncodeNativeMixin.
  • Wired new mixins via Acestep/core/generation/handler/init.py and AceStepHandler inheritance.
  • Removed the duplicated inline service_generate implementation from �cestep/handler.py so ServiceGenerateMixin remains the single active orchestration path.

Safety / Behaviour

  • No intentional behaviour changes in generation flow; this is a structural decomposition.
  • Added explicit resolver guards for MLX encode/decode callables to avoid eager fallback evaluation when mlx_vae is unset.

Tests Added

  • Acestep/core/generation/handler/mlx_dit_init_test.py
  • Acestep/core/generation/handler/mlx_vae_init_test.py
  • Acestep/core/generation/handler/mlx_vae_native_test.py

Validation Run

  • python acestep/core/generation/handler/mlx_dit_init_test.py (pass)
  • python acestep/core/generation/handler/mlx_vae_init_test.py (pass)
  • python acestep/core/generation/handler/mlx_vae_native_test.py (pass)
  • python acestep/core/generation/handler/service_generate_test.py (pass)
  • python acestep/core/generation/handler/init_service_test.py (pass)

Standards Compliance

  • Docstrings present for all new/modified modules, classes, and functions in scope.
  • All new files are within LOC hard cap (<= 200).
  • Changes are scoped to handler decomposition part 17 only.

Summary by CodeRabbit

  • New Features

    • MLX acceleration added for DiT and VAE paths, including native encode/decode with overlap-tiling for improved performance.
  • Refactor

    • MLX behavior modularized into reusable mixins, simplifying the main handler and reducing in-file MLX complexity.
  • Stability

    • Graceful fallback when MLX is unavailable or compilation fails; non-fatal warnings preserve functionality.
  • Tests

    • New unit tests covering MLX init, encode/decode, tiling, compilation success and fallback.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 17, 2026

📝 Walkthrough

Walkthrough

Extracts MLX-related initialization and VAE/DiT encode/decode implementations into four new mixin modules, exports them from the handler package, updates AceStepHandler to inherit those mixins, removes the prior in-file MLX implementations, and adds unit tests for each mixin.

Changes

Cohort / File(s) Summary
Handler Package Exports
acestep/core/generation/handler/__init__.py
Added imports and __all__ entries for four MLX mixins: MlxDitInitMixin, MlxVaeInitMixin, MlxVaeDecodeNativeMixin, MlxVaeEncodeNativeMixin.
MLX DiT Init + Tests
acestep/core/generation/handler/mlx_dit_init.py, acestep/core/generation/handler/mlx_dit_init_test.py
New MlxDitInitMixin with _init_mlx_dit(compile_model: bool=False) -> bool implementing guarded MLX DiT init, state updates, and graceful fallback; unit tests for unavailable and success paths.
MLX VAE Init + Tests
acestep/core/generation/handler/mlx_vae_init.py, acestep/core/generation/handler/mlx_vae_init_test.py
New MlxVaeInitMixin with _init_mlx_vae() -> bool that loads/converts VAE from PyTorch, optional FP16 casting, optional mx.compile of encode/decode, state tracking, fallback behavior; tests cover success, FP16/compile paths, and compile-failure fallback.
MLX VAE Encode/Decode + Tests
acestep/core/generation/handler/mlx_vae_decode_native.py, acestep/core/generation/handler/mlx_vae_encode_native.py, acestep/core/generation/handler/mlx_vae_native_test.py
New MlxVaeDecodeNativeMixin and MlxVaeEncodeNativeMixin providing native MLX encode/decode with dtype handling, chunking/overlap tiling, progress reporting, timing logs, and per-sample tiling logic; tests validate shapes, tiling behavior, progress updates, and error paths.
Handler Refactor
acestep/handler.py
Augmented AceStepHandler bases to include the four new MLX mixins and removed prior in-file MLX helpers and service_generate MLX-specific implementations.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Handler
  participant MlxInit as MlxVaeInitMixin
  participant MxCore as mlx.core
  participant MlxModel as MLX_VAE_Model

  Handler->>MlxInit: _init_mlx_vae()
  MlxInit->>MxCore: mlx_available() / import core utils
  alt mlx available
    MlxInit->>MlxModel: convert_and_load(from_pytorch_config)
    MlxModel-->>MlxInit: model instance
    MlxInit->>MxCore: mx.compile(decode/encode paths) (optional)
    MxCore-->>MlxInit: compiled callables or raises
    MlxInit-->>Handler: initialization success (use_mlx_vae=True, dtype, compiled flags)
  else mlx unavailable
    MlxInit-->>Handler: initialization failed (use_mlx_vae=False)
  end
Loading
sequenceDiagram
  autonumber
  participant Handler
  participant MlxDecode as MlxVaeDecodeNativeMixin
  participant MxCore as mlx.core
  participant DecodeFn as decode_fn (compiled or runtime)

  Handler->>MlxDecode: _mlx_vae_decode(latents_torch)
  MlxDecode->>MxCore: convert latents -> mx.array, cast dtype
  MlxDecode->>DecodeFn: call per-sample decode (tiling if needed)
  DecodeFn-->>MlxDecode: decoded sample arrays
  MlxDecode->>MxCore: mx.eval / mx.concatenate / mx.clear_cache
  MlxDecode-->>Handler: returns torch tensor (batch, channels, samples)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • ChuxiJ

"I hop and nibble through the change,
New mixins set the MLX in range,
DiT and VAE now neatly split,
Handler light, my tail's a twitch —
Code carrots ready, let's arrange! 🥕"

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Refactor(handler part 17): extract MLX helper mixins' directly and clearly summarizes the main change: extracting MLX helper logic into mixins as part of handler decomposition. It is specific, concise, and accurately reflects the primary objective of the changeset.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@1larity 1larity marked this pull request as ready for review February 17, 2026 19:24
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
acestep/core/generation/handler/mlx_vae_init_test.py (1)

146-173: Add environment variable patch for test isolation.

This test doesn't patch ACESTEP_MLX_VAE_FP16, unlike test_init_mlx_vae_success_sets_compiled_callables. If the ambient environment has this variable set, the FP16 code path may execute differently, potentially causing non-deterministic behavior. Consider adding the patch for consistency:

🔧 Proposed fix to improve test isolation
         with patch.dict(
             sys.modules,
             {
                 "mlx": fake_mlx_pkg,
                 "mlx.core": fake_mx_core,
                 "mlx.utils": fake_utils,
                 "acestep.models.mlx": fake_mlx,
                 "acestep.models.mlx.vae_model": fake_vae_model,
                 "acestep.models.mlx.vae_convert": fake_vae_convert,
             },
         ):
-            self.assertTrue(host._init_mlx_vae())
+            with patch.dict(os.environ, {"ACESTEP_MLX_VAE_FP16": "0"}, clear=False):
+                self.assertTrue(host._init_mlx_vae())
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/mlx_vae_init_test.py` around lines 146 - 173,
The test test_init_mlx_vae_compile_failure_falls_back can be nondeterministic if
ACESTEP_MLX_VAE_FP16 is present in the environment; to isolate it, add an
environment patch inside the with patch.dict(...) context (or wrap the outer
with) to set ACESTEP_MLX_VAE_FP16 to a known value (e.g., "0" or remove it) so
the _init_mlx_vae codepath is deterministic. Modify the existing with
patch.dict(sys.modules, {...}) block to also patch os.environ (via
patch.dict(os.environ, {"ACESTEP_MLX_VAE_FP16": "0"}, clear=False) or include it
in a nested with) so host._init_mlx_vae(), host._mlx_compiled_decode and
host._mlx_compiled_encode_sample exercise the intended non-FP16 path
consistently.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/core/generation/handler/mlx_vae_init.py`:
- Around line 56-70: The logger always reports compiled=True even when
mx.compile() fails; update the init to record whether compilation succeeded
(e.g., a local boolean like compiled = True set in the try block and compiled =
False in the except) and then use that flag in the logger.info call instead of
the hardcoded "compiled=True"; ensure the try/except still assigns
self._mlx_compiled_decode and self._mlx_compiled_encode_sample (and
self.mlx_vae/self._mlx_vae_dtype/self.use_mlx_vae) as before but log f"[MLX-VAE]
Native MLX VAE initialized (dtype={vae_dtype}, compiled={compiled})." so the
message accurately reflects the outcome of mx.compile() for functions
mx.compile(mlx_vae.decode) and mx.compile(mlx_vae.encode_and_sample).

In `@acestep/core/generation/handler/mlx_vae_native_test.py`:
- Around line 97-120: The tests
test_mlx_decode_single_without_tiling_uses_decode_fn and
test_mlx_decode_single_with_tiling_concatenates_trimmed_chunks assign a lambda
to the variable decode_fn (triggering Ruff E731); replace each lambda assignment
with a local def function (e.g., def decode_fn(chunk): return np.repeat(chunk,
2, axis=1)) so the tests calling host._mlx_decode_single use a named local
function instead of a lambda; update both occurrences where decode_fn is
defined.

---

Nitpick comments:
In `@acestep/core/generation/handler/mlx_vae_init_test.py`:
- Around line 146-173: The test test_init_mlx_vae_compile_failure_falls_back can
be nondeterministic if ACESTEP_MLX_VAE_FP16 is present in the environment; to
isolate it, add an environment patch inside the with patch.dict(...) context (or
wrap the outer with) to set ACESTEP_MLX_VAE_FP16 to a known value (e.g., "0" or
remove it) so the _init_mlx_vae codepath is deterministic. Modify the existing
with patch.dict(sys.modules, {...}) block to also patch os.environ (via
patch.dict(os.environ, {"ACESTEP_MLX_VAE_FP16": "0"}, clear=False) or include it
in a nested with) so host._init_mlx_vae(), host._mlx_compiled_decode and
host._mlx_compiled_encode_sample exercise the intended non-FP16 path
consistently.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/core/generation/handler/mlx_vae_init.py`:
- Around line 12-17: Update the _init_mlx_vae docstring to briefly state its
purpose and list the key inputs it relies on (self.vae and the
ACESTEP_MLX_VAE_FP16 flag/env), the side effects/mutations it performs (sets
self.mlx_vae, self.use_mlx_vae, and updates/compiles any MLX VAE callables), the
returned value (bool), and any exceptions/errors it can raise or conditions when
it returns False; reference the method name _init_mlx_vae and the attributes
self.vae, self.mlx_vae, and self.use_mlx_vae so readers can quickly locate the
related code.
- Around line 74-78: On MLX VAE initialization failure inside the except block
of the MLX VAE init routine, explicitly reset the compiled callable and dtype
state by clearing self._mlx_compiled_decode, self._mlx_compiled_encode_sample,
and self._mlx_vae_dtype in addition to setting self.mlx_vae = None and
self.use_mlx_vae = False so the exception path mirrors the clean initial state;
update the except handler in mlx_vae_init.py where the logger.warning and these
flags are set to also assign None (or the appropriate empty value) to those
three attributes.

@ChuxiJ ChuxiJ merged commit b0ee1b0 into ace-step:main Feb 18, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments