yocto-gl icon indicating copy to clipboard operation
yocto-gl copied to clipboard

Fix missing dtype issue for transformer pipeline

Open B-Step62 opened this issue 1 year ago • 2 comments

🛠 DevTools 🛠

Open in GitHub Codespaces

Install mlflow from this PR

pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10979/merge

Checkout with GitHub CLI

gh pr checkout 10979

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

When constructing Transformer pipeline from a model and tokenizer, it doesn't inherit the torch_dtype attribute from the model.

model = AutoModelForCausalLM.from_pretrained("t5-small", torch_dtype = torch.bfloat16)
pipeline = pipeline(model=model, task="text-generation", tokenizer=tokenizer)

print(pipeline.torch_dtype)
=> None

As we currently only check pipeline's dtype, the saved model doesn't have the dtype information of the model, resulted in larger memory consumption once loaded.

As per HuggingFace's documentation, it should be same to treat them as a synonym.

torch_dtype (str or torch.dtype, optional) — Sent directly as model_kwargs (just a simpler shortcut) to use the available precision for this model (torch.float16, torch.bfloat16, … or "auto").

Note: Indeed I submitted a feature request to Transformers, and this handling becomes unnecessary once it get implemented. However, we still need this issue for the mean time and also for the lower version of Transformers.

How is this PR tested?

  • [x] Existing unit/integration tests
  • [x] New unit/integration tests
  • [ ] Manual tests

Does this PR require documentation update?

  • [x] No. You can skip the rest of this section.
  • [ ] Yes. I've updated:
    • [ ] Examples
    • [ ] API references
    • [ ] Instructions

Release Notes

Is this a user-facing change?

  • [x] No. You can skip the rest of this section.
  • [ ] Yes. Give a description of this change to be included in the release notes for MLflow users.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • [ ] area/artifacts: Artifact stores and artifact logging
  • [ ] area/build: Build and test infrastructure for MLflow
  • [ ] area/deployments: MLflow Deployments client APIs, server, and third-party Deployments integrations
  • [ ] area/docs: MLflow documentation pages
  • [ ] area/examples: Example code
  • [ ] area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • [x] area/models: MLmodel format, model serialization/deserialization, flavors
  • [ ] area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • [ ] area/projects: MLproject format, project running backends
  • [ ] area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • [ ] area/server-infra: MLflow Tracking server backend
  • [ ] area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • [ ] area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • [ ] area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • [ ] area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • [ ] area/windows: Windows support

Language

  • [ ] language/r: R APIs and clients
  • [ ] language/java: Java APIs and clients
  • [ ] language/new: Proposals for new client languages

Integrations

  • [ ] integrations/azure: Azure and Azure ML integrations
  • [ ] integrations/sagemaker: SageMaker integrations
  • [ ] integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • [ ] rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • [ ] rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • [ ] rn/feature - A new user-facing feature worth mentioning in the release notes
  • [x] rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • [ ] rn/documentation - A user-facing documentation change worth mentioning in the release notes

B-Step62 avatar Feb 01 '24 10:02 B-Step62

Documentation preview for 36b441884a187c017e1f9ee28da9a777bcf443c2 will be available here when this CircleCI job completes successfully.

More info
  • Ignore this comment if this PR does not change the documentation.
  • It takes a few minutes for the preview to be available.
  • The preview is updated when a new commit is pushed to this PR.
  • This comment was created by https://github.com/mlflow/mlflow/actions/runs/7793691841.

github-actions[bot] avatar Feb 01 '24 10:02 github-actions[bot]

Just found that the model exposes dtype property for getting dtype from parameter... (source). Tho still used as fallback for missing model config in their source code. Should be simpler if it works, watching how cross version tests go👀

B-Step62 avatar Feb 02 '24 08:02 B-Step62

Made a change for dtype extracting logic, basically not to log it when the model's dtype is default one torch.float32. The issue is that models.dtype returns valid torch dtype even when the model/pipeline doesn't support torch_dtype parameter for construction, resulting in a failure when loading model with the logged dtype. Unfortunately, it seems there is no easy way to determine if the model/pipeline supports torch_dtype attribute (even those don't support it can have torch_dtype param in model config!!). Hence here I used a workaround that just logs dtype only when it is non-default i.e. not float32, assuming that there is no way to change model precision other than torch_dtype. This approach may be a bit fragile, as depending on the implicit fact that float32 is the default. WDYT? @BenWilson2 @harupy

B-Step62 avatar Feb 06 '24 00:02 B-Step62

I think the approach that I saw this morning with not logging if it is the default is good. No objections here!

BenWilson2 avatar Feb 06 '24 00:02 BenWilson2