text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

Automatically map deduplicated safetensors weights to their original values

Open Vinno97 opened this issue 2 years ago • 5 comments

What does this PR do?

This PR automatically points tensors that were removed due to deduplication to their still existing twin.

In server.text_generation_server.utils.convert.py#convert_file, tensors that have a value equal to another tensor are removed from the list of weights. Their name, and the name of the still-existing "twin" are logged to the "metadata" dictionary. However, this dictionary was not yet used during loading. This requires explicit in-code remapping when loading the models (as mentioned in the docstring).

This PR adds some simple code to check, during loading, if a weight is one of those removed weights. It then automatically retrieves the values of its still-existing "twin" instead.

What does this fix?

We currently cannot load h2oai/h2ogpt-oig-oasst1-falcon-40b with the unmodified server, since the transformer.word_embeddings.weight weight is equal to lm_head.weight and is automatically removed. The falcon code, however, still expects this weight to exist. I could have also added some extra checks to the model itself, though that would only be a workaround.

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

Vinno97 avatar Jun 28 '23 17:06 Vinno97

Hey thanks for the PR !.

Unfortunately that metadata is kept for hard debugging but it's missing crucial information, namely it doesn't recall the the tensor was a slice or not. And metadata will not necessarily be present.

I suggest a different fix:

diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py
index 5f963bf..33079ac 100644
--- a/server/text_generation_server/models/flash_rw.py
+++ b/server/text_generation_server/models/flash_rw.py
@@ -48,7 +48,13 @@ class FlashRWSharded(FlashCausalLM):

         torch.distributed.barrier(group=self.process_group)
         filenames = weight_files(model_id, revision=revision, extension=".safetensors")
-        weights = Weights(filenames, device, dtype, process_group=self.process_group)
+        weights = Weights(
+            filenames,
+            device,
+            dtype,
+            process_group=self.process_group,
+            aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
+        )

         config.quantize = quantize

Would that work ?

Narsil avatar Jun 30 '23 08:06 Narsil

Without having tested the fix (don't have access to the GPU server during the weekends), this seems like it would also fix my problem.

I initially also thought about doing this. However, it is still only a fix for this specific model. The metadata may not be a trustworthy source for tensor aliases, but it may still be a valid fallback, no? It could also be improved by adding a namespacing prefix like alias-- to the key, to prevent conflicts.

Curious to hear what you think

Vinno97 avatar Jul 01 '23 22:07 Vinno97

However, it is still only a fix for this specific model.

Indeed, but the other one can lead to potentially catastrophic failure (loading the wrong weights) which is even worse imo.

Are you ok if I update this PR ? (If I can, otherwise I'll just create a new one with you as co-author).

Narsil avatar Jul 03 '23 09:07 Narsil

Alright, fair point. Silently loading the wrong weights is definitely undesired. Sure, update this PR!

Vinno97 avatar Jul 04 '23 08:07 Vinno97

I tested the new fix "on my machine" and it worked. A colleage used it then and it didn't work for him. The difference? In his version of model-00001-of-00018.safetensors, the transformer.word_embeddings.weight was there, but the lm_head.weight was gone.

Unless the conversion can be made to be consistent, this means we should either also alias the weights the other way around. Maybe the same as done in the current patch already, or automatically (two-way aliases by default).

I don't know why his safetensors conversion is different, I guess because the order of dictionary keys is not guaranteed to be consistent, which then affects safetensors.torch_remove_duplicate_names. Pretty sure that using its preferred_names argument to would fix this weight mapping issue, but it'd be a model-specific fix in a generic part of the code-base. Unless that can work, but it'd be nicer to keep it in flash_rw.py

Vinno97 avatar Jul 06 '23 09:07 Vinno97

Any update on this PR? I also encountered the issue of missing lm_head.weight when I try to load Falcon model with text-generation-inference.

lppllppl920 avatar Jul 31 '23 01:07 lppllppl920

Hi @lppllppl920 Thanks for the ping. I'm not sure why it wasn't merged.

Narsil avatar Aug 02 '23 17:08 Narsil

Hi @lppllppl920 Thanks for the ping. I'm not sure why it wasn't merged.

Thank you!

lppllppl920 avatar Aug 02 '23 20:08 lppllppl920

The fix actually doesn't work: I discovered it while testing. Fix coming soon: https://github.com/huggingface/text-generation-inference/pull/762/files#diff-2111bae5f77d998a3fe39888906b3c7be122313241ed6b69b0b0baf5abb735bbL57

Narsil avatar Aug 03 '23 11:08 Narsil