airllm icon indicating copy to clipboard operation
airllm copied to clipboard

For me this model is extremely underperforming

Open SadafShafi opened this issue 1 year ago • 1 comments

this is my prompt "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n ### Instruction:\n do whatever is given in the input \n\n### Input:\n write an essay on intelligence \n\n### Response: "

and this is the response: "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n ### Instruction:\n do whatever is given in the input \n\n### Input:\n write an essay on intelligence \n\n### Response: 1000 words\nLong Answer: Write an essay on intelligence.\nGold Document ID: 10"

I cannot imagine how much efforts i put into dockerising this project and I am getting insanely poor results what might I be possibly doing wrong ? all the help is appreciated

the following is my code

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


class AnimaChat:

    def __init__(self,max_new_tokens=300):
        print("initializing ################################")
        base_model = "lyogavin/Anima-7B-100K"
        self.tokenizer = AutoTokenizer.from_pretrained(base_model)
        self.model = AutoModelForCausalLM.from_pretrained(
                    base_model,
                    torch_dtype=torch.float16,
                    trust_remote_code=True,
                    device_map="auto", 
                )
        self.model.eval()
        self.max_new_tokens = max_new_tokens

    def changeParams(self,max_new_tokens):
        self.max_new_tokens = max_new_tokens

    def chat(self,prompt):
        #prompt = "Where is the capital of US?"
        inputs = self.tokenizer(prompt, return_tensors="pt")

        inputs['input_ids'] = inputs['input_ids'].cuda()
        inputs['attention_mask'] = inputs['attention_mask'].cuda()

        # Generate
        generate_ids = self.model.generate(**inputs, max_new_tokens=800,
                            only_last_logit=True, # to save memory
                            use_cache=False, # when run into OOM, enable this can save memory
                            xentropy=True)
        output = self.tokenizer.batch_decode(generate_ids, 
                                        skip_special_tokens=True,
                                        clean_up_tokenization_spaces=False)[0]

        return output

nvcc and nvidia-smi output

root@5d6d2eed64f8:/app# nvcc --version 
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0

root@5d6d2eed64f8:/app# nvidia-smi
Thu Jan 18 12:24:27 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA RTX A5000               Off | 00000000:73:00.0 Off |                  Off |
| 30%   34C    P8              17W / 230W |      1MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

below are the libraries installed

root@5d6d2eed64f8:/app# pip freeze 
/usr/bin/pip:6: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import load_entry_point
accelerate==0.26.0
aiohttp==3.9.1
aiosignal==1.3.1
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.2.0
blinker==1.7.0
certifi==2023.11.17
charset-normalizer==3.3.2
click==8.1.7
datasets==2.16.1
dill==0.3.7
docker-pycreds==0.4.0
einops==0.6.1
evaluate==0.4.0
filelock==3.13.1
flash-attn==2.4.2
flask==3.0.0
Flask-Cors==4.0.0
frozenlist==1.4.1
fsspec==2023.12.2
gitdb==4.0.11
GitPython==3.1.41
huggingface-hub==0.20.2
idna==3.6
importlib-metadata==7.0.1
itsdangerous==2.1.2
Jinja2==3.1.3
joblib==1.3.2
MarkupSafe==2.1.3
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
networkx==3.1
ninja==1.11.1.1
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
packaging==23.2
pandas==2.0.3
pathtools==0.1.2
peft==0.4.0
pillow==10.2.0
protobuf==4.25.2
psutil==5.9.7
pyarrow==14.0.2
pyarrow-hotfix==0.6
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
responses==0.18.0
rotary-emb==0.1
safetensors==0.4.1
scikit-learn==1.2.2
scipy==1.10.1
sentencepiece==0.1.99
sentry-sdk==1.39.2
setproctitle==1.3.3
six==1.16.0
smmap==5.0.1
sympy==1.12
threadpoolctl==3.2.0
tokenizers==0.13.3
torch==2.1.2
torchaudio==2.1.2
torchvision==0.16.2
tqdm==4.66.1
transformers==4.31.0
triton==2.1.0
typing-extensions==4.9.0
tzdata==2023.4
urllib3==2.1.0
wandb==0.15.3
werkzeug==3.0.1
xentropy-cuda-lib==0.1
xxhash==3.4.1
yarl==1.9.4
zipp==3.17.0

SadafShafi avatar Jan 18 '24 11:01 SadafShafi

I had the same problem.

andyli386 avatar Mar 18 '24 06:03 andyli386