ColossalAI
ColossalAI copied to clipboard
[BUG]: Assertion error when forward passing `nn.Embedding` without gradient.
🐛 Describe the bug
I have a freezed embedding table, i.e. all parameters in the table do not require gradients. When using this embedding table in the forward pass in training, there is an AssertionError
raised. (Note that there are other parameters in the model that require gradients.)
Here is a minimal script to reproduce:
>>> test.py
import torch
import torch.nn as nn
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.nn.optimizer import HybridAdam
colossalai.launch_from_torch({})
plugin = GeminiPlugin(precision="bf16", initial_scale=2**16)
booster = Booster(plugin=plugin)
class Model(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.embedding = nn.Embedding(100, 1024)
self.embedding.requires_grad_(False)
self.linear = nn.Linear(1024,1024)
def forward(self, x):
embed = self.embedding(x)
transform = self.linear(embed)
loss = (transform ** 2).sum()
return loss
model = Model()
optimizer = HybridAdam(model.parameters(), lr=5e-5, betas=(0.9, 0.999), weight_decay=0)
model, optimizer = booster.boost(model, optimizer)[:2]
inputs = torch.tensor([1,2,3], device="cuda")
loss = model(inputs)
booster.backward(loss, optimizer)
Run with torchrun --nproc_per_node 4 test.py
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- NVIDIA A100-SXM4-40GB
- available: True
- version: 11.8
- GPU:
- Lightning:
- torch: 2.1.0
- torch-scatter: 2.1.2
- torchvision: 0.16.0
- Packages:
- accelerate: 0.23.0
- aiohttp: 3.8.6
- aiosignal: 1.3.1
- anyio: 4.1.0
- asttokens: 2.4.0
- async-timeout: 4.0.3
- attrs: 23.1.0
- backcall: 0.2.0
- bcrypt: 4.1.1
- beautifulsoup4: 4.12.2
- blis: 0.7.11
- cachetools: 5.3.1
- catalogue: 2.0.10
- certifi: 2023.7.22
- cffi: 1.16.0
- cfgv: 3.4.0
- charset-normalizer: 3.3.0
- click: 8.1.7
- cloudpathlib: 0.16.0
- coloredlogs: 15.0.1
- colossalai: 0.3.4
- comm: 0.1.4
- confection: 0.1.4
- contexttimer: 0.3.3
- contourpy: 1.1.1
- cryptography: 41.0.7
- cycler: 0.12.1
- cymem: 2.0.8
- cython: 3.0.6
- datasets: 2.14.5
- debugpy: 1.8.0
- decorator: 5.1.1
- deepspeed: 0.11.1
- deprecated: 1.2.14
- dill: 0.3.7
- distlib: 0.3.8
- distro: 1.8.0
- einops: 0.7.0
- exceptiongroup: 1.1.3
- executing: 2.0.0
- fabric: 3.2.2
- faiss: 1.7.2
- filelock: 3.13.1
- flagembedding: 1.1.5
- flash-attn: 2.3.4
- flatbuffers: 23.5.26
- fonttools: 4.43.1
- frozenlist: 1.4.0
- fsspec: 2023.6.0
- fuzzywuzzy: 0.18.0
- gmpy2: 2.1.2
- google: 3.0.0
- h11: 0.14.0
- hjson: 3.1.0
- httpcore: 1.0.2
- httpx: 0.25.2
- huggingface-hub: 0.17.3
- humanfriendly: 10.0
- identify: 2.5.33
- idna: 3.4
- instructorembedding: 1.0.1
- invoke: 2.2.0
- ipykernel: 6.25.2
- ipython: 8.16.1
- ipywidgets: 8.1.1
- jedi: 0.19.1
- jieba: 0.42.1
- jinja2: 3.1.2
- joblib: 1.3.2
- jsonschema: 4.20.0
- jsonschema-specifications: 2023.11.2
- jupyter-client: 8.4.0
- jupyter-core: 5.4.0
- jupyterlab-widgets: 3.0.9
- keybert: 0.8.3
- kiwisolver: 1.4.5
- langcodes: 3.3.0
- levenshtein: 0.23.0
- lightgbm: 4.1.0
- marisa-trie: 1.1.0
- markdown-it-py: 3.0.0
- markupsafe: 2.1.1
- matplotlib: 3.8.0
- matplotlib-inline: 0.1.6
- mdurl: 0.1.2
- mpmath: 1.3.0
- msgpack: 1.0.7
- multidict: 6.0.4
- multiprocess: 0.70.15
- murmurhash: 1.0.10
- nest-asyncio: 1.5.8
- networkx: 3.1
- ninja: 1.11.1.1
- nltk: 3.8.1
- nmslib: 2.1.1
- nodeenv: 1.8.0
- numpy: 1.26.1
- nvidia-ml-py: 12.535.108
- nvitop: 1.3.1
- onnxruntime: 1.16.3
- openai: 1.3.9
- packaging: 23.2
- pandas: 2.1.1
- paramiko: 3.3.1
- parso: 0.8.3
- peft: 0.6.1
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 10.1.0
- pip: 23.3
- platformdirs: 3.11.0
- pre-commit: 3.6.0
- preshed: 3.0.9
- prompt-toolkit: 3.0.39
- protobuf: 4.25.0
- psutil: 5.9.6
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py-cpuinfo: 9.0.0
- pyarrow: 13.0.0
- pybind11: 2.6.1
- pycparser: 2.21
- pydantic: 1.10.13
- pygments: 2.16.1
- pyjnius: 1.6.1
- pynacl: 1.5.0
- pyparsing: 3.1.1
- python-dateutil: 2.8.2
- python-levenshtein: 0.23.0
- pytrec-eval: 0.5
- pytz: 2023.3.post1
- pyyaml: 6.0
- pyzmq: 25.1.1
- rapidfuzz: 3.5.1
- ray: 2.8.1
- referencing: 0.32.0
- regex: 2023.10.3
- requests: 2.31.0
- rich: 13.6.0
- rouge: 1.0.1
- rpds-py: 0.13.2
- safetensors: 0.4.0
- scikit-learn: 1.3.1
- scipy: 1.11.3
- seaborn: 0.13.0
- sentence-transformers: 2.2.2
- sentencepiece: 0.1.99
- setuptools: 68.0.0
- six: 1.16.0
- smart-open: 6.4.0
- sniffio: 1.3.0
- soupsieve: 2.5
- spacy: 3.7.2
- spacy-legacy: 3.0.12
- spacy-loggers: 1.0.5
- srsly: 2.4.8
- stack-data: 0.6.3
- sympy: 1.11.1
- termcolor: 2.3.0
- thinc: 8.2.1
- threadpoolctl: 3.2.0
- tiktoken: 0.5.2
- tokenizers: 0.14.1
- torch: 2.1.0
- torch-scatter: 2.1.2
- torchvision: 0.16.0
- tornado: 6.3.3
- tqdm: 4.66.1
- traitlets: 5.11.2
- transformers: 4.34.1
- triton: 2.1.0
- typer: 0.9.0
- typing-extensions: 4.7.1
- tzdata: 2023.3
- urllib3: 2.0.7
- virtualenv: 20.25.0
- wasabi: 1.1.2
- wcwidth: 0.2.8
- weasel: 0.3.4
- wheel: 0.41.2
- widgetsnbextension: 4.0.9
- wrapt: 1.16.0
- xformers: 0.0.22.post7
- xxhash: 3.4.1
- yarl: 1.9.2
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.13
- release: 5.4.0-147-generic
- version: #164-Ubuntu SMP Tue Mar 21 14:23:17 UTC 2023
Hi, this seems to be incompatible with the Gemini strategy. Indeed, encountering this issue with a frozen embedding table.
Okay thank you.
Are there any plans to fix it?