vision
vision copied to clipboard
How to add a new model in flowvision
向flowvision中添加模型的几个步骤
整体分为以下几个过程:
1. 添加模型
添加新的模型,需要注意几个细节
- 如果借鉴参考了别人的model, 需要在文件开头声明
"""
Modified from xxx.py
"""
- 导包顺序必须按照
python自带的包 - 额外安装的包 - 自身仓库的module的顺序(注意:重复的代码模块需要写入单独的文件。如Drop Path,Patch Embedding等)
import math
import oneflow as flow
import oneflow.nn as nn
from .registry import ModelCreator
from .utils import load_state_dict_from_url
- 必须定义
model_urls变量
model_urls = {
"convnext_tiny_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ConvNeXt/convnext_tiny_1k_224_ema.zip",
}
- 定义并注册好模型,写好相关的docstring(注意:docstring风格需要统一,如冒号,大小写,结尾句号等)
class ConvNeXt(nn.Module)
def __init__(self, **kwargs):
pass
def _create_convnext(arch, pretrained=False, progress=True, **model_kwargs):
model = ConvNeXt(**model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
@ModelCreator.register_model
def convnext_tiny_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the ConvNext-Tiny model trained on ImageNet2012.
.. note::
ConvNext-Tiny model from `"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>` _.
The required input size of the model is 224x224.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> convnext_tiny_224 = flowvision.models.convnext_tiny_224(pretrained=False, progress=True)
"""
model_kwargs = dict(
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
**kwargs
)
return _create_convnext(
"convnext_tiny_224", pretrained=pretrained, progress=progress, **model_kwargs
)
- 在 https://github.com/Oneflow-Inc/vision/blob/main/flowvision/models/init.py 中 import一下
2. 转换对应的模型权重
- 一个简单的函数, 自己可以拓展,但是基本是按照这个函数来改
from flowvision.models import ModelCreator
import oneflow as flow
import torch
def convert_torch_to_flow(model, torch_weight_path, save_path):
parameters = torch.load(torch_weight_path)
new_parameters = dict()
for key, value in parameters.items():
if "num_batches_tracked" not in key:
val = value.detach().cpu().numpy()
new_parameters[key] = val
model.load_state_dict(new_parameters)
flow.save(model.state_dict(), save_path)
print("successfully save model to %s" % (save_path))
model = ModelCreator.create_model("efficientnet_b7")
torch_weight = "/home/rentianhe/code/OneFlow-Models/vision/weights/efficientnet_b7_lukemelas-dcc49843.pth"
convert_torch_to_flow(model, torch_weight, save_path="./weights/efficientnet_b7")
3. 测试权重结果并记录
利用 https://github.com/Oneflow-Inc/vision/blob/main/projects/benchmark/classification/eval.sh 进行测试
4. 更新MODEL_ZOO
更新结果至 https://github.com/Oneflow-Inc/vision/blob/main/results/results_imagenet.md
5. 更新docstring
更新至 https://github.com/Oneflow-Inc/vision/blob/main/docs/source/flowvision.models.rst
6. 更新README中的表格
更新至 https://github.com/Oneflow-Inc/vision#overview-of-flowvision-structure
7. 添加与torch的速度对比,测试数据要列在pr里面,同时在测速脚本里添加对应的模型
有了这个,能否公开,给别人一些类似志愿者的机会。
有了这个,能否公开,给别人一些类似志愿者的机会。
感觉可以
有了这个,能否公开,给别人一些类似志愿者的机会。
可以的,放在vision下就是感觉可以给大家看,然后社区有人愿意帮忙的话~也可以按照这个步骤来添加模型
测速脚本使用方式
测试已有的模型
运行脚本之前需要安装一些依赖项:
pip 安装:
- oneflow
- torch & torchvision
- timm
交互式安装:
- flowvision, 在项目根目录下
pip3 install -e . --user
然后:
cd ci/check
bash run_speed_test.sh
速度对比结果会输出到当前运行脚本的目录的 result 文件
新增模型
如果是新增模型,就需要去读懂 compare_speed_with_pytorch.py 脚本的内容。
这个脚本测速的方式,简单来说就是直接加载某个模型的 .py 文件,实例化里面的模型,然后进行前后向算时间。
又因为想通过同一份代码既测oneflow也测torch的耗时,所以采用了 读取文件每一行,然后经过修改之后再保存成临时文件,然后用 importlib 加载的方式来加载模块。
测试 oneflow
如果是测试oneflow会进入脚本的 72行的分支。
首先会删除一些开头的 import:
from .registry import ModelCreator
from .utils import load_state_dict_from_url
同时删除代码中的 ModelCreator.register_model 注册代码,这里对应 compare_speed_with_pytorch.py 脚本 77~93 行
接着就是将 from .helpers import xxx 的相对路径 import 都改成 from flowvision.models.helpers import * ,因为当修改后的代码保存至临时路径之后,是找不到 helpers 模块的,这里对应脚本 95~106 行。
测试 pytorch
会进入 115行的分支
首先也是删除 ModelCreator 和 load_state_dict_from_url 的 import,同时删除代码中的 ModelCreator.register_model 注册代码,对应脚本 119~138 行。
接着就是替换 oneflow 相关的 import 为 import torch as flow 等等,对应脚本 140~155 行。
最后就是将 flowvision.layers 和 helpers 模块的 import 改为 timm 库同名的 import ,
这里需要注意的是,如果后续出现 timm 库中都没有的模块,则就需要类似 157 行 的处理,手动添加代码字符串的方式往文件中添加模块代码。