Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

provide examples of back-propagation #1137

Open
vograno opened this issue Sep 12, 2024 · 5 comments
Open

provide examples of back-propagation #1137

vograno opened this issue Sep 12, 2024 · 5 comments

Comments

@vograno
Copy link

vograno commented Sep 12, 2024

A dag defines a forward flow of information. At the same time it implies a back-flow of information when we reverse each edge on the DAG and associate a back-flow node function with each node. Additionally, we should also provide a merging function that merges back-flow inputs that connect to the common forward output, but this is a technicality.

The gradient descent of a feed-forward neural network is an example of such back-flow. Namely, the forward pass computes node outputs and gradients given the model parameters, while the back pass updates model parameters according to the computed gradients. I think, the merging function is sum in this case.

The question is then whether Hamilton is an appropriate framework for inferring the back flow DAG from the forward one. Here inferring means compute the back-flow driver given the forward-flow one.

Use gradient descent as a study case.

@skrawcz
Copy link
Collaborator

skrawcz commented Sep 12, 2024

@vograno haven't tried. The internal FunctionGraph does have the bi-directional linking, so the building blocks are there. To me it sounds like you'd want to change a bit of how the graph is walked and what state is stored where for this, e.g. a new Driver.

@elijahbenizzy
Copy link
Collaborator

Adding to what @skrawcz said:

Heh, fascinating idea. I've had a similar idea but never really considered it. Some thoughts (feel free to disagree on these points!)

  • High-level -- the goal of back-propogation is optimizing a function. In backpropogation we know enough about the composition of the function (combination of nodes + whatnot) to optimize and do this efficiently (ish). Basically given we know the derivative (and can compute it chained), we can compute the derivative of the whole function.
  • Alternatively, there are less efficient ways to optimize a function (network in this case) that do not require differentiation -- gaussian processes, etc...
  • Unless we know the derivative of nodes, back-propogation is not feasible -- thus we'd need to have some way of knowing the derivative
  • If we knew every node were in pytorch, we could pretty easily just compile Hamilton to pytorch. In fact, this is probably close to trivial. That said, at that point, you probably want to just be using pytorch (unless you have large blocks that Hamilton can help make modular, E.G. components of a NN that you want to reuse). OTOH specifying recurrence relations (non-strict-DAGs) can get harder.
  • If some nodes are in pytorch, then we need to provide the derivative, or compute it in some way. This could be a decorator or some metadata you attach.)
  • If you had a guarentee, say, that every function were in numpy, you could use something like autograd
  • Otherwise you'd need some way of patching it together

As @skrawcz said -- it would also require rewalking the graph, at least backwards.

What I'd do if you're interested is first build a simple 2-node graph out of pytorch objects, which have the relationship. Then you can POC it out. You can compute gradients individually, and expand to more ways of specifying nodes/gradients.

I'd also consider other optimization routines if the goal is flexibility!

@vograno
Copy link
Author

vograno commented Sep 13, 2024

... The internal FunctionGraph does have the bi-directional linking, so the building blocks are there. To me it sounds like you'd want to change a bit of how the graph is walked and what state is stored where for this, e.g. a new Driver.

I can walk the graph backward all right, but I also need to create new nodes in the back-flow graph along the way, and this is where I'm not sure. I can see two options, but first note:

  • all these steps are happening inside backprop function, not at the module level where we typically write node functions
  • for every forward node I need to emit multiple backward nodes: one backprop node per forward output, one split-out node per forward input, and one merging node per forward output.

Option 1 - using temp module.

  • For each node on the forward graph create a back-flow function that maps backward input,to backward output. The backward output is a dict, one entry per input. Add the back-flow function to the temp module.
  • split the dict into individual nodes, one per forward input. Add the splitting functions to the temp module.
  • group the split-out nodes by the node they are attached to on the forward graph. For each group create the merging function and add it to the temp module.
  • Create the back-flow driver from the functions in the temp module.

Option 2.
Start the back-flow graph empty and add nodes to it as I traverse the forward graph. Here I need to create nodes outside the Builder and I'm not sure what API to use.

@vograno
Copy link
Author

vograno commented Sep 13, 2024

Let me propose a less fascinating, but simpler example to work with.

  1. There are two kinds of nodes, Red and Blue.
  2. Let y denote the node ouput, and (x1, ..., xn) denote the inputs.
  3. The forward function of each node is just the sum of its inputs, regardless of the color of the node, i.e. y = sum(x_i).
  4. The backward function of a Red node sends its input to the first input node and zero to all other nodes, i.e. x1 = y, x2=0, ..., xn=0.
  5. The backward function of a Blue node splits its input equally between all input nodes, i.e. x1 = y/n, x2=y/n, ..., xn=y/n.
  6. The merging function that combines back-flow values of x-s attached to upstream_node(x) is the sum.

The goal is to compute the back-flow driver given the forward-flow one.

@skrawcz
Copy link
Collaborator

skrawcz commented Sep 17, 2024

@vograno sounds cool. Unfortunately we're a little bandwidth constrained at the current time to be as helpful as we'd like. So just wanted to mention that. Do continue to add context / questions here - we'll try to help when we can.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants