oumi icon indicating copy to clipboard operation
oumi copied to clipboard

[OP][Feature] Add support for loading pre-trained custom models

Open oelachqar opened this issue 9 months ago • 10 comments

Feature request

Oumi currently supports training with custom defined models, e.g. NanoGPT, but the ability to load a checkpoint or weights is missing.

Motivation / references

Implementing this feature will allow users to load a pre-trained custom model for inference, as well as finetuning, and other use cases

Your contribution

Code-review, pair-programming

oelachqar avatar Feb 19 '25 19:02 oelachqar

There's a few things I've run into here. It's possible they all exist and I haven't found them yet.

First is model registration. I can use @registry.register("GPT", registry.RegistryType.MODEL) to register a model. This works when everything is in memory (I can train in a notebook using this). It's not clear how I can use the command line train command with only a config. This also means I can't run jobs.

Second is loading the model and weights. I can do something like this (where I had a load_pretrained_weights to the model):

   tokenizer = AutoTokenizer.from_pretrained("gpt2")

    oumi.core.registry.register("GPT", GPT)
    
    # Specify parameters to customize your model
    model_params = ModelParams(model_name="GPT", tokenizer_name="gpt2", load_pretrained_weights=None, model_kwargs={"load_pretrained_weights": "output/my_gpt2.pt/model.safetensors", "tokenizer": tokenizer})

    # Build the model
    device = torch.device("cuda") # or gpu, or mps, etc.
    model = build_model(model_params).to(device)

    # Build a corresponding tokenizer
    tokenizer = build_tokenizer(model_params)
    input_data = tokenizer("What are the benefits of open source coding?", return_tensors="pt")

    # Use the same interface regardless of model type for generation
    outputs = model.generate(input_data['input_ids'].to(device), max_length=64)
    print(tokenizer.decode(outputs[0]))

But we still have the problem of how it works with config files.

Third is evaluation. I'm focused on lm_harness at the moment because I'm trying to do experiments with a custom gpt-style model and that's what I need first. Looking at how the lm-evaluation-harness wraps models (external library usage), I need something like:

class MyCustomLM(LM):
    def __init__(self, model, batch_size):
        self.model = model
        self.batch_size = batch_size

    def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
        return self.model.loglikelihood(requests, self.batch_size)

    def loglikelihood_rolling(self, requests: list[Instance]) -> list[tuple[float, bool]]:
        return self.model.loglikelihood_rolling(requests, self.batch_size)

    def generate_until(self, requests: list[Instance]) -> list[str]:
        return self.model.generate_until(requests, self.batch_size)

Then create a wrapped version of the model:

lm_obj = MyCustomLM(model=model, batch_size=16)

Then it's not clear how to get it into the evaluate(eval_config). Then we have the same problem with the model itself which is how this works with just an eval_config.yaml file.

I hope this helps outline the problems I've run into so far. I'm happy to help in anyway I can, but I do not understand the oumi architecture enough to understand how to

ggilley avatar Feb 21 '25 17:02 ggilley

Just for fun, I asked Deepseek-R1 how to solve this problem:

To enable local extensibility in your Python framework, allowing users to register custom models via YAML configuration, follow this structured approach:

1. Model Registration Mechanism

Create a registry within your framework to track custom models using a decorator or function.

# framework/registry.py
MODEL_REGISTRY = {}

def register_model(name):
    """Decorator to register a model class under a given name."""
    def decorator(cls):
        MODEL_REGISTRY[name] = cls
        return cls
    return decorator

2. Dynamic Module Loading

Allow the YAML configuration to specify paths to custom modules, which the framework will dynamically import.

# user_config.yaml
imports:
  - models/my_model.py  # Path relative to the YAML file

model:
  name: my_custom_model
  params:
    ...

3. Import Custom Modules

When parsing the YAML, load the specified modules to trigger registration. Use importlib to handle dynamic imports from arbitrary paths.

import importlib.util
import os
from pathlib import Path

def load_module_from_path(module_path: str, yaml_path: str):
    """Dynamically load a Python module from a file path."""
    # Resolve path relative to the YAML file's directory
    yaml_dir = Path(yaml_path).parent
    full_path = (yaml_dir / module_path).resolve()

    module_name = full_path.stem  # Use filename as module name
    spec = importlib.util.spec_from_file_location(module_name, full_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

4. YAML Processing Workflow

Parse the YAML, import custom modules first, then instantiate the model using the registry.

import yaml

def parse_yaml_config(yaml_path: str):
    with open(yaml_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Import custom modules
    for module_path in config.get('imports', []):
        load_module_from_path(module_path, yaml_path)
    
    # Retrieve model from registry
    model_name = config['model']['name']
    model_cls = MODEL_REGISTRY.get(model_name)
    if not model_cls:
        raise ValueError(f"Model '{model_name}' not registered.")

    # Instantiate the model with parameters
    model = model_cls(**config['model']['params'])
    return model

5. User's Custom Model

Users define their model and register it using the framework's decorator.

# models/my_model.py
from framework.registry import register_model

@register_model('my_custom_model')
class CustomModel:
    def __init__(self, params):
        # Initialize with provided parameters
        ...

Key Considerations:

  • Path Handling: Resolve paths relative to the YAML file's location to avoid dependency on the working directory.
  • Security: Only load trusted modules, as dynamic imports can execute arbitrary code.
  • Namespace Management: Encourage unique model names to prevent clashes (e.g., prefixing with a project name).
  • Error Handling: Provide clear exceptions if a model isn't registered or modules fail to load.

Alternatives:

  • Entry Points: If users package their extensions, use setuptools entry points for discovery.
  • Class Paths: Allow direct class references in YAML (e.g., model: my_module.MyClass), dynamically importing the class.

This approach keeps the framework extensible while maintaining a clean separation between core code and user extensions.

ggilley avatar Feb 21 '25 19:02 ggilley

@ggilley One quick comment about using your custom classes. Have you tried setting the OUMI_EXTRA_DEPS_FILE environment variable? This should let you load files at runtime for oumi (either via the CLI or raw python): https://oumi.ai/docs/en/latest/user_guides/customization.html#enable-your-classes-for-the-cli

taenin avatar Feb 21 '25 22:02 taenin

@taenin I had just found that in the code and was trying to decipher it (I was grepping for occurrences of importlib :-). That does seem to solve the loading problem for training. Thanks for the pointer! Now to try evaluation (which is what this ticket is actually for).

ggilley avatar Feb 21 '25 23:02 ggilley

Thanks for doing an initial analysis and bringing up the issue! I'll take a look into custom evaluation as it's definitely important to support.

UPDATE: Unassigning myself for now due to other priorities. Others may feel free to take this issue, but if not, I can come back to take a look

wizeng23 avatar Feb 22 '25 04:02 wizeng23

Hello @wizeng23 I would like to take a look into this issue, if it's alright.

ritankarsaha avatar Nov 10 '25 05:11 ritankarsaha

Go for it! Thanks for taking a look.

wizeng23 avatar Nov 10 '25 19:11 wizeng23

Go for it! Thanks for taking a look.

Sure !! Coming up with the PR ASAP.

ritankarsaha avatar Nov 11 '25 03:11 ritankarsaha