|
| 1 | +from functools import partial |
1 | 2 | from typing import Any
|
2 | 3 |
|
3 | 4 | from thunder.core import prims
|
@@ -45,9 +46,9 @@ def read(x: VariableInterface | Any) -> Any:
|
45 | 46 | def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None:
|
46 | 47 | if not isinstance(v, VariableInterface):
|
47 | 48 | return
|
48 |
| - # Duplicates are allowed and overwritten |
49 | 49 | if v.name in env:
|
50 | 50 | if allow_duplicates:
|
| 51 | + # Duplicates are allowed and not overwritten |
51 | 52 | return
|
52 | 53 | raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
|
53 | 54 | env[v.name] = val
|
@@ -104,9 +105,9 @@ def read(x: VariableInterface | Any) -> Any:
|
104 | 105 | def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None:
|
105 | 106 | if not isinstance(v, VariableInterface):
|
106 | 107 | return
|
107 |
| - # Duplicates are allowed and overwritten |
108 | 108 | if v.name in env:
|
109 | 109 | if allow_duplicates:
|
| 110 | + # Duplicates are allowed and not overwritten |
110 | 111 | return
|
111 | 112 | raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
|
112 | 113 | env[v.name] = val
|
@@ -203,3 +204,179 @@ def do_swap(v):
|
203 | 204 | return new_trace, tree_map(read, trace.output), env
|
204 | 205 |
|
205 | 206 | return new_trace, tree_map(read, trace.output)
|
| 207 | + |
| 208 | + |
| 209 | +class TraceSubstitutionProcessor: |
| 210 | + """This processes a trace in an interpretation-style way by looping over the bound symbols. |
| 211 | + This processing aims to preserve as much information on the proxies as possible. |
| 212 | +
|
| 213 | + Args: |
| 214 | + trace: trace to process |
| 215 | + *args: arguments to process the trace with |
| 216 | + **kwargs: keyword arguments to process the trace with |
| 217 | +
|
| 218 | + The user is expected to subclass the trace and implement process_bsym with the help of add_unprocessed_bsyms (useful eg for using subsymbols to compute a symbol), add_processed_bsyms, and add_bsyms_from_function. |
| 219 | +
|
| 220 | + Calling the instantiated object initiates the processing and returns |
| 221 | + the new trace and a mapping of the outputs. |
| 222 | +
|
| 223 | + See the OpExProcessor in thunder.executors.passes._transform_for_operator_executor_execution for an example of subclassing. |
| 224 | + """ |
| 225 | + |
| 226 | + NULL = object() |
| 227 | + |
| 228 | + def __init__(self, trace, *args, **kwargs): |
| 229 | + self.env = {} |
| 230 | + self.trace = trace |
| 231 | + self.new_trace = from_trace(self.trace) |
| 232 | + self.have_processed_args = False |
| 233 | + |
| 234 | + def read(self, x: VariableInterface | Any) -> Any: |
| 235 | + if isinstance(x, VariableInterface): |
| 236 | + return self.env[x.name] |
| 237 | + else: |
| 238 | + return x |
| 239 | + |
| 240 | + def write(self, v: VariableInterface | Any, val: Any, allow_duplicates=True) -> None: |
| 241 | + if not isinstance(v, VariableInterface): |
| 242 | + return |
| 243 | + if v.name in self.env: |
| 244 | + if allow_duplicates: |
| 245 | + # Duplicates are allowed and not overwritten |
| 246 | + return |
| 247 | + raise ValueError(f"Variable {v.name} is being overwritten this is not allowed") |
| 248 | + self.env[v.name] = val |
| 249 | + |
| 250 | + def add_to_swap_map(self, old, new): |
| 251 | + if old is new: |
| 252 | + return |
| 253 | + if isinstance(old, ProxyInterface): |
| 254 | + if isinstance(new, ProxyInterface) and variableify(new) in self.env: |
| 255 | + # the new isn't new, but something returned the input |
| 256 | + # this means we need to map the old to the new |
| 257 | + old, new = new, old |
| 258 | + elif isinstance(old, TensorProxyInterface): |
| 259 | + # should we have a fix shapes pass? the sharding |
| 260 | + # (FSDP, tensor parallel) transforms do "break" shape metadata |
| 261 | + self.new_trace.names.remove(old.name) # taken by the .replace proxy |
| 262 | + if isinstance(new, VJPDual): |
| 263 | + old = old.replace(shape=new.primal._shape) |
| 264 | + else: |
| 265 | + old = old.replace(shape=new._shape) |
| 266 | + |
| 267 | + if isinstance(new, VJPDual): |
| 268 | + self.swap_map[variableify(new.primal)] = old |
| 269 | + new.primal = old |
| 270 | + else: |
| 271 | + assert isinstance(new, ProxyInterface), (old, new) |
| 272 | + self.swap_map[variableify(new)] = old |
| 273 | + |
| 274 | + def do_swap(self, v): |
| 275 | + if isinstance(v, VJPDual): |
| 276 | + v.primal = tree_map(self.do_swap, v.primal) |
| 277 | + v.residuals = tree_map(self.do_swap, v.residuals) |
| 278 | + return v |
| 279 | + if not isinstance(v, ProxyInterface): |
| 280 | + return v |
| 281 | + return self.swap_map.get(variableify(v), v) |
| 282 | + |
| 283 | + def add_unprocessed_bsyms(self, bsyms): |
| 284 | + self.unprocessed_bsyms[:0] = bsyms |
| 285 | + |
| 286 | + def add_bsyms_from_function(self, fn, /, *args, **kwargs): |
| 287 | + self.new_trace.push_scope([]) |
| 288 | + result = fn(*args, **kwargs) |
| 289 | + self.new_bsyms += self.new_trace.pop_scope() |
| 290 | + self.set_result(result) |
| 291 | + return result |
| 292 | + |
| 293 | + def add_processed_bsyms(self, bsyms): |
| 294 | + self.new_bsyms += bsyms |
| 295 | + |
| 296 | + def set_result(self, result): |
| 297 | + self.replacement_result = result |
| 298 | + |
| 299 | + def process_bsym(self, bsym): |
| 300 | + raise NotImplementedError("This needs to be implemented in subclasses") |
| 301 | + |
| 302 | + def process_args(self, *args, **kwargs): |
| 303 | + self.have_processed_args = True |
| 304 | + with tracectx(self.new_trace): |
| 305 | + self.swap_map = {} |
| 306 | + |
| 307 | + safe_map_flat(self.add_to_swap_map, list(self.trace.args), list(args)) |
| 308 | + safe_map_flat(self.add_to_swap_map, list(self.trace.kwargs.values()), list(kwargs.values())) |
| 309 | + args, kwargs = tree_map(self.do_swap, (args, kwargs)) |
| 310 | + |
| 311 | + safe_map_flat(self.write, list(self.trace.args), list(args)) |
| 312 | + safe_map_flat(self.write, list(self.trace.kwargs.values()), list(kwargs.values())) |
| 313 | + |
| 314 | + def __call__(self): |
| 315 | + with tracectx(self.new_trace): |
| 316 | + self.unprocessed_bsyms = self.trace.bound_symbols[:] |
| 317 | + |
| 318 | + while self.unprocessed_bsyms: |
| 319 | + bsym = self.unprocessed_bsyms.pop(0) |
| 320 | + |
| 321 | + if self.have_processed_args and bsym.sym.id in trace_interpreter_skip_list: |
| 322 | + self.new_trace.bound_symbols.append(bsym.from_bsym()) |
| 323 | + continue |
| 324 | + |
| 325 | + args = tree_map(self.read, bsym.args) |
| 326 | + kwargs = tree_map(self.read, bsym.kwargs) |
| 327 | + |
| 328 | + # this should be prettier |
| 329 | + self.replacement_result = self.NULL |
| 330 | + self.new_bsyms = [] |
| 331 | + |
| 332 | + self.process_bsym(bsym) |
| 333 | + |
| 334 | + if self.new_bsyms: |
| 335 | + assert self.replacement_result is not self.NULL, "Need to call set_result if producing new bsyms" |
| 336 | + |
| 337 | + if self.replacement_result is not self.NULL: |
| 338 | + self.swap_map = {} |
| 339 | + |
| 340 | + # TODO: if inputs are returned, the old outputs should be mapped on the new ones (= the inputs) instead of the other way round |
| 341 | + if not self.new_bsyms: |
| 342 | + # empty result means we want to swap references to the old |
| 343 | + # result to the new result (which will be one of the args) |
| 344 | + safe_map_flat( |
| 345 | + self.add_to_swap_map, |
| 346 | + list(sequencify(self.replacement_result)), |
| 347 | + list(sequencify(bsym.output)), |
| 348 | + ) |
| 349 | + else: |
| 350 | + safe_map_flat( |
| 351 | + self.add_to_swap_map, |
| 352 | + list(sequencify(bsym.output)), |
| 353 | + list(sequencify(self.replacement_result)), |
| 354 | + ) |
| 355 | + |
| 356 | + ### replace bsyms |
| 357 | + |
| 358 | + for new_bsym in self.new_bsyms: |
| 359 | + # TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym? |
| 360 | + self.new_trace.bound_symbols.append( |
| 361 | + new_bsym.from_bsym_swap_proxies(self.swap_map).from_bsym( |
| 362 | + source_filename=bsym.source_filename, source_positions=bsym.source_positions |
| 363 | + ) |
| 364 | + ) |
| 365 | + |
| 366 | + result = tree_map(self.do_swap, self.replacement_result) |
| 367 | + |
| 368 | + # we need to allow duplicates here because the re-interpretation is not necessairly DCEed when subsymbols symbols are flattened into the trace after re-execution. |
| 369 | + try: |
| 370 | + safe_map_flat( |
| 371 | + partial(self.write, allow_duplicates=True), |
| 372 | + list(sequencify(bsym.output)), |
| 373 | + list(sequencify(result)), |
| 374 | + ) |
| 375 | + except AssertionError as e: |
| 376 | + raise RuntimeError( |
| 377 | + f"Error while assigning the result of dispatched function {prim_func} to the output of the original symbol {bsym}." |
| 378 | + " This is likely due to a mismatch in the number of outputs." |
| 379 | + f" The original symbol has {len(bsym.output)} outputs and the dispatched function has {len(sequencify(result))} outputs." |
| 380 | + ) from e |
| 381 | + |
| 382 | + return self.new_trace, tree_map(self.read, self.trace.output) |
0 commit comments