InternVL icon indicating copy to clipboard operation
InternVL copied to clipboard

Please Add Support For Triton Flash Attention Inference

Open radna0 opened this issue 8 months ago • 1 comments

Please add config.attn_config['attn_impl'] = 'triton' for Triton Flash Attention Inference

import torch
from PIL import Image
from transformers import AutoModel, AutoConfig, CLIPImageProcessor

# Define the model name
model_name = 'OpenGVLab/InternViT-6B-224px'

# Load the model configuration and set attention implementation to Triton
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.attn_config['attn_impl'] = 'triton'
config.init_device = 'cuda:0'  # For fast initialization directly on GPU

# Load the model with the updated configuration
model = AutoModel.from_pretrained(
    model_name,
    config=config,
    torch_dtype=torch.bfloat16,  # Load model weights in bfloat16
    low_cpu_mem_usage=True,
    trust_remote_code=True
).cuda().eval()

# Load and process the image
image = Image.open('./examples/image1.jpg').convert('RGB')
image_processor = CLIPImageProcessor.from_pretrained(model_name)
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()

# Run inference
outputs = model(pixel_values)

# Print outputs for debugging purposes
print(outputs)

If running normally, the flash attention package is required, but I have Triton Installed ImportError: This modeling file requires the following packages that were not found in your environment: flash_attn. Run `pip install flash_attn

import torch
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor

model = AutoModel.from_pretrained(
    'OpenGVLab/InternViT-6B-224px',
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True).cuda().eval()

image = Image.open('./examples/image1.jpg').convert('RGB')

image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-224px')

pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()

outputs = model(pixel_values)

radna0 avatar Jun 09 '24 05:06 radna0