java icon indicating copy to clipboard operation
java copied to clipboard

API to read weights from ckpt file

Open kushal-g opened this issue 3 years ago • 5 comments

System information

  • TensorFlow version (you are using): Nightly
  • Are you willing to contribute it (Yes/No): No

Describe the feature and the current behavior/state. This API would help in reading the weights from ckpt file. Currently, there is no such feature in the Java SDK but the same can be achieved in python via CheckPointReader.

Will this change the current API? How? This shouldn't affect the preexisting API. It would be a simple addition over it.

Who will benefit with this feature? I was developing a Federated Learning Application and was stuck at the issue where I was not able to get the weights from the ckpt file (Reason: No API in Java SDK) and neither could I write a signature method in tflite for the same (Reason: tflite doesn't support dynamic tensor shapes). So this would help promote active development in FL on Android clients.

kushal-g avatar Feb 15 '22 16:02 kushal-g

session.restore(String path) should do that already - https://github.com/tensorflow/java/blob/master/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java#L709.

Craigacp avatar Feb 22 '22 23:02 Craigacp

How would I get the graph for creating the session? I need to get the weights from the following ckpt file.

https://drive.google.com/file/d/1ZH1SiF4Z-YYGZjwPbtd4cpCbvGXYIuu2/view?usp=sharing

This operation will take place in an Android app so as to send these weights to a server to aggregate using Federated Learning.

kushal-g avatar Feb 24 '22 19:02 kushal-g

The graph should be defined somewhere as a protobuf, but you can also specify the model yourself. How are you loading in the model architecture at the moment?

Craigacp avatar Feb 24 '22 19:02 Craigacp

I am not. I assumed that the model architecture could be loaded from the ckpt itself. I have the model architecture in a python code. From what I understand from your reply, I have two options:

  1. Export the model architecture as protobuf using python code and then load the architecture in my java code to generate the graph and then use the session to get the weights
  2. Reimplement the model inside my java code and then use that graph to load the session and get the weights.

Is that correct? If so, could you guide me to some documentation for this?

kushal-g avatar Feb 25 '22 06:02 kushal-g

Variable checkpoints are stored separately from model structure in TensorFlow, so you need both components to load in a model. The saved model format is a directory containing the model structure protobuf and a variable checkpoint, so you can load that in and continue training. There are other ways of exporting the computation graph, but those are harder to do now in TF 2.x. Otherwise then you'll need to implement the model in Java (which might prove a bit tricky as you'll need to enforce that the variables all have the same names as the variables in Python).

Craigacp avatar Feb 25 '22 14:02 Craigacp