RWKV-LM icon indicating copy to clipboard operation
RWKV-LM copied to clipboard

Access/train to use the embeddings

Open jn2clark opened this issue 2 years ago • 1 comments

Hi @BlinkDL ! Really interested in your work here. I am looking to test out some of the models for embedding based tasks. What is the best way to access the embeddings? I would be looking to use these for training as well (i.e. contrastive loss using siamese training setup). Any information on this would be greatly appreciated.

jn2clark avatar Jan 28 '23 02:01 jn2clark

Thanks :) from the README here: Read the inference code in src/model.py and try using the final hidden state(.xx .aa .bb) as a faithful sentence embedding for other tasks. Probably you shall begin with .xx and .aa/.bb (.aa divided by .bb).

BlinkDL avatar Jan 28 '23 10:01 BlinkDL

Can you explaim future? still don't get what they mean and which value should be use. xx or aa / bb?

tiendung avatar Mar 29 '23 18:03 tiendung

@tiendung the hidden state has 5 tensors per block (att+ffn): xx aa bb pp xx

BlinkDL avatar Apr 05 '23 11:04 BlinkDL

@BlinkDL how would the implementation look like for a def embed(text: str) -> List[float]: method in RWKV class? Such a method would be very useful. I asked gpt-4 about it and this is what it wrote:

class RWKV(pl.LightningModule):
    # (...)
    def embed(self, text: str) -> List[float]:
        args = self.args
        input_ids = args.tokenizer.encode(text)
        input_ids = torch.tensor(input_ids).unsqueeze(0).cuda()

        with torch.no_grad():
            x = self.emb(input_ids)
            x_emb = x

            if args.tiny_att_dim > 0:
                for block in self.blocks:
                    x = block(x, x_emb)
            else:
                for block in self.blocks:
                    x = block(x)

            x = self.ln_out(x)
        
        x = x[:, -1, :].detach().cpu()
        return x.squeeze().tolist()

ricardopinto avatar Apr 11 '23 13:04 ricardopinto

@ricardopinto any progress with this? I'm also interested in working with embeddings generated with RWKV but still don't have a clear understanding of how we can make it work to get embeddings like any of the models from sentence-transformers

sgaseretto avatar May 09 '23 20:05 sgaseretto

I'm now doing this in HF transformers. 430m seems faithful on writing style, not content.

KnutJaegersberg avatar Jul 04 '23 05:07 KnutJaegersberg

There is a function in gptcache that does this, too. I'm using that code in my HF transformers code, it's just a few lines.

https://gptcache.readthedocs.io/en/latest/_modules/gptcache/embedding/rwkv.html?highlight=rwkv#

KnutJaegersberg avatar Jul 04 '23 05:07 KnutJaegersberg