Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor hardshrink_opinfo with singularity_fn_producer #1517

Merged
merged 3 commits into from
Dec 6, 2024

Conversation

beverlylytle
Copy link
Collaborator

Many linear activation functions of torch.nn.functional have partial derivatives with jump discontinuities at dynamically defined values, eg, hardshrink has a kwarg lambd which sets the relevant discontinuities at +/-lambd. The test test_vjp_correctness relies on using the technique of computing finite differences to approximate these partial derivatives to validate Thunder's computation of the partials. These finite differences behave badly around these discontinuities. Currently, each OpInfo allows the supplement of a singularity_fn to push test input values away from the discontinuities, but it only allows for a single singularity_fn, which cannot reflect the dynamic variation of the "bad" points. This PR introduces a singularity_fn_producer, which is a function mapping a SampleInput to a singularity_fn, allowing the singularity_fn to reflect the kwargs of the SampleInput.

@mruberry
Copy link
Collaborator

mruberry commented Dec 5, 2024

The test failure is unrelated to this PR, fyi @t-vi. It is

FAILED thunder/tests/test_grad.py::test_vjp_correctness_celu_torch_cpu_thunder.dtypes.float64 - AssertionError: Scalars are not close!

Expected 12.54497983051468 but got 12.544988595853251.
Absolute difference: 8.765338570526637e-06 (up to 1e-07 allowed)
Relative difference: 6.98712846807903e-07 (up to 1e-07 allowed)

which is tracked by #1514

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

This is a good generalization of singularity functions. It's nice that it doesn't require rewriting existing OpInfos that use singularity functions. In the future we may want to remove the singularity_fn option for OpInfos, since it's now just a sugar for specifying a singularity_fn_producer that ignores its inputs

@mruberry mruberry enabled auto-merge (squash) December 5, 2024 19:10
@mruberry mruberry merged commit 4410127 into main Dec 6, 2024
41 checks passed
@mruberry mruberry deleted the singularity_prod branch December 6, 2024 09:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants