zenml icon indicating copy to clipboard operation
zenml copied to clipboard

Support for `safetensors` materializers

Open Dev-Khant opened this issue 11 months ago β€’ 29 comments

Describe changes

I implemented support for safetensors for model serialization. It is regarding #2532

Pre-requisites

Please ensure you have done the following:

  • [x] I have read the CONTRIBUTING.md document.
  • [x] If my change requires a change to docs, I have updated the documentation accordingly.
  • [x] I have added tests to cover my changes.
  • [x] I have based my new branch on develop and the open PR is targeting develop. If your branch wasn't based on develop read Contribution guide on rebasing branch to develop.
  • [ ] If my changes require changes to the dashboard, these changes are communicated/requested.

Types of changes

  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to change)
  • [ ] Other (add details above)

Dev-Khant avatar Mar 18 '24 14:03 Dev-Khant

[!IMPORTANT]

Auto Review Skipped

Auto reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger a review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

coderabbitai[bot] avatar Mar 18 '24 14:03 coderabbitai[bot]

I think the thing to add now would be a section in that docs page on materializers that explains:

  • why you might want to use safetensors for materialization instead of the default
  • how to use it / set it up so you can use these custom materializers (i.e. give a code example showing a step and how you'd specify to use the safetensors materializer)

strickvl avatar Mar 18 '24 14:03 strickvl

@strickvl Do I create a separate kind of section for these new materializers or add them with existing ones? And I'll also prepare a different section explaining why to use safetensors and provide a code example.

Dev-Khant avatar Mar 18 '24 15:03 Dev-Khant

I think I'd do it as a section on its own before https://docs.zenml.io/user-guide/advanced-guide/data-management/handle-custom-data-types#custom-materializers this section

On Mon, 18 Mar 2024 at 16:38, Dev Khant @.***> wrote:

@strickvl Do I create a separate kind of section for these new materializers or add them with existing ones? And I'll also prepare a different section explaining why to use safetensors and provide a code example.

β€” Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>

strickvl avatar Mar 18 '24 15:03 strickvl

Also note that tests are failing https://github.com/zenml-io/zenml/actions/runs/8328772930/job/22789533265?pr=2539

strickvl avatar Mar 18 '24 15:03 strickvl

Hi @strickvl, I have added the documentation, but I'm not sure if the code is correct for using materializers as I could not find any docs on how to use integration-specific materialized. I have also made fixes for failing tests.

So could please check and let me know if that's correct or how it can be improved? Thanks.

Dev-Khant avatar Mar 19 '24 13:03 Dev-Khant

@Dev-Khant one thing you'll have to do is to make sure that your PR is made off the develop branch. See https://github.com/zenml-io/zenml/blob/develop/CONTRIBUTING.md#-pull-requests-rebase-your-branch-on-develop for more. At the moment this PR is listed as being based on main branch.. (See at the top).

strickvl avatar Mar 20 '24 09:03 strickvl

I'm fixing failed test cases!

Dev-Khant avatar Mar 20 '24 09:03 Dev-Khant

@strickvl Issue is being caused because here only step_output_type is passed but load_model of safetensors expect a model as well in input reference. This issue is coming up for both pytorch and huggingface safetensors materializers

So how do you think we should handle safetensors materializers in test cases?

Dev-Khant avatar Mar 20 '24 10:03 Dev-Khant

Hi @strickvl @avishniakov @bcdurak, please can you guide me on how to fix this issue? Thanks.

@strickvl Issue is being caused because here only step_output_type is passed but load_model of safetensors expect a model as well in input reference. This issue is coming up for both pytorch and huggingface safetensors materializers

So how do you think we should handle safetensors materializers in test cases?

Dev-Khant avatar Mar 22 '24 10:03 Dev-Khant

@Dev-Khant Not sure I understand the question. The test function takes a model in as you've currently defined it.

btw, this is currently also failing mypi linting (https://github.com/zenml-io/zenml/actions/runs/8389112411/job/22974663097?pr=2539)

strickvl avatar Mar 22 '24 10:03 strickvl

@Dev-Khant Not sure I understand the question. The test function takes a model in as you've currently defined it.

btw, this is currently also failing mypi linting (https://github.com/zenml-io/zenml/actions/runs/8389112411/job/22974663097?pr=2539)

@strickvl I have fixed the lint issue, will push in next commit. The issue here is loaded_data = materializer.load(step_output_type) should also take step_output(which a pytorch or hf model in our case) as input.

Because when safetensor materialize is called it usesload() of safetensors which requires a model and filename.

So here if we do this materializer.load(step_output, step_output_type) then all the tests are passing locally.

So what is the best to way handle all materializers that need alsomodel to load and materializers that only need filename.

Dev-Khant avatar Mar 22 '24 11:03 Dev-Khant

I made the fix for failing test cases and for lint issues.

Dev-Khant avatar Mar 23 '24 02:03 Dev-Khant

Linting still is failing, btw. @Dev-Khant

strickvl avatar Mar 25 '24 09:03 strickvl

Linting still is failing, btw. @Dev-Khant

@strickvl Fixed lint issue. Please check.

Dev-Khant avatar Mar 25 '24 13:03 Dev-Khant

@strickvl Here there's an import error https://github.com/zenml-io/zenml/actions/runs/8421449350/job/23058505433?pr=2539

Dev-Khant avatar Mar 25 '24 14:03 Dev-Khant

Hey @Dev-Khant,

Thank you for your contribution :)

It seems like the implementation is off to a good start, however, there are a few things that need some modifications.

Before I talk about these modifications though, it is a good idea to take a look at the concept of materializers. In short, they constitute a mechanism that ZenML uses to manage the inputs and outputs of steps under the hood. Depending on the type, our orchestration logic selects the right materializer for the job and uses it during the execution of a step.

In the current example (from your docs page), the materializer acts as an external factor (dealing with the saving/loading outside of the orchestration) which goes against the main idea.

With that in mind, you can modify the example you wrote to something like this:

import logging

from torch.nn import Module

from zenml import step, pipeline
from zenml.integrations.pytorch.materializers import PyTorchModuleSTMaterializer


@step(enable_cache=False, output_materializers=PyTorchModuleSTMaterializer)
def my_first_step() -> Module:
    """Step that saves a Pytorch model"""
    from torchvision.models import resnet50

    pretrained_model = resnet50()

    return pretrained_model


@step(enable_cache=False)
def my_second_step(model: Module):
    """Step that loads the model."""
    logging.info("Input loaded correctly.")


@pipeline
def first_pipeline():
    model = my_first_step()
    my_second_step(model)


first_pipeline()

This is the recommended way of using a custom materializer in a ZenML pipeline.

However, if you run this example now, the first step will succeed, however, the second step will fail due to:

TypeError: BasePyTorchSTMaterializer.load() missing 1 required positional argument: 'obj'

This is due to the fact that the new materializers use the abstraction of the save functionality correctly but they introduce a load function which has a different signature from the base materializers. The orchestration logic can not handle this which eventually leads to a failure. This is also the same reason why the test that you mentioned in your comment was failing.

I would propose modifying them to use the same abstraction as follows because the orchestrators do not know how to handle the input obj:

    def load(self, data_type: Type[Any]) -> Any:
        """Write logic here to load the data of an artifact.

        Args:
            data_type: The type of data that the artifact should be loaded as.

        Returns:
            The data of the artifact.
        """
        # read from a location inside self.uri
        # 
        # Example:
        # data_path = os.path.join(self.uri, "abc.json")
        # return yaml_utils.read_json(data_path)

I understand this is a challenging task because safetensors' load_model function requires you to pass a model to load onto, however, this needs to be solved in a different manner than passing the object to the load function.

I hope this explanation is helpful and feel free to reach out again if you have any additional questions.

I have also added a few more small comments as well.

@bcdurak Thanks for reviewing the code and helping me understand this much better.

I can think of solution here that would store the architecture of model like this:

pretrained_model = resnet50()
arch = {"model": pretrained_model}
save_file(pretrained_model.state_dict(), "weights.safetensors")
torch.save(arch, "model_arch.json")

print("Model Saved!")

new_arch = torch.load("model_arch.json")
weights = load_file("weights.safetensors")
_model = new_arch["model"]

loaded_model = _model.load_state_dict(weights)

print("Model Loaded!")

This way we can save model's architecture and weights in save of materializer and then load both in load of materializer. Do you think this a good approach?

Dev-Khant avatar Mar 28 '24 05:03 Dev-Khant

@bcdurak I have relevant changes. Please review it. Thanks.

Dev-Khant avatar Mar 30 '24 07:03 Dev-Khant

By the way, there is a known issue with our testing suite. We are currently fixing it at the moment. I will keep you updated as it goes on. For the time being, feel free to ignore the failing tests.

bcdurak avatar Apr 02 '24 15:04 bcdurak

@bcdurak Thanks for this detailed review, I have gone through your comments. And here are my answers:

  1. For tensorflow there is no method to directly save and load models using safetensors. And for numpy I'll add safetensors materialize.
  2. I'll add a test for pytorch_lightning as well.

Dev-Khant avatar Apr 03 '24 03:04 Dev-Khant

New and removed dependencies detected. Learn more about Socket for GitHub β†—οΈŽ

Package New capabilities Transitives Size Publisher
pypi/[email protected] filesystem, unsafe 0 5.61 MB McPotato, Nicolas.Patry, Wauplin, ...1 more

View full reportβ†—οΈŽ

socket-security[bot] avatar Apr 08 '24 09:04 socket-security[bot]

@bcdurak @avishniakov Can you please guide me how can I change pyproject.toml so that safetensors is installed properly?

Dev-Khant avatar Apr 10 '24 11:04 Dev-Khant

@bcdurak @avishniakov Can you please guide me how can I change pyproject.toml so that safetensors is installed properly?

Hey @Dev-Khant , IMO, since you modified numpy materializers to rely on safetensors it is not an optional dependency anymore, but the base one, so it should fall under [tool.poetry.dependencies] section directly. According to their pyproject.toml there are no mandatory dependencies, which should be quite good for a base dependency here. @bcdurak WDYT, is it fine to push safetensors to default deps?

avishniakov avatar Apr 10 '24 15:04 avishniakov

@bcdurak @avishniakov Can you please guide me how can I change pyproject.toml so that safetensors is installed properly?

Hey @Dev-Khant , IMO, since you modified numpy materializers to rely on safetensors it is not an optional dependency anymore, but the base one, so it should fall under [tool.poetry.dependencies] section directly. According to their pyproject.toml there are no mandatory dependencies, which should be quite good for a base dependency here. @bcdurak WDYT, is it fine to push safetensors to default deps?

Understood thanks @avishniakov. @bcdurak Let me know if should I make it a default dependency.

Dev-Khant avatar Apr 10 '24 15:04 Dev-Khant

@Dev-Khant let me discuss the dependency issue with the team internally, I will update this thread asap.

bcdurak avatar Apr 11 '24 15:04 bcdurak

⚠️ GitGuardian has uncovered 1 secret following the scan of your pull request.

Please consider investigating the findings and remediating the incidents. Failure to do so may lead to compromising the associated services or software components.

πŸ”Ž Detected hardcoded secret in your pull request
GitGuardian id GitGuardian status Secret Commit Filename
- Username Password e42348416ec75067cf92829835b160da935a4be6 src/zenml/cli/init.py View secret
πŸ›  Guidelines to remediate hardcoded secrets
  1. Understand the implications of revoking this secret by investigating where it is used in your code.
  2. Replace and store your secret safely. Learn here the best practices.
  3. Revoke and rotate this secret.
  4. If possible, rewrite git history. Rewriting git history is not a trivial act. You might completely break other contributing developers' workflow and you risk accidentally deleting legitimate data.

To avoid such incidents in the future consider


πŸ¦‰ GitGuardian detects secrets in your source code to help developers and security teams secure the modern development process. You are seeing this because you or someone else with access to this repository has authorized GitGuardian to scan your pull request.

Our GitHub checks need improvements? Share your feedbacks!

gitguardian[bot] avatar Apr 18 '24 03:04 gitguardian[bot]

@bcdurak Any update for this?

Dev-Khant avatar Apr 19 '24 06:04 Dev-Khant

Hey @Dev-Khant , I can give you a quick update regarding the status. I think, I can sum it up in there different roadblocks that still remain:

  1. The save and load methods are quite inefficient right now. As I mentioned above, in their current state, the materializers are saving/loading the models twice (for instance, once through the regular torch package and once through safetensors). I think, the first version of the implementation was closer to an actual solution, where you used save_model and load_model functions from safetensors. However, in this case, you would need to figure out how to store the model type when you call the save method, so you can access it during the load method and call the load_model function properly.

  2. In their current state, the materializers do not work with remote artifact stores, because the load_... and save_... calls of safetensors do not inherently work with remote storage systems. In general, ZenML handles this issue for other materializers by using our fileio functions around the save and load methods. You can see an example of it right here. I have seen that you already implemented it for some of the materializers, however, this needs to be applied to all of the materializers.

  3. Lastly, there is the question of how to handle the dependency of safetensors. We had a discussion within the team regarding this topic and we thought the best way to go forward here is to implement a safetensors integration instead of adding it to the main package or the respective integrations (like torch or huggingface). However, this is not the main road block right now and before trying this I would recommend fixing the materializers themselves.

bcdurak avatar Apr 30 '24 15:04 bcdurak

Alright @bcdurak, For first point I'll switch back to the previous method and see how to store model_type. And for second point I'll try different things to see which works in storing the file in the correct location.

Dev-Khant avatar May 03 '24 09:05 Dev-Khant

@Dev-Khant sorry for the delay here but I would be closing this PR due to staleness. Feel free to reopen when you work on it again!

htahir1 avatar Jul 09 '24 10:07 htahir1