Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Ensure In-place correctness checks work properly

Open Tcc0403 opened this issue 1 year ago • 6 comments
trafficstars

Summary

Fix #272

It's a show case of how to trigger error properly. I only apply it to cross_entropy for demonstration, can apply to others if we want.

Testing Done

same gist as the issue's

import torch
import torch.nn.functional as F

from liger_kernel.transformers.functional import liger_cross_entropy


def run_inplace_experiment(logits_p, logits_q, cross_entropy_fn):
    _p = logits_p.clone().detach().requires_grad_(True)
    _p.retain_grad()
    softmax = torch.nn.Softmax(dim=-1)
    p = softmax(_p)
    p.retain_grad()
    loss = cross_entropy_fn(p, logits_q)
    loss.backward(retain_graph=True)

    print(f"Cross Entropy Loss: {loss.item()}")
    print(f"Input _p: {_p}")
    print(f"Input logits_q: {logits_q}")
    print(f"Gradients of p (batch item 0): {p.grad[0]}")
    print(f"Gradients of _p (batch item 0): {_p.grad[0]}")


torch.manual_seed(0)
logits_p = torch.randn(8, 8, requires_grad=True, device="cuda")
logits_q = torch.randint(0, 8, (8,), device="cuda", dtype=torch.long)


run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=F.cross_entropy)

print()
print("LIGER:")
run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=liger_cross_entropy)

Properly raised the error

❯ python3 inplace_bug.py
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
        [-1.0745, -0.3631, -1.6711,  2.2655,  0.3117, -0.1842,  1.2866,  1.1820],
        [-0.1271,  1.2169,  1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
        [ 0.0697, -0.0074,  1.8969,  0.6878, -0.0779, -0.8373,  1.3506, -0.2879],
        [-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
        [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806,  1.0252, -1.4622, -0.7554],
        [-0.1836,  0.3824,  0.3918, -0.0830,  0.8971, -1.1123,  0.1116,  0.4863],
        [-0.5499, -0.3231, -0.5469,  0.9049,  0.2837,  0.1210,  0.4730, -1.0823]],
       device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149,  0.0157,  0.0140,  0.0174, -0.1086,  0.0154,  0.0153,  0.0159],
       device='cuda:0')
Gradients of _p (batch item 0): tensor([ 0.0017,  0.0029,  0.0003,  0.0055, -0.0182,  0.0024,  0.0023,  0.0032],
       device='cuda:0')

LIGER:
Traceback (most recent call last):
  File "/home/tcc/Liger-Kernel/inplace_bug.py", line 36, in <module>
    run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=liger_cross_entropy)
  File "/home/tcc/Liger-Kernel/inplace_bug.py", line 18, in run_inplace_experiment
    loss.backward(retain_graph=True)
  File "/home/tcc/Liger-Kernel/.venv/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/home/tcc/Liger-Kernel/.venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/home/tcc/Liger-Kernel/.venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 8]], which is output 0 of SoftmaxBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
  • Hardware Type: <BLANK>
  • [ ] run make test to ensure correctness
  • [ ] run make checkstyle to ensure code style
  • [ ] run make test-convergence to ensure convergence

Tcc0403 avatar Sep 26 '24 03:09 Tcc0403

bench i did some benchmarks on H100, adding any torch's inplace op increases time cost by roughtly 50% (original -> with_hint = 23 -> 32 ms for 128k vocab size).

so i guess its not worth it? full stdout:

**************************************
     BENCHMARKING SPEED for CROSS_ENTROPY
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "hint",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      0.7055680155754089,
      1.253440022468567,
      2.433199882507324,
      4.869984149932861,
      10.520671844482422,
      23.0479679107666
    ],
    "y_values_20": [
      0.7003520131111145,
      1.2493120431900024,
      2.4296703338623047,
      4.865350246429443,
      10.509568214416504,
      23.046571731567383
    ],
    "y_values_80": [
      0.7126911878585815,
      1.2599040269851685,
      2.4357247352600098,
      4.873280048370361,
      10.537690162658691,
      23.04932403564453
    ],
    "timestamp": "2024-10-02 23:21:11",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "original",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      0.41332799196243286,
      0.6783679723739624,
      1.2743680477142334,
      2.535327911376953,
      5.867008209228516,
      13.692416191101074
    ],
    "y_values_20": [
      0.41091200709342957,
      0.6760640144348145,
      1.2711039781570435,
      2.5320703983306885,
      5.845632076263428,
      13.691308975219727
    ],
    "y_values_80": [
      0.41729921102523804,
      0.6832832098007202,
      1.2798080444335938,
      2.539724826812744,
      5.877439975738525,
      13.695513725280762
    ],
    "timestamp": "2024-10-02 23:21:12",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "hint",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      1.1566720008850098,
      2.042720079421997,
      3.7615039348602295,
      7.380864143371582,
      15.551360130310059,
      32.693214416503906
    ],
    "y_values_20": [
      1.1196672916412354,
      2.026726245880127,
      3.7539713382720947,
      7.370649337768555,
      15.547072410583496,
      32.686397552490234
    ],
    "y_values_80": [
      1.1723840236663818,
      2.0592191219329834,
      3.7860095500946045,
      7.387167930603027,
      15.649503707885742,
      32.69830322265625
    ],
    "timestamp": "2024-10-02 23:21:13",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "original",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      0.8923839926719666,
      1.45249605178833,
      2.6448960304260254,
      5.078847885131836,
      10.8754243850708,
      23.530689239501953
    ],
    "y_values_20": [
      0.8872384428977966,
      1.4472639560699463,
      2.63141131401062,
      5.075232028961182,
      10.859647750854492,
      23.527257919311523
    ],
    "y_values_80": [
      0.9067007899284363,
      1.4663935899734497,
      2.6562368869781494,
      5.088294506072998,
      11.039955139160156,
      23.535194396972656
    ],
    "timestamp": "2024-10-02 23:21:14",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  }
]
**************************************
     BENCHMARKING MEMORY for CROSS_ENTROPY
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "hint",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "y_values_20": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "y_values_80": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "timestamp": "2024-10-02 23:21:15",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "original",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "y_values_20": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "y_values_80": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "timestamp": "2024-10-02 23:21:15",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  }
]

Tcc0403 avatar Oct 02 '24 23:10 Tcc0403

Yeah, i was looking if we can call bump() from python... 50% cost does not worth ..

lancerts avatar Oct 03 '24 04:10 lancerts

I am wondering why the error does not happen for normal case?

ByronHsu avatar Oct 03 '24 04:10 ByronHsu

@ByronHsu

I am wondering why the error does not happen for normal case?

I left an explanation in issue

Tcc0403 avatar Oct 04 '24 15:10 Tcc0403

With @mgrabban's suggestion in #343, I made another implmentation with mark_dirty().

Note: I haven't benchmarked this new approach against current liger_ce, will do it in few days. draft gist for benchmark

Result with new approach

import torch
import torch.nn.functional as F

from liger_kernel.transformers.functional import liger_cross_entropy


def run_inplace_experiment(
    logits_p, logits_q, cross_entropy_fn, is_liger=False, use_inplace=False
):
    _p = logits_p.clone().detach().requires_grad_(True)
    _p.retain_grad()
    softmax = torch.nn.Softmax(dim=-1)
    p = softmax(_p)
    p.retain_grad()
    try:
        if is_liger:

            loss, _ = cross_entropy_fn(p, logits_q, -100, 0.0, "mean", use_inplace)
        else:
            loss = cross_entropy_fn(p, logits_q)

        loss.backward(retain_graph=True)

        print(f"Cross Entropy Loss: {loss.item()}")
        print(f"Input _p: {_p}")
        print(f"Input logits_q: {logits_q}")
        print(f"Gradients of p (batch item 0): {p.grad[0]}")
        print(f"Gradients of _p (batch item 0): {_p.grad[0]}")
    except Exception as e:
        print(e)


torch.manual_seed(0)
logits_p = torch.randn(8, 8, requires_grad=True, device="cuda")
logits_q = torch.randint(0, 8, (8,), device="cuda", dtype=torch.long)


run_inplace_experiment(
    logits_p, logits_q, cross_entropy_fn=F.cross_entropy, is_liger=False
)

print()
print("LIGER use_inplace=True:")
run_inplace_experiment(
    logits_p,
    logits_q,
    cross_entropy_fn=liger_cross_entropy,
    is_liger=True,
    use_inplace=True,
)

print()
print("LIGER use_inplace=False:")
run_inplace_experiment(
    logits_p,
    logits_q,
    cross_entropy_fn=liger_cross_entropy,
    is_liger=True,
    use_inplace=False,
)
❯ python3 inplace_bug.py
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
        [-1.0745, -0.3631, -1.6711,  2.2655,  0.3117, -0.1842,  1.2866,  1.1820],
        [-0.1271,  1.2169,  1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
        [ 0.0697, -0.0074,  1.8969,  0.6878, -0.0779, -0.8373,  1.3506, -0.2879],
        [-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
        [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806,  1.0252, -1.4622, -0.7554],
        [-0.1836,  0.3824,  0.3918, -0.0830,  0.8971, -1.1123,  0.1116,  0.4863],
        [-0.5499, -0.3231, -0.5469,  0.9049,  0.2837,  0.1210,  0.4730, -1.0823]],
       device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149,  0.0157,  0.0140,  0.0174, -0.1086,  0.0154,  0.0153,  0.0159],
       device='cuda:0')
Gradients of _p (batch item 0): tensor([ 0.0017,  0.0029,  0.0003,  0.0055, -0.0182,  0.0024,  0.0023,  0.0032],
       device='cuda:0')

LIGER use_inplace=True:
inplace=True
one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 8]], which is output 0 of SoftmaxBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

LIGER use_inplace=False:
inplace=False
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
        [-1.0745, -0.3631, -1.6711,  2.2655,  0.3117, -0.1842,  1.2866,  1.1820],
        [-0.1271,  1.2169,  1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
        [ 0.0697, -0.0074,  1.8969,  0.6878, -0.0779, -0.8373,  1.3506, -0.2879],
        [-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
        [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806,  1.0252, -1.4622, -0.7554],
        [-0.1836,  0.3824,  0.3918, -0.0830,  0.8971, -1.1123,  0.1116,  0.4863],
        [-0.5499, -0.3231, -0.5469,  0.9049,  0.2837,  0.1210,  0.4730, -1.0823]],
       device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149,  0.0157,  0.0140,  0.0174, -0.1086,  0.0154,  0.0153,  0.0159],
       device='cuda:0')
Gradients of _p (batch item 0): tensor([ 0.0017,  0.0029,  0.0003,  0.0055, -0.0182,  0.0024,  0.0023,  0.0032],
       device='cuda:0')

cc @ByronHsu @lancerts

Tcc0403 avatar Nov 04 '24 19:11 Tcc0403

Update: Here's the benchmark against liger's ce speed is slower by 17% (TODO: investigate where the overhead occurs) memory is double (i guess its because the tensors passed in mark_dirty() must return as output, maybe there's a way to reduce it?)

# speed full
# fix version
"y_values_50": [
      1.0791840553283691,
      1.9665439128875732,
      3.3525118827819824,
      6.478032112121582,
      13.283712387084961,
      27.797151565551758
    ],
# original
"y_values_50": [
      0.9956480264663696,
      2.710239887237549,
      3.014863967895508,
      5.4488959312438965,
      11.20921516418457,
      23.64148712158203
    ],

# memory full
# fix version
"y_values_50": [
      512.12744140625,
      1024.12744140625,
      2048.12744140625,
      4096.12744140625,
      8192.126953125,
      16384.126953125
    ],
# original
 "y_values_50": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],

Tcc0403 avatar Nov 04 '24 20:11 Tcc0403

not required for now

Tcc0403 avatar Jun 07 '25 02:06 Tcc0403