lightning-thunder
lightning-thunder copied to clipboard
Cannot correctly compute cross entropy loss (backward) when the reduction used is "sum" and torch executor is used.
Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.
🐛 Bug
With the torch executor, Thunder cannot correctly compute cross entropy loss (backward) when the reduction used is "sum". It works correctly with "mean"
To Reproduce
Run the code snippet below.
Code sample
device='cuda'
executor=TorchExecutor
def cross_entropy_fn(logits, labels):
return torch.nn.functional.cross_entropy(logits, labels, reduction="sum")
sequence_length, vocab_size = 8192, 32064
logits = torch.rand((sequence_length, vocab_size), device=device, dtype=thunder_dtype, requires_grad=True)
labels = torch.randint(0, sequence_length, (sequence_length,), requires_grad=False, device=device)
fn = executor.make_callable(cross_entropy_fn)
a = fn(logits, labels)
a.backward()
print(logits.grad)
Expected behavior
Should run correctly to completion with output matching that of torch compile.
Environment
- PyTorch Version (e.g., 1.0):
- OS (e.g., Linux):
- How you installed PyTorch (
conda,pip, source): - Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information:
Additional context
A similar issue that shows up is this code:
from collections.abc import Sequence
import itertools
from functools import partial
from typing import Any
# NOTE: Dependency on fdm and NumPy is temporary.
# We will remove it once we have a native way to compute numerical derivatives.
import fdm
import numpy as np
import pytest
import torch
import thunder
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
import thunder.clang as clang
from thunder import torch as ltorch
from thunder.core.dtypes import is_exact_dtype, to_dtype as thunder_dtype
from thunder.core.pytree import tree_map, tree_flatten
from thunder.core.transforms import vjp, grad, check_bsym_for_vjp
from thunder.core.utils import flatten_func, is_cpu_scalar_tensor
from thunder.tests.framework import (
instantiate,
NOTHING,
ops,
run_snippet,
assert_closer,
IN_CI,
NVFUSER_AVAILABLE,
requiresCUDA,
version_between,
)
from thunder.tests.make_tensor import make_tensor, make_tensor_like
from thunder.tests.opinfos import get_opinfo, opinfos, tensor_creation_ops
from thunder.executors.nvfuserex import nvfuser_version, nvfuserex
import thunder.executors as executors
from thunder.tests.framework import (
instantiate,
TestExecutor,
NOTHING,
ops,
run_snippet,
assert_closer,
nvFuserExecutor,
TorchExecutor,
)
def mini_model(logits, labels):
labels = torch.nn.functional.pad(labels, (0, 1))
labels = labels[1 : labels.shape[-1]]
logits = logits.to(dtype=torch.float32)
logits = logits.squeeze(dim=0)
return torch.nn.functional.cross_entropy(logits, labels)
batch_size = 4096
vocab_size = 8192
input = torch.randn(
1,
batch_size,
vocab_size,
device="cuda",
dtype=torch.bfloat16,
requires_grad=True,
)
labels = torch.randint(
0,
vocab_size - 1,
(batch_size,),
device="cuda",
requires_grad=False,
)
executor = TorchExecutor
a = mini_model(input, labels)
print(a)
b = executor.make_callable(mini_model)
print(b(input, labels))
If I comment out:
def mini_model(logits, labels):
# labels = torch.nn.functional.pad(labels, (0, 1))
# labels = labels[1 : labels.shape[-1]]
logits = logits.to(dtype=torch.float32)
logits = logits.squeeze(dim=0)
return torch.nn.functional.cross_entropy(logits, labels)
Then it works correctly.
The interesting thing is it for the nvfuser backend it fails, but if I turn on the custom decomposition for cross entropy loss (for nvFuser) it works correctly.