Skip to content

Commit f888c13

Browse files
committed
fix model_loaded race condition with lock
1 parent bdf25fc commit f888c13

File tree

3 files changed

+49
-17
lines changed

3 files changed

+49
-17
lines changed

pytrickle/client.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676

7777
# Guard to ensure model is loaded once per client lifecycle
7878
self._model_loaded = False
79+
self._model_load_lock = asyncio.Lock()
7980

8081
async def start(self, request_id: str = "default"):
8182
"""Start the trickle client."""
@@ -183,21 +184,26 @@ def set_target_fps(self, target_fps: Optional[float]):
183184

184185
async def _ensure_model_loaded(self):
185186
"""Load the model once on the same event loop before processing begins."""
186-
if not self._model_loaded:
187-
# Transition to LOADING while model warms up, if state is available
188-
try:
189-
if getattr(self.frame_processor, "state", None) is not None:
190-
self.frame_processor.state.set_state(PipelineState.LOADING)
191-
await self.frame_processor.load_model()
192-
# Mark startup complete; this moves LOADING → IDLE per state machine
193-
if getattr(self.frame_processor, "state", None) is not None:
194-
self.frame_processor.state.set_startup_complete()
195-
self._model_loaded = True
196-
except Exception as e:
197-
# Reflect error in state if available, then propagate
198-
if getattr(self.frame_processor, "state", None) is not None:
199-
self.frame_processor.state.set_error(str(e))
200-
raise
187+
# Transition to LOADING while model warms up, if state is available
188+
try:
189+
if getattr(self.frame_processor, "state", None) is not None:
190+
self.frame_processor.state.set_state(PipelineState.LOADING)
191+
192+
# Use the thread-safe wrapper
193+
await self.frame_processor.ensure_model_loaded()
194+
195+
# Mark startup complete; this moves LOADING → IDLE per state machine
196+
if getattr(self.frame_processor, "state", None) is not None:
197+
self.frame_processor.state.set_startup_complete()
198+
199+
# Update our local flag for consistency
200+
self._model_loaded = True
201+
202+
except Exception as e:
203+
# Reflect error in state if available, then propagate
204+
if getattr(self.frame_processor, "state", None) is not None:
205+
self.frame_processor.state.set_error(str(e))
206+
raise
201207

202208
async def _on_protocol_error(self, error_type: str, exception: Optional[Exception] = None):
203209
"""Handle protocol errors and shutdown events."""

pytrickle/frame_processor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
making it easy to integrate AI models and async pipelines with PyTrickle.
55
"""
66

7+
import asyncio
78
import logging
89
from abc import ABC, abstractmethod
910
from typing import Optional, Any, Dict, List
@@ -55,11 +56,25 @@ def __init__(
5556
"""
5657
self.error_callback = error_callback
5758
self.state: Optional[StreamState] = None
59+
60+
# Model loading protection
61+
self._model_load_lock = asyncio.Lock()
62+
self._model_loaded = False
5863

5964
def attach_state(self, state: StreamState) -> None:
6065
"""Attach a pipeline state manager and set IDLE if model already loaded."""
6166
self.state = state
6267

68+
async def ensure_model_loaded(self, **kwargs):
69+
"""Thread-safe wrapper that ensures model is loaded exactly once."""
70+
async with self._model_load_lock:
71+
if not self._model_loaded:
72+
await self.load_model(**kwargs)
73+
self._model_loaded = True
74+
logger.debug(f"Model loaded for {self.__class__.__name__}")
75+
else:
76+
logger.debug(f"Model already loaded for {self.__class__.__name__}")
77+
6378
@abstractmethod
6479
async def load_model(self, **kwargs):
6580
"""

pytrickle/stream_processor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ def __init__(
6363
self.frame_skip_config = frame_skip_config
6464
self.server_kwargs = server_kwargs
6565

66+
# Add lock for model loading protection
67+
self._model_load_lock = asyncio.Lock()
68+
self._model_loaded = False
69+
6670
# Create internal frame processor
6771
self._frame_processor = _InternalFrameProcessor(
6872
video_processor=video_processor,
@@ -86,7 +90,7 @@ def __init__(
8690
self._frame_processor.attach_state(self.server.state)
8791
except Exception:
8892
# If attach fails for any reason, log and continue (non-fatal)
89-
logger.debug("Failed to attach server state to frame processor")
93+
logger.warning("Failed to attach server state to frame processor")
9094

9195
# Register server startup hook to preload model on same event loop
9296
async def _on_startup(_app):
@@ -95,10 +99,17 @@ async def _background_preload():
9599
try:
96100
if getattr(self._frame_processor, "state", None) is not None:
97101
self._frame_processor.state.set_state(PipelineState.LOADING)
98-
await self._frame_processor.load_model()
102+
103+
# Use the thread-safe wrapper
104+
await self._frame_processor.ensure_model_loaded()
105+
99106
if getattr(self._frame_processor, "state", None) is not None:
100107
self._frame_processor.state.set_startup_complete()
108+
109+
# Update our local flag for consistency
110+
self._model_loaded = True
101111
logger.info(f"StreamProcessor '{self.name}' model preloaded on server startup")
112+
102113
except Exception as e:
103114
if getattr(self._frame_processor, "state", None) is not None:
104115
self._frame_processor.state.set_error(str(e))

0 commit comments

Comments
 (0)