Warnings When Restoring the Params
The code shows like the following. It could run but prompted some warning: /opt/conda/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1544: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with. warnings.warn(
Model params are a dictionary type tree like this: {'params': {'Dense_0': {'bias': Array([ 1.75125599e-02, 4.10381891e-02, 1.96171561e-04, -2.42875870e-02, 1.51480837e-02, 3.71114984e-02, -4.87685064e-03, -1.63130835e-02, 5.47729768e-02, -5.70005644e-03, -5.12132980e-03, 3.10970427e-05, 2.31470224e-02, -1.55021911e-02, 1.72994770e-02, 2.26450190e-02, -3.05333477e-03, 9.84513387e-03, -3.31428014e-02, 3.80380601e-02, -3.20659392e-03, -3.09392507e-03, 1.86821781e-02, -1.25538018e-02, 4.41285521e-02, -4.72985283e-02, 3.28246184e-04, 1.31683890e-02, 1.32193940e-03, 1.48607325e-02, -3.43988538e-02, 8.36286321e-03, -2.90089939e-02, -3.98164280e-02, 2.31531989e-02, 3.27519067e-02, 2.72216517e-02, -2.89463606e-02, 2.44598440e-03, 6.63389359e-03, 4.59096301e-03, -1.23022813e-02, 1.29767824e-02, 4.81516495e-03, -1.20902760e-02, -2.27207374e-02, -1.27110668e-02, 1.20020472e-02, 3.91368084e-02, -4.30837227e-03, 3.32566164e-02, -2.71463916e-02, 2.25058272e-02, -3.91818397e-03, 1.49554424e-02, -6.85477350e-03, 1.01907691e-03, -6.12435490e-02, -1.18386028e-02, -6.03230670e-03, 7.54657155e-03, 8.14247876e-03, -6.61915401e-03, 8.85959063e-03], dtype=float32), 'kernel': Array([[-0.08260956, 0.42094436, -0.27531517, ..., -0.09135673, 0.21974503, 0.21818572], [-0.04729075, -0.2666923 , 0.14365157, ..., 0.13939556, -0.16218886, -0.04071451], [ 0.10921595, 0.01364996, -0.11194808, ..., -0.01299416, -0.02805288, -0.0272818 ], ..., [-0.04990593, -0.01473087, 0.06877133, ..., -0.05618783, -0.06337533, -0.17277789], [-0.10326906, -0.03525492, 0.21592571, ..., -0.06726424, 0.04024971, 0.21430357], [-0.06426816, 0.01593289, 0.01053577, ..., -0.08965493, 0.1562466 , 0.19774263]], dtype=float32)}, 'Dense_1': {'bias': Array([ 0.01895721, 0.02381056, -0.00297396, 0.00253655, -0.00579324, -0.00917996, -0.0524504 , -0.01307405, -0.00445831, -0.01765897, -0.02990872, -0.01783756, -0.00417391, -0.02153626, -0.01237699, 0.00332377], dtype=float32), 'kernel': Array([[-0.0687123 , 0.11527583, 0.02760898, ..., -0.11483309, 0.09793864, 0.24956086], [-0.17475414, 0.06557149, 0.02568068, ..., -0.18699066, -0.235098 , 0.17345282], [ 0.21747173, -0.00923413, -0.04049944, ..., 0.04021717, -0.03704283, 0.13622351], ..., [ 0.1976054 , -0.07143398, 0.11763132, ..., 0.15076494, -0.08623252, 0.08628309], [ 0.142208 , -0.07710048, 0.05116218, ..., 0.05643938, 0.01690205, -0.00337057], [-0.08983981, -0.08721507, 0.05885444, ..., 0.2054291 , -0.0595689 , 0.09482205]], dtype=float32)}, 'Dense_2': {'bias': Array([-0.27287585, -0.31808662, -0.22906446, -0.2392324 , -0.1169002 , -0.45564348, -0.27986547, -0.4403381 , -0.3194529 , -0.03579619, -0.27706683, -0.20705369, -0.3464241 , -0.16313383, -0.3245753 , -0.12070157, -0.10058393, -0.335585 , -0.23487404, -0.13635263, -0.3551262 , -0.19502614, -0.27066055, -0.22264665, -0.17983833, -0.38362965, -0.2549991 , -0.35028023, -0.02632488, -0.24093926, -0.26272595, -0.32823324, -0.1442327 , -0.18271838, -0.3466661 , -0.2975728 , -0.2519938 , -0.20744751, -0.48289314, -0.20181467, -0.0694458 , -0.2868131 , -0.0621618 , -0.1489881 , -0.22316173, -0.26048866, -0.3741152 , -0.22691546, -0.28160262, -0.39583412, -0.44518995, -0.26774997, -0.18526609, -0.3136557 , -0.29002288, -0.2983223 , -0.4889701 , -0.20518056, -0.06886528, -0.18853416, -0.06637306, -0.45197925, -0.3145519 , -0.23673685], dtype=float32), 'kernel': Array([[ 0.01163595, 0.02972907, 0.02223774, ..., 0.0228426 , 0.06626749, 0.04824122], [-0.00046695, -0.08048075, -0.10955726, ..., -0.02934793, -0.04758933, -0.06418303], [-0.1286408 , -0.1255539 , -0.11133979, ..., -0.02911641, -0.16321352, -0.1160882 ], ..., [-0.01398295, -0.02459321, 0.21012494, ..., 0.07538891, -0.08655173, 0.02649066], [ 0.11309541, 0.09003462, -0.01682626, ..., -0.18835074, 0.09409627, 0.05982505], [-0.12019534, -0.06023936, 0.14683168, ..., -0.10527591, -0.0902904 , -0.08336279]], dtype=float32)}, 'Dense_3': {'bias': Array([-0.19127552, -0.1539499 , -0.10825736, -0.1273831 , -0.14423408, -0.10800537, -0.1158509 , -0.19560331, -0.0544809 , -0.12320589, -0.14327425, -0.06410812, -0.06359567, -0.05706155, -0.16820268, -0.06965973], dtype=float32), 'kernel': Array([[-0.17942922, -0.2944296 , -0.15672816, ..., -0.22876067, 0.04339508, -0.00558091], [ 0.00910171, 0.00975822, -0.06065388, ..., -0.14715518, -0.05254569, -0.09955268], [-0.1705753 , -0.0021669 , 0.120933 , ..., -0.01333852, -0.09636445, -0.13689627], ..., [-0.07949128, -0.12297069, -0.30489385, ..., 0.08778288, -0.10832835, -0.20170009], [-0.00643041, -0.06448855, -0.02339345, ..., -0.1243507 , 0.10904145, 0.01637404], [-0.15585563, -0.07519744, -0.05542706, ..., -0.0285616 , -0.03688109, -0.07460079]], dtype=float32)}}}
# Save the params
if os.path.exists('/user/working/model_params'):
shutil.rmtree('/user/working/model_params')
checkpoint=orbax.checkpoint.PyTreeCheckpointer()
checkpoint.save('/user/working/model_params',params)
checkpoint=orbax.checkpoint.PyTreeCheckpointer()
checkpoint.restore('/user/working/model_params')
How should I fix this? Thx.
This documentation should be useful: https://orbax.readthedocs.io/en/latest/checkpointing_pytrees.html#checkpointing-pytrees-of-arrays