yocto-gl icon indicating copy to clipboard operation
yocto-gl copied to clipboard

Implement ChatModel (pyfunc subclass)

Open daniellok-db opened this issue 1 year ago • 1 comments

🛠 DevTools 🛠

Open in GitHub Codespaces

Install mlflow from this PR

pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10820/merge

Checkout with GitHub CLI

gh pr checkout 10820

Related Issues/PRs

What changes are proposed in this pull request?

This PR adds the ChatModel subclass to make it more seamless for users to implement and serve chat models. The ChatModel class requires users to fill out a predict method of the following type (corresponding to the OpenAI chat request format):

class MyChatModel(mlflow.pyfunc.ChatModel):
    def predict(self, context, messages: List[ChatMessage], params: ChatParams) -> ChatResponse:
        # user-defined behavior

This makes it so that the user doesn't have to implement any parsing logic, and can directly work with the pydantic objects that are passed in. Additionally, input/output signatures and an input example are automatically provided.

To support this, we implement a new custom loader for these types of models, defined in mlflow.pyfunc.loaders.chat_model. This loader wraps the ChatModel in a _ChatModelPyfuncWrapper class that accepts the standard chat request format, and breaks it up into messages and params for the user.

How is this PR tested?

  • [x] Existing unit/integration tests
  • [x] New unit/integration tests
  • [x] Manual tests

Ran the following to create a chat model:

class TestChatModel(mlflow.pyfunc.ChatModel):
    def predict(self, context, messages: List[ChatMessage], params: ChatParams) -> ChatResponse:
        mock_response = {
            "id": "123",
            "object": "chat.completion",
            "created": 1677652288,
            "model": "MyChatModel",
            "choices": [
                {
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": json.dumps([m.model_dump(exclude_none=True) for m in messages]),
                    },
                    "finish_reason": "stop",
                },
                {
                    "index": 1,
                    "message": {
                        "role": "user",
                        "content": params.model_dump_json(exclude_none=True),
                    },
                    "finish_reason": "stop",
                },
            ],
            "usage": {
                "prompt_tokens": 10,
                "completion_tokens": 10,
                "total_tokens": 20,
            },
        }
        return ChatResponse(**mock_response)

mlflow.pyfunc.save_model(
    path="chat-model",
    python_model=TestChatModel(),
)

Then on the command line:

$ mlflow models serve -m chat-model

$ curl http://127.0.0.1:5000/invocations -H 'Content-Type: application/json' -d '{ "messages": [ { "role": "system", "content": "You are a helpful assistant" }, { "role": "user", "content": "Hello!" } ] }' | jq

{
  "id": "123",
  "object": "chat.completion",
  "created": 1677652288,
  "model": "MyChatModel",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "[{\"role\": \"system\", \"content\": \"You are a helpful assistant\"}, {\"role\": \"user\", \"content\": \"Hello!\"}]"
      },
      "finish_reason": "stop"
    },
    {
      "index": 1,
      "message": {
        "role": "user",
        "content": "{\"temperature\":1.0,\"n\":1,\"stream\":false}"
      },
      "finish_reason": "stop"
    }
  ],
  "usage": {
    "prompt_tokens": 10,
    "completion_tokens": 10,
    "total_tokens": 20
  }
}

Also tried viewing the model in MLflow UI:

Validate that the MLmodel file looks as expected Screenshot 2024-01-15 at 12 57 49 PM

Validate that the signature looks correct:

https://github.com/mlflow/mlflow/assets/148037680/eccd3abe-2a82-4e37-9d97-ec8f576512f7

Does this PR require documentation update?

Requires a tutorial, but we can work on this in a follow-up PR

  • [ ] No. You can skip the rest of this section.
  • [x] Yes. I've updated:
    • [ ] Examples
    • [x] API references
    • [ ] Instructions

Release Notes

Is this a user-facing change?

  • [ ] No. You can skip the rest of this section.
  • [x] Yes. Give a description of this change to be included in the release notes for MLflow users.

Added the ChatModel pyfunc class, which allows for more convenient definition of chat models conforming to the OpenAI request/response format.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • [ ] area/artifacts: Artifact stores and artifact logging
  • [ ] area/build: Build and test infrastructure for MLflow
  • [ ] area/deployments: MLflow Deployments client APIs, server, and third-party Deployments integrations
  • [ ] area/docs: MLflow documentation pages
  • [ ] area/examples: Example code
  • [ ] area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • [x] area/models: MLmodel format, model serialization/deserialization, flavors
  • [ ] area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • [ ] area/projects: MLproject format, project running backends
  • [ ] area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • [ ] area/server-infra: MLflow Tracking server backend
  • [ ] area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • [ ] area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • [ ] area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • [ ] area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • [ ] area/windows: Windows support

Language

  • [ ] language/r: R APIs and clients
  • [ ] language/java: Java APIs and clients
  • [ ] language/new: Proposals for new client languages

Integrations

  • [ ] integrations/azure: Azure and Azure ML integrations
  • [ ] integrations/sagemaker: SageMaker integrations
  • [ ] integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • [ ] rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • [ ] rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • [x] rn/feature - A new user-facing feature worth mentioning in the release notes
  • [ ] rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • [ ] rn/documentation - A user-facing documentation change worth mentioning in the release notes

daniellok-db avatar Jan 15 '24 05:01 daniellok-db

Documentation preview for 407586018bf097e30fe31c200ca6cf26cceeab75 will be available here when this CircleCI job completes successfully.

More info
  • Ignore this comment if this PR does not change the documentation.
  • It takes a few minutes for the preview to be available.
  • The preview is updated when a new commit is pushed to this PR.
  • This comment was created by https://github.com/mlflow/mlflow/actions/runs/7752490505.

github-actions[bot] avatar Jan 15 '24 05:01 github-actions[bot]