mesh
mesh copied to clipboard
Debug in mesh Tensorflow
Hey guys,
thanks so much for releasing all the t51.1 and mt5 weights! I'm currently working on porting all these models to huggingface's transformers. Is there anyway to run mesh tensorflow in eager mode by any chance?
E.g. if I run the following predict command:
import t5
from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
t5_model = t5.models.MtfModel(
model_dir="./checkpoint",
batch_size=16,
sequence_length={"inputs": 128, "targets": 32},
learning_rate_schedule=0.003,
save_checkpoints_steps=5000,
keep_checkpoint_max=None,
iterations_per_loop=100,
tpu=None
)
vocab_model_path = '<path/to/spm_vocab>'
vocab = SentencePieceVocabulary(vocab_model_path, extra_ids=100)
t5_model.predict(
input_file="input.txt",
output_file="output.txt",
vocabulary=vocab,
temperature=0
)
is there any way that I can run the prediction in eager mode so that I can print out the actual values in of the tensors? E.g. the tensor values of the input to the cross attention layer: https://github.com/tensorflow/mesh/blob/165d3dc7b4186ee5b6d31c9b17b3df4f7571cf42/mesh_tensorflow/transformer/transformer_layers.py#L729
I had a hard time finding tests in the repo that run a small transformer network.
I'd be super happy for some pointers :-)
Also pinging @craffel in case you have any good pointers for good debugging tools :-)
Hey Patrick, unfortunately I believe that because the mesh tf transformer uses tf.Estimator it is not eager-friendly. In the past when we've needed to do similar things, I'm sad to say that we just used Print ops.
FWIW we will soon be releasing a JAX implementation of T5(.1.1) which should make this kind of debugging and inspection a lot easier.
Thanks a lot for your answer! :-) Print ops it is then
How do we use Print ops? The function says "WARNING:tensorflow:Warning - mtf.Print not implemented for this mesh type"