Skip to content

Commit

Permalink
imports
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Dec 27, 2024
1 parent 9d79b8d commit b55706b
Showing 1 changed file with 34 additions and 75 deletions.
109 changes: 34 additions & 75 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
@@ -1,116 +1,71 @@
import thunder
from __future__ import annotations
import math
from typing import Any, Optional, Dict, Tuple, Literal
import builtins
from typing import Any
import collections
from collections.abc import ValuesView, Iterable, Iterator
from collections.abc import Callable, Sequence
import weakref
import random
from functools import partial, wraps, reduce
import linecache
import operator
import copy
from functools import wraps
import contextvars
from contextlib import contextmanager
import dis
import warnings
from enum import Enum, auto
from io import StringIO
import inspect
import time

from thunder.core.compile_data import compile_data_and_stats, get_cache_option, get_compile_data
import thunder.clang as clang
import thunder.core.transforms
from thunder.core.baseutils import run_once

from types import (
CellType,
ClassMethodDescriptorType,
CodeType,
CoroutineType,
FrameType,
FunctionType,
MethodType,
MethodDescriptorType,
ModuleType,
NoneType,
BuiltinFunctionType,
BuiltinMethodType,
MethodWrapperType,
WrapperDescriptorType,
TracebackType,
CellType,
ModuleType,
CodeType,
BuiltinFunctionType,
FunctionType,
MethodType,
GetSetDescriptorType,
MethodDescriptorType,
NoneType,
UnionType,
WrapperDescriptorType,
)

import torch
import torch.utils.checkpoint

from thunder.core.compile_data import compile_data_and_stats, get_cache_option, get_compile_data
import thunder.clang as clang
from thunder.core import dtypes
import thunder.core.transforms
from thunder.core.proxies import (
AnyProxy,
DistParallelType,
proxy,
NumberProxy,
Proxy,
ProxyTag,
AnyProxy,
NumberProxy,
StringProxy,
TensorProxy,
FutureTensorProxy,
make_proxy_name,
Variable,
variableify,
unvariableify,
is_proxy_name_available,
proxy,
unvariableify,
variableify,
)
from thunder.core.trace import set_tracectx, reset_tracectx, tracectx, from_trace
from thunder.core.interpreter import (
InterpreterLogItem,
InterpreterFrame,
interpret,
_interpret_call,
CapsuleType,
default_callbacks,
INTERPRETER_CALLBACKS,
INTERPRETER_SIGNALS,
default_opcode_interpreter,
_default_lookaside_map,
InterpreterRuntimeCtx,
ProvenanceRecord,
PseudoInst,
WrappedValue,
_interpret_call,
default_callbacks,
default_lookaside,
do_raise,
get_interpreterruntimectx,
InterpreterRuntimeCtx,
interpret,
interpreter_needs_wrap,
is_opaque,
Py_NULL,
member_descriptor,
WrappedValue,
unwrap,
wrap,
wrap_const,
PseudoInst,
ProvenanceRecord,
interpreter_needs_wrap,
)
from thunder.core.langctxs import set_langctx, reset_langctx, Languages, resolve_language
from thunder.core.baseutils import extract_callable_name
from thunder.core.codeutils import get_siginfo, SigInfo
from thunder.core.codeutils import SigInfo
import thunder.core.prims as prims
from thunder.common import transform_for_execution
from thunder.core.options import CACHE_OPTIONS, SHARP_EDGES_OPTIONS, DebugOptions
from thunder.core.symbol import Symbol, BoundSymbol, is_traceable

from thunder.extend import Executor
from thunder.common import CompileData, CompileStats
from thunder.core.symbol import Symbol
from thunder.core.trace import TraceCtx, TraceResults
from thunder.torch import _torch_to_thunder_function_map
from thunder.clang import _clang_fn_set
from thunder.core.pytree import tree_map, tree_iter
from thunder.core.compile_data import compile_data_and_stats

#
# jit_ext.py implements extensions of thunder's interpreter
Expand Down Expand Up @@ -266,7 +221,9 @@ def proxify(self, value: WrappedValue) -> Any:
DistParallelType.REPLICATED,
DistParallelType.FULLY_SHARDED,
):
p_new = thunder.distributed.prims.synchronize(
from thunder.distributed.prims import synchronize

p_new = synchronize(
p,
self._process_group_for_ddp,
)
Expand Down Expand Up @@ -889,8 +846,8 @@ def autocast_exit(autocast_obj, exc_type, exc_val, exc_tb):

@register_general_jit_lookaside(torch.finfo)
@interpreter_needs_wrap
def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype):
torch_dtype = thunder.dtypes.to_torch_dtype(dtype)
def _general_jit_torch_finfo_lookaside(dtype: dtypes.dtype):
torch_dtype = dtypes.to_torch_dtype(dtype)
res = torch.finfo(torch_dtype)
return res

Expand Down Expand Up @@ -1400,6 +1357,8 @@ def get_parameter_or_buffer_or_submodule_name_and_root(provenance):


def unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs):
from thunder import _get_cache_info

already_unpacked: dict[int, Proxy] = {}
orig_modules: dict[int, Proxy] = {}

Expand Down Expand Up @@ -1671,7 +1630,7 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy:

prim(*args)

cache_info = thunder._get_cache_info()
cache_info = _get_cache_info()
# assert len of cache info to ensure that we're not missing anything?
if cache_info:
cache_info_p = Proxy(name="cache_info")
Expand Down

0 comments on commit b55706b

Please sign in to comment.