slow inference with very long input embeddings
I've been able to run a large multimodal model with your library but this means I have easily very large tensors as input (148x4096 for example). this led me to find out that the inference time is proportional to the input length but why is it? I think that in the ideal case, the bottleneck here should definitely be the memory bandwidth as for every tensor multiplication the cpu needs to retrieve a large chunk of weights data from memory and then multiply it very fast in the upper cache levels and all of this shouldn't exceed the L3 cache size for efficiency but I'm no expert in tensor multiplication. for example, in pytorch the inference time is only slightly slower with larger chunks of data but it doesn't scale linearly. is there something that can be done to improve cache locality for very long input embeddings?
for context, the model is GPT neo with FP16 weights and custom adapters. I ran it upscaled to fp32 on pytorch and it takes a few seconds for the first inference with the large input embeddings but it takes around a minute on my i7 8750H for ggml for the same first step!
I think this behavior can be easily replicated on gpt-j with a very long text input, around 150 tokens.
this led me to find out that the inference time is proportional to the input length but why is it?
because transformers like GPT-Neo and GPT-J have to run for every token of the input. if you want a model that does not do this, you will have to use an RNN like RWKV (ggml version here) which is constant-time per token