add nim to flytekit core
Tracking issue
Fixes https://github.com/flyteorg/flyte/issues/5478
Why are the changes needed?
This PR adds NIM to flytekit/core to enable serving optimized model containers, which can include NVIDIA CUDA software, NVIDIA Triton Inference Server, and NVIDIA TensorRT-LLM software. Since each Flyte task runs in a Kubernetes pod, a NIM container can be easily deployed as a sidecar service, allowing users to invoke the model's endpoint as if it is hosted locally, thus minimizing network overhead. Additionally, Flyte can manage data pre-processing and post-processing, facilitating the creation of end-to-end batch inference pipelines with optimized models.
What changes were proposed in this pull request?
- A new
ModelInferenceTemplatehas been added, providing a base for creating other inference solutions like Ollama and Hugging Face. - The
NIMclass, which subclassesModelInferenceTemplate, has been introduced as well.
from flytekit import ImageSpec, Secret, task, Resources
from flytekit.core.inference import NIM, NIMSecrets
from flytekit.extras.accelerators import A10G
from openai import OpenAI
image = ImageSpec(
name="nim",
registry="...",
packages=["kubernetes", "openai"],
)
nim_instance = NIM(
image="nvcr.io/nim/meta/llama3-8b-instruct:1.0.0",
secrets=NIMSecrets(
ngc_image_secret="nvcrio-cred", ngc_secret_key=NGC_KEY
),
)
@task(
container_image=image,
pod_template=nim_instance.pod_template,
accelerator=A10G,
secret_requests=[
Secret(
key="ngc_api_key", mount_requirement=Secret.MountType.ENV_VAR
) # must be mounted as an env var
],
requests=Resources(gpu="0"),
)
def model_serving() -> str:
client = OpenAI(
base_url=f"{nim_instance.base_url}/v1", api_key="nim"
) # api key required but ignored
completion = client.chat.completions.create(
model="meta/llama3-8b-instruct",
messages=[
{
"role": "user",
"content": "Write a limerick about the wonders of GPU computing.",
}
],
temperature=0.5,
top_p=1,
max_tokens=1024,
)
return completion.choices[0].message.content
How was this patch tested?
Setup process
Screenshots
Check all the applicable boxes
- [ ] I updated the documentation accordingly.
- [ ] All new and existing tests passed.
- [x] All commits are signed-off.
Related PRs
Docs link
Codecov Report
Attention: Patch coverage is 83.09859% with 12 lines in your changes missing coverage. Please review.
Project coverage is 76.21%. Comparing base (
097e9e8) to head (7e62555). Report is 4 commits behind head on master.
| Files | Patch % | Lines |
|---|---|---|
| flytekit/core/inference.py | 80.39% | 5 Missing and 5 partials :warning: |
| flytekit/core/utils.py | 90.00% | 1 Missing and 1 partial :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## master #2475 +/- ##
==========================================
- Coverage 76.22% 76.21% -0.01%
==========================================
Files 187 186 -1
Lines 18938 18903 -35
Branches 3706 3719 +13
==========================================
- Hits 14435 14407 -28
+ Misses 3870 3851 -19
- Partials 633 645 +12
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
How about we move this pattern to core flytekit. I don't see any extra requirements?