SIMPLE icon indicating copy to clipboard operation
SIMPLE copied to clipboard

Exporting to TF SavedModel/TFLite

Open maciel310 opened this issue 2 years ago • 2 comments

I'm trying to export the resulting model to TFLite so I can run inference on another device, but I'm hitting some issues. I found instructions on how to export a model in the Stable Baselines documentation and tried adapting it for PPO1 instead of PPO2, however when I try and load the resulting SavedModel I get an exception about the Tensor not existing.

Here's the code:

  ppo_model = load_model(env, 'best_model.zip')

  tf.saved_model.simple_save(ppo_model.sess, "TEST_OUTPUT", inputs={"obs": ppo_model.policy_pi.obs_ph},
                                   outputs={"action": ppo_model.policy_pi._policy_proba})

  converter = tf.lite.TFLiteConverter.from_saved_model("TEST_OUTPUT")
  tflite_model = converter.convert()

And the full error message: KeyError: "The name 'input/Ob:0' refers to a Tensor which does not exist. The operation, 'input/Ob', does not exist in the graph."

I've verified that ppo_model is being loaded correctly by running the inference (using ppo_model.action_probability()), so I don't believe there's an issue there. The SavedModel directory does get created on the tf.saved_model.simple_save step, however I believe it may not be a complete export as the size is very small.

I'm rather new to the ML side of things, so there might be something obvious that I'm missing, so any help would be greatly appreciated!

Thanks for putting together this great library!

maciel310 avatar Sep 08 '22 04:09 maciel310

@maciel310 I ran into this exact same problem. I was able to perform a tflite export like this:

ppo_model = load_model(env, 'best_model.zip')

with ppo_model.graph.as_default():
    tf.saved_model.simple_save(ppo_model.sess, "TEST_OUTPUT",
        inputs={"obs": ppo_model.policy_pi.obs_ph},
        outputs={"action": ppo_model.policy_pi._policy_proba})
converter = tf.lite.TFLiteConverter.from_saved_model("TEST_OUTPUT")
tflite_model = converter.convert()
with open('best_model.tflite', 'wb') as f:
    f.write(tflite_model)

dbravender avatar Jan 06 '23 18:01 dbravender

I added a script to do this in https://github.com/davidADSP/SIMPLE/pull/34

dbravender avatar Jun 11 '23 01:06 dbravender