Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove from baseutils import * #1421

Merged
merged 6 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions thunder/core/codeutils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from types import CodeType, FunctionType, MethodType, EllipsisType
from typing import List, Dict, Tuple, Set, Deque, Any, NamedTuple
from typing import List, Dict, Tuple, Set, Deque, Any, NamedTuple, Optional
from numbers import Number
from collections import deque
from collections.abc import Mapping, Sequence, Iterable
from collections.abc import Mapping, Sequence, Iterable, Callable
import inspect
from inspect import Parameter
import string
Expand All @@ -11,6 +11,7 @@
import dis
import linecache
import dataclasses
import sys

import torch

Expand All @@ -19,7 +20,6 @@
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.pytree import tree_flatten, tree_unflatten
from thunder.core.baseutils import *
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK to replace uses of baseutils with explicit baseutils.foo calls in this file, but conceptually baseutils, codeutils and utils offer a progressive expansion of utility functions. Most files can just import utils (and doing so gives them all the functions in codeutils and baseutils), but some files can only import codeutils or baseutils. baseutils depends only on Python and PyTorch. codeutils depends on thunder's dtypes, devices and pytree submodules. utils also depends on prims, trace and proxies.

For files that cannot import utils but can import codeutils, they probably want baseutils, too, just like how importing utils provides access to the functions in baseutils and codeutils.

There are some files, like trace.py, that import both codeutils and baseutils, but they could just import codeutils, instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not aim to minimize the number of import lines but maximize that there is one standard way to access a given function and only deviate from it when there is a need to. I don't mind if we use baseutils.foo or utils.foo by default but let's not do a thing codeutils.foo if you import codeutils and otherwise baseutils.foo .


#
# Functions related to analyzing and printing functions and arguments
Expand Down Expand Up @@ -78,7 +78,7 @@ def is_printable(x: Any) -> tuple[bool, None | tuple[str, Any]]:

if isinstance(x, ContextObject):
return True, None
if is_collection(x):
if baseutils.is_collection(x):
# TODO RC1 Fix collection printing by testing if each item is printable and gathering the imports
# required (if any)
flat, _ = tree_flatten(x)
Expand All @@ -96,7 +96,7 @@ def is_literal(x: Any) -> bool:
if isinstance(x, (ContextObject, ProxyInterface)):
return False

if is_collection(x):
if baseutils.is_collection(x):
flat, _ = tree_flatten(x)
for f in flat:
if is_literal(f):
Expand All @@ -106,7 +106,7 @@ def is_literal(x: Any) -> bool:
return True


def _to_printable(tracectx: Optional, x: Any) -> tuple[Any, Optional[tuple[str, Any]]]:
def _to_printable(tracectx: Optional, x: Any) -> tuple[Any, tuple[str, Any] | None]:
can_print, module_info = is_printable(x)
if can_print:
return x, module_info
Expand All @@ -126,8 +126,8 @@ def to_printable(
trace: Optional,
x: Any,
*,
import_ctx: Optional[dict] = None,
object_ctx: Optional[dict] = None,
import_ctx: dict | None = None,
object_ctx: dict | None = None,
) -> Printable:
# Short-circuits if x is a Proxy
if isinstance(x, ProxyInterface):
Expand All @@ -139,7 +139,7 @@ def to_printable(
# Return the instance as printable object (as function `prettyprint` knows how to deal with it).
return x

if is_collection(x):
if baseutils.is_collection(x):
# specify namespace="" to avoid flattening dataclasses
flat, spec = tree_flatten(x, namespace="")
if flat and flat[0] is x:
Expand Down Expand Up @@ -193,7 +193,7 @@ def prettyprint(

m = partial(_qm, quote_markers=_quote_markers)

if literals_as_underscores and is_literal(x) and not is_collection(x):
if literals_as_underscores and is_literal(x) and not baseutils.is_collection(x):
return m("_")

if type(x) is str:
Expand Down Expand Up @@ -232,7 +232,7 @@ def prettyprint(
call_repr_str = ",".join(call_repr)
return m(f"{name}({call_repr_str})")

if is_collection(x):
if baseutils.is_collection(x):
# specify namespace="" to avoid flattening dataclasses
flat, spec = tree_flatten(x, namespace="")
printed = tuple(
Expand Down Expand Up @@ -260,7 +260,7 @@ def prettyprint(
return m(f"{baseutils.print_type(x, with_quotes=False)}")

# Handles objects that this doesn't know how to serialize as a string
return m(f"(object of type {print_type(type(x), with_quotes=False)})")
return m(f"(object of type {baseutils.print_type(type(x), with_quotes=False)})")


# Use dis.Positions in 3.11+ and make it up in <3.11
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def default_python_printer(
kwarg_str = ", ".join(f"{k}={codeutils.prettyprint(v)}" for k, v in kwarg_printables.items())

result_str: str
if bsym.output is None or (codeutils.is_collection(bsym.output) and len(bsym.output) == 0):
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0):
result_str = ""
else:
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = "
Expand Down
Loading