lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Cannot correctly compute cross entropy loss (backward) when the reduction used is "sum" and torch executor is used.

Open protonu opened this issue 7 months ago • 2 comments

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

With the torch executor, Thunder cannot correctly compute cross entropy loss (backward) when the reduction used is "sum". It works correctly with "mean"

To Reproduce

Run the code snippet below.

Code sample

device='cuda'
executor=TorchExecutor


def cross_entropy_fn(logits, labels):
    return torch.nn.functional.cross_entropy(logits, labels, reduction="sum")


sequence_length, vocab_size = 8192, 32064
logits = torch.rand((sequence_length, vocab_size), device=device, dtype=thunder_dtype, requires_grad=True)
labels = torch.randint(0, sequence_length, (sequence_length,), requires_grad=False, device=device)

fn = executor.make_callable(cross_entropy_fn)
a = fn(logits, labels)
a.backward()
print(logits.grad)

Expected behavior

Should run correctly to completion with output matching that of torch compile.

Environment

  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

protonu avatar May 09 '25 19:05 protonu

A similar issue that shows up is this code:

from collections.abc import Sequence
import itertools
from functools import partial
from typing import Any

# NOTE: Dependency on fdm and NumPy is temporary.
# We will remove it once we have a native way to compute numerical derivatives.
import fdm
import numpy as np
import pytest
import torch

import thunder
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
import thunder.clang as clang

from thunder import torch as ltorch
from thunder.core.dtypes import is_exact_dtype, to_dtype as thunder_dtype
from thunder.core.pytree import tree_map, tree_flatten
from thunder.core.transforms import vjp, grad, check_bsym_for_vjp
from thunder.core.utils import flatten_func, is_cpu_scalar_tensor
from thunder.tests.framework import (
    instantiate,
    NOTHING,
    ops,
    run_snippet,
    assert_closer,
    IN_CI,
    NVFUSER_AVAILABLE,
    requiresCUDA,
    version_between,
)
from thunder.tests.make_tensor import make_tensor, make_tensor_like
from thunder.tests.opinfos import get_opinfo, opinfos, tensor_creation_ops


from thunder.executors.nvfuserex import nvfuser_version, nvfuserex
import thunder.executors as executors

from thunder.tests.framework import (
    instantiate,
    TestExecutor,
    NOTHING,
    ops,
    run_snippet,
    assert_closer,
    nvFuserExecutor,
    TorchExecutor,
)


def mini_model(logits, labels):
    labels = torch.nn.functional.pad(labels, (0, 1))
    labels = labels[1 : labels.shape[-1]]
    logits = logits.to(dtype=torch.float32)
    logits = logits.squeeze(dim=0)
    return torch.nn.functional.cross_entropy(logits, labels)


batch_size = 4096
vocab_size = 8192

input = torch.randn(
    1,
    batch_size,
    vocab_size,
    device="cuda",
    dtype=torch.bfloat16,
    requires_grad=True,
)
labels = torch.randint(
    0,
    vocab_size - 1,
    (batch_size,),
    device="cuda",
    requires_grad=False,
)

executor = TorchExecutor

a = mini_model(input, labels)
print(a)
b = executor.make_callable(mini_model)
print(b(input, labels))

If I comment out:

def mini_model(logits, labels):
    # labels = torch.nn.functional.pad(labels, (0, 1))
    # labels = labels[1 : labels.shape[-1]]
    logits = logits.to(dtype=torch.float32)
    logits = logits.squeeze(dim=0)
    return torch.nn.functional.cross_entropy(logits, labels)

Then it works correctly.

protonu avatar May 21 '25 20:05 protonu

The interesting thing is it for the nvfuser backend it fails, but if I turn on the custom decomposition for cross entropy loss (for nvFuser) it works correctly.

protonu avatar May 21 '25 20:05 protonu