Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functional JIT loading closures sharp edge #132

Open
riccardofelluga opened this issue Apr 4, 2024 · 3 comments
Open

Functional JIT loading closures sharp edge #132

riccardofelluga opened this issue Apr 4, 2024 · 3 comments
Labels
bug Something isn't working design required jit

Comments

@riccardofelluga
Copy link
Collaborator

riccardofelluga commented Apr 4, 2024

Strategy required

This issue resumes form PR2410, we need to decide on the strategy for closures sharp edge. Let's start simple, I think we can all agree that this is a sharp edge if we jit foo:

x = 5
def foo():
      return x

And that's because we are using a variable outside of the jitted scope. However, here is where things get interesting: should we consider the following a sharp egde?

def foo(x):
    def bar():
        return x
    return bar()

I assume that, since we captured x when jitting foo, this should not be a sharp edge for bar because the variable was declared in the scope(or in this case captured). To fix such a case we can remember what variables we captured and then look them up when we see a freevar. However, @mruberry has an interesting point, what happens in the case that the variable gets deleted? How can we deal with something like:

def foo():
  a = 5

  def bar():
    nonlocal a
    del a

  bar()

  return a

In conclusion, what do you think should be the definition of sharp edge in this context?

cc @apaz-cli @t-vi @mruberry

@riccardofelluga riccardofelluga added bug Something isn't working design required labels Apr 4, 2024
@mruberry
Copy link
Collaborator

mruberry commented Apr 4, 2024

I was thinking more about this and I wonder if we should just define the thunder.functional.jit to bind all global and nonlocal accesses at compile-time. This would be consistent with how it wants to treat functions. That way if a global int with the value of 5 is loaded, the 5 would just become a constant on future calls to thunder.functional.jit. This behavior is similar to numba's, and I think the functional jit and numba have a lot of UX overlap.

@mruberry mruberry added the jit label Apr 4, 2024
@apaz-cli
Copy link
Contributor

apaz-cli commented Apr 9, 2024

@riccardofelluga

My opinion is that it the interpreter should work how cpython does, whenever possible. If we can't, we should invalidate the cache and re-trace. It remains to be seen how far we want to go with this idea, maybe it's slow in some situations, but I think this behavior is fairly easy to emulate. If we can track global variable inputs, we can certainly model any __closures__ that are accessed as inputs, and invalidate the cache accordingly. Provenance tracking shouldn't prove too big an issue. After all, the interpreter accesses and uses these values in a particular order, and we're tracking their downstream uses.

Notably though, there are two cases. Either the closure is attached to the function we're compiling (in the co_consts, etc), or it's attached to some object we end up encountering later. Both need to be modeled as inputs to the function for the purposes of cache invalidation if we want this to work.

I would argue that the first case is a subset of the second case. The second case is the one that we should decide if we want to support. Personally, I'm sort of unsure why loading nonlocals is a sharp edge as well. We know exactly how to do it, and it isn't any more difficult than loading globals. If loading nonlocals is a sharp edge, then the same should be true for loading globals and accessing attributes. It's just a matter of provenance tracking, and deciding when to invalidate the cache. But the cache invalidation decision should ultimately be based on how the inputs are used, not where they're coming from. As an argument, from the code object, a global, a nonlocal, or otherwise.

An example of the former case:

def outer():
  lst = None
  def inner():
    nonlocal lst
    if lst is None: # CONTROL FLOW CHANGES BASED ON NONLOCAL INPUT
      lst = []
    lst.append("#")
    return lst
  tj = thunder.jit(inner) # Loads the function from co_consts
  tj() # Returns ['#']
  return tj() # Returns ['#', '#']

thunder.jit(outer)() # Returns ['#', '#']

And an example of the latter:

global_lst = []
def unrelated():
  lst = None
  def inner():
    nonlocal lst
    if lst is None: # CONTROL FLOW CHANGES BASED ON NONLOCAL INPUT
      lst = global_lst
    lst.append("#")
    return lst
  return inner

def foo():
  inner_fn = unrelated() # Encounters the function with nonlocals as result of a LOAD_GLOBAL, which we can't have known the result of beforehand.
  lst1 = inner_fn()
  print(lst1) # ['#']
  lst2 = inner_fn()
  print(lst2) # ['#', '#']
  print(lst1 is global_lst) # True
  print(lst2 is global_lst) # True

thunder.jit(foo)()

@mruberry
Copy link
Collaborator

mruberry commented Apr 9, 2024

I think it's a sharp edge for the functional jit, @apaz-cli, because it's functional — no surprise inputs allowed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working design required jit
Projects
None yet
Development

No branches or pull requests

3 participants