yocto-gl
yocto-gl copied to clipboard
Fix missing dtype issue for transformer pipeline
🛠DevTools ðŸ›
Install mlflow from this PR
pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10979/merge
Checkout with GitHub CLI
gh pr checkout 10979
Related Issues/PRs
#xxxWhat changes are proposed in this pull request?
When constructing Transformer pipeline from a model and tokenizer, it doesn't inherit the torch_dtype
attribute from the model.
model = AutoModelForCausalLM.from_pretrained("t5-small", torch_dtype = torch.bfloat16)
pipeline = pipeline(model=model, task="text-generation", tokenizer=tokenizer)
print(pipeline.torch_dtype)
=> None
As we currently only check pipeline's dtype, the saved model doesn't have the dtype information of the model, resulted in larger memory consumption once loaded.
As per HuggingFace's documentation, it should be same to treat them as a synonym.
torch_dtype (str or torch.dtype, optional) — Sent directly as model_kwargs (just a simpler shortcut) to use the available precision for this model (torch.float16, torch.bfloat16, … or "auto").
Note: Indeed I submitted a feature request to Transformers, and this handling becomes unnecessary once it get implemented. However, we still need this issue for the mean time and also for the lower version of Transformers.
How is this PR tested?
- [x] Existing unit/integration tests
- [x] New unit/integration tests
- [ ] Manual tests
Does this PR require documentation update?
- [x] No. You can skip the rest of this section.
- [ ] Yes. I've updated:
- [ ] Examples
- [ ] API references
- [ ] Instructions
Release Notes
Is this a user-facing change?
- [x] No. You can skip the rest of this section.
- [ ] Yes. Give a description of this change to be included in the release notes for MLflow users.
What component(s), interfaces, languages, and integrations does this PR affect?
Components
- [ ]
area/artifacts
: Artifact stores and artifact logging - [ ]
area/build
: Build and test infrastructure for MLflow - [ ]
area/deployments
: MLflow Deployments client APIs, server, and third-party Deployments integrations - [ ]
area/docs
: MLflow documentation pages - [ ]
area/examples
: Example code - [ ]
area/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registry - [x]
area/models
: MLmodel format, model serialization/deserialization, flavors - [ ]
area/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templates - [ ]
area/projects
: MLproject format, project running backends - [ ]
area/scoring
: MLflow Model server, model deployment tools, Spark UDFs - [ ]
area/server-infra
: MLflow Tracking server backend - [ ]
area/tracking
: Tracking Service, tracking client APIs, autologging
Interface
- [ ]
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev server - [ ]
area/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Models - [ ]
area/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registry - [ ]
area/windows
: Windows support
Language
- [ ]
language/r
: R APIs and clients - [ ]
language/java
: Java APIs and clients - [ ]
language/new
: Proposals for new client languages
Integrations
- [ ]
integrations/azure
: Azure and Azure ML integrations - [ ]
integrations/sagemaker
: SageMaker integrations - [ ]
integrations/databricks
: Databricks integrations
How should the PR be classified in the release notes? Choose one:
- [ ]
rn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section - [ ]
rn/breaking-change
- The PR will be mentioned in the "Breaking Changes" section - [ ]
rn/feature
- A new user-facing feature worth mentioning in the release notes - [x]
rn/bug-fix
- A user-facing bug fix worth mentioning in the release notes - [ ]
rn/documentation
- A user-facing documentation change worth mentioning in the release notes
Documentation preview for 36b441884a187c017e1f9ee28da9a777bcf443c2 will be available here when this CircleCI job completes successfully.
More info
- Ignore this comment if this PR does not change the documentation.
- It takes a few minutes for the preview to be available.
- The preview is updated when a new commit is pushed to this PR.
- This comment was created by https://github.com/mlflow/mlflow/actions/runs/7793691841.
Just found that the model exposes dtype
property for getting dtype from parameter... (source). Tho still used as fallback for missing model config in their source code. Should be simpler if it works, watching how cross version tests go👀
Made a change for dtype extracting logic, basically not to log it when the model's dtype is default one torch.float32
. The issue is that models.dtype
returns valid torch dtype even when the model/pipeline doesn't support torch_dtype
parameter for construction, resulting in a failure when loading model with the logged dtype. Unfortunately, it seems there is no easy way to determine if the model/pipeline supports torch_dtype
attribute (even those don't support it can have torch_dtype
param in model config!!). Hence here I used a workaround that just logs dtype only when it is non-default i.e. not float32, assuming that there is no way to change model precision other than torch_dtype
. This approach may be a bit fragile, as depending on the implicit fact that float32 is the default. WDYT? @BenWilson2 @harupy
I think the approach that I saw this morning with not logging if it is the default is good. No objections here!