Skip to content

Commit

Permalink
Import updates of utils, baseutils, and codeutils. (#1444)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Nov 18, 2024
1 parent 4914141 commit 6f50e52
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 40 deletions.
59 changes: 49 additions & 10 deletions thunder/core/baseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,63 @@
# This feature is available in Python 3.7 and later.
# This import (like all __future__ imports) must be at the beginning of the file.
from __future__ import annotations
from collections.abc import Sequence
from enum import Enum
from types import MappingProxyType, ModuleType, CodeType, EllipsisType, FunctionType, MethodType
from typing import TYPE_CHECKING
import collections.abc
import dis
import functools
import inspect
import os
import dis

import sys
import collections.abc
from numbers import Number
from typing import Any, Type, Union, Optional, Tuple, List
from collections.abc import Callable
from collections.abc import Sequence
from types import MappingProxyType, ModuleType, CodeType, EllipsisType, FunctionType, MethodType
import re
import inspect
import sys

import torch
import numpy as np

if TYPE_CHECKING:
from collections.abc import Callable
from numbers import Number
from typing import Any


__all__ = [
"BoundSymbolInterface",
"NumberProxyInterface",
"ProxyInterface",
"SymbolInterface",
"TagBase",
"TensorProxyInterface",
"TermColors",
"TorchAutogradFunctionCtxProxyInterface",
"build_callable",
"check",
"check_type",
"check_types",
"check_valid_length",
"check_valid_shape",
"default_dataclass_params",
"extract_callable_name",
"fnprint",
"get_module",
"indent",
"init_colors",
"init_windows_terminal",
"is_base_printable",
"is_base_printable_literal",
"is_base_printable_type",
"is_base_printable_value",
"is_collection",
"print_base_printable",
"print_base_type",
"print_number",
"print_type",
"run_once",
"sequencify",
"warn_term_variable_once",
]


#
# Common utilities importable by any other file
Expand Down
46 changes: 31 additions & 15 deletions thunder/core/codeutils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,40 @@
from types import CodeType, FunctionType, MethodType, EllipsisType
from typing import List, Dict, Tuple, Set, Deque, Any, NamedTuple, Optional
from numbers import Number
from collections import deque
from collections.abc import Mapping, Sequence, Iterable, Callable
import inspect
from inspect import Parameter
import string
import functools
from __future__ import annotations
from functools import partial
from inspect import Parameter
from typing import TYPE_CHECKING, NamedTuple
import dataclasses
import dis
import functools
import inspect
import linecache
import dataclasses
import sys

import torch

import thunder.core.baseutils as baseutils
from thunder.core.baseutils import ProxyInterface, check
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.pytree import tree_flatten, tree_unflatten

if TYPE_CHECKING:
from typing import Any
from collections.abc import Callable, Sequence
from thunder.core.trace import TraceCtx


__all__ = [
"ContextObject",
"SigInfo",
"get_siginfo",
"get_source_line",
"indent_string",
"is_literal",
"is_printable",
"is_simple_printable_collection",
"module_shortname",
"prettyprint",
"to_printable",
]

#
# Functions related to analyzing and printing functions and arguments
#
Expand Down Expand Up @@ -106,7 +120,7 @@ def is_literal(x: Any) -> bool:
return True


def _to_printable(tracectx: Optional, x: Any) -> tuple[Any, tuple[str, Any] | None]:
def _to_printable(tracectx: TraceCtx | None, x: Any) -> tuple[Any, tuple[str, Any] | None]:
can_print, module_info = is_printable(x)
if can_print:
return x, module_info
Expand All @@ -123,7 +137,7 @@ def _to_printable(tracectx: Optional, x: Any) -> tuple[Any, tuple[str, Any] | No

# TODO Improve type annotations
def to_printable(
trace: Optional,
trace: TraceCtx | None,
x: Any,
*,
import_ctx: dict | None = None,
Expand Down Expand Up @@ -302,7 +316,9 @@ def __repr__(self):
# TODO Print the original signature's type annotations
# TODO Maybe be clear about what inputs are const and what aren't?
# TODO Improve this signature's type annotations
def prettyprint(self, *, trace: Optional = None, import_ctx: Optional = None, object_ctx=None) -> str:
def prettyprint(
self, *, trace: TraceCtx | None = None, import_ctx: Any | None = None, object_ctx: Any | None = None
) -> str:
def _arg_printer(name: str, has_default: bool, default: Any = None) -> str:
# NOTE In this case the argument has a default value, like 'a' in foo(a=5)
if has_default:
Expand Down
33 changes: 18 additions & 15 deletions thunder/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import sys
import os
from __future__ import annotations
from collections import defaultdict, deque, UserDict
from collections.abc import Callable, Hashable, Iterable, Iterator, Sequence, Mapping
from enum import Enum
from functools import reduce, wraps
import itertools
from itertools import chain
from functools import reduce
from numbers import Number
from typing import Any, overload, Generic, Optional, TypeVar, TYPE_CHECKING
from collections.abc import Callable
from collections.abc import Hashable, Iterable, Iterator, Sequence
from collections import defaultdict
from types import MappingProxyType
from typing import overload, Generic, TypeVar, TYPE_CHECKING
import itertools
import os

from typing_extensions import Self
import torch

import thunder.core.dtypes as dtypes
from thunder.core.pytree import tree_flatten, tree_unflatten, tree_map
Expand All @@ -20,6 +20,9 @@
from thunder.core.trace import TraceCtx
import thunder.core.prims as prims

if TYPE_CHECKING:
from typing import Any

# This file defines utilities that can be used when defining primitive operations.

# This file depends on proxies.py and the dtypes submodule.
Expand Down Expand Up @@ -729,17 +732,17 @@ def __len__(self) -> int:
return len(self.d)

# -
def __sub__(self, other: "_OrderedSet") -> Self:
def __sub__(self, other: _OrderedSet) -> Self:
return self.__class__(k for k in self if k not in other)

def __and__(self, other: "_OrderedSet") -> Self:
def __and__(self, other: _OrderedSet) -> Self:
return self.__class__(k for k in self if k in other)

def __or__(self, other: "_OrderedSet") -> Self:
def __or__(self, other: _OrderedSet) -> Self:
return self.__class__(itertools.chain(self, other))

# NOTE: actual set signature is (self, *others)
def difference(self, other: "_OrderedSet") -> Self:
def difference(self, other: _OrderedSet) -> Self:
return self - other

def add(self, x: T | T1):
Expand All @@ -753,7 +756,7 @@ def discard(self, x: T | T1):
def issubset(self, other):
return all((e in other) for e in self)

def union(self, *others: "Sequence[_OrderedSet]") -> Self:
def union(self, *others: Sequence[_OrderedSet]) -> Self:
return self.__class__(itertools.chain(self, *others))

def update(self, x: Iterable[T | T1]) -> None:
Expand Down Expand Up @@ -791,7 +794,7 @@ def __missing__(self, key: T) -> T1:
if TYPE_CHECKING:
_UserDictT = dict
else:
_UserDictT = collections.UserDict
_UserDictT = UserDict


class FrozenDict(_UserDictT[T, T1], Mapping[T, T1]):
Expand Down

0 comments on commit 6f50e52

Please sign in to comment.