tensor2tensor
tensor2tensor copied to clipboard
how to decode with tensorflow 2.x?
I have follow this notebook(https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/Transformer_translate.ipynb) to do translation task. But I am using tensorflow 2.3. I can't find any alternatives for restore_variables_on_create
def translate(inputs):
encoded_inputs = encode(inputs)
with tfe.restore_variables_on_create(ckpt_path):
model_output = translate_model.infer(encoded_inputs)["outputs"]
return decode(model_output)
How should I load model from checkpoint?