Getting JaxRuntimeError key awaitable_signals_contract not found
Over the last few days using orbax checkpoint installed from the repo directly I've had increasing numbers of grpc key awaitable_signals_contract not found issues pop up. This is on midsize TPU clusters writing to gcs. So far the checkpoints themselves seem to be ok and everything continues as normal but I wanted to flag it anyway. Shape of the error looks like this with the key number before the slash changing:
WARNING JaxRuntimeError raised while trying to get key '295/awaitable_signals_contract'.
Traceback (most recent call last):
File "/home/redacted/orbax/checkpoint/orbax/checkpoint/_src/futures/signaling_client.py", line 141, in key_value_try_get
return str(self._client.key_value_try_get(key))
jaxlib.xla_extension.XlaRuntimeError: NOT_FOUND: Config key 295/awaitable_signals_contract not found.
Additional GRPC error information from remote target coordination_service while calling /tensorflow.CoordinationService/TryGetKeyValue:
:{"created":"@1746905336.599367778","description":"Error received from peer ipv4:ip_redacted:8476","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Config key 295/awaitable_signals_contract not found.","grpc_status":5}
I am not sure if the error is actually new or if it's just properly highlighted now because of the changes in 8bbe3a4
This error is actually "expected" in some cases. So it should not actually be a warning log for those cases. This is being highlighted after https://github.com/google/orbax/commit/8bbe3a4178dd146694f6f4d53d7809d4941be169
https://github.com/google/orbax/pull/1925 should stop reduce this log to a vlog. Till we figure out a way to log "unexpected" errors correctly.