H2O
H2O copied to clipboard
KeyError: 'Cache only has 0 layers, attempted to access layer with index 0'
How do I reproduce?
import torch
from transformers import AutoModelForCausalLM, LlamaForCausalLM, AutoTokenizer, AutoConfig, LlamaTokenizer, LlamaConfig
from transformers.modeling_utils import load_sharded_checkpoint
from accelerate import init_empty_weights, load_checkpoint_in_model, load_checkpoint_and_dispatch
from utils_hh.modify_llama import convert_kvcache_llama_heavy_recent, LlamaAttention_heavy_hitter
ENABLE_Heavy_Hitter_FUNCTIONS = {
"llama": convert_kvcache_llama_heavy_recent,
}
model_name = 'meta-llama/Llama-2-7b-hf'
cache_dir = 'checkpoint/models--meta-llama--Llama-2-7b-hf/snapshots/8a0442e81540efaeb1a0fe3e95477b5e0edfd423'
heavy_ratio = 0.1
recent_ratio = 0.1
length = 64
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: {}".format(device))
tokenizer = LlamaTokenizer.from_pretrained(model_name)
config = LlamaConfig.from_pretrained(model_name)
config.heavy_ratio = heavy_ratio
config.recent_ratio = recent_ratio
# load model without weight
with init_empty_weights():
model = LlamaForCausalLM(config)
model = convert_kvcache_llama_heavy_recent(model, config)
print(model)
# load checkpoint into the empty weight
model = load_checkpoint_and_dispatch(
model,
checkpoint=cache_dir,
device_map="auto",
offload_folder=cache_dir,
dtype=torch.float16,
offload_state_dict=True,
)
prompt_text = 'Hello.'
input_ids = tokenizer(prompt_text, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device)
generate_ids_hh = model.generate(
input_ids,
max_new_tokens=100,
use_cache=True,
) # error here
using transformers==4.33.0 solved it.
Hi, @Kyriection, as I mentioned, one workaround to avoid this error is to downgrade transformers to something < 3.36.
BUT, I am trying to test H2O on some of the of the latest LLM. For that, I need to use an upgraded version of transformer (>=4.37.2) Do you have any suggestion?
the error is causing here in the forward function of H2OLlamaAttention_streaming --
# remake causal mask
attention_mask = _make_causal_mask(
bsz=bsz,
tgt_len=q_len,
past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
dtype=query_states.dtype,
device=query_states.device,
)
Hi @hasanibnarif, Huggingface update their cache implementation since version 3.36. Previously the past_key_value are a list of tensors that contain key and value embeddings while now they use a cache instance to maintain the kv cache. The definition of kv cache is located in https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py#L76.
I have a initial version of the new h2o kv cache implementation based on the cache class (https://github.com/Kyriection/llama-recipes/blob/main/research/long-context-llama/H2O/utils/cache.py#L342), Please note that this version is still under developed and I will release it once finished.