rust
rust copied to clipboard
Possibility to Save/Restore Checkpoints
Hi,
one thing I need to do is train a model, either one time or online, and then at regular intervals store the model so when the server goes down I can query my trained model again and I don't need to hold everything in memory after training.
I don't think this is exposed in rust yet, right? Does the C-API already support it?
We don't have anything like that at the moment. You can just read your variables and save them to a file, though, and then later read them from the file and initialize them. It's annoying and manual, but it should work. I'll leave this open because we do need to add easy save/restore functionality.
The C API supports loading from SavedModels now. Language bindings should support that instead.
Support for loading saved models was added in #68.
@jhseu, is there going to be C API for saving models to storage?
There are two ways of saving:
- Checkpoints are used for saving and restoring during training. Checkpoint saving and restoring are done by ops in the graph, so it technically works. There is a bunch of boilerplate code needed to make that work well that doesn't exist yet, though.
- SavedModel is the common format for distributing models. Saving isn't yet exposed anywhere except Python.
For those who want to save checkpoints using Saver ops:
# Define your model
# placeholders, ops, etc
# ...
# Declare saver ops
saver = tf.train.Saver(tf.global_variables())
# Export graph
definition = tf.Session().graph_def
tf.train.write_graph(definition, 'modeldir', 'modelname.pb', as_text=False)
let mut graph = Graph::new();
let mut proto = Vec::new();
File::open(modelpath)?.read_to_end(&mut proto)?;
graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?;
let mut session = Session::new(&SessionOptions::new(), &graph)
// Do training, or whatever you want to
// ...
// Then save model
let op_file_path = graph.operation_by_name_required("save/Const")?;
let op_save = graph.operation_by_name_required("save/control_dependency")?;
let file_path_tensor: Tensor<String> = Tensor::from(String::from(ckpt_file_path));
let mut step = StepWithGraph::new();
step.add_input(&op_file_path, 0, &file_path_tensor);
step.add_target(&op_save);
session.run(&mut step)?;
// Load the model
let op_load = graph.operation_by_name_required("save/restore_all")?;
let mut step = StepWithGraph::new();
step.add_input(&op_file_path, 0, &file_path_tensor);
step.add_target(&op_load);
session.run(&mut step)?;
Check https://stackoverflow.com/a/37671613 for details.
@bekker Would you be willing to add your code to the examples directory? This would be a great addition!
@adamcrume Sure, just opened a PR #159.
This is a very important feature to me.
The use case of tensorflow-rust is in production systems. So saving and restoring checkpoints makes a lot of sense.
In the example above, the ops Saver is used. But the documentation warns that Saver should not be used for saving/restore in tensorflow 2. The code in SavedModelBuilder uses SaveV2 instead. Should we use that code as an example?
@ramon-garcia Hi, the example above is quite old, and loading old Saver operations by name. It highly likely does not work with SaverV2.
I suggest using SavedModelBuilder you mentioned.