RWKV-LM
RWKV-LM copied to clipboard
Access/train to use the embeddings
Hi @BlinkDL ! Really interested in your work here. I am looking to test out some of the models for embedding based tasks. What is the best way to access the embeddings? I would be looking to use these for training as well (i.e. contrastive loss using siamese training setup). Any information on this would be greatly appreciated.
Thanks :) from the README here: Read the inference code in src/model.py and try using the final hidden state(.xx .aa .bb) as a faithful sentence embedding for other tasks. Probably you shall begin with .xx and .aa/.bb (.aa divided by .bb).
Can you explaim future? still don't get what they mean and which value should be use. xx or aa / bb?
@tiendung the hidden state has 5 tensors per block (att+ffn): xx aa bb pp xx
@BlinkDL how would the implementation look like for a def embed(text: str) -> List[float]:
method in RWKV
class? Such a method would be very useful. I asked gpt-4 about it and this is what it wrote:
class RWKV(pl.LightningModule):
# (...)
def embed(self, text: str) -> List[float]:
args = self.args
input_ids = args.tokenizer.encode(text)
input_ids = torch.tensor(input_ids).unsqueeze(0).cuda()
with torch.no_grad():
x = self.emb(input_ids)
x_emb = x
if args.tiny_att_dim > 0:
for block in self.blocks:
x = block(x, x_emb)
else:
for block in self.blocks:
x = block(x)
x = self.ln_out(x)
x = x[:, -1, :].detach().cpu()
return x.squeeze().tolist()
@ricardopinto any progress with this? I'm also interested in working with embeddings generated with RWKV but still don't have a clear understanding of how we can make it work to get embeddings like any of the models from sentence-transformers
I'm now doing this in HF transformers. 430m seems faithful on writing style, not content.
There is a function in gptcache that does this, too. I'm using that code in my HF transformers code, it's just a few lines.
https://gptcache.readthedocs.io/en/latest/_modules/gptcache/embedding/rwkv.html?highlight=rwkv#