oumi
oumi copied to clipboard
[OP][Feature] Add support for loading pre-trained custom models
Feature request
Oumi currently supports training with custom defined models, e.g. NanoGPT, but the ability to load a checkpoint or weights is missing.
Motivation / references
Implementing this feature will allow users to load a pre-trained custom model for inference, as well as finetuning, and other use cases
Your contribution
Code-review, pair-programming
There's a few things I've run into here. It's possible they all exist and I haven't found them yet.
First is model registration. I can use @registry.register("GPT", registry.RegistryType.MODEL) to register a model. This works when everything is in memory (I can train in a notebook using this). It's not clear how I can use the command line train command with only a config. This also means I can't run jobs.
Second is loading the model and weights. I can do something like this (where I had a load_pretrained_weights to the model):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
oumi.core.registry.register("GPT", GPT)
# Specify parameters to customize your model
model_params = ModelParams(model_name="GPT", tokenizer_name="gpt2", load_pretrained_weights=None, model_kwargs={"load_pretrained_weights": "output/my_gpt2.pt/model.safetensors", "tokenizer": tokenizer})
# Build the model
device = torch.device("cuda") # or gpu, or mps, etc.
model = build_model(model_params).to(device)
# Build a corresponding tokenizer
tokenizer = build_tokenizer(model_params)
input_data = tokenizer("What are the benefits of open source coding?", return_tensors="pt")
# Use the same interface regardless of model type for generation
outputs = model.generate(input_data['input_ids'].to(device), max_length=64)
print(tokenizer.decode(outputs[0]))
But we still have the problem of how it works with config files.
Third is evaluation. I'm focused on lm_harness at the moment because I'm trying to do experiments with a custom gpt-style model and that's what I need first. Looking at how the lm-evaluation-harness wraps models (external library usage), I need something like:
class MyCustomLM(LM):
def __init__(self, model, batch_size):
self.model = model
self.batch_size = batch_size
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
return self.model.loglikelihood(requests, self.batch_size)
def loglikelihood_rolling(self, requests: list[Instance]) -> list[tuple[float, bool]]:
return self.model.loglikelihood_rolling(requests, self.batch_size)
def generate_until(self, requests: list[Instance]) -> list[str]:
return self.model.generate_until(requests, self.batch_size)
Then create a wrapped version of the model:
lm_obj = MyCustomLM(model=model, batch_size=16)
Then it's not clear how to get it into the evaluate(eval_config). Then we have the same problem with the model itself which is how this works with just an eval_config.yaml file.
I hope this helps outline the problems I've run into so far. I'm happy to help in anyway I can, but I do not understand the oumi architecture enough to understand how to
Just for fun, I asked Deepseek-R1 how to solve this problem:
To enable local extensibility in your Python framework, allowing users to register custom models via YAML configuration, follow this structured approach:
1. Model Registration Mechanism
Create a registry within your framework to track custom models using a decorator or function.
# framework/registry.py
MODEL_REGISTRY = {}
def register_model(name):
"""Decorator to register a model class under a given name."""
def decorator(cls):
MODEL_REGISTRY[name] = cls
return cls
return decorator
2. Dynamic Module Loading
Allow the YAML configuration to specify paths to custom modules, which the framework will dynamically import.
# user_config.yaml
imports:
- models/my_model.py # Path relative to the YAML file
model:
name: my_custom_model
params:
...
3. Import Custom Modules
When parsing the YAML, load the specified modules to trigger registration. Use importlib to handle dynamic imports from arbitrary paths.
import importlib.util
import os
from pathlib import Path
def load_module_from_path(module_path: str, yaml_path: str):
"""Dynamically load a Python module from a file path."""
# Resolve path relative to the YAML file's directory
yaml_dir = Path(yaml_path).parent
full_path = (yaml_dir / module_path).resolve()
module_name = full_path.stem # Use filename as module name
spec = importlib.util.spec_from_file_location(module_name, full_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
4. YAML Processing Workflow
Parse the YAML, import custom modules first, then instantiate the model using the registry.
import yaml
def parse_yaml_config(yaml_path: str):
with open(yaml_path, 'r') as f:
config = yaml.safe_load(f)
# Import custom modules
for module_path in config.get('imports', []):
load_module_from_path(module_path, yaml_path)
# Retrieve model from registry
model_name = config['model']['name']
model_cls = MODEL_REGISTRY.get(model_name)
if not model_cls:
raise ValueError(f"Model '{model_name}' not registered.")
# Instantiate the model with parameters
model = model_cls(**config['model']['params'])
return model
5. User's Custom Model
Users define their model and register it using the framework's decorator.
# models/my_model.py
from framework.registry import register_model
@register_model('my_custom_model')
class CustomModel:
def __init__(self, params):
# Initialize with provided parameters
...
Key Considerations:
- Path Handling: Resolve paths relative to the YAML file's location to avoid dependency on the working directory.
- Security: Only load trusted modules, as dynamic imports can execute arbitrary code.
- Namespace Management: Encourage unique model names to prevent clashes (e.g., prefixing with a project name).
- Error Handling: Provide clear exceptions if a model isn't registered or modules fail to load.
Alternatives:
- Entry Points: If users package their extensions, use
setuptoolsentry points for discovery. - Class Paths: Allow direct class references in YAML (e.g.,
model: my_module.MyClass), dynamically importing the class.
This approach keeps the framework extensible while maintaining a clean separation between core code and user extensions.
@ggilley One quick comment about using your custom classes. Have you tried setting the OUMI_EXTRA_DEPS_FILE environment variable? This should let you load files at runtime for oumi (either via the CLI or raw python): https://oumi.ai/docs/en/latest/user_guides/customization.html#enable-your-classes-for-the-cli
@taenin I had just found that in the code and was trying to decipher it (I was grepping for occurrences of importlib :-). That does seem to solve the loading problem for training. Thanks for the pointer! Now to try evaluation (which is what this ticket is actually for).
Thanks for doing an initial analysis and bringing up the issue! I'll take a look into custom evaluation as it's definitely important to support.
UPDATE: Unassigning myself for now due to other priorities. Others may feel free to take this issue, but if not, I can come back to take a look
Hello @wizeng23 I would like to take a look into this issue, if it's alright.
Go for it! Thanks for taking a look.
Go for it! Thanks for taking a look.
Sure !! Coming up with the PR ASAP.