icefall icon indicating copy to clipboard operation
icefall copied to clipboard

whisper load_model

Open spiderlx opened this issue 9 months ago • 3 comments

I have a question about the model loading of whisper。https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/whisper/train.py line 752 :

model = whisper.load_model(params.model_name, "cpu")

We can see the official code provided by whisper。There are two ways to load the model in the code: download the official model and load the local model。When we pass in only the model type name, the model will be downloaded。 from https://github.com/openai/whisper/blob/main/whisper/init.py line103-160.

def load_model(
    name: str,
    device: Optional[Union[str, torch.device]] = None,
    download_root: str = None,
    in_memory: bool = False,
) -> Whisper:
    """
    Load a Whisper ASR model

    Parameters
    ----------
    name : str
        one of the official model names listed by "whisper.available_models()", or
        path to a model checkpoint containing the model dimensions and the model state_dict.
    device : Union[str, torch.device]
        the PyTorch device to put the model into
    download_root: str
        path to download the model files; by default, it uses "~/.cache/whisper"
    in_memory: bool
        whether to preload the model weights into host memory
    Returns
    -------
    model : Whisper
        The Whisper ASR model instance
    """

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    if download_root is None:
        default = os.path.join(os.path.expanduser("~"), ".cache")
        download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

    if name in _MODELS:
        checkpoint_file = _download(_MODELS[name], download_root, in_memory)
        alignment_heads = _ALIGNMENT_HEADS[name]
    elif os.path.isfile(name):
        checkpoint_file = open(name, "rb").read() if in_memory else name
        alignment_heads = None
    else:
        raise RuntimeError(
            f"Model {name} not found; available models = {available_models()}"
        )

    with (
        io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
    ) as fp:
        checkpoint = torch.load(fp, map_location=device)
    del checkpoint_file

    dims = ModelDimensions(**checkpoint["dims"])
    model = Whisper(dims)
    model.load_state_dict(checkpoint["model_state_dict"])

    if alignment_heads is not None:
        model.set_alignment_heads(alignment_heads)

    return model.to(device) 

As you can see from the above code, as long as the model name is passed in, the model will be downloaded or loaded from a certain download path。However, there are also statements that are followed by the local model being loaded.https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/whisper/train.py line 770-772 :

checkpoints = load_checkpoint_if_available(
        params=params, model=model, model_avg=model_avg
    )

Model be loaded twice.This pattern occurs not only in the training phase, but also in the decoding phase。This will cause code redundancy and a waste of storage space.But it could also be that my code understanding is off. I've tried loading the local model directly and it was no problem。: model = whisper.load_model(params.exp_path, "cpu")

And have other two question, but I suspect that may be due to the official code:

  • Turbo models cannot be loaded:
usage: decode.py [-h] [--epoch EPOCH] [--rank RANK] [--avg AVG] [--method METHOD] [--beam-size BEAM_SIZE] [--exp-dir EXP_DIR]
                 [--model-name {large-v2,large-v3,medium,small,base,tiny}] [--data-type DATA_TYPE]
                 [--remove-whisper-encoder-input-length-restriction REMOVE_WHISPER_ENCODER_INPUT_LENGTH_RESTRICTION] [--manifest-dir MANIFEST_DIR]
                 [--max-duration MAX_DURATION] [--bucketing-sampler BUCKETING_SAMPLER] [--num-buckets NUM_BUCKETS] [--concatenate-cuts CONCATENATE_CUTS]
                 [--duration-factor DURATION_FACTOR] [--gap GAP] [--on-the-fly-feats ON_THE_FLY_FEATS] [--shuffle SHUFFLE] [--return-cuts RETURN_CUTS]
                 [--num-workers NUM_WORKERS] [--enable-spec-aug ENABLE_SPEC_AUG] [--spec-aug-time-warp-factor SPEC_AUG_TIME_WARP_FACTOR]
                 [--enable-musan ENABLE_MUSAN] [--training-subset TRAINING_SUBSET]
decode.py: error: argument --model-name: invalid choice: 'large-v3-turbo' (choose from 'large-v2', 'large-v3', 'medium', 'small', 'base', 'tiny')
  • large-v3 models cannot match channels of input. RuntimeError: Given groups=1, weight of size [1280, 128, 3], expected input[110, 80, 407] to have 128 channels, but got 80 channels instead

I sincerely hope to receive your response.I hope my poor English does not affect your reading and viewing experience.

spiderlx avatar Feb 26 '25 02:02 spiderlx