Slow Inference on LLAMA 3.1 405B using ollama.generate with Large Code Snippets on multi-H100 GPUs
I'm experiencing very slow inference times when using the ollama.generate function on a multiple H100 GPU machine. Specifically, it is taking up to 5 minutes per inference, even though the hardware should be able to handle this much faster. The input is a large code snippet, and I expected inference to take significantly less time.
Setup:
- Model: LLaMA 405B
- Hardware: multiple H100 GPUs
- Library Version: 0.3.3
- CUDA Version: 12.6
- Driver Version: 560.35.03
- Operating System: Ubuntu 2404
Steps to Reproduce:
- Use the ollama.generate function with a large code snippet as input.
- Observe inference time (up to 5 minutes per call).
Code Example:
import ollama
response = ollama.generate(
model="llama-405b",
prompt="Explain the following code:\n[Insert large code snippet here]"
)
print(response)
Expected Behavior: I expected the inference time to be significantly faster, especially on a machine with multiple H100 GPUs. Ideally, the inference should take seconds, not minutes.
Actual Behavior: The inference is taking up to 5 minutes per call, which seems excessively slow for this hardware setup.
Additional Information:
-
GPU Utilization: Memory usage across all GPUs is around 50%, while compute utilization is around 25%, with occasional spikes. This suggests under-utilization of GPU resources.
-
Mixed Precision: Not sure if mixed precision or quantization is being used by default. This could help improve the inference time.
-
Parallelism: It's unclear how the model is being distributed across GPUs, or if any model parallelism optimizations are being applied.
Questions:
- Is there any support for batching inputs, using mixed precision (FP16/BF16), or quantization in the Ollama library to speed up inference?
- Are there any known optimizations for better multi-GPU inference (e.g., reducing communication overhead) when using this library?
- Are there configuration settings that can help fully utilize the multiple H100 GPUs and reduce inference time for large code snippets?
I am experiencing the same!