attention-is-all-you-need-keras icon indicating copy to clipboard operation
attention-is-all-you-need-keras copied to clipboard

pure language model

Open XiaoLiuAI opened this issue 6 years ago • 1 comments

Hello, inspired by openai/finetune-transformer-lm, I am now trying to make a language model based on your code. I got a question during implementation.

self.model = Model([src_seq_input, tgt_seq_input], loss)
self.model.compile(optimizer, None)

Why don't you add the loss function through compile api? I am not quite sure about the effect of api add_loss.

By the way, I made a language model encoder based on your Encoder, but I added GetSubMask as you did in Decoder. Then I would like to add a crf layer after the encoder (for sequence labelling, while openAi's model is for text classification). Finally, train the model based on the language model loss + crf loss. Do you have any implementation suggestion? Especially any idea for verifying the correctness of the code...

I saw you example data about pinyin and Chinese, are you Chinese?

XiaoLiuAI avatar Aug 22 '18 08:08 XiaoLiuAI

My current implementation is

class TransformerEncoderCrf(Wrapper):
    def __init__(self, config):
        self.len_limit = config.len_limit

    def get_pos_seq(self, x):
        mask = K.cast(K.not_equal(x, 0), 'int32')
        pos = K.cumsum(K.ones_like(x, 'int32'), 1)
        return pos * mask # TODO add length limit

    def get_loss(args):
        y_pred, y_true = args
        y_true = K.cast(y_true, 'int32')
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true[:,1:], logits=y_pred[:,:-1])
        mask = K.cast(tf.not_equal(y_true, 0), 'float32')
        loss = tf.reduce_sum(loss * mask, -1) / tf.reduce_sum(mask, -1) # batch_size * 1
        loss = K.mean(loss) 
        return loss

    def load_model(self, config):
        nvocab = config.nvocab
        len_limit = config.len_limit
        d_embed = config.d_embed
        share_word_emb = config.share_word_emb
        n_head = config.n_head
        d_k = config.d_k
        d_v = config.d_v
        d_inner_hid = config.d_inner_hid
        n_layers = config.n_layers
        dropout = config.dropout

        if config.pos_trainable:
            pos_emb = Embedding(len_limit, d_embed, trainable=True)
            pos_emb = Embedding(len_limit, d_embed, trainable=False, weights=[GetPosEncodingMatrix(len_limit, d_embed)])

        word_emb = Embedding(nvocab, d_embed)

        self.encoder = FullSeqEncoder(d_embed, d_inner_hid, n_head, d_k, d_v, n_layers, dropout,
                                     word_emb=word_emb, pos_emb=pos_emb)

        self.tok_input = Input(shape=(None,), dtype='int32')
        self.tok_output = Input(shape=(None,), dtype='int32')

        self.position_input = Lambda(self.get_pos_seq)(self.tok_input)

        enc_output = self.encoder(self.tok_input, self.position_input)
        self.encoder_model = Model(inputs=self.tok_input, outputs=enc_output)  # for possible pre-training

        lm_output = TimeDistributed(TiedEmbeddingsTransposed(tied_to=word_emb))(enc_output)
        lm_loss = Lambda(self.get_loss)([lm_output, self.tok_input])

        if config.use_crf:
            fully_connected_layer = TimeDistributed(Dense(config.num_fully_connect, activation='tanh'))
            crf_layer = CRF(config.ntags, sparse_target=False)
            ner_output = crf_layer(fully_connected_layer(self.encoder_model))

            self.loss = [crf_layer.loss_function, lm_loss]
            self.metrics = crf_layer.accuracy

            self.model = Model(inputs=self.tok_input, outputs=ner_output)
            # TO BE DONE
            output_layer = TimeDistributed(Dense(config.ntags, activation='softmax'))
            self.model = Model(inputs=self.tok_input, outputs=output_layer)
            self.loss = 'categorical_crossentropy'
            self.metrics = 'accuracy'


    def compile(self, *args, **kwargs):
        # TO BE DONE
        if 'metrics' in kwargs:
            kwargs['metrics'] = [self.metrics]
        self.model.compile(*args, loss=self.loss, **kwargs)

XiaoLiuAI avatar Aug 22 '18 08:08 XiaoLiuAI