RWKV-LM icon indicating copy to clipboard operation
RWKV-LM copied to clipboard

RWKV-4 169m/430m in browser with ORT Web / TF.js / tfjs-tflite?

Open josephrocca opened this issue 1 year ago • 32 comments

Hi, really exciting project! I'm wondering if you've published the model conversion script that you used to create the js_models files from the .pth model file? It would be awesome to see how the larger and newer models like RWKV-4 169m/430m perform in the browser! I think the inference speed of RWKV opens up many new possibilities for language models on the web.

josephrocca avatar Aug 20 '22 00:08 josephrocca

Exporting to ONNX is something that I've been tinkering with and I can report that the 169m RWKV-4 model does run in browser. Here's my code: https://github.com/AXKuhta/RWKV-LM/tree/onnx

There's two things missing:

  • JavaScript implementation of the tokenizer
  • JavaScript implementation of sample_logits().

Running python -i -u export_onnx.py and then rnn_export() will export the model as rwkw.onnx, which can then be tested with test_onnx.py and loaded from index.html. The demo in index.html uses greedy sampling and you just sorta have to visit https://goose.ai/tokenizer in order to encode/decode the text. It works using the wasm backend, but unfortunately throws an error if you try the webgl backend.

AXKuhta avatar Aug 20 '22 08:08 AXKuhta

Exporting to ONNX is something that I've been tinkering with and I can report that the 169m RWKV-4 model does run in browser. Here's my code: https://github.com/AXKuhta/RWKV-LM/tree/onnx

Great work :)

Did you get this error with webgl? cannot resolve operator 'Max' with opsets: ai.onnx v13

You can remove RWKV_HEAD_QK and RWKV-ffnPre which are not required for Pile models, and probably that will fix it.

p.s. upgrade onnxruntime to latest version and then you can test CUDAExecutionProvider in python. I think you might be using an older onnxruntime because all new versions require explicitly setting providers when initializing InferenceSession().

BlinkDL avatar Aug 20 '22 14:08 BlinkDL

@AXKuhta Nice! I got a web demo going here (for 169m and 430m):

  • Demo: https://josephrocca.github.io/rwkv-v4-web/demo/
  • Code: https://github.com/josephrocca/rwkv-v4-web

But it seems like something is going wrong - the model isn't "coherent" in using the context. For example, if you prompt the 430m model with "The capital of France is" it continues with "first of the, the city of Paris". I checked that the tokenizer is working properly, so I think it's something to do with the inference / context-handling code.

Some other random notes:

  • The models were twice their size when porting to ONNX - e.g. 169m model goes from 339MB to 679MB. I quantized down to 171MB, but that makes inference half the speed (~5 tokens/sec for quantized versus ~13 tokens/sec for original). I'm guessing the non-quantized version have been converted from f16 to f32, hence the size doubling? The demo includes both the normal and quantized versions.

  • @BlinkDL Yes, I got TypeError: cannot resolve operator 'Max' with opsets: ai.onnx v13 when trying to use the WebGL backend. How would I go about removing RWKV_HEAD_QK and RWKV-ffnPre? I made a conversion notebook here: https://colab.research.google.com/github/josephrocca/rwkv-v4-web/blob/main/RWKV_v4_ONNX_conversion.ipynb Is it as simple as adding a few commands to that, or is there more work involved?

  • The WebGL backend doesn't work with quantized models. It gives this error: TypeError: cannot resolve operator 'DequantizeLinear' with opsets: ai.onnx v13, com.microsoft.experimental v1, ai.onnx.preview.training v1, ai.onnx.training v1, com.ms.internal.nhwc v17, org.pytorch.aten v1, com.microsoft.nchwc v1, ai.onnx.ml v3, com.microsoft v1

  • I used a very overkill approach to getting the tokenizer working... https://github.com/josephrocca/tokenizers-pyodide I haven't looked into how different the tokenizer is from gpt 2/3, but if it's similar, then I guess it shouldn't be too hard to make an edited version of this https://github.com/josephrocca/gpt-2-3-tokenizer ?

josephrocca avatar Aug 21 '22 05:08 josephrocca

For example, if you prompt the 430m model with "The capital of France is" it continues with "first of the, the city of Paris"

That seems familiar!

The => first
The capital => of
The capital of => the
The capital of France => ,
The capital of France is => the

It looks like you display the outputs during the prompt-feeding stage, which happens one token at a time.

That should fix it:

         let token = greedySampling(results.x.data);

         if (promptTokens.length == 0) {
+          if(streamingCallback) streamingCallback(token);
           ctx.push( token );
         } else {
           ctx.push( promptTokens.shift() );
         }
-
-        if(streamingCallback) streamingCallback(token);

         feeds.xx_att = results.xx_att_r;
         feeds.aa_att = results.aa_att_r;

AXKuhta avatar Aug 21 '22 07:08 AXKuhta

@josephrocca I had to host the demo locally because huggingface keeps terminating the model downloads for some reason, but otherwise I can confirm that it works on my machine. Good job with getting the tokenizer and the quantization working!

I'm guessing the non-quantized version have been converted from f16 to f32, hence the size doubling?

Yeah, that's what's happening. RWKV-v4 is bf16 which can't be losslessly converted to fp16, so fp32 is the next best option. The fp32-converted model also compresses really well since half the bytes in it are zero.

AXKuhta avatar Aug 21 '22 08:08 AXKuhta

  • @BlinkDL Yes, I got TypeError: cannot resolve operator 'Max' with opsets: ai.onnx v13 when trying to use the WebGL backend. How would I go about removing RWKV_HEAD_QK and RWKV-ffnPre? I made a conversion notebook here: https://colab.research.google.com/github/josephrocca/rwkv-v4-web/blob/main/RWKV_v4_ONNX_conversion.ipynb Is it as simple as adding a few commands to that, or is there more work involved?

take a look at src/model_run.py. for the pile model, self.model_type == 'RWKV' and RWKV_HEAD_QK_DIM = 0 so you can remove some useless code.

moreover, use https://github.com/daquexian/onnx-simplifier to optimizer the onnx model

BlinkDL avatar Aug 21 '22 08:08 BlinkDL

And the onnx version might work for AMD & Intel gpus. The DirectML backend supports them (on win10).

I tried that for RWKV-1.

BlinkDL avatar Aug 21 '22 08:08 BlinkDL

Yeah, that's what's happening. RWKV-v4 is bf16 which can't be losslessly converted to fp16, so fp32 is the next best option. The fp32-converted model also compresses really well since half the bytes in it are zero.

You can loseless "transform" bf16 to fp16, and the idea is to use the same raw binary value. The float value will be totally different, but you can do an inverse transform in JS to loselessly recover the original bf16.

BlinkDL avatar Aug 21 '22 08:08 BlinkDL

@AXKuhta Could have sworn I replied here earlier, sorry - apparently I didn't click send. I fixed the demo according to your comment soon after you posted it - thanks for your help!! Strange that huggingface is terminating the download for you... 🤔

@BlinkDL Thanks for the tips! I'll look into the stuff you've mentioned.

josephrocca avatar Aug 21 '22 08:08 josephrocca

Hi, really exciting project! I'm wondering if you've published the model conversion script that you used to create the js_models files from the .pth model file? It would be awesome to see how the larger and newer models like RWKV-4 169m/430m perform in the browser! I think the inference speed of RWKV opens up many new possibilities for language models on the web.

The python code for RWKV-2 weight conversion to .bin (for tf.js):

w = torch.load(MODEL_NAME + '.pth')
for x in w.keys():
	if 'copy_mask' in x: # this is for headQK which is not used in pile models
		continue
	print(x, w[x].shape)
	
	# we are doing some pre-computations here. change them to match RWKV-4. or you can just skip all of them and do everything in js first.
	if '.time_' in x: 
		w[x] = w[x].squeeze()
	if '.time_decay' in x:
		w[x] = torch.exp(-torch.exp(w[x]))
	if '.time_first' in x:
		w[x] = torch.exp(w[x])
	
	w[x].numpy().tofile(f'20220425/{x}.bin')

You can gradually port it to RWKV-4 by matching the outputs for each layer.

The Chinese RWKV-2 has a better UI: https://github.com/BlinkDL/AI-Writer/blob/main/docs/index.html

The English RWKV-2: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng

BlinkDL avatar Aug 21 '22 09:08 BlinkDL

@AXKuhta Could have sworn I replied here earlier, sorry - apparently I didn't click send. I fixed the demo according to your comment soon after you posted it

Add top-p top-k and temperature and then it's very usable :)

BlinkDL avatar Aug 21 '22 09:08 BlinkDL

It looks like the webgl backend has a lot of limitations. I did some testing by stripping out different parts of the model in order to see if I can get anything at all to work on the webgl backend. I think I got like four different error messages with different combinations. The bottom line is that I can't even get a matmul to work.

matmul

https://github.com/AXKuhta/RWKV-LM/blob/matmul/RWKV-v4/matmul_test.py https://github.com/AXKuhta/RWKV-LM/blob/matmul/RWKV-v4/index.html

It does work on the wasm backend!

EDIT: It actually works on webgl if you do this: https://github.com/AXKuhta/RWKV-LM/commit/75ad1609f3dfd6ad13d7333c459e9e75712432d2

AXKuhta avatar Aug 23 '22 05:08 AXKuhta

I have been able to force the full model to run on webgl, but it doesn't produce anything coherent, so something's still broken:

https://github.com/AXKuhta/RWKV-LM/tree/onnx_webgl

@BlinkDL The "cannot resolve operator 'Max' with opsets: ai.onnx v13" error was caused by torch.maximum(pp, ww) and I was able to suppress it by using torch.max(torch.stack([pp, ww]), 0).values instead. I also had to add a bunch of .view([768,1]) around matmul operations and then fix layer_norm() from producing NaNs. ~~Now it looks like self.FF() always produces zeroes on webgl, but I'm not sure yet as to why~~ nope, self.FF() does produce something.

AXKuhta avatar Aug 23 '22 12:08 AXKuhta

I have been able to force the full model to run on webgl, but it doesn't produce anything coherent, so something's still broken:

https://github.com/AXKuhta/RWKV-LM/tree/onnx_webgl

@BlinkDL The "cannot resolve operator 'Max' with opsets: ai.onnx v13" error was caused by torch.maximum(pp, ww) and I was able to suppress it by using torch.max(torch.stack([pp, ww]), 0).values instead. I also had to add a bunch of .view([768,1]) around matmul operations and then fix layer_norm() from producing NaNs. ~Now it looks like self.FF() always produces zeroes on webgl, but I'm not sure yet as to why~ nope, self.FF() does produce something.

That's great. Could you check whether https://github.com/daquexian/onnx-simplifier can help? Use https://github.com/lutzroeder/netron to visualize models.

And then you can print() the outputs of interesting layers to find the culprit... gradually matching the results of webgl vs wasm.

BlinkDL avatar Aug 23 '22 20:08 BlinkDL

@BlinkDL After some painstaking debugging I got it to produce coherent output on webgl. The fix was really bizarre: add + 0.0 in a bunch of places. Some nodes on the ONNX graph that follow matmul+reshape operations kept getting bugged inputs that looked like a single value across all 768 elements. Performing +0.0 with the bugged input fixes it.

Here's the changes: https://github.com/AXKuhta/RWKV-LM/commits/onnx_webgl

Could you check whether https://github.com/daquexian/onnx-simplifier can help?

I did try onnx-simplifier with RWKV-3, but it didn't find much to simplify. The graph was almost unchanged. I will retest with RWKV-4 though.

AXKuhta avatar Aug 24 '22 15:08 AXKuhta

@AXKuhta Nice! Can you upload the webgl-compatible 169m/430m models to hugging face so I can add them to the web demo?

Also, I wonder if the +0.0 bug is something that would be worth reporting to the ONNX runtime team?

josephrocca avatar Aug 24 '22 23:08 josephrocca

@josephrocca I think it's better to keep all the web models in one place so I made two PRs in your huggingface repository. Oh, and by the way, I also improved my initial index.html a little to not create new tensors inside the loop and to remove leading_pad(). I think you should integrate these changes into your demo too.

I ran some performance tests with the hardware that I have available:

All tests performed in Chromium
169m model

========= WASM =========
Intel Core i7 2760QM:			280ms per token
Intel Core i7 6650U:			204ms per token
AMD A10-7800:				331ms per token
Snapdragon 865:				233ms per token

========= WebGL =========
Intel Core i7 2760QM iGPU 		600ms per token
Nvidia GeForce 520MX 			305ms per token
Intel Core i7 6650U iGPU 		192ms per token
AMD A10-7800 iGPU 			232ms per token
Snapdragon 865 iGPU:			Produces NaNs

These numbers are not very impressive :joy_cat:

I didn't try in on a real GPU with a wide memory bus, but I suspect it won't perform massively better.

There are three different webgl bug reports to be made to onnxruntime:

  • Matmuls like [768, 768] @ [768] complain about dimension mismatch, must be converted to [768, 768] @ [768, 1]
  • NaNs produced by layer_norm() if there are negative inputs
  • This strange +0.0 stuff if I can reproduce it in a standalone fashion

AXKuhta avatar Aug 25 '22 14:08 AXKuhta

@AXKuhta Maybe there are some hidden bottlenecks :) Check the time consumption of all major functions and code fragments.

BlinkDL avatar Aug 25 '22 21:08 BlinkDL

@AXKuhta Thanks! Great work. I've always struggled with the WebGL backend - I'm guessing that it doesn't get as much attention as wasm because it isn't a port of C++, but must be written from scratch IIUC. I'm hoping that WebGPU will change that situation and we'll get really serious GPU ML on the web.

Another factor RE performance could be relevant here is that wasm can just be faster for some models, but I'd have thought that this would only be the case for models that are very small. Some discussion in this article about tf.js: https://blog.tensorflow.org/2020/09/supercharging-tensorflowjs-webassembly.html

image

josephrocca avatar Aug 26 '22 09:08 josephrocca

@BlinkDL The final [768, 50277] matmul is the slowest component. It's almost as slow as the entire model on WASM, which is kind of surprising, considering that GPUs are supposed to be good at matmul. It may be caused by the fact that it doesn't fit under the texture size limit of 16384 so onnxruntime does some magic to remap it into a 6214x6214 texture instead, possibly making it slow.

Gradually removing parts of the model until there is nothing left except input->output passthrough
Nvidia GeForce 520MX
169m model

Baseline full model			344ms		N/A
Removed state store/restore		326ms		-18ms
Removed final matmul 			145ms 		-181ms
Removed 12 x self.FF() 			60ms 		-85ms
Removed 12 x self.SA() 			30ms 		-30ms
Removed 26 x self.LN() 			16ms 		-14ms
Removed w.emb.weight[ctx[-1]] 		0.7ms 		-15.3ms

@josephrocca Yeah, I think it's better to wait for WebGPU instead of pursuing WebGL any further. It seems to work well for graphics, but not so much for compute.

AXKuhta avatar Aug 26 '22 09:08 AXKuhta

@BlinkDL The final [768, 50277] matmul is the slowest component. It's almost as slow as the entire model on WASM, which is kind of surprising, considering that GPUs are supposed to be good at matmul. It may be caused by the fact that it doesn't fit under the texture size limit of 16384 so onnxruntime does some magic to remap it into a 6214x6214 texture instead, possibly making it slow.

Probably can try tf.js for the final matmul and see if its faster

BlinkDL avatar Aug 26 '22 13:08 BlinkDL

@AXKuhta @josephrocca And actually you can skip the final matmul when scanning the prompt (because we just need the hidden states).

I will provide some more efficient code soon to quickly generate the initial hidden states from prompt.

BlinkDL avatar Aug 26 '22 17:08 BlinkDL

Oh and please check the speed of onnxruntime in pytorch :) I wonder if it will be faster.

You can actually install pytorch in Android too.

BlinkDL avatar Aug 26 '22 17:08 BlinkDL

And actually you can skip the final matmul when scanning the prompt

@BlinkDL Ooh, somehow I didn't think of that before!

There is a "only_execute_path_to_fetches" switch in onnxruntime that can be used to make this work even with existing .onnx files. It looks like they forgot to expose it to JavaScript, so I had to make a custom build of ort-wasm-simd.wasm with that flag toggled in the source. I found that it actually works:

Intel Core i7 2760QM
169m model
WASM only_execute_path_to_fetches = true

Don't want the x output 	158ms per token
Want the x output 		258ms per token

I put the custom-built ort-wasm-simd.wasm and the index.html updated with fetches logic here if anyone wants to try this too.

I think it should be possible to pack both the RNN-style model and the GPT-style model into a single .onnx graph. Since the weights are shared between the two, there would only be a minimal increase in file size. I'll wait for the new GPT code (The current one doesn't run without CUDA btw).

AXKuhta avatar Aug 27 '22 14:08 AXKuhta

Oh and please check the speed of onnxruntime in pytorch

Here's some performance numbers for RWKV-4 with pytorch and native onnxruntime:

Native pytorch + onnxruntime
169m model

Intel Core i7 2760QM 	Pytorch 	79.3 ms/token
Intel Core i7 2760QM 	ONNX 		152 ms/token 	Note: ONNX forced to use 8 threads to hit full CPU utilization

Intel Core i7 6650U 	Pytorch 	62.1 ms/token
Intel Core i7 6650U 	ONNX 		129 ms/token 	Note: ONNX forced to use 4 threads to hit full CPU utilization

Snapdragon 865 		Pytorch 	71.0 ms/token
Snapdragon 865		ONNX 		180ms/token

But I think I made a bit of a mistake by not excluding sample_logits() from the pytorch version. It seems to take somewhere about ~10ms too. I need to rerun those tests with more caution.

EDIT: I totally forgot that my test_onnx.py had sample_logits() too, so these comparisons are fair after all.

AXKuhta avatar Aug 27 '22 15:08 AXKuhta

Finally tested the webgl backend on a real GPU:

GTX 1060 6GB
webgl

169m model		68.6 ms/token
430m model 		119 ms/token

As seen above, the 430m model also works on webgl now. It turns out my state store/restore code was breaking it: with a 24 layer model, it would attempt to stack 24 tensors at once, which would exceed the 16 input textures limit in WebGL. I worked around this by stacking 12 tensors at a time, twice, then using torch.cat() to glue two stacks.

The stacking code can be removed, but then the 430m model will have 120 individual inputs/outputs for state, which sound scary.

I guess this kind of vindicates the webgl backend? It does outperform wasm when used on a real GPU, and it can also run the non-quantized 430m model, while wasm can't. Of course, it is still significantly slower than native.

@josephrocca I opened two new PRs in your huggingface repo, one with the updated 430m webgl model and the other removing the outdated model.

AXKuhta avatar Aug 28 '22 05:08 AXKuhta

@AXKuhta Thanks! I've accepted the pull request and updated the demo.

it can also run the non-quantized 430m model, while wasm can't

Note that the wasm runtime should be able to run the non-quantized, 1.7GB model with no problems if it had enough memory available. There's currently an arbitrary 2GB limit that needs to be raised: https://github.com/microsoft/onnxruntime/issues/10957#issuecomment-1074397486

The memory limits should be gone completely once we get Memory64: https://github.com/WebAssembly/memory64

josephrocca avatar Aug 29 '22 02:08 josephrocca

The memory limits should be gone completely once we get Memory64

So there is work ongoing to lift that limit. That's good to know :+1:

AXKuhta avatar Aug 29 '22 14:08 AXKuhta

@AXKuhta Thanks! I've accepted the pull request and updated the demo.

Please try the raw binary BF16 trick too :) https://github.com/BlinkDL/RWKV-LM/issues/7#issuecomment-1221499836

And please show the progress (1/32 etc.) in the webpage

BlinkDL avatar Aug 30 '22 22:08 BlinkDL

@AXKuhta Another idea: the w.emb.weight shall be a simple Float32Array on CPU.

BlinkDL avatar Sep 02 '22 11:09 BlinkDL