generative-ai-docs
generative-ai-docs copied to clipboard
Fix Unknown Architecture Error
Description of the change
Initialize model architecture to gemma_config.Architecture.GEMMA_1
Motivation
The code in the notebook when run natively throw this error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-8-e00c04154560>](https://localhost:8080/#) in <cell line: 11>()
9 torch.set_default_dtype(model_config.get_dtype())
10 device = torch.device(MACHINE_TYPE)
---> 11 model = GemmaForCausalLM(model_config)
12 model.load_weights(ckpt_path)
13 model = model.to(device).eval()
1 frames
[/content/gemma_pytorch/gemma/model.py](https://localhost:8080/#) in __init__(self, config)
479 self.layers.append(Gemma2DecoderLayer(config, attn_type))
480 else:
--> 481 raise ValueError(f'Unknown architecture: {config.architecture}')
482 self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
483
ValueError: Unknown architecture: Architecture.GEMMA_1
which is caused by the code in model.py not recognizing Architecture.GEMMA_1 as gemma_config.Architecture.GEMMA_1
Architecture.GEMMA_1's definition and how it is the default value of the class can be seen in config.py:
class Architecture(enum.Enum):
GEMMA_1 = 1
GEMMA_2 = 2
@dataclasses.dataclass
class GemmaConfig:
# The architecture of the model.
architecture: Architecture = Architecture.GEMMA_1
Type of change
Bug fix
Checklist
- [x] I have performed a self-review of my code.
- [x] I have added detailed comments to my code where applicable.
- [x] I have verified that my change does not break existing code.
- [x] My PR is based on the latest changes of the main branch (if unsure, please run
git pull --rebase upstream main). - [x] I am familiar with the Google Style Guide for the language I have coded in.
- [x] I have read through the Contributing Guide and signed the Contributor License Agreement.