8
8
from typing import Any , ParamSpec , TypeVar
9
9
10
10
from observability_utils .tracing import (
11
- add_span_attributes ,
12
11
get_context_propagator ,
13
12
get_tracer ,
14
13
start_as_current_span ,
@@ -44,17 +43,19 @@ class WorkerDispatcher:
44
43
45
44
_config : ApplicationConfig
46
45
_subprocess : PoolClass | None
47
- _use_subprocess : bool
48
46
_state : EnvironmentResponse
49
47
50
48
def __init__ (
51
49
self ,
52
50
config : ApplicationConfig | None = None ,
53
- use_subprocess : bool = True ,
51
+ subprocess_factory : Callable [[], PoolClass ] | None = None ,
54
52
) -> None :
53
+ def default_subprocess_factory ():
54
+ return Pool (initializer = _init_worker , processes = 1 )
55
+
55
56
self ._config = config or ApplicationConfig ()
56
57
self ._subprocess = None
57
- self ._use_subprocess = use_subprocess
58
+ self ._subprocess_factory = subprocess_factory or default_subprocess_factory
58
59
self ._state = EnvironmentResponse (
59
60
initialized = False ,
60
61
)
@@ -68,12 +69,8 @@ def reload(self):
68
69
69
70
@start_as_current_span (TRACER )
70
71
def start (self ):
71
- add_span_attributes (
72
- {"_use_subprocess" : self ._use_subprocess , "_config" : str (self ._config )}
73
- )
74
72
try :
75
- if self ._use_subprocess :
76
- self ._subprocess = Pool (initializer = _init_worker , processes = 1 )
73
+ self ._subprocess = self ._subprocess_factory ()
77
74
self .run (setup , self ._config )
78
75
self ._state = EnvironmentResponse (initialized = True )
79
76
except Exception as e :
@@ -107,40 +104,25 @@ def run(
107
104
function : Callable [P , T ],
108
105
* args : P .args ,
109
106
** kwargs : P .kwargs ,
110
- ) -> T :
111
- """Calls the supplied function, which is modified to accept a dict as it's new
112
- first param, before being passed to the subprocess runner, or just run in place.
113
- """
114
- add_span_attributes ({"use_subprocess" : self ._use_subprocess })
115
- if self ._use_subprocess :
116
- return self ._run_in_subprocess (function , * args , ** kwargs )
117
- else :
118
- return function (* args , ** kwargs )
119
-
120
- @start_as_current_span (TRACER , "function" , "args" , "kwargs" )
121
- def _run_in_subprocess (
122
- self ,
123
- function : Callable [P , T ],
124
- * args : P .args ,
125
- ** kwargs : P .kwargs ,
126
107
) -> T :
127
108
"""Call the supplied function, passing the current Span ID, if one
128
- exists,from the observability context inro the _rpc caller function.
109
+ exists,from the observability context into the import_and_run_function
110
+ caller function.
111
+
129
112
When this is deserialized in and run by the subprocess, this will allow
130
113
its functions to use the corresponding span as their parent span."""
114
+
131
115
if self ._subprocess is None :
132
116
raise InvalidRunnerStateError ("Subprocess runner has not been started" )
133
117
if not (hasattr (function , "__name__" ) and hasattr (function , "__module__" )):
134
118
raise RpcError (f"{ function } is anonymous, cannot be run in subprocess" )
135
- if not callable (function ):
136
- raise RpcError (f"{ function } is not Callable, cannot be run in subprocess" )
137
119
try :
138
120
return_type = inspect .signature (function ).return_annotation
139
121
except TypeError :
140
122
return_type = None
141
123
142
124
return self ._subprocess .apply (
143
- _rpc ,
125
+ import_and_run_function ,
144
126
(
145
127
function .__module__ ,
146
128
function .__name__ ,
@@ -164,7 +146,7 @@ def __init__(self, message):
164
146
class RpcError (Exception ): ...
165
147
166
148
167
- def _rpc (
149
+ def import_and_run_function (
168
150
module_name : str ,
169
151
function_name : str ,
170
152
expected_type : type [T ] | None ,
0 commit comments