keras icon indicating copy to clipboard operation
keras copied to clipboard

Weights sharding for Keras saving

Open nkovela1 opened this issue 1 year ago • 3 comments
trafficstars

This PR adds weights sharding initial functionality to the Keras saving/loading APIs, which are accessed by passing the sharded=True flag to the corresponding saving/loading calls.

nkovela1 avatar Mar 11 '24 20:03 nkovela1

Codecov Report

Attention: Patch coverage is 67.76860% with 39 lines in your changes missing coverage. Please review.

Project coverage is 75.61%. Comparing base (c8700f4) to head (cfbb761). Report is 583 commits behind head on master.

Files Patch % Lines
keras/saving/saving_lib.py 66.66% 26 Missing and 13 partials :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19286      +/-   ##
==========================================
- Coverage   80.14%   75.61%   -4.53%     
==========================================
  Files         341      365      +24     
  Lines       36163    39909    +3746     
  Branches     7116     7747     +631     
==========================================
+ Hits        28982    30177    +1195     
- Misses       5578     8054    +2476     
- Partials     1603     1678      +75     
Flag Coverage Δ
keras 75.46% <67.76%> (-4.53%) :arrow_down:
keras-jax 59.71% <67.76%> (-3.35%) :arrow_down:
keras-numpy 54.30% <66.11%> (-2.79%) :arrow_down:
keras-tensorflow 61.21% <67.76%> (-3.44%) :arrow_down:
keras-torch 60.29% <53.71%> (-3.58%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Mar 11 '24 21:03 codecov-commenter

Talked with Neel a bit about this, but one idea, building off the recent change Francois made with __getitem__/__setitem__...

  • Create h5 groups on variable write, instead of eagerly when creating a H5Entry. (Side note, this would clean up our h5 entires, no empty groups for layers without weights.)
  • H5Entry is just a dict like object that proxies calls get/set calls to parent H5Store.
  • On write, H5Store could just keep a running list of how big the shard is currently, and "roll over" to a new shard as soon as the next variable would be bigger than shard limit.
  • On read, H5Store could just check every shard file for the weight is try to load (as checking is cheap, reading is slow).

Pseudocode:

write(path, key, value):
    if self.current_shard_size + value.nbytes > self.shard_size:
        close current shard
        open new shard file
        self.current_shard_size = 0
    group = create parent groups if needed
    self.current_shard_size += value.nbytes
    group[key] = value

read(path, key):
    for file in shards:
        if path in file:
            group = file[path]
            if key in group:
                return group[key]

This could be fairly simple. Avoid the need for a separate class if we want (though we still could), allow splitting up individual layer weight across shards (important if you have one big layer).

This could even allow avoiding the json file entirely I think? Supporting something like this:

# If shard_size is set, pass a format string as path?
filenames = model.save_weights("./model_{}.weights.h5", shard_size="10GB")
# Load weights handles loading a list of files, and checking all files for the variables.
model.load_weights(filenames)

This last bit is optional, just though it was interesting. What do people think?

mattdangerw avatar Mar 13 '24 02:03 mattdangerw

Actually thinking about this more, let's keep the json file. When downloading from hubs, we want to be able to download one file that tells us exactly what other paths to download.

mattdangerw avatar Sep 03 '24 20:09 mattdangerw

Hey @nkovela1 @mattdangerw @divyashreepathihalli

Divya mentioned that we want to add this feature but I'm not very familiar with it. Could you clarify if the ultimate goal is to automatically split large weights into smaller chunks? (similar to transformers lib)

Refs:

  • PreTrainedModel.save_pretrained: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2733
  • split_state_dict_into_shards_factory: https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_base.py#L49

james77777778 avatar Mar 03 '25 14:03 james77777778

@james77777778 yes! Something like max_shard_size for HF. We want a good way to store weights in our keras format for, say, 100b+ models. If we assume two bytes per weight, that's hundreds of gigabytes of weights.

It's definitely not practical to store such a model as a single .keras zip file, and even in our "directory of assets" approach for keras-hub, a multi-hundred gigabytes .weights.h5 single file isn't great. So the goal is just to give us a shard size in model.save_weights we can use from keras-hub when saving our presets. Better handle interruptions to the download process, etc.

The goal is not to do any fancy distribution of weights that would play nice with model parallel training (e.g. each device only downloads the weight it needs). That's a much more complicated problem, let's punt for now.

mattdangerw avatar Mar 08 '25 00:03 mattdangerw

This was completed as part of https://github.com/keras-team/keras/pull/21022

hertschuh avatar Aug 08 '25 18:08 hertschuh