Update GPT2 with the loss from a black box
I want to implement an idea using tensorflow and gpt2(huggingface transformers version).
For each iteration, I just let the gpt2 produce some sentences, and these sentences are fed into a black box, which can return the loss, for example, based on the quality of these synthesis sentences produced by current GPT2. Then I want to use this loss to update the parameters of GPT2 to improve it. This idea is quite simple, however, I find that it is difficult to realize. Possibly because the connection between the loss and the GPT2, is not differentiable.
Here is my codes in tensorflow:
from transformers import GPT2Tokenizer, TFGPT2LMHeadModel, TFGPT2Model, TFAutoModelForCausalLM
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token # to avoid an error
gpt2 = TFGPT2LMHeadModel.from_pretrained('gpt2')
gpt2.trainable = True
#model = TFAutoModelForCausalLM.from_pretrained("gpt2")
#model = TFGPT2LMHeadModel.from_pretrained('gpt2')
#model.train()
# when generating, we will use the logits of right-most token to predict the next token
# so the padding should be on the left
num_return_sequences = 1
#prompts = list(x_batch_train.numpy().reshape(-1))
#token_lens = [len(tokenizer.tokenize(sent)) for sent in prompts]
#max_length = math.ceil(np.array(token_lens).max())*2
max_len = get_tokens_len(ds, 0.99)
cce = tf.keras.losses.CategoricalCrossentropy()
optimizer = keras.optimizers.Adam(learning_rate=0.0001)
def loss_fn(output_sequences, labels):
syn_sents = tokenizer.batch_decode(output_sequences, clean_up_tokenization_spaces=True, skip_special_tokens=True)
syn_sents_pure = []
for sent, sent_syn in zip(prompts, syn_sents):
syn_sents_pure.append(sent_syn.replace(sent, '').replace('\n',' ').strip())
preds = model(np.array(syn_sents_pure))
assert preds.shape[0] == len(prompts) and preds.shape[1] == num_classes
label_oht = tf.keras.utils.to_categorical( np.array([label_idx[l] for l in labels]), num_classes = num_classes, dtype='int' )
label_oht_tf = tf.convert_to_tensor(label_oht)
assert label_oht.shape == preds.shape
loss_value = cce(label_oht_tf, preds)#.numpy()
return loss_value
rows = ds.df_test.sample(5)
prompts = rows['content'].tolist()
labels = rows['label'].tolist()
with tf.GradientTape() as tape:
# Run the forward pass of the layer.
# The operations that the layer applies
# to its inputs are going to be recorded
# on the GradientTape.
#logits = model(x_batch_train, training=True) # Logits for this minibatch
inputs = tokenizer(prompts, padding='max_length', truncation=True, max_length=max_len, return_tensors="tf")
output_sequences = gpt2.generate(
input_ids = inputs['input_ids'],
attention_mask = inputs['attention_mask'],
max_length= max_len*2,
temperature=1,
top_k=0,
top_p=0.9,
repetition_penalty=1,
do_sample=True,
num_return_sequences=num_return_sequences
)
# Compute the loss value for this minibatch.
#loss_value = loss_fn(y_batch_train, logits)
loss_value = loss_fn(output_sequences, labels) # <tf.Tensor: shape=(), dtype=float32, numpy=0.062384058>
# Use the gradient tape to automatically retrieve
# the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, gpt2.trainable_weights)
# Run one step of gradient descent by updating
# the value of the variables to minimize the loss.
optimizer.apply_gradients(zip(grads, model.trainable_weights))
I find that the grads are None once I run it.
Actually, the black box, loss_fn, decode the outputs from GPT2 to plain texts, and use another model to calculate the loss.
So is there any proper way to implement my idea?
Or how to correct my codes ?
Thanks.