Refactor(handler part 17): extract MLX helper mixins#625
Refactor(handler part 17): extract MLX helper mixins#625ChuxiJ merged 3 commits intoace-step:mainfrom
Conversation
📝 WalkthroughWalkthroughExtracts MLX-related initialization and VAE/DiT encode/decode implementations into four new mixin modules, exports them from the handler package, updates AceStepHandler to inherit those mixins, removes the prior in-file MLX implementations, and adds unit tests for each mixin. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Handler
participant MlxInit as MlxVaeInitMixin
participant MxCore as mlx.core
participant MlxModel as MLX_VAE_Model
Handler->>MlxInit: _init_mlx_vae()
MlxInit->>MxCore: mlx_available() / import core utils
alt mlx available
MlxInit->>MlxModel: convert_and_load(from_pytorch_config)
MlxModel-->>MlxInit: model instance
MlxInit->>MxCore: mx.compile(decode/encode paths) (optional)
MxCore-->>MlxInit: compiled callables or raises
MlxInit-->>Handler: initialization success (use_mlx_vae=True, dtype, compiled flags)
else mlx unavailable
MlxInit-->>Handler: initialization failed (use_mlx_vae=False)
end
sequenceDiagram
autonumber
participant Handler
participant MlxDecode as MlxVaeDecodeNativeMixin
participant MxCore as mlx.core
participant DecodeFn as decode_fn (compiled or runtime)
Handler->>MlxDecode: _mlx_vae_decode(latents_torch)
MlxDecode->>MxCore: convert latents -> mx.array, cast dtype
MlxDecode->>DecodeFn: call per-sample decode (tiling if needed)
DecodeFn-->>MlxDecode: decoded sample arrays
MlxDecode->>MxCore: mx.eval / mx.concatenate / mx.clear_cache
MlxDecode-->>Handler: returns torch tensor (batch, channels, samples)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
acestep/core/generation/handler/mlx_vae_init_test.py (1)
146-173: Add environment variable patch for test isolation.This test doesn't patch
ACESTEP_MLX_VAE_FP16, unliketest_init_mlx_vae_success_sets_compiled_callables. If the ambient environment has this variable set, the FP16 code path may execute differently, potentially causing non-deterministic behavior. Consider adding the patch for consistency:🔧 Proposed fix to improve test isolation
with patch.dict( sys.modules, { "mlx": fake_mlx_pkg, "mlx.core": fake_mx_core, "mlx.utils": fake_utils, "acestep.models.mlx": fake_mlx, "acestep.models.mlx.vae_model": fake_vae_model, "acestep.models.mlx.vae_convert": fake_vae_convert, }, ): - self.assertTrue(host._init_mlx_vae()) + with patch.dict(os.environ, {"ACESTEP_MLX_VAE_FP16": "0"}, clear=False): + self.assertTrue(host._init_mlx_vae())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/core/generation/handler/mlx_vae_init_test.py` around lines 146 - 173, The test test_init_mlx_vae_compile_failure_falls_back can be nondeterministic if ACESTEP_MLX_VAE_FP16 is present in the environment; to isolate it, add an environment patch inside the with patch.dict(...) context (or wrap the outer with) to set ACESTEP_MLX_VAE_FP16 to a known value (e.g., "0" or remove it) so the _init_mlx_vae codepath is deterministic. Modify the existing with patch.dict(sys.modules, {...}) block to also patch os.environ (via patch.dict(os.environ, {"ACESTEP_MLX_VAE_FP16": "0"}, clear=False) or include it in a nested with) so host._init_mlx_vae(), host._mlx_compiled_decode and host._mlx_compiled_encode_sample exercise the intended non-FP16 path consistently.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@acestep/core/generation/handler/mlx_vae_init.py`:
- Around line 56-70: The logger always reports compiled=True even when
mx.compile() fails; update the init to record whether compilation succeeded
(e.g., a local boolean like compiled = True set in the try block and compiled =
False in the except) and then use that flag in the logger.info call instead of
the hardcoded "compiled=True"; ensure the try/except still assigns
self._mlx_compiled_decode and self._mlx_compiled_encode_sample (and
self.mlx_vae/self._mlx_vae_dtype/self.use_mlx_vae) as before but log f"[MLX-VAE]
Native MLX VAE initialized (dtype={vae_dtype}, compiled={compiled})." so the
message accurately reflects the outcome of mx.compile() for functions
mx.compile(mlx_vae.decode) and mx.compile(mlx_vae.encode_and_sample).
In `@acestep/core/generation/handler/mlx_vae_native_test.py`:
- Around line 97-120: The tests
test_mlx_decode_single_without_tiling_uses_decode_fn and
test_mlx_decode_single_with_tiling_concatenates_trimmed_chunks assign a lambda
to the variable decode_fn (triggering Ruff E731); replace each lambda assignment
with a local def function (e.g., def decode_fn(chunk): return np.repeat(chunk,
2, axis=1)) so the tests calling host._mlx_decode_single use a named local
function instead of a lambda; update both occurrences where decode_fn is
defined.
---
Nitpick comments:
In `@acestep/core/generation/handler/mlx_vae_init_test.py`:
- Around line 146-173: The test test_init_mlx_vae_compile_failure_falls_back can
be nondeterministic if ACESTEP_MLX_VAE_FP16 is present in the environment; to
isolate it, add an environment patch inside the with patch.dict(...) context (or
wrap the outer with) to set ACESTEP_MLX_VAE_FP16 to a known value (e.g., "0" or
remove it) so the _init_mlx_vae codepath is deterministic. Modify the existing
with patch.dict(sys.modules, {...}) block to also patch os.environ (via
patch.dict(os.environ, {"ACESTEP_MLX_VAE_FP16": "0"}, clear=False) or include it
in a nested with) so host._init_mlx_vae(), host._mlx_compiled_decode and
host._mlx_compiled_encode_sample exercise the intended non-FP16 path
consistently.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@acestep/core/generation/handler/mlx_vae_init.py`:
- Around line 12-17: Update the _init_mlx_vae docstring to briefly state its
purpose and list the key inputs it relies on (self.vae and the
ACESTEP_MLX_VAE_FP16 flag/env), the side effects/mutations it performs (sets
self.mlx_vae, self.use_mlx_vae, and updates/compiles any MLX VAE callables), the
returned value (bool), and any exceptions/errors it can raise or conditions when
it returns False; reference the method name _init_mlx_vae and the attributes
self.vae, self.mlx_vae, and self.use_mlx_vae so readers can quickly locate the
related code.
- Around line 74-78: On MLX VAE initialization failure inside the except block
of the MLX VAE init routine, explicitly reset the compiled callable and dtype
state by clearing self._mlx_compiled_decode, self._mlx_compiled_encode_sample,
and self._mlx_vae_dtype in addition to setting self.mlx_vae = None and
self.use_mlx_vae = False so the exception path mirrors the clean initial state;
update the except handler in mlx_vae_init.py where the logger.warning and these
flags are set to also assign None (or the appropriate empty value) to those
three attributes.
Summary
This PR delivers handler decomposition part 17 by extracting remaining MLX-specific helper logic from �cestep/handler.py into focused mixins while preserving runtime behaviour.
What Changed
Safety / Behaviour
Tests Added
Validation Run
Standards Compliance
Summary by CodeRabbit
New Features
Refactor
Stability
Tests