orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Strange behavior of saving sharded trainstate in GCP.

Open chiamp opened this issue 1 year ago • 3 comments

A user posted in the Flax discussions about an orbax discrepancy between different zones in GCE. Do different zones have different orbax versions?

==================================================================

what happened

When I save my sharded state in asia-northeast3-a in GCE with orbax, the orbax create /tmp/orbax_ckpt/0/_sharding file which starts with

{"dropout_rng":"{\"sharding_type\": \"NamedSharding\", \"shape\": [2, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}","opt_state.0.0.count":"{\"sharding_type\": \"NamedSharding\", \"shape\": [2, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}",
...

My sharded state has "dropout_rng" state, so above file make sense.

However, when I run same script in other region like asia-southeast1-b, the orbax create _sharding file without proper layer names, for example,

{"ZHJvcG91dF9ybmc=":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}","b3B0X3N0YXRlLjAuMC5jb3VudA==":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}",
...

Theory

I doubt that this is related to OCDBT, because the only difference in between terminal outputs is ocdbt is intitialized in asia-northeast3-a but the other regions are not having this message type_handlers.py:223] OCDBT is initialized successfully..

I checked tensorstore==0.1.51 in all region.

Anyone can help me please?

Thank you.

Originally posted by @sw32-seo in https://github.com/google/flax/discussions/3538

chiamp avatar Jan 11 '24 18:01 chiamp

@sw32-seo Can you please check Orbax version in all regions?

niketkumar avatar Jan 12 '24 18:01 niketkumar

@niketkumar I used same docker image which has orbax-checkpoint==0.4.7 for all regions.

sw32-seo avatar Jan 16 '24 00:01 sw32-seo

Hi @chiamp, thanks for raising the issue. We had to update the pytree key names encoded since special characters like '~' were not encoded properly with the previous encoding.

The names are still proper since it now is under the base64 urlsafe_encode. I'm just curious if you could re-run ur script in asia-northeast3-a and see if the sharding file updates? My suspection is that it will, and will be the same as the one you provided below.

Thanks!

liangyaning33 avatar Jan 19 '24 18:01 liangyaning33