ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

[BUG]: Assertion error when forward passing `nn.Embedding` without gradient.

Open namespace-Pt opened this issue 1 year ago • 3 comments

🐛 Describe the bug

I have a freezed embedding table, i.e. all parameters in the table do not require gradients. When using this embedding table in the forward pass in training, there is an AssertionError raised. (Note that there are other parameters in the model that require gradients.)

Here is a minimal script to reproduce:

>>> test.py

import torch
import torch.nn as nn
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.nn.optimizer import HybridAdam

colossalai.launch_from_torch({})

plugin = GeminiPlugin(precision="bf16", initial_scale=2**16)
booster = Booster(plugin=plugin)

class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embedding = nn.Embedding(100, 1024)
        self.embedding.requires_grad_(False)
        
        self.linear = nn.Linear(1024,1024)
    
    def forward(self, x):
        embed = self.embedding(x)
        transform = self.linear(embed)
        loss = (transform ** 2).sum()
        return loss
    
model = Model()
optimizer = HybridAdam(model.parameters(), lr=5e-5, betas=(0.9, 0.999), weight_decay=0)
model, optimizer = booster.boost(model, optimizer)[:2]

inputs = torch.tensor([1,2,3], device="cuda")
loss = model(inputs)
booster.backward(loss, optimizer)

Run with torchrun --nproc_per_node 4 test.py

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
      • NVIDIA A100-SXM4-40GB
    • available: True
    • version: 11.8
  • Lightning:
    • torch: 2.1.0
    • torch-scatter: 2.1.2
    • torchvision: 0.16.0
  • Packages:
    • accelerate: 0.23.0
    • aiohttp: 3.8.6
    • aiosignal: 1.3.1
    • anyio: 4.1.0
    • asttokens: 2.4.0
    • async-timeout: 4.0.3
    • attrs: 23.1.0
    • backcall: 0.2.0
    • bcrypt: 4.1.1
    • beautifulsoup4: 4.12.2
    • blis: 0.7.11
    • cachetools: 5.3.1
    • catalogue: 2.0.10
    • certifi: 2023.7.22
    • cffi: 1.16.0
    • cfgv: 3.4.0
    • charset-normalizer: 3.3.0
    • click: 8.1.7
    • cloudpathlib: 0.16.0
    • coloredlogs: 15.0.1
    • colossalai: 0.3.4
    • comm: 0.1.4
    • confection: 0.1.4
    • contexttimer: 0.3.3
    • contourpy: 1.1.1
    • cryptography: 41.0.7
    • cycler: 0.12.1
    • cymem: 2.0.8
    • cython: 3.0.6
    • datasets: 2.14.5
    • debugpy: 1.8.0
    • decorator: 5.1.1
    • deepspeed: 0.11.1
    • deprecated: 1.2.14
    • dill: 0.3.7
    • distlib: 0.3.8
    • distro: 1.8.0
    • einops: 0.7.0
    • exceptiongroup: 1.1.3
    • executing: 2.0.0
    • fabric: 3.2.2
    • faiss: 1.7.2
    • filelock: 3.13.1
    • flagembedding: 1.1.5
    • flash-attn: 2.3.4
    • flatbuffers: 23.5.26
    • fonttools: 4.43.1
    • frozenlist: 1.4.0
    • fsspec: 2023.6.0
    • fuzzywuzzy: 0.18.0
    • gmpy2: 2.1.2
    • google: 3.0.0
    • h11: 0.14.0
    • hjson: 3.1.0
    • httpcore: 1.0.2
    • httpx: 0.25.2
    • huggingface-hub: 0.17.3
    • humanfriendly: 10.0
    • identify: 2.5.33
    • idna: 3.4
    • instructorembedding: 1.0.1
    • invoke: 2.2.0
    • ipykernel: 6.25.2
    • ipython: 8.16.1
    • ipywidgets: 8.1.1
    • jedi: 0.19.1
    • jieba: 0.42.1
    • jinja2: 3.1.2
    • joblib: 1.3.2
    • jsonschema: 4.20.0
    • jsonschema-specifications: 2023.11.2
    • jupyter-client: 8.4.0
    • jupyter-core: 5.4.0
    • jupyterlab-widgets: 3.0.9
    • keybert: 0.8.3
    • kiwisolver: 1.4.5
    • langcodes: 3.3.0
    • levenshtein: 0.23.0
    • lightgbm: 4.1.0
    • marisa-trie: 1.1.0
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.1
    • matplotlib: 3.8.0
    • matplotlib-inline: 0.1.6
    • mdurl: 0.1.2
    • mpmath: 1.3.0
    • msgpack: 1.0.7
    • multidict: 6.0.4
    • multiprocess: 0.70.15
    • murmurhash: 1.0.10
    • nest-asyncio: 1.5.8
    • networkx: 3.1
    • ninja: 1.11.1.1
    • nltk: 3.8.1
    • nmslib: 2.1.1
    • nodeenv: 1.8.0
    • numpy: 1.26.1
    • nvidia-ml-py: 12.535.108
    • nvitop: 1.3.1
    • onnxruntime: 1.16.3
    • openai: 1.3.9
    • packaging: 23.2
    • pandas: 2.1.1
    • paramiko: 3.3.1
    • parso: 0.8.3
    • peft: 0.6.1
    • pexpect: 4.8.0
    • pickleshare: 0.7.5
    • pillow: 10.1.0
    • pip: 23.3
    • platformdirs: 3.11.0
    • pre-commit: 3.6.0
    • preshed: 3.0.9
    • prompt-toolkit: 3.0.39
    • protobuf: 4.25.0
    • psutil: 5.9.6
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • py-cpuinfo: 9.0.0
    • pyarrow: 13.0.0
    • pybind11: 2.6.1
    • pycparser: 2.21
    • pydantic: 1.10.13
    • pygments: 2.16.1
    • pyjnius: 1.6.1
    • pynacl: 1.5.0
    • pyparsing: 3.1.1
    • python-dateutil: 2.8.2
    • python-levenshtein: 0.23.0
    • pytrec-eval: 0.5
    • pytz: 2023.3.post1
    • pyyaml: 6.0
    • pyzmq: 25.1.1
    • rapidfuzz: 3.5.1
    • ray: 2.8.1
    • referencing: 0.32.0
    • regex: 2023.10.3
    • requests: 2.31.0
    • rich: 13.6.0
    • rouge: 1.0.1
    • rpds-py: 0.13.2
    • safetensors: 0.4.0
    • scikit-learn: 1.3.1
    • scipy: 1.11.3
    • seaborn: 0.13.0
    • sentence-transformers: 2.2.2
    • sentencepiece: 0.1.99
    • setuptools: 68.0.0
    • six: 1.16.0
    • smart-open: 6.4.0
    • sniffio: 1.3.0
    • soupsieve: 2.5
    • spacy: 3.7.2
    • spacy-legacy: 3.0.12
    • spacy-loggers: 1.0.5
    • srsly: 2.4.8
    • stack-data: 0.6.3
    • sympy: 1.11.1
    • termcolor: 2.3.0
    • thinc: 8.2.1
    • threadpoolctl: 3.2.0
    • tiktoken: 0.5.2
    • tokenizers: 0.14.1
    • torch: 2.1.0
    • torch-scatter: 2.1.2
    • torchvision: 0.16.0
    • tornado: 6.3.3
    • tqdm: 4.66.1
    • traitlets: 5.11.2
    • transformers: 4.34.1
    • triton: 2.1.0
    • typer: 0.9.0
    • typing-extensions: 4.7.1
    • tzdata: 2023.3
    • urllib3: 2.0.7
    • virtualenv: 20.25.0
    • wasabi: 1.1.2
    • wcwidth: 0.2.8
    • weasel: 0.3.4
    • wheel: 0.41.2
    • widgetsnbextension: 4.0.9
    • wrapt: 1.16.0
    • xformers: 0.0.22.post7
    • xxhash: 3.4.1
    • yarl: 1.9.2
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.10.13
    • release: 5.4.0-147-generic
    • version: #164-Ubuntu SMP Tue Mar 21 14:23:17 UTC 2023

namespace-Pt avatar Dec 15 '23 10:12 namespace-Pt

Hi, this seems to be incompatible with the Gemini strategy. Indeed, encountering this issue with a frozen embedding table.

flybird11111 avatar Dec 16 '23 08:12 flybird11111

Okay thank you.

namespace-Pt avatar Dec 17 '23 16:12 namespace-Pt

Are there any plans to fix it?

luckyyangrun avatar Dec 19 '23 05:12 luckyyangrun