encode_jpeg generates noise when processing 4k image
Hi I tried the latest torchvision 19.0 with pytorch2.4. I found that the encode_jpeg func had a problem when processing 4k image. For example, I have a 4K image tensor, which size is (3, 2160, 3840), then I use torchvision.io.encode_jpeg() in a loop, in the first loop, it can encode the tensor correctly. But in the following loops, it will generate jpeg image of noise only. Could you help with this please? Thanks!
Hi @Lily-Git-hub can you please provide a minimal reproducing example? Thank you
Hi Nicolas,
Please try this example:
import torch
import torchvision
import torch.nn.functional as F
for i in range(2):
image_data = torch.load('image_data.pt')
resized_image_tensor = F.interpolate(image_data.unsqueeze(0), size=(2160, 3820), mode='bilinear', align_corners=False)
image_data_resized = resized_image_tensor[0]
image_data_encoded = torchvision.io.encode_jpeg( (image_data_resized).to(torch.uint8) )
data = image_data_encoded.cpu().numpy().tobytes()
with open(f'1.jpg', 'wb') as f:
f.write(data)
del data, image_data_encoded, resized_image_tensor, image_data_resized, image_data`
without the last line of code, which deleted the used variables, the saved image would be noise only. Please unzip the 'image_data.zip' to get image_data.pt.
[image_data.zip](https://github.com/user-attachments/files/16656971/image_data.zip)
Sorry @Lily-Git-hub , I cannot reproduce your issue.
del data, image_data_encoded, resized_image_tensor, image_data_resized, image_data`
Hi Nicolas,
Did you remove the above line of code? The error occurs when not deleting used variables. Thanks!
Yes, I deleted these lines. Can you please provide a more minimal reproducing example, without a for loop, wihtout resizing, and from a normal image rather than from a pt file (which I won't load on my machine for security reaosns)
I encountered a similar issue. I resolved it by adding torch.cuda.synchronize() before using encode_jpeg. It seems there might be some synchronization problems between F.interpolate and torchvision.io.encode_jpeg.
resized_image_tensor = F.interpolate(image_data.unsqueeze(0), size=(2160, 3820), mode='bilinear', align_corners=False)
image_data_resized = resized_image_tensor[0].to(torch.uint8)
# add synchronize after modified image and before encode jpeg
torch.cuda.synchronize()
image_data_encoded = torchvision.io.encode_jpeg(image_data_resized)
Bump.
Similar issue here. torch 2.5.1. Not 4k img though. Just some 1024*1024 images. About 50% saved imgs become noise.
Also resolved by adding torch.cuda.synchronize() just before encode_jpeg(). Thanks glazhh.
I'm not using F.interpolate(), but before call encode_jpeg(), my latest op is also to(torch.uint8).
...
img = img * 255
img = img.to(torch.uint8)
raw_jpeg: torch.Tensor = encode_jpeg(img, quality=80)
...
update:
I can comfirm and repreduce that if encode_jpeg() right after to(torch.uint8), 90% change noise. If there is a little bit more code between encode_jpeg() and to(torch.uint8), no noise. seems some synchronization problems.
Code to reproduce the problem consistently:
import torch
from torchvision.io import encode_jpeg, decode_jpeg
import random
device = torch.device('cuda')
def get_random_dimensions(min_size=64, max_size=2048):
height = random.randint(min_size, max_size)
width = random.randint(min_size, max_size)
return height, width
threshold = 5.0 # Mean difference threshold for detecting a bug
iteration = 0
while True:
# Much harder to trigger the bug with fixed dimensions
# height, width = 512, 512
height, width = get_random_dimensions()
# Create a 3xHxW image with values in [0, 1]
image = torch.linspace(0, 1, height * width, device=device).view(1, height, width).expand(3, -1, -1)
# Convert to uint8 and encode as JPEG
image_uint8 = (image * 255).clamp(0, 255).to(torch.uint8)
# Synchronize before encoding (potential fix)
# torch.cuda.synchronize()
jpeg_bytes = encode_jpeg(image_uint8, quality=100)
# Decode JPEG back to a tensor
decoded_image = decode_jpeg(jpeg_bytes.cpu(), device=device).float() / 255.0
# Calculate the mean difference in 255 space
mean_difference = (image - decoded_image).abs().mean().item() * 255
# Print progress
print(f"\rIteration {iteration}: Mean Difference {mean_difference:.2f} (HxW: {height}x{width})", end="")
# Stop if the mean difference exceeds the threshold
if mean_difference > threshold:
print(f"\nBug triggered at iteration {iteration}: Mean Difference {mean_difference:.2f} (HxW: {height}x{width})")
# Save the broken image
with open("broken_image.jpg", "wb") as f:
f.write(jpeg_bytes.cpu().numpy())
print("Broken image saved to broken_image.jpg")
# Save the non-broken image
non_broken_bytes = encode_jpeg(image_uint8.cpu(), quality=100)
with open("non_broken_image.jpg", "wb") as f:
f.write(non_broken_bytes.numpy())
print("Non-broken image saved to non_broken_image.jpg")
break
iteration += 1
Example output from different runs:
Iteration 9: Mean Difference 93.12 (HxW: 1949x1648)
Bug triggered at iteration 9: Mean Difference 93.12 (HxW: 1949x1648)
Broken image saved to broken_image.jpg
Non-broken image saved to non_broken_image.jpg
Iteration 81: Mean Difference 78.51 (HxW: 2041x1543)
Bug triggered at iteration 81: Mean Difference 78.51 (HxW: 2041x1543)
Broken image saved to broken_image.jpg
Non-broken image saved to non_broken_image.jpg
The script writes two images to the current working directory. Here's an example:
The broken image, from GPU/CUDA encoding (broken_image.jpg):
The intended one, from re-encoding with CPU (non_broken_image.jpg):
Commenting in the torch.cuda.synchronize() fixes the problem for me. The test script then runs for thousands of iterations without reporting an error.
Always encoding images of the same size (e.g. 512x512) in a loop is less likely to trigger the problem.
Reproduced with torchvision==0.19.0, torchvision==0.20.0, and nightly torchvision==0.22.0.dev20250128+cu126
From v0.20 onwards, the script sometimes fails with RuntimeError: image encoding failed: 8 on encode_jpeg, which I never saw with v0.19. But re-running the script also produces the broken images.
Environment:
OS: Ubuntu 22.04.5 LTS (x86_64)
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 535.183.01
o1 suggests that the event https://github.com/pytorch/vision/blob/867521ec82c78160b16eec1c3a02d4cef93723ff/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp#L98 should be recorded after the encode_jpeg, not before
https://chatgpt.com/share/679916c0-7fe8-8000-9ee7-1a95f10f666a
Edit:
I tried this locally. The suggested o1 change (moving the event block after the encode_jpeg) does not appear to fix the problem. Here's the updated conversation: https://chatgpt.com/share/679916c0-7fe8-8000-9ee7-1a95f10f666a, now suggesting a stream mismatch. I'll have a look.
cc @deekay42
Here's a branch with the code to reproduce the problem as a pytest. And fixes for the two issues raised by o1, neither of which fixes the bug: https://github.com/pytorch/vision/compare/main...w-m:vision:encode_jpeg_cuda_sync?expand=1
Thanks all for the report and especially @w-m for the investigations and the reproducing example. @deekay42 was able to provide a fix in https://github.com/pytorch/vision/pull/8929 which will be available with the next release (probably ~April)
I started using GPU JPEG encoder after this fix (0.22.0+cu128), but there are still bugs. Sometimes the encoded JPEG data is corrupted and can’t be decoded. It doesn’t happen on CPU.
This is a part of my software. (I’d like to create a minimal reproducible code, though it may be hard to do.)
# The JPEG encoding part of my super complicated code
ENABLE_GPU_JPEG = True # Doesn’t happen when this is False.
def to_jpeg_data(frame, quality, tick, gpu_jpeg=True):
bio = io.BytesIO()
if ENABLE_GPU_JPEG and gpu_jpeg and frame.device.type == "cuda":
# torch.cuda.synchronize()
jpeg_data = encode_jpeg(to_uint8(frame), quality=quality)
# torch.cuda.synchronize()
jpeg_data = jpeg_data.cpu()
else:
jpeg_data = encode_jpeg(to_uint8(frame).cpu(), quality=quality)
bio.write(jpeg_data.numpy())
jpeg_data = bio.getbuffer().tobytes()
detect_bug(frame, jpeg_data) # Added for detecting bug
return (jpeg_data, tick)
def to_uint8(x):
return x.mul(255).round_().to(torch.uint8)
# Detect corrupted JPEG data
def detect_bug(frame, jpeg_data):
from torchvision.io import decode_jpeg
jpeg_data = torch.tensor(list(jpeg_data), dtype=torch.uint8)
try:
decodec_frame = decode_jpeg(jpeg_data).cpu() / 255.0
diff = (frame.cpu() - decodec_frame).abs().mean()
print(diff)
except RuntimeError as e:
print(e)
debug output
tensor(0.0023)
tensor(0.0022)
Corrupt JPEG data: premature end of data segment
Unsupported marker type 0x6f
tensor(0.0022)
Corrupt JPEG data: premature end of data segment
Unsupported marker type 0x6f
tensor(0.0022)
tensor(0.0022)
....
torch.cuda.synchronize() didn’t fix it and actually makes it happen more often.
This issue was reported by a Windows user, and I was able to reproduce it in my Linux environment as well.
EDIT: The software is multithreaded, and the issue seems to happen more frequently when other threads are running large models or handling high-res images.
@nagadomi to_uint8 still writing encode_jpeg input data while encoding it, consider change your code to this before this issue fixed
frame_uint8 = to_uint8(frame)
torch.cuda.synchronize()
jpeg_data = encode_jpeg(frame_uint8, quality=quality)
torch.cuda.synchronize()
@glazhh Thanks for the suggestion. I tried that, but it didn’t solve the issue. My problem isn’t that the JPEG has noise; the encoded data itself is corrupted. I suspect the issue is that the path from the output of encode_jpg() to cpu() isn’t properly synchronized. I also think torch.cuda.synchronize() shouldn’t be needed anymore after https://github.com/pytorch/vision/pull/8929
Also, I was calling to_jpeg_data() in a separate thread, but I found that the issue doesn’t occur when it’s called from the main thread.
I successfully created a reproducible example.
import torch
import torch.nn as nn
import io
from concurrent.futures import ThreadPoolExecutor
import torchvision.transforms.functional as TF
from torchvision.io import encode_jpeg, decode_jpeg
def gen_frame(size):
x = torch.rand((3, size // 8, size // 8))
x = TF.resize(x, (size, size))
return x
def to_uint8(x):
return x.mul(255).round_().to(torch.uint8)
def to_jpeg_data(frame):
frame = to_uint8(frame)
frame = encode_jpeg(frame, quality=90).cpu()
bio = io.BytesIO()
bio.write(frame.numpy())
return bio.getbuffer().tobytes()
def debug_jpeg_data(frame, jpeg_data):
jpeg_data = torch.tensor(list(jpeg_data), dtype=torch.uint8)
try:
decodec_frame = decode_jpeg(jpeg_data).cpu() / 255.0
diff = (frame.cpu() - decodec_frame).abs().mean()
print(diff)
except RuntimeError as e:
print(e)
def image_handler(frame):
jpeg_data = to_jpeg_data(frame)
debug_jpeg_data(frame, jpeg_data)
def run():
N = 100
frame = gen_frame(4096).cuda()
x = torch.rand((32, 256, 64, 64)).cuda()
model = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=7, padding=2),
nn.Conv2d(256, 256, kernel_size=7, padding=2),
nn.Conv2d(256, 256, kernel_size=7, padding=2),
nn.Conv2d(256, 256, kernel_size=7, padding=2),
).cuda()
with ThreadPoolExecutor(max_workers=8) as pool:
futures = []
for i in range(N):
# encode_jpeg in a separate thread
futures.append(pool.submit(image_handler, frame))
# Random heavy processing
model(x)
if i % 10 == 0:
# sync
for f in futures:
f.result()
futures = []
if __name__ == "__main__":
run()
output
tensor(0.0064)
Corrupt JPEG data: premature end of data segment
Unsupported marker type 0xbd
tensor(0.0064)
tensor(0.0064)
tensor(0.0064)
tensor(0.0064)
tensor(0.0064)
tensor(0.0064)
tensor(0.0064)
tensor(0.0064)
tensor(0.0064)
Corrupt JPEG data: premature end of data segment
Unsupported marker type 0xbd
tensor(0.0064)
tensor(0.0064)
tensor(0.0064)
...
I posted this in https://github.com/pytorch/vision/issues/9060