transformers
transformers copied to clipboard
[WIP] RWKV4Neo the RNN and GPT Hybrid Model
What does this PR do?
Adds the model from issue Fixes # (https://github.com/huggingface/transformers/issues/20737)
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
@younesbelkada @ArthurZucker
Hi @ArEnSc ! Thanks for starting over the PR 💪 Let us know whenever you need help with @ArthurZucker !
Hi @ArEnSc ! Thanks for starting over the PR 💪 Let us know whenever you need help with @ArthurZucker !
Will do still doing some research, just figured out how the training notebook works, model executes in notebook so that's a positive
Update: tracing the model and came up with a state based api for the RNN inference mode on my own code base to experiment with
Thanks a lot for the status update! Feel free to ping whenever you need help
Sometimes I look at working on this a little. Here are my notes and possible tasks, started 2023-01-16.
-
The template appears to be from a T5 style model. The RWKV state could be the encoder hidden state (a little intuitive) and/or the past key values (normative generation). It will take some algebra and tests to add input state to the GPT training form from the RNN inference form.
-
[ ] The tensorflow loading code appears complicating to me. I might move it out to another file for now.
-
[ ] The embeddings can likely be adjusted to reflect parts "i" and "ii" of the high level outline below
-
[ ] It could be helpful to organize the file to retain layout similarity with blinkdl’s files.
-
[ ] For below outline, next step is reviewing timemix. Draft of architecture (maybe leave out optional parts to start).
High level:
- word embeddings
emb
- layernorm
ln0
- optional 2-axis trained position embeddings seen in training code for image modelingpos_emb_x
pos_emb_y
. this is converted to 1-axispos_emb
and used prior to ln0 in inference. - layers of blocks
1. layernorm
ln1
2. timemix self attentiontime_mix_k
,time_mix_v
,time_mix_r
,time_first
,time_decay
,key
,value
,receptance
,output
.time_first
andtime_decay
are kept as float32 in inference. 3. layernormln2
4. feedforward channelmixtime_mix_k
,time_mix_r
,key
,value
,receptance
(see channelmix section below) - timemix self attention optionally replaced with feedforward channelmix for block 0 in training code - for one optional block, tiny attentiontiny_ln
,tiny_q
,tiny_k
,tiny_v
,tiny_mask
seen in training code, inference code in development - optionally inference code uses what looks like a numeric stability trick to extract a factor of 2 from the weights every 6 layere - layernorm
ln_out
- optional "copy" attentionhead_q
,head_k
,copy_mask
then summed to head in training code, inference code in development - linear language modeling
head
- for training loss, blink presently has a function after cross entropy calledL2Wrap
to reduce magnitudes
GPT(training) and RNN (inference) equivalence:
- i think special training initialization values may be used in timemix, channelmix
- for inference
time_decay
= -exp(time_decay) is factored out when loaded, but for training this is done in the forward pass. - 5 state elements per layer:
- 0 = ChannelMix/FF
xx
- 1 = TimeMix/SA
xx
- 2 =
aa
- 3 =
bb
- 4 =
pp
in inference,o
in training
- 0 = ChannelMix/FF
TimeMix:
- the previous state is shifted into the
x
vector to makexx
. in training this is done by "time shifting" withnn.ZeroPad2d((0, 0, 1, -1))
; in single token inference it is passed as state element 1, which is then replaced byx
. - linear interpolation between the old state xx and the new state x, weighting
x
by a ratio oftime_mix_k
,time_mix_v
, andtime_mix_r
to makexk
,xv
, andxr
respectivly. - k = key @ xk
- v = value @ xv
- sr = sigmoid(receptance @ xr) # called simply
r
in inference code
- the GPT training form of this is now handed off to a hand-written cuda kernel, compiled on first run, from cuda/wkv_cuda.cu
- kernel parameters:
B
= batchsize;T
= sequence length;C
= channel count;_w
=time_decay
;_u
=time_first
;_k
=k
;_v
=v
;_y
=wkv
. - i think this used to be a convolution; i'm not sure whether it still is
-
o
andno
appear to be running values for magnitude management in exponential space, initialized to -1e38; p and q are initialized to 0 -
k
andv
are indexed by thread so thetoken
offset may represent different subregions. i'm not quite clear on that and should test or ask.
- no = max(o, time_first[channel] + k[token])
- A = exp(o - no) # this is e1 in the RNN form
- B = exp(time_first[channel] + k[token] - no) # this is e2 in RNN
- wkv[token] = (A * p + B * v[token]) / (A * q + B)
- no = max(time_decay[channel] + o, k[token])
- A = exp(time_decay[channel] + o - no)
- B = exp(k[token] - no)
- p = A * p + B * v[token]
- q = A * q + B
- o = no; token += 1
- kernel parameters:
- ... here would be the remaining core algebra and code inspection
- WIP unified summary of wkv kernel between inference and training:
- ww = time_first + k[token]
- next_pp = max(pp, ww)
- A = exp(pp - next_pp ...
- rwkv = sr * wkv
- return output @ rwkv
ChannelMix:
- the previous state is shifted into the
x
vector to makexx
. in training this is done by "time shifting" withnn.ZeroPad2d((0, 0, 1, -1))
; in single token inference it is passed as state element 0, which is then replaced byx
. - linear interpolation between the old state xx and the new state x, weighting
x
by a ratio oftime_mix_k
andtime_mix_r
to makexk
andxr
respectivly. - r = sigmoid(receptance @ xr)
- k = square(relu(key @ xk))
- kv = value @ k
- rkv = r * kv
- return rkv
- word embeddings
-
[ ] review or improve model file further
@ArEnSc do you need any help?
@ArEnSc do you need any help?
if you want to help pm me! on discord, otherwise I should have something end of week minor update
Hi @ArEnSc, Can you share with us your discord handle? Thanks!
Hi @ArEnSc, Can you share with us your discord handle? Thanks!
ARENSC#5905 yeah still working on it haha it will be a while
Working on having GPT Encoder to generate the context and RNN mode inference and sharing weights
Deleted a bunch of not needed stuff
Added the [WIP] Label to prevent the bot from coming back 😉
@ArEnSc Please let us know if you won't have time to finish this PR. The model is heavily requested as you may see from the linked issue, do you want us to take over this PR and finish this?
@ArEnSc Please let us know if you won't have time to finish this PR. The model is heavily requested as you may see from the linked issue, do you want us to take over this PR and finish this?
Sure yes, sorry been busy at the hospital these days! I think it's probably important that you guys take this on =)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.