-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove requires_grad from regular TensorProxy and add specialized TensorProxy with requires_grad and is_leaf attributes #1599
base: proxy-update3
Are you sure you want to change the base?
Conversation
@@ -1868,6 +1844,52 @@ def real(self): | |||
return method(self) | |||
|
|||
|
|||
class RuntimeTensorProxy(TensorProxy): | |||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you add a docstring explaining this class's use vs. TensorProxy?
I think the idea of a Separately, how challenging would it be to fix the existing meta functions to propagate the requires_grad and is_leaf attributes correctly? Is that something we should start working on? |
thunder/core/prims.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes in this file should be submitted in a separate PR.
… test_auto_register_torchops.py tests
PyTorch tensors have attributes that affect the computation and Thunder should mirror those in TensorProxies. However, some of the attributes (
requires_grad
,is_leaf
,stride
) are tricky to model precisely and propagate the values in Thunder's meta functions. This PR removes the possibility of querying and specifyingrequires_grad
on regular TensorProxies which are used for intermediate values and reintroduces the attribute with a special classRuntimeTensorProxy
that should be constructed through a corresponding real PyTorch tensor.A new attribute is added,
is_leaf
, which is useful for raising errors when in-place copies should be disallowed (#1577, #1458).It's not part of this PR but an extension to query static strides information will be easy to add in the future.
Ref #1577, #1570.