luminal icon indicating copy to clipboard operation
luminal copied to clipboard

Training

Open jafioti opened this issue 1 year ago • 13 comments

It should be possible to write an autograd compiler that runs on a primgraph, and derives a backward graph and attaches it to the end of the main graph. With this, we can then run the rest of the optimizations on the full graph, and get good performance for free.

We will also need an api around it, for transferring weights between different sets (prev weights, new weights) and an optimizer library.

Compiler should implement reverse mode differentiation because it derives gradients of every node with respect to one output. Compiler assumes the graph is a primgraph. Ensure this is enough to create gradient nodes for each model node. The compiler should take in a node set of model weights. Automatically derive the gradient graph by walking backward through the existing graph. Does this happen with the optimizer inside the forward graph? Or is optimizer totally seperate. Perhaps optimizer is a function, hopefully it can be implemented with primops.

The compiler should create not only the gradient nodes, but then add those gradients to the weight nodes to get new weights. New weight set should be returned from the compiler.

In a normal training loop, the weights aren't marked as kept, but the updated weights are. After the graph is ran, the previous weights are gone, and the new weights are present. We should then simply swap the node sets (similar to how KV caching works in mistral).

So training consists of something like this:

let gpus = [...];
let input = cx.tensor();
let target = cx.tensor();
let output = model.forward(input);
let loss = loss_fn(output, target);
let learning_rate = cx.constant(1e-4);
adam_w(loss, learning_rate);
let old_weights = state_dict(model);

// Distribute training over gpus
let new_weights = cx.compile(Autograd(old_weights, loss));
cx.compile((GenericCompiler, CudaCompiler<bf16>, DistributedDataParallel(gpus)));

for (input, target) in data {
    input_tensor.set(input);
    target_tensor.set(target);
    cx.execute();
    // Transfer new weights back to old weights (weight update)
    transfer_nodes_same_graph(new_weights, old_weights, &mut cx);
    println!("Loss: {}", loss.data());
    loss.drop();
    // Update lr
    learning_rate.set([1e-3]);
}

Semantics still need to be worked out for how gradients are accumulated before they're applied, etc.

Todo:

  • [x] Sketch out concepts fully (what the optimizer is, derivatives of each node, etc)
  • [x] Compiler to derive and attach backward graph to forward graph (https://colah.github.io/posts/2015-08-Backprop/)
  • [x] Optimizers (SGD, Adam)
  • [ ] Train mnist feedforward

jafioti avatar Feb 06 '24 16:02 jafioti

Some furthur thinking on this.

Essentially we need to device a local derivative for each primop:

  • Log2 - 1 / (x * ln(2))
  • Exp2 - exp2(x) * ln(2)
  • Sin - cos(x)
  • Sqrt - 1/(2 * sqrt(x))
  • Recip - -1 / x^2
  • Add - 1
  • Mul - dy/da = b, dy/dx = a
  • Mod - Discontinuities
  • LessThan - Discontinuities
  • SumReduce - 1
  • MaxReduce - Discontinuities
  • Contiguous - 1

The ops with discontinuities don't really have derivatives, so I think if we just don't backprop through them it'll be fine. It seems like where they are used, we already have other paths that maintain the derivative like max(x, 0).

The autograd compiler should essentially just do a reverse dfs from the loss node through the forward graph and iteratively build up the backward graph starting at the loss node. Along the dfs if it encounters a 0 derivative, it should stop going down that branch (since chain rule specifies that every other derivative along that path will end up 0 because of the multiply).

jafioti avatar Mar 16 '24 02:03 jafioti

Hi, just a comment: if I understand correctly, MaxReduce takes the maximum of multiple quantities. If that is the case, it does not have a discontinuity, merely a small set of non-differentiability points, just like the special case max(x,0) does.

For such functions the correct thing to do is to take the local gradient of the active (maximizing) quantity, and if there are multiple, choose any convex combination of the gradients of the active ones.

For Mod and LessThen, the path you propose is correct because at all points of continuity, they are piecewise constant, so their gradient is zero wherever it is defined.

daniel-vainsencher avatar Mar 20 '24 01:03 daniel-vainsencher

@daniel-vainsencher Great catch, yeah the local gradient I guess should just be 1 for the elements of the input to max reduce equal to the corresponding output (the inputs that actually are max), and 0 for each input that isn't max.

jafioti avatar Mar 20 '24 14:03 jafioti

Commit f65b0292495a66e941edbdb7374de5ddf3668a9c contains an initial implementation for the autograd.

Let's do a 2 graph solution: We have the gradient graph which is responsible for doing whatever needed to produce the gradients for an update, and the optimizer graph, which takes in model weights and gradients and produces new weights.

Both graphs can be stateful if needed (for gradient accumulation in the gradient graph and optimizer state in the optimizer graph for instance), and gradients + weights are transferred from the gradient graph to the optimizer graph when a weight update is done, and new weights are transferred back again. We can make nice APIs for this if needed.

jafioti avatar Mar 21 '24 02:03 jafioti

Now we have autograd fully implemented and tested for transformers! examples/train_math_net contains the first training example using sgd.

I'll implement adam soon and do an mnist training example, and then this issue can be closed.

jafioti avatar Apr 03 '24 14:04 jafioti

Apologies if this is a little presumptuous, but I'd like to suggest maybe aiming for a training example using a dataset like CIFAR-10 (or other RGB dataset, although I think that might be the simplest RGB dataset) instead of MNIST. I've been playing around with various Rust deep learning libraries for a while, and it often seems like the methods used to get good classification results on MNIST don't really translate well to RGB/multi-dimensional data that require convolutional layers and more complex architectures to get good results. As a result, even though it's super common to use MNIST as sort of an MVP example, I tend to think it kind of falls short of actually being "minimum viable" for moving onto solving a lot of real world problems. Luminal looks like it's really starting to come together, and I'm excited to see where things go!

quietlychris avatar Apr 07 '24 20:04 quietlychris

@quietlychris Great point. Let's do CIFAR instead. Do you know how much more time cifar would take to train vs mnist? Like 2x longer or 10x longer? Ideally we want these training jobs to finish in a few minutes because I want them to run on every commit ideally.

jafioti avatar Apr 08 '24 14:04 jafioti

https://github.com/tysam-code/hlb-CIFAR10 claims CIFAR10 @ 94% in <7 seconds on an A100.

They've apparently done a lot to get that fast, so are an ambitious comparison point. That said, I think having training working on anything is a great milestone, and reproducing world class training speed can totally be a separate issue :)

daniel-vainsencher avatar Apr 09 '24 15:04 daniel-vainsencher

Wow that's nuts. For luminal_training the training run needs to happen just with primops to test the autograd / training stuff in isolation so I'm thinking mnist for that. We can also have other training jobs use cuda or metal for platforms that support those (which we'll definitely need to test) so cifar should be a good fit

jafioti avatar Apr 09 '24 20:04 jafioti

Apologies for the delay on this; things have been a bit crazy recently. Off-hand, CIFAR-10 usually does take a while longer than MNIST, I believe typically because the requisite convolutional layers are much more compute intensive. If you're running them on Github Actions, I would expect that to take at least a couple of minutes. That said, the PyTorch tutorial at https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html says the following when switching to CPU

Why don’t I notice MASSIVE speedup compared to CPU? Because your network is really small.

I don't have the actual benchmarks off-hand, though, although I seem to recall it being in the "pour yourself a cup of coffee, but don't start a new project" level time between training runs.

That said, I'm not sure that you'd actually need to train until reaching >90% on CIFAR? Just stacking fully-connected layers can get you to about 40-45%, and a little convolution will get you around 50%, but getting up into the >70% range seems to really be the inflection point where the library capability/network architecture is fine and now it's purely a train-for-more-epochs problem. All of which is to say, for a regularly-run CI job, it might be fine to target only training to a lower accuracy threshold and still get most of the same benefits without the increased time trade-off for the last few percent.

quietlychris avatar Apr 13 '24 21:04 quietlychris

Sounds exciting :) I'm curious what your thoughts are on potentially using Enzyme for the autodiff? Its performance advantages look promising so far? Or are there no significant advantages to be expected in the case of luminal because the derivatives of the primops are quite simple and then the automatically derived gradient graph can be heavily optimized before it even comes to the LLVM stage? Talk at DOE CSGF Enzyme Rust Thanks!

janroden avatar Jun 07 '24 17:06 janroden

@janroden Good idea, I don't think there would be a huge benifit since luminal primitive ops are quite bounded, so the problem of autodiff is very closed. You can see in luminal_train that the autograd compiler is ~150 lines, so we don't need to bring in a general-purpose autodiff.

jafioti avatar Jun 08 '24 15:06 jafioti

@janroden Good idea, I don't think there would be a huge benifit since luminal primitive ops are quite bounded, so the problem of autodiff is very closed. You can see in luminal_train that the autograd compiler is ~150 lines, so we don't need to bring in a general-purpose autodiff.

Thanks for the explanation! Makes sense. The concept of the primops seems quite powerful :)

janroden avatar Jun 09 '24 09:06 janroden