burn icon indicating copy to clipboard operation
burn copied to clipboard

Cached Data Corruption

Open Gadersd opened this issue 1 year ago • 10 comments

Summary

Caching in whisper-burn becomes corrupted when the Whisper logits output is not cloned before computing log probabilities. This issue occurs on both CUDA and CPU when using burn-tch, but not with WGPU. The root cause seems to be a bug in burn-tch, tch, or libtorch in which a pointer mutates data pointed to by a different pointer.

Steps to Replicate

  1. Check out project branch: https://github.com/Gadersd/whisper-burn/tree/cache_bug

  2. Run whisper-burn with burn-tch backend

  3. Observe corrupted outputs when cloning is disabled (line 272 of transcribe.rs)

Depth: 20 
Diff = 43.029815673828125
Depth: 21
Diff = 43.029808044433594 
  1. No corruption occurs when cloning:
Depth: 20
Diff = 0.0002593994140625 
Depth: 21
Diff = 0.00017833709716796875

Expected vs Actual Behavior

  • Expected: Cached data remains uncorrupted regardless of cloning
  • Actual: Corruption occurs without cloning

Root Cause Analysis

  • Issue only occurs with tch backend, not WGPU
  • Likely a bug in unsafe code in burn-tch, tch, or libtorch

Code Snippet Demonstrating Issue

// BUGGED! Should clone because SOMEONE used unsafe code somewhere
let log_probs = log_softmax(logits_tensor, 0).into_data().value;

//let log_probs = log_softmax(logits_tensor.clone(), 0).into_data().value; 

Gadersd avatar Oct 18 '23 16:10 Gadersd

I've been able to produce a cache bug on wgpu as well. See branch https://github.com/Gadersd/whisper-burn/tree/wgpu_cache_bug. Running with wgpu with a modification of one line of code results in transcription failure.

Line 654 of model/mod.rs Changing qkv_attention(q2, k2, v2, None, self.n_head) to qkv_attention(q, k2, v2, None, self.n_head) results in corrupted output but diff demonstrates that q is exactly equal to q2 in each iteration.

let diff = (q.clone() - q2.clone()).flatten::<1>(0, 2).abs().max().into_scalar().elem::<f64>();
println!("Diff = {}", diff);

let wv = qkv_attention(q, k2, v2, None, self.n_head);

Gadersd avatar Oct 18 '23 17:10 Gadersd

This is surely a bug with an inplace operation, we should investigate which tensor is modified, but shouldn't.

nathanielsimard avatar Oct 18 '23 20:10 nathanielsimard

For the latter example the query cache is modified when it shouldn't be. Replacing let wv = qkv_attention(q, k2, v2, None, self.n_head); with let wv = qkv_attention(q.clone(), k2, v2, None, self.n_head); still results in a corrupted cache. However, using let wv = qkv_attention(Tensor::from_data(q.clone().into_data()), k2, v2, None, self.n_head); works without issue. Evidently cloning isn't actually producing a clone.

Gadersd avatar Oct 18 '23 21:10 Gadersd

I have quite a hard time looking into the codebase and trying to see what may be causing the problem. Do you have an intuition for what kind of situation might create this bug (probably related to caching) and how to create a minimal example to reproduce? This would be a huge help!

nathanielsimard avatar Oct 18 '23 23:10 nathanielsimard

It is not an easy bug to reproduce or create a minimal example for as it doesn't seem to occur in every case and behaves differently depending on the backend. My intuition tells me it smells like an unsafe code problem. I never write unsafe code so I think it probably originates in burn's tensor memory management code.

My previous example in which passing a cloned tensor to a function modifies the original tensor is a strong indication that there is memory violation in an unsafe code block somewhere. The caching code which I pulled from burn clones the tensors so the caches and tensors returned from forward_autoregressive should occupy different memory locations if one is modified but for whatever reason sometimes the memory separation does not happen properly and the modification of one leads to a modification of the other. I encountered the same kind of issue in both the burn-tch (GPU and CPU) and burn-wgpu backends so I don't think this is a backend specific problem.

Gadersd avatar Oct 19 '23 00:10 Gadersd

I have isolated the bug further by eliminating the caching. The caching is apparently not the cause of the bug. Adding a single clone() to q at line 658 of model/mod.rs in the wgpu_cache_bug branch corrupts the results as shown below.

Running command cargo run --release --bin transcribe --features wgpu-backend tiny_en audio16k.wav en transcription.txt

Line 658 in model/mod.rs:

let wv = qkv_attention(q.clone(), k2, v2, None, self.n_head);
println!("Data: {:?}", wv.clone().slice([0..1, 0..1, 0..20]).into_data());

Result:

Data: Data { value: [0.22148757, 0.1708659, -0.20998125, -0.05612477, 0.2864229, -0.21298157, -0.38640028, 0.016380915, -0.04644726, 0.29196838, -0.005733948, -0.14728771, -0.1574831, -0.078477256, -0.3404586, 0.13537605, -0.20357189, 0.107251376, -0.052113228, 0.114457615], shape: Shape { dims: [1, 1, 20] } }
Depth: 3
Chunk 0:  10.

Transcription finished.

For

let wv = qkv_attention(q, k2, v2, None, self.n_head);
println!("Data: {:?}", wv.clone().slice([0..1, 0..1, 0..20]).into_data());

The result is correct:

Data: Data { value: [0.18244137, 0.08941843, -0.16000068, 0.03786571, 0.18850096, -0.19262771, -0.2876657, -0.0066655697, -0.068978235, 0.30545524, 0.017571663, -0.16025116, -0.18520899, -0.07191827, -0.29022524, 0.115666814, -0.1865571, 0.08898805, 0.061428268, 0.1264473], shape: Shape { dims: [1, 1, 20] } }
Depth: 22
Chunk 0:  Hello, I am the whisper machine learning model. If you see this as text then I am working properly.

Transcription finished.

What might be the cause of this? burn's reference counting going awry?

Gadersd avatar Oct 21 '23 14:10 Gadersd

@Gadersd I don't think the reference counting is creating a problem, but I may be wrong. Perhaps a kernel is wrongly implemented for the inplace operation, or the readonly operation, which isn't verified by our test suite. Ultimately, what we need is a scenario with custom tensors where we can reproduce the bug:

fn test() {
    let input1 = Tensor::ones();
    let input2 = Tensor::ones();
    let input3 = Tensor::ones();
    
    let output_no_clone = a_function(input1, input2, input3);
    
    let input1 = Tensor::ones();
    let input2 = Tensor::ones();
    let input3 = Tensor::ones();
    
    let output_with_clone = a_function(input1.clone(), input2.clone(), input3.clone());
    
    output_no_clone.into_data().assert_approx_eq(&output_with_clone.into_data(), 3);
}

It would be extremely helpful to fix the current bug. I also think we should create such a test suite for all operations automatically for all backends.

nathanielsimard avatar Oct 22 '23 15:10 nathanielsimard

@Gadersd Also, make sure you use the main version of wgpu, I remember a flaky test existing with wgpu that was related to data corruption.

nathanielsimard avatar Oct 22 '23 16:10 nathanielsimard

Running cargo update fixed that specific example, but the data is still corrupted in other places. This is one of those tricky bugs that seems to pop up in random places and manifests in different ways as the code is modified. I could probably spend weeks on this so I'll have to call it quits for now. Hopefully someone will eventually find a minimal example of it that is easier to analyze.

Gadersd avatar Oct 22 '23 21:10 Gadersd

I gave it one more go and found that the main line causing issue is let wv = qkv_attention(q, k2, v2, None, self.n_head);. Simply using let q = Tensor::from_data(self.query.forward(x).into_data()); rather than let q = self.query.forward(x); causes wv to be corrupted. This issue is highly contextual. Using Tensor::from_data for q doesn't lead to corruption when the exact same input tensor values are tested in the main function. Due to the contextual nature of this bug I don't think I will be able to find a minimal example. Unfortunately, this bug may exist in the wild for a while.

Gadersd avatar Oct 23 '23 01:10 Gadersd

@Gadersd Could you check if the new fix is solving your problem? https://github.com/tracel-ai/burn/pull/1434 I believe it should.

nathanielsimard avatar Mar 08 '24 14:03 nathanielsimard

Probably it's fixed via #1434

antimora avatar Mar 29 '24 01:03 antimora