InternVL
InternVL copied to clipboard
Please Add Support For Triton Flash Attention Inference
Please add config.attn_config['attn_impl'] = 'triton'
for Triton Flash Attention Inference
import torch
from PIL import Image
from transformers import AutoModel, AutoConfig, CLIPImageProcessor
# Define the model name
model_name = 'OpenGVLab/InternViT-6B-224px'
# Load the model configuration and set attention implementation to Triton
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.attn_config['attn_impl'] = 'triton'
config.init_device = 'cuda:0' # For fast initialization directly on GPU
# Load the model with the updated configuration
model = AutoModel.from_pretrained(
model_name,
config=config,
torch_dtype=torch.bfloat16, # Load model weights in bfloat16
low_cpu_mem_usage=True,
trust_remote_code=True
).cuda().eval()
# Load and process the image
image = Image.open('./examples/image1.jpg').convert('RGB')
image_processor = CLIPImageProcessor.from_pretrained(model_name)
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()
# Run inference
outputs = model(pixel_values)
# Print outputs for debugging purposes
print(outputs)
If running normally, the flash attention package is required, but I have Triton Installed ImportError: This modeling file requires the following packages that were not found in your environment: flash_attn. Run `pip install flash_attn
import torch
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor
model = AutoModel.from_pretrained(
'OpenGVLab/InternViT-6B-224px',
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True).cuda().eval()
image = Image.open('./examples/image1.jpg').convert('RGB')
image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-224px')
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()
outputs = model(pixel_values)