data icon indicating copy to clipboard operation
data copied to clipboard

Add a Stack node for ensembling

Open TanmayPani opened this issue 4 months ago • 9 comments

🚀 The feature

A torchdata.nodes.Stack node that would use multiple torch.utils.data.Sampler and maybe torchdata.nodes.Batcher instances to generate independent batches from the same dataset and stack them.

A naive implementation (Still working on the actual code):

from typing import Optional, Any, Dict

import torch
from torch import Tensor
from torchdata.nodes import BaseNode


class Stack(BaseNode[Tensor]):
    def __init__(self, *sources : BaseNode[Tensor]):
        super().__init__()
        self.sources = sources

    def next(self):
        try:
            return torch.stack([next(node) for node in self.sources])
        except StopIteration:
            raise StopIteration()

    def reset(self, initial_state: Optional[Dict[str, Any]] = None):
        super().reset(initial_state)
        if initial_state is None:
            for node in self.sources:
                node.reset()
        else:
            for inode, node in enumerate(self.sources):
                node.reset(initial_state[f"source_{inode}"])

    def get_state(self) -> Dict[str, Any]:
        state_dict = {}
        for inode, node in enumerate(self.sources):
            state_dict[f"source_{inode}"] = node.state_dict()
        return state_dict

A very minimal example usage:

from typing import Optional, Callable
from collections.abc import Sequence

from torch.utils.data import RandomSampler
from torchdata.nodes import SamplerWrapper
from torchdata.nodes import Batcher
from torchdata.nodes import ParallelMapper
from torchdata.nodes import Loader, Header

def stacked_batch_loader(
    dataset: Sequence, 
    ncopies: int, 
    batch_size: int,
    collate_fn : Optional[Callable] = None,
    num_workers: int =0,
    load_only_first : Optional[int] = None,
):
    nodes_to_stack = []
    num_workers_per_copy = num_workers // ncopies

    map_fn = collate_fn or getattr(dataset, "__getitems__", None)

    if map_fn is None:
        raise ValueError("Either give a collate_fn or define a __getitems__ to sample mini-batches!")

    for icopy in range(ncopies):
        g = torch.Generator()
        g.manual_seed(icopy % 2**32)
        sampler = RandomSampler(dataset, generator=g)
        node = SamplerWrapper(sampler)
        node = Batcher(node, batch_size=batch_size)
        node = ParallelMapper(node, 
                              map_fn=map_fn, 
                              num_workers=num_workers_per_copy, 
                              method="process", 
                              in_order=True,
                            )
        nodes_to_stack.append(node)

    stacked_node = Stack(*nodes_to_stack)

    if load_only_first is not None:
        stacked_node = Header(stacked_node, load_only_first)

    return Loader(stacked_node)

I would appreciate any feedback on this. I can do a PR if y'all think its useful.

Motivation, pitch

I am trying to do ensemble training using independent instances of the model stacked using torch.func.stack_module_state and using torch.func.vmap to vectorize the forward-pass/gradient computation over bootstrap-sampled and stacked batches from data. Similar to this tutorial.

For my use-case, if my main dataset has 10000 samples, I use say 10 sub-datasets of 5000 sampled using different random seeds to train 10 instances of my model, following that, I use mean and standard-deviations of the 10 predictions from the models to quantify uncertainties in different regions of my feature space.

An iterator that gives stacked batches of different subsamples of data would be convenient for this use case.

Alternatives

No response

Additional context

Sorry if I am mixing up any terminologies, I am not a software engineer strictly speaking, just use torch for ML applications in my research.

TanmayPani avatar Aug 21 '25 15:08 TanmayPani

Hey @TanmayPani this is very interesting. Can you elaborate more on what kind of use-cases would this be useful for ? trying to think when stacking different batches comes in handy?

side implementation thought: we probably not want to loop over each self.sources serially but do it in parallel - need to think of dataset's thread safety in that case

divyanshk avatar Aug 21 '25 18:08 divyanshk

Can you elaborate more on what kind of use-cases would this be useful for ? trying to think when stacking different batches comes in handy?

For me, the biggest use case is efficiently determining the uncertainties of a model's predictions. Particularly the epistemic uncertainty that come from the lack of training data in a particular part of the feature space.

For example, say our data samples look like $\left(\vec{x}, y\right)$ where, input features $\vec{x}$ $\in$ $X$ (feature space) are distributed according to a bi-variate normal distribution with $\vec{\mu}=\left(0, 0 \right)$ and $\vec{\sigma}$ $=$ $\begin{bmatrix}1&0 \\ 0&1\end{bmatrix}$. Now given targets $y$ $\in$ $Y$, lets fit the model:

M(\vec{a}, b)(\vec{x}) = \vec{a}.\vec{x}+b

the predictions with the optimized parameters, $\left(\vec{a}^{\rm fit}, b^{\rm fit}\right)$ after training:

y^{\rm pred}(\vec{x}) = M(\vec{a}^{\rm fit}, b^{\rm fit})(\vec{x})

But since most optimization steps (SGD, Adam, etc.) are stochastic, $\left(\vec{a}^{\rm fit}, b^{\rm fit}\right)$ will be slightly different based on which part of the feature space the model sees first during training.

So if my training dataset has $N$ samples, we would random sample a subset with $\gamma N$ samples ($\gamma<1$) and train the model. After repeating this $k$ times, each time with a different instance of the model (each time model weights are initialized with the exact same values), we would have $k$ predictions $\{y^{\rm pred}_i\}_{i=1}^k$ for any input point $\vec{x}$. So, for a given $\vec{x}$ we can calculate mean and standard deviation as:

\langle y^{\rm pred}\rangle(\vec{x}) = \frac{1}{k}\sum_{i=1}^{k}y^{\rm pred}_i
\delta y^{\rm pred}(\vec{x}) = \sqrt{\frac{1}{k-1}\sum_{i=1}^{k}(y^{\rm pred}_i - \langle y^{\rm pred}\rangle)^2}

Assuiming no issues in training, we would expect $\delta y^{\rm pred}\left(\vec{x}\right)$ to be lower for say, $\vec{x} = \left(-0.1, 0.15\right)$ than for $\vec{x} = \left(2, -1.5\right)$ as there are a lot more samples in the neighborhood of $\left(-0.1, 0.15\right)$ (closer to $\vec{\mu}=\left(0, 0 \right)$) than in the neighborhood of $\left(2, -1.5\right)$. Thus, the statistical uncertainties of the data can be propagated to the predictions.

Now IRL, inputs $\vec{x}$ has 20+ components, $M$ is at-least an moderate sized MLP and dataset has millions of samples. And for a robust estimate of uncertainties, I have to train atleast 50 instances of the model. Done sequentially, this takes days, but as demonstrated here, with torch.func.vmap it takes the time down to hours. So if I can easily stack the batches from sub-samples of my dataset, then applying vmap would allow me to train 10+ instances of the model simultaneously.

Example pseudocodes:

from sklearn.model_selection import train_test_split
import numpy as np
from numpy.random import RandomState

all_indices = np.arange(len(dataset))
nodes_to_stack = []
for icopy in range(10):
    rng_state = RandomState(seed=icopy)
    sub_sample_indices, _ = train_test_split(all_indices, 
                                             train_size=0.7,
                                             random_state=rng_state
                                        )

    sampler = SubsetRandomSampler(sub_sample_indices)
    node = SamplerWrapper(sampler)
    node = Batcher(node, batch_size=batch_size)
    node = ParallelMapper(node, 
                              map_fn=collate_fn,
                                  )
   
    nodes_to_stack.append(node)

stacked_node = Stack(*nodes_to_stack)
loader = Loader(stacked_node)

from functools import partial
from torch.func import stack_module_state
from torch.func import functional_call

#Assume some get_model function that identically initializes model and returns them 
models = [get_model() for _ in range(10)]

params, buffers = stack_module_state(models)

def model_call(model,  params, buffers, *args):
    return functional_call(model, (params, buffers), *args)

base_model = models[0].clone().to(device="meta")
fmodel = partial(model_call, base_model)

for batch in loader:
    stacked_predictions = torch.vmap(fmodel)(params, buffers, batch)
    #calculate losses and do optimizer step next ...

Sorry if this is too long, I am a physicist by trade, so have to learn the data science on the fly. There's probably a more succinct way to say all this

TanmayPani avatar Aug 21 '25 20:08 TanmayPani

@TanmayPani Thanks for explaining! TIL about vmap today. Sounds like a real good use-case. I'm curious, when stacking M varied batches together to practically train N models concurrently, how do you manage the learning rate ? The learning rate depends on the batch size but since we have N batches here, how is that handled? In other words, how do you ensure the training of the stacked batches is identical to doing them serially one by one. You can extend the my learning rate doubt to any hyperparameter.

divyanshk avatar Aug 21 '25 22:08 divyanshk

The learning rate depends on the batch size but since we have N batches here, how is that handled? In other words, how do you ensure the training of the stacked batches is identical to doing them serially one by one. You can extend the my learning rate doubt to any hyperparameter.

There are two options that I have explored:

  1. The params dict that you get from the torch.func.stack_module_state contains stacked parameters from each module. These can be sliced and assigned as param groups to the optimizer. In optimizer.step it optimizes the parameters groups independently so the code would look something like:
from functools import partial
from torch.optim.adam import Adam
from torch.func import functional_call, vmap, stack_module_state
from torch.nn.functional import binary_cross_entropy_with_logits

def loss(fmodel, params, buffers, X, y):
    y_pred = fmodel((params, buffers), X)
    loss = binary_cross_entropy_with_logits(y_pred, y)
    return loss

def get_model():
    #get model architeture here...

def get_loader():
    #get batch loader here...

if __name__ == "__main__":
    ncopies = 10
    
    model = get_model()
    base_model = model.clone().to("meta") #don't need storage for this
    params, buffers = stack_module_state([model.clone() for _ in range(ncopies)])
    fmodel = partial(functional_call, base_model)

    params_for_optim = [{k : v[icopy]  for k, v in params.items()} for icopy in range(ncopies)]
    optimizer = Adam(params_for_optim, lr = 0.001)

    loader = get_loader()

    for epoch in range(100):
        for X, y in loader:
            stacked_losses = vmap(loss)(params, buffers, X, y)
            optimizer.zero_grad()
            stacked_losses.backward()
            optimizer.step()

You can use a separate lr scheduler for each param-group. Only drawback is that if you look inside torch.optim.Adam.step(), it is still using a for loop to go over the param groups, while most of the work is being done in the callable "torch.optim.adam.adam", which can potentially be vmapped.

  1. (Lot more involved) I can make the torch.optim.Optimizers into a functional form like in JAX. I have used the tensordict package to make handling of optimizer states easier
from typing import Self
from typing import Callable

from torch import Tensor

from tensordict import NonTensorDataBase, TensorDict
from tensordict import TensorClass

type Params = dict[str, Tensor]

class OptimState(TensorClass["nocast"]):
    def check_state(self) -> None:
        raise NotImplementedError("Use a subclass defined for a paritcular optimizer!")

    def __call__(self, params : Params) -> Self:
        raise NotImplementedError("Use a subclass defined for a paritcular optimizer!")

class GradientTransformations:
    def __init__(
        self, 
        optimizer_call : Callable,
        optim_state : OptimState,
    ):
        self.optimizer_call = optimizer_call
        self.optim_state = optim_state
        self.optim_state.check_state()

    def init(self, params : Params) -> OptimState:
        return self.optim_state(params)

    def __call__(
        self, 
        params : Params, grads : Params, 
        state : OptimState, /,
    ):
        _state = {}
        for key, val in state.items():
            if isinstance(val, (TensorDict, dict)):
                _state[key] = list(val.values())
            elif isinstance(val, NonTensorDataBase):
                _state[key] = val.data
            else:
                _state[key] = val

        self.optimizer_call(
            list(params.values()),
            list(grads.values()),
            **_state,
        )

        return params, state

Then, say for Adam:

from typing import Self
from typing import Union
from typing import Optional

from dataclasses import field

import torch
from torch import Tensor
from torch.optim.adam import adam

from tensordict import TensorDict

type Params = dict[str, Tensor]

class AdamOptState(OptimState):
    exp_avgs : TensorDict = field(default_factory = TensorDict)
    exp_avg_sqs : TensorDict = field(default_factory = TensorDict)
    max_exp_avg_sqs : TensorDict = field(default_factory = TensorDict)
    state_steps : TensorDict = field(default_factory = TensorDict)
    lr: Union[float, Tensor] = 1e-3
    beta1: Union[float, Tensor] = 0.9
    beta2: Union[float, Tensor] = 0.999
    eps: float = 1e-8
    weight_decay: float = 0
    amsgrad: bool = False
    foreach: Optional[bool] = None
    maximize: bool = False
    capturable: bool = False
    differentiable: bool = False
    fused: Optional[bool] = None
    decoupled_weight_decay: bool = False

    def check_state(self):
        if isinstance(self.lr, Tensor):
            if self.foreach and not self.capturable:
                raise ValueError(
                    "lr as a Tensor is not supported for capturable=False and foreach=True"
                )
            if self.lr.numel() != 1:
                raise ValueError("Tensor lr must be 1-element")
        
        if not 0.0 <= self.lr:
            raise ValueError(f"Invalid learning rate: {self.lr}")
        if not 0.0 <= self.eps:
            raise ValueError(f"Invalid epsilon value: {self.eps}")
        if not 0.0 <= self.beta1 < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {self.beta1}")
        if not 0.0 <= self.beta2 < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {self.beta2}")
        if not 0.0 <= self.weight_decay:
            raise ValueError(f"Invalid weight_decay value: {self.weight_decay}")
        if not (
            (isinstance(self.beta1, float) and isinstance(self.beta2, float))
            or (isinstance(self.beta1, Tensor) and isinstance(self.beta2, Tensor))
        ):
            raise ValueError("betas must be either both floats or both Tensors")
        if isinstance(self.beta1, Tensor):
            if not self.capturable and self.foreach:
                raise ValueError(
                    "beta1 as a Tensor is not supported for capturable=False and foreach=True"
                )
            if self.beta1.numel() != 1:
                raise ValueError("Tensor betas[0] must be 1-element")
        if isinstance(self.beta2, Tensor):
            if not self.capturable and self.foreach:
                raise ValueError(
                    "beta2 as a Tensor is not supported for capturable=False and foreach=True"
                )
            if self.beta2.numel() != 1:
                raise ValueError("Tensor betas[1] must be 1-element")

       
    def __call__(self, params: Params)->Self:
        for param_name, param in params.items():
            self.exp_avgs[param_name] = torch.zeros_like(param)
            self.exp_avg_sqs[param_name] = torch.zeros_like(param)
            self.max_exp_avg_sqs[param_name] = torch.zeros_like(param)
            self.state_steps[param_name] = torch.tensor(
                (), dtype=torch.float32, device=param.device
            ).new_zeros(())

        return self

class FuncAdam(GradientTransformations):
    def __init__(
        self, 
        lr: Union[float, Tensor] = 1e-3,
        betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 0,
        amsgrad: bool = False,
        *,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        capturable: bool = True,
        differentiable: bool = False,
        fused: Optional[bool] = None,
        decoupled_weight_decay: bool = False,
    ):

        optim_state = {
            "lr": lr,
            "beta1": betas[0],
            "beta2": betas[1],
            "eps": eps,
            "weight_decay": weight_decay,
            "amsgrad": amsgrad,
            "maximize": maximize,
            "foreach": foreach,
            "capturable": capturable,
            "differentiable": differentiable,
            "fused": fused,
            "decoupled_weight_decay": decoupled_weight_decay,
        }
        super().__init__(adam, AdamOptState(**optim_state))

Then the training loop becomes:

from functools import partial
from torch.func import functional_call, stack_module_state
from torch.func import vmap, grad_and_value
from torch.nn.functional import binary_cross_entropy_with_logits

def loss(base_model, params, buffers, X, y):
    y_pred = functional_call(base_model, (params, buffers), X)
    loss = binary_cross_entropy_with_logits(y_pred, y)
    return loss

def train_step(loss_fn, optimizer, optimizer_state, params, buffers, X, y):
    grads, loss = grad_and_value(loss_fn)(params, buffers, X, y)
    params, optimizer_state = optimizer(params, grads, optimizer_state)
    return loss, params, optimizer_state

def get_model():
    #get model architeture here...

def get_loader():
    #get batch loader here...

if __name__ == "__main__":
    ncopies = 10
    
    model = get_model()
    base_model = model.clone().to("meta") #don't need storage for this
    params, buffers = stack_module_state([model.clone() for _ in range(ncopies)])
    loss_fn = partial(loss, base_model)

    optimizer = FuncAdam(lr = 0.001)
    optimizer_state = optimizer.init({k : v[0]  for k, v in params.items()})

    # we can stack optimizer states like normal tensors thanks to tensordicts...
    stacked_optimizer_state = torch.stack([optimizer_state.detach() for _ in range(ncopies)])

   train_step_fn = partial(train_step, loss_fn, optimizer)

    loader = get_loader()

    for epoch in range(100):
        for X, y in loader:
            stacked_losses, params,  stacked_optimizer_state= vmap(train_step_fn)(stacked_optimizer_state, params, buffers, X, y)

While this is in-principle faster, I need to figure out a way to implement lr-scheduling. I am thinking about making a LearningRate inheriting from torch.Tensor with the lr-scheduling functions implemented as classmethods, that way I can stack the lrs and send them into the vmap.

TanmayPani avatar Aug 22 '25 16:08 TanmayPani

side implementation thought: we probably not want to loop over each self.sources serially but do it in parallel - need to think of dataset's thread safety in that case

I am currently working on a torch.utils.data.Sampler subclass that would generate stacked batch indices, then the torchdata.nodes.BaseNode subclass could vectorize the fetching part by using vmap on the dataset.__getitems__ or a user provided collate function. I'll post it here as soon as an implementation is ready. I am taking inspiration from this article which tries to achieve same thing as I in JAX. Any thoughts in the meantime @divyanshk ?

TanmayPani avatar Aug 22 '25 16:08 TanmayPani

@TanmayPani thanks for the code. I will go through it, in the meantime I think for this feature we should think of all the features that would be needed (stacking batching, 'stacking' samplers) and add them together as that helps build a cohosive story on what truly is being unblocked here. Which in my mind is "Parallelizing neural networks training on single GPU, simply"

divyanshk avatar Aug 22 '25 19:08 divyanshk

@divyanshk I think I have an working implementation ready, what would be the procedure for iterating on it between us? Or do we go for the pull request directly?

TanmayPani avatar Aug 26 '25 18:08 TanmayPani

@TanmayPani Let's create the PR! Also check out any recent nodes added for any patterns / unit testing etc. cc. @ramanishsingh

divyanshk avatar Aug 29 '25 23:08 divyanshk

Thanks @divyanshk, I forked the repo, should be done pushing the StackedBatchSampler, and associated BaseNode subclass implementations and their tests by Tuesday.

TanmayPani avatar Aug 30 '25 20:08 TanmayPani