icefall
icefall copied to clipboard
whisper load_model
I have a question about the model loading of whisper。https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/whisper/train.py line 752 :
model = whisper.load_model(params.model_name, "cpu")
We can see the official code provided by whisper。There are two ways to load the model in the code: download the official model and load the local model。When we pass in only the model type name, the model will be downloaded。 from https://github.com/openai/whisper/blob/main/whisper/init.py line103-160.
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by "whisper.available_models()", or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)
As you can see from the above code, as long as the model name is passed in, the model will be downloaded or loaded from a certain download path。However, there are also statements that are followed by the local model being loaded.https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/whisper/train.py line 770-772 :
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
Model be loaded twice.This pattern occurs not only in the training phase, but also in the decoding phase。This will cause code redundancy and a waste of storage space.But it could also be that my code understanding is off.
I've tried loading the local model directly and it was no problem。:
model = whisper.load_model(params.exp_path, "cpu")
And have other two question, but I suspect that may be due to the official code:
- Turbo models cannot be loaded:
usage: decode.py [-h] [--epoch EPOCH] [--rank RANK] [--avg AVG] [--method METHOD] [--beam-size BEAM_SIZE] [--exp-dir EXP_DIR]
[--model-name {large-v2,large-v3,medium,small,base,tiny}] [--data-type DATA_TYPE]
[--remove-whisper-encoder-input-length-restriction REMOVE_WHISPER_ENCODER_INPUT_LENGTH_RESTRICTION] [--manifest-dir MANIFEST_DIR]
[--max-duration MAX_DURATION] [--bucketing-sampler BUCKETING_SAMPLER] [--num-buckets NUM_BUCKETS] [--concatenate-cuts CONCATENATE_CUTS]
[--duration-factor DURATION_FACTOR] [--gap GAP] [--on-the-fly-feats ON_THE_FLY_FEATS] [--shuffle SHUFFLE] [--return-cuts RETURN_CUTS]
[--num-workers NUM_WORKERS] [--enable-spec-aug ENABLE_SPEC_AUG] [--spec-aug-time-warp-factor SPEC_AUG_TIME_WARP_FACTOR]
[--enable-musan ENABLE_MUSAN] [--training-subset TRAINING_SUBSET]
decode.py: error: argument --model-name: invalid choice: 'large-v3-turbo' (choose from 'large-v2', 'large-v3', 'medium', 'small', 'base', 'tiny')
- large-v3 models cannot match channels of input.
RuntimeError: Given groups=1, weight of size [1280, 128, 3], expected input[110, 80, 407] to have 128 channels, but got 80 channels instead
I sincerely hope to receive your response.I hope my poor English does not affect your reading and viewing experience.