torchchat icon indicating copy to clipboard operation
torchchat copied to clipboard

Executorch Slows Down after second or third response in Torchchat

Open infil00p opened this issue 1 year ago • 8 comments

🐛 Describe the bug

As discussed earlier in https://github.com/pytorch/executorch/issues/3674, to increase the size of max_seq_len we have to both increase it in the export scripts as well as bump up the hardcoded max_seq_len in runner.cpp. We're using Executorch in our Proof-of-Concept demo that we're looking to release at NeurIPS and we discovered this bug when using the LlamaRunner with Ktor. We also notice it with torchchat, BUT since Torchchat is local, it won't just time out if Llama fails to generate in time.

Step 1. Update the export.py, as done on this forked repo here: https://github.com/baseweight/torchchat/blob/hardcoded_default/torchchat/export.py#L393 Step 2. Update the runner, as done on this forked repo here: https://github.com/baseweight/executorch/blob/baseweight_demo/examples/models/llama/runner/runner.cpp#L53 Step 3. Follow the instructions to export the model and build the AAR. I used Llama-3.2-3b-instruct, since it produces actual good demo results about Vancouver (because NeurIPS) Step 4. Copy the model onto a phone and load in torchchat. I used a Pixel 9 running Android 15, but I also confirmed this on a OnePlus 12R

Step 4. Type a prompt (i.e. "Tell me about Nardwuar") Step 5. Type a follow up prompt (i.e. "And the Evaporators?") Step 6. Attempt to type another follow up prompt.

It seems that this MIGHT be the limit for actual chat on an Android phone on Executorch, since the device starts to overheat. Maybe it's not the case and I'm just missing something?

Versions

Here's the info from my Gaming PC that I'm using to build Executorch. I have a conda environment setup for this.

Collecting environment information... PyTorch version: 2.6.0.dev20241007+cpu Is debug build: False CUDA used to build PyTorch: Could not collect ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64) GCC version: (Ubuntu 13.2.0-23ubuntu4) 13.2.0 Clang version: Could not collect CMake version: version 3.30.5 Libc version: glibc-2.39

Python version: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct 4 2024, 13:27:36) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-6.8.0-48-generic-x86_64-with-glibc2.39 Is CUDA available: False CUDA runtime version: 12.6.77 CUDA_MODULE_LOADING set to: N/A GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070 Ti SUPER Nvidia driver version: 555.58.02 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 16 On-line CPU(s) list: 0-15 Vendor ID: AuthenticAMD Model name: AMD Ryzen 7 7800X3D 8-Core Processor CPU family: 25 Model: 97 Thread(s) per core: 2 Core(s) per socket: 8 Socket(s): 1 Stepping: 2 CPU(s) scaling MHz: 52% CPU max MHz: 5050.0000 CPU min MHz: 545.0000 BogoMIPS: 8383.77 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d Virtualization: AMD-V L1d cache: 256 KiB (8 instances) L1i cache: 256 KiB (8 instances) L2 cache: 8 MiB (8 instances) L3 cache: 96 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-15 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; Safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] executorch==0.5.0a0+72b3bb3 [pip3] flake8==6.0.0 [pip3] flake8-breakpoint==1.1.0 [pip3] flake8-bugbear==23.6.5 [pip3] flake8-comprehensions==3.12.0 [pip3] flake8-plugin-utils==1.3.3 [pip3] flake8-pyi==23.5.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.4 [pip3] nvidia-cublas-cu12==12.1.3.1 [pip3] nvidia-cuda-cupti-cu12==12.1.105 [pip3] nvidia-cuda-nvrtc-cu12==12.1.105 [pip3] nvidia-cuda-runtime-cu12==12.1.105 [pip3] nvidia-cudnn-cu12==9.1.0.70 [pip3] nvidia-cufft-cu12==11.0.2.54 [pip3] nvidia-curand-cu12==10.3.2.106 [pip3] nvidia-cusolver-cu12==11.4.5.107 [pip3] nvidia-cusparse-cu12==12.1.0.106 [pip3] nvidia-nccl-cu12==2.21.5 [pip3] nvidia-nvjitlink-cu12==12.4.127 [pip3] nvidia-nvtx-cu12==12.1.105 [pip3] pytorch-triton==3.1.0+cf34004b8a [pip3] torch==2.6.0.dev20241007+cpu [pip3] torchao==0.5.0 [pip3] torchaudio==2.5.0.dev20241007+cpu [pip3] torchsr==1.0.4 [pip3] torchtune==0.4.0.dev20241010+cu121 [pip3] torchvision==0.20.0.dev20241007+cpu [conda] executorch 0.5.0a0+aa67cd9 pypi_0 pypi [conda] numpy 2.0.2 pypi_0 pypi [conda] torch 2.6.0.dev20241112+cpu pypi_0 pypi [conda] torch-stoi 0.2.3 pypi_0 pypi [conda] torchaudio 2.5.0.dev20241112+cpu pypi_0 pypi [conda] torchgen 0.0.1 pypi_0 pypi [conda] torchsr 1.0.4 pypi_0 pypi [conda] torchvision 0.20.0.dev20241112+cpu pypi_0 pypi

infil00p avatar Dec 04 '24 09:12 infil00p

I transferred this over to the torchchat repo since it seems TC-related on its surface

dbort avatar Dec 04 '24 22:12 dbort

I havent looked at the ET repos runner in a while, but do our apps actually have a chat function or is it just calling generate each time and having to repopulate the cache with every new message? I remember having to fix that issue in the torchchat cli chat command.

edit:

looks like jni has this for the multimodal runner https://github.com/baseweight/executorch/blob/baseweight_demo/extension/android/jni/jni_layer_llama.cpp#L290

now gonna try and see if thats the same runner used for text only. The runner.cpp you linked above doesnt have a way to generate at start_pos > 0 which is why I'm concerned

JacobSzwejbka avatar Dec 04 '24 22:12 JacobSzwejbka

Thanks for spinning this up @infil00p. Did you get a chance to test out exporting with the ET script btw?

The export in TC is based on that of ET, so my gut says either: (a) exporting in ET is bugged (and torchchat by extension) => Fix in ET and port to TC (b) exporting in ET works, but fails in TC => There's a bug in TC

cc: @kirklandsign

Jack-Khuu avatar Dec 04 '24 22:12 Jack-Khuu

Ok yeah it looks like the demo app effectively starts from scratch every chat message and treats the entire chat history as a new context from zero instead of just prefilling the new user message from a start_pos == length of chat history so far. https://github.com/baseweight/executorch/blob/baseweight_demo/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java#L705 This would be the first thing someone should fix probably.

JacobSzwejbka avatar Dec 04 '24 22:12 JacobSzwejbka

@ infil00p

If you run the model with generate instead of chat do you still hit the same performance throttling?

Is generate(4096) significantly faster then N chats summing up to 4096?

JacobSzwejbka avatar Dec 04 '24 22:12 JacobSzwejbka

Yes, I'm actually experiencing this in our own app which just calls generate every time. I definitely think the lack of generate_from_pos() on the runner is probably the issue here. I haven't tested this with a multimodal yet to confirm.

infil00p avatar Dec 08 '24 19:12 infil00p

@kirklandsign Would you be the right person to do this either on the ET or TC side?

(or if @infil00p figures it out we'd love the contribution on ExecuTorch/torchchat)

Jack-Khuu avatar Dec 09 '24 17:12 Jack-Khuu

Track by https://github.com/pytorch/executorch/issues/8290

kirklandsign avatar Feb 27 '25 19:02 kirklandsign