keras-nlp
keras-nlp copied to clipboard
Machine Translation With Transformers
can't run this example on jax or pytorch backend it just works on tensorflow backend
https://keras.io/examples/nlp/neural_machine_translation_with_keras_nlp/
also inferencing is significantly slower than a similar implementation in pytorch, like 8 times slower
Thanks! We will take a look.
#1189 for context.
@shivance and how's the decoding part?
def decode_sequences(input_sentences):
batch_size = tf.shape(input_sentences)[0]
#print(tf.strings.unicode_split(input_sentences, 'UTF-8'))
# Tokenize the encoder input.
encoder_input_tokens = input_packer(tokenizer(input_sentences))
# Define a function that outputs the next token's probability given the
# input sequence.
def next(prompt, cache, index):
logits = transformer([encoder_input_tokens, prompt])[:, index - 1, :]
# Ignore hidden states for now; only needed for contrastive search.
hidden_states = None
return logits, hidden_states, cache
# Build a prompt of length 40 with a start token and padding tokens.
length = TARGET_MAX_SEQUENCE_LENGTH
start = tf.fill((batch_size, 1), START_VALUE)
pad = tf.fill((batch_size, length - 1), PAD_VALUE)
prompt = tf.concat((start, pad), axis=-1)
generated_tokens = keras_nlp.samplers.GreedySampler()(
next,
prompt,
end_token_id=END_VALUE,
index=1, # Start sampling after start token.
)
generated_sentences = tokenizer.detokenize(generated_tokens)
return generated_sentences
test_eng_texts = [pair[0] for pair in test_pairs]
iter = tqdm(enumerate(test_pairs))
corrects = 0
for i, pair in iter:
input_sentence = pair[0]
target_sentence = pair[1]
translated = decode_sequences(tf.constant([input_sentence]))
translated = translated.numpy()[0].decode("utf-8")
translated = (
translated.replace(PAD_TOKEN, "")
.replace(START_TOKEN, "")
.replace(END_TOKEN, "")
.replace(' ', '')
.strip()
)
if translated == target_sentence:
corrects += 1
iter.set_postfix(corrects=corrects, accuracy=corrects / (i + 1))
print(f"** Example {i} **")
print(input_sentence)
print(translated)
print()