NATTEN icon indicating copy to clipboard operation
NATTEN copied to clipboard

Nested tensors

Open Birch-san opened this issue 6 months ago • 6 comments

Hi, thanks for the library; we're using NATTEN currently for image generation and it's really helpful.

I was wondering whether you had any plan to (or already do) support nested tensors?
this would enable image training to use batches with varying dimensions (you wouldn't need aspect ratio bucketing any more), and would enable languge model training to use batches with varying lengths (you wouldn't need padding or packing any more).

torch scaled_dot_product_attention supports nested tensors nowadays, perhaps the code provides a clue as to how they approached it:

  • forward support: https://github.com/pytorch/pytorch/commit/5fb687182dba781d9c95388d19f4784b98cb8b20
  • backward support: https://github.com/pytorch/pytorch/commit/5183760ca52ea7dbd892ebc88cabdd43ebc81cbe
  • flash attn support: https://github.com/pytorch/pytorch/commit/80614783e3333bd19eeec5856c60dd2eeaa16431

Birch-san avatar Dec 13 '23 16:12 Birch-san

Hello and thanks for your interest; I'm glad you're finding it useful.

I think it should be relatively easy to add support for nested tensors; I'll take a look at that later today. The changes I made in the summer to the library structure separates out a lot of torch tensor API dependencies into a different layer, which means we could in theory have the API layer do multiple kernel calls on the different tensors within the nested tensor.

alihassanijr avatar Dec 13 '23 16:12 alihassanijr

that's fantastic news! thanks; I hope it turns out to be as straightforward as you're anticipating. 🙂

Birch-san avatar Dec 13 '23 16:12 Birch-san

Okay; it looks like adding a C++ API will be difficult, mostly because PyTorch doesn't expose all the nested utils that are present in their source code (i.e. ATen/native/nested sources don't appear to ship even with the latest nightly build.)

I doubt any other plugin / framework with PyTorch integration supports nested tensors yet, but if you happen to know any, please let us know.

This unfortunately leaves us with only one choice: unpacking the nested tensors in the python interface and calling the C++ API on the buffers individually. It might be suboptimal, but it will get the job done. It'll just require a minor change in the API that I've been wanting to make for quite some time. I'll try to get it done by the end of the week.

alihassanijr avatar Dec 28 '23 22:12 alihassanijr

I have a bit of bad news. I think PyTorch still doesn't support custom autograd functions taking in nested tensors, regardless of what the underlying operations are.

I even tried a simple identity op:

import torch
from torch.autograd import Function

class Na1dQKNested(Function):
    @staticmethod
    def forward(ctx, x):
        return x

    @staticmethod
    def jvp(ctx, x):
        return x

    @staticmethod
    def backward(ctx, x):
        return x

Forward pass is of course working fine; but passing in any tensors with requires_grad=True, or with the forward mode AD context will just fail. I tried a couple of the most recent torch releases, and the latest nightly.

But if you happen to be interested more in inference than training at the moment, I can still push that out today.

alihassanijr avatar Dec 29 '23 14:12 alihassanijr

only one choice: unpacking the nested tensors in the python interface and calling the C++ API on the buffers individually

this is probably fine, because for image training the microbatch sizes are pretty small anyway. perhaps there's a way to dispatch all the work asynchronously, then wait for all results to come back? so progress is made on all batch items simultaneously?

PyTorch still doesn't support custom autograd functions taking in nested tensors Forward pass is of course working fine

hmmm darn, thanks for implementing. maybe nested mode forward can still be useful for things like inference services, where multiple users request images of different sizes simultaneously.

I wonder if the pytorch team know that nested tensors for custom autograd functions is needed. not just for simplifying multi-aspect image training, but also for simplifying the training of Mistral-style language models. it's also interesting to see that utilities that could help for invoking C++ on nested tensors, are not exposed. maybe we're too early. but if I get a chance to talk to any pytorch people on Discord, I'll let them know there's a use-case for this.

Birch-san avatar Dec 29 '23 17:12 Birch-san

this is probably fine, because for image training the microbatch sizes are pretty small anyway. perhaps there's a way to dispatch all the work asynchronously, then wait for all results to come back? so progress is made on all batch items simultaneously?

It's all perfectly possible; the only issue is that if we want to use autograd, then our only real choice is to stick to the torch API.

But yes, the change that I made to the NATTEN C++ API is that instead of the C++ API handling the creation of output tensors, we're now handling that in Python. This is generally the preferred solution, and I recall a very hard-to-reproduce bug where the C++ API for new tensors started a separate cuda stream on device index 0, regardless of the device all the input tensors were on.

So this means we can easily preallocate all the memory that's necessary for our output tensors, and then just call the NATTEN C++ handles, and the kernel calls are asynchronous, so the only thing running synchronously would be the checks.

With this change, I got one op (1D - QK) to support nested tensors, and forward pass appeared to work, but when I toggled requires_grad, it hits the following error, and I'm assuming somewhere in the Autograd C++ API, they're calling an API that's only implemented for non-nested tensors:

RuntimeError: Internal error: NestedTensorImpl doesn't support sizes. Please file an issue.

Forward mode AD isn't nearly as nice, it just spits out tens of lines of error traces in the C++ library that are difficult to even comprehend unless one is very familiar with the torch C++ APIs.

I wonder if the pytorch team know that nested tensors for custom autograd functions is needed. not just for simplifying multi-aspect image training, but also for simplifying the training of Mistral-style language models. it's also interesting to see that utilities that could help for invoking C++ on nested tensors, are not exposed. maybe we're too early. but if I get a chance to talk to any pytorch people on Discord, I'll let them know there's a use-case for this.

I think I'm also at a point that I'm convinced it's not something I'm doing wrong, and I haven't seen a single pytorch issue, open or closed, refer to this exact scenario, so I'll probably open an issue there and see how that goes.

But anyway, I'll try and push support for forward pass only today.

alihassanijr avatar Dec 29 '23 18:12 alihassanijr