keras
keras copied to clipboard
Weights sharding for Keras saving
This PR adds weights sharding initial functionality to the Keras saving/loading APIs, which are accessed by passing the sharded=True flag to the corresponding saving/loading calls.
Codecov Report
Attention: Patch coverage is 67.76860% with 39 lines in your changes missing coverage. Please review.
Project coverage is 75.61%. Comparing base (
c8700f4) to head (cfbb761). Report is 583 commits behind head on master.
| Files | Patch % | Lines |
|---|---|---|
| keras/saving/saving_lib.py | 66.66% | 26 Missing and 13 partials :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## master #19286 +/- ##
==========================================
- Coverage 80.14% 75.61% -4.53%
==========================================
Files 341 365 +24
Lines 36163 39909 +3746
Branches 7116 7747 +631
==========================================
+ Hits 28982 30177 +1195
- Misses 5578 8054 +2476
- Partials 1603 1678 +75
| Flag | Coverage Δ | |
|---|---|---|
| keras | 75.46% <67.76%> (-4.53%) |
:arrow_down: |
| keras-jax | 59.71% <67.76%> (-3.35%) |
:arrow_down: |
| keras-numpy | 54.30% <66.11%> (-2.79%) |
:arrow_down: |
| keras-tensorflow | 61.21% <67.76%> (-3.44%) |
:arrow_down: |
| keras-torch | 60.29% <53.71%> (-3.58%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Talked with Neel a bit about this, but one idea, building off the recent change Francois made with __getitem__/__setitem__...
- Create h5 groups on variable write, instead of eagerly when creating a
H5Entry. (Side note, this would clean up our h5 entires, no empty groups for layers without weights.) H5Entryis just a dict like object that proxies calls get/set calls to parentH5Store.- On write,
H5Storecould just keep a running list of how big the shard is currently, and "roll over" to a new shard as soon as the next variable would be bigger than shard limit. - On read,
H5Storecould just check every shard file for the weight is try to load (as checking is cheap, reading is slow).
Pseudocode:
write(path, key, value):
if self.current_shard_size + value.nbytes > self.shard_size:
close current shard
open new shard file
self.current_shard_size = 0
group = create parent groups if needed
self.current_shard_size += value.nbytes
group[key] = value
read(path, key):
for file in shards:
if path in file:
group = file[path]
if key in group:
return group[key]
This could be fairly simple. Avoid the need for a separate class if we want (though we still could), allow splitting up individual layer weight across shards (important if you have one big layer).
This could even allow avoiding the json file entirely I think? Supporting something like this:
# If shard_size is set, pass a format string as path?
filenames = model.save_weights("./model_{}.weights.h5", shard_size="10GB")
# Load weights handles loading a list of files, and checking all files for the variables.
model.load_weights(filenames)
This last bit is optional, just though it was interesting. What do people think?
Actually thinking about this more, let's keep the json file. When downloading from hubs, we want to be able to download one file that tells us exactly what other paths to download.
Hey @nkovela1 @mattdangerw @divyashreepathihalli
Divya mentioned that we want to add this feature but I'm not very familiar with it.
Could you clarify if the ultimate goal is to automatically split large weights into smaller chunks?
(similar to transformers lib)
Refs:
PreTrainedModel.save_pretrained: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2733split_state_dict_into_shards_factory: https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_base.py#L49
@james77777778 yes! Something like max_shard_size for HF. We want a good way to store weights in our keras format for, say, 100b+ models. If we assume two bytes per weight, that's hundreds of gigabytes of weights.
It's definitely not practical to store such a model as a single .keras zip file, and even in our "directory of assets" approach for keras-hub, a multi-hundred gigabytes .weights.h5 single file isn't great. So the goal is just to give us a shard size in model.save_weights we can use from keras-hub when saving our presets. Better handle interruptions to the download process, etc.
The goal is not to do any fancy distribution of weights that would play nice with model parallel training (e.g. each device only downloads the weight it needs). That's a much more complicated problem, let's punt for now.
This was completed as part of https://github.com/keras-team/keras/pull/21022