mmskeleton
mmskeleton copied to clipboard
Testing realtime demo on ST-GCN project with a mmskeleton trained model
Hi,
I have used the mmskeleton code to train the net with my own dataset. Now I wanna test the realtime demo recongition using my dataset. I have managed to make it work on the Kinetics model and dataset with the ST-CGN code. Now I wanna test it with my own model.
It seems that the saved model generated with mmskeleton is not compatible with the old st-gcn code. I'm having this error when trying to load model weights : AttributeError: 'dict' object has no attribute 'cpu'
Traceback (most recent call last):
File "main.py", line 31, in <module>
p = Processor(sys.argv[2:])
File "/home/deep01/st-gcn/processor/io.py", line 28, in __init__
self.load_weights()
File "/home/deep01/st-gcn/processor/io.py", line 75, in load_weights
self.arg.ignore_weights)
File "/home/deep01/anaconda3/lib/python3.7/site-packages/torchlight-1.0-py3.7.egg/torchlight/io.py", line 66, in load_weights
SEEK_CUR = 1
File "/home/deep01/anaconda3/lib/python3.7/site-packages/torchlight-1.0-py3.7.egg/torchlight/io.py", line 66, in <listcomp>
SEEK_CUR = 1
AttributeError: 'dict' object has no attribute 'cpu'
Any suggestions ?
Hi, I have used the mmskeleton code to train the net with my own dataset. Now I wanna test the realtime demo recongition using my dataset. I have managed to make it work on the Kinetics model and dataset with the ST-CGN code. Now I wanna test it with my own model. It seems that the saved model generated with mmskeleton is not compatible with the old st-gcn code. I'm having this error when trying to load model weights : AttributeError: 'dict' object has no attribute 'cpu' Traceback (most recent call last): File "main.py", line 31, in
p = Processor(sys.argv[2:]) File "/home/deep01/st-gcn/processor/io.py", line 28, in init self.load_weights() File "/home/deep01/st-gcn/processor/io.py", line 75, in load_weights self.arg.ignore_weights) File "/home/deep01/anaconda3/lib/python3.7/site-packages/torchlight-1.0-py3.7.egg/torchlight/io.py", line 66, in load_weights SEEK_CUR = 1 File "/home/deep01/anaconda3/lib/python3.7/site-packages/torchlight-1.0-py3.7.egg/torchlight/io.py", line 66, in SEEK_CUR = 1 AttributeError: 'dict' object has no attribute 'cpu' Any suggestions ?
I got the same problem. Have you solved it? Or have you tested demo on st-gcn with your own video? Thanks a lot !
Hi, Yes I made some modifications on the mmcv library, now I can train dataset with a model trained with the mmskeleton project and test the real time demo with the old project st-gcn.
@mejdidallel can you show us modification details on mmcv to run on real time? Did you use the st-gcn demo file trun on real time ?
@yosagaf can you show us modification details on mmcv to run on real time? Did you use the st-gcn demo file trun on real time ?
Use this command to locate your mmcv package :
python -c "import mmcv as _; print(_.__path__)"
Then search for mmcv /runner/checkpoint.py and replace it with this :
# Copyright (c) Open-MMLab. All rights reserved.
import os
import os.path as osp
import pkgutil
import time
import warnings
from collections import OrderedDict
from importlib import import_module
import torch
import torchvision
from torch.utils import model_zoo
import mmcv
from .dist_utils import get_dist_info
open_mmlab_model_urls = {
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501
'resnet50_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth', # noqa: E501
'resnet101_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth', # noqa: E501
'resnext50_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50-32x4d-0ab1a123.pth', # noqa: E501
'resnext101_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth', # noqa: E501
'resnext101_64x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth', # noqa: E501
'contrib/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth', # noqa: E501
'detectron/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn-9186a21c.pth', # noqa: E501
'detectron/resnet101_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn-cac0ab98.pth', # noqa: E501
'jhu/resnet50_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_ws-15beedd8.pth', # noqa: E501
'jhu/resnet101_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth', # noqa: E501
'jhu/resnext50_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth', # noqa: E501
'jhu/resnext101_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth', # noqa: E501
'jhu/resnext50_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth', # noqa: E501
'jhu/resnext101_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth', # noqa: E501
'msra/hrnetv2_w18': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w18-00eb2006.pth', # noqa: E501
'msra/hrnetv2_w32': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth', # noqa: E501
'msra/hrnetv2_w40': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w40-ed0b031c.pth', # noqa: E501
'bninception_caffe': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth', # noqa: E501
'kin400/i3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth', # noqa: E501
'kin400/nl3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth', # noqa: E501
} # yapf: disable
def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
load = None # break load->load reference cycle
# ignore "num_batches_tracked" of BN layers
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:
err_msg.append('unexpected key in source state_dict: {}\n'.format(
', '.join(unexpected_keys)))
if missing_keys:
err_msg.append('missing keys in source state_dict: {}\n'.format(
', '.join(missing_keys)))
rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
def load_url_dist(url):
""" In distributed setting, this function only download checkpoint at
local rank 0 """
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
checkpoint = model_zoo.load_url(url)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(url)
return checkpoint
def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module('torchvision.models.{}'.format(name))
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
return model_urls
def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
# load checkpoint from modelzoo or file or url
if filename.startswith('modelzoo://'):
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead')
model_urls = get_torchvision_models()
model_name = filename[11:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('torchvision://'):
model_urls = get_torchvision_models()
model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('open-mmlab://'):
model_name = filename[13:]
checkpoint = load_url_dist(open_mmlab_model_urls[model_name])
elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename)
else:
if not osp.isfile(filename):
raise IOError('{} is not a checkpoint file'.format(filename))
checkpoint = torch.load(filename, map_location=map_location)
# get state_dict from checkpoint
if isinstance(checkpoint, OrderedDict):
state_dict = checkpoint
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
raise RuntimeError(
'No state_dict found in checkpoint file {}'.format(filename))
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
# load state_dict
if hasattr(model, 'module'):
load_state_dict(model.module, state_dict, strict, logger)
else:
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def weights_to_cpu(state_dict):
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
return state_dict_cpu
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
#if meta is None:
#meta = {}
#elif not isinstance(meta, dict):
#raise TypeError('meta must be a dict or None, but got {}'.format(
#type(meta)))
#meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename))
if hasattr(model, 'module'):
model = model.module
#checkpoint = {
#'meta': meta,
#'state_dict': weights_to_cpu(model.state_dict())
#}
#if optimizer is not None:
#checkpoint['optimizer'] = optimizer.state_dict()
#torch.save(checkpoint, filename)
torch.save(model.state_dict(), filename)
Then re-train you model with your dataset (with the mmskeleton project) and then you can use the generated models in realtime recognition code of the ST-GCN project.
嗨,是的,我对mmcv库进行了一些修改,现在我可以使用通过mmskeleton项目训练的模型来训练数据集,并使用旧项目st-gcn测试实时演示。 Hello, you know the old version of the st-gcn how to NTU-RGBD real-time visual recognition results? look forward to your reply!
嗨,是的,我对mmcv库进行了一些修改,现在我可以使用通过mmskeleton项目训练的模型来训练数据集,并使用旧项目st-gcn测试实时演示。 Hello, you know the old version of the st-gcn how to NTU-RGBD real-time visual recognition results? look forward to your reply!
Hi,
I tried the real time demo on my database and not on the NTU-RGB+D database.
@yosagaf can you show us modification details on mmcv to run on real time? Did you use the st-gcn demo file trun on real time ?
Use this command to locate your mmcv package :
python -c "import mmcv as _; print(_.__path__)"
Then search for mmcv /runner/checkpoint.py and replace it with this :
# Copyright (c) Open-MMLab. All rights reserved. import os import os.path as osp import pkgutil import time import warnings from collections import OrderedDict from importlib import import_module import torch import torchvision from torch.utils import model_zoo import mmcv from .dist_utils import get_dist_info open_mmlab_model_urls = { 'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501 'resnet50_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth', # noqa: E501 'resnet101_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth', # noqa: E501 'resnext50_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50-32x4d-0ab1a123.pth', # noqa: E501 'resnext101_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth', # noqa: E501 'resnext101_64x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth', # noqa: E501 'contrib/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth', # noqa: E501 'detectron/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn-9186a21c.pth', # noqa: E501 'detectron/resnet101_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn-cac0ab98.pth', # noqa: E501 'jhu/resnet50_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_ws-15beedd8.pth', # noqa: E501 'jhu/resnet101_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth', # noqa: E501 'jhu/resnext50_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth', # noqa: E501 'jhu/resnext101_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth', # noqa: E501 'jhu/resnext50_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth', # noqa: E501 'jhu/resnext101_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth', # noqa: E501 'msra/hrnetv2_w18': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w18-00eb2006.pth', # noqa: E501 'msra/hrnetv2_w32': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth', # noqa: E501 'msra/hrnetv2_w40': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w40-ed0b031c.pth', # noqa: E501 'bninception_caffe': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth', # noqa: E501 'kin400/i3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth', # noqa: E501 'kin400/nl3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth', # noqa: E501 } # yapf: disable def load_state_dict(module, state_dict, strict=False, logger=None): """Load state_dict to a module. This method is modified from :meth:`torch.nn.Module.load_state_dict`. Default value for ``strict`` is set to ``False`` and the message for param mismatch will be shown even if strict is False. Args: module (Module): Module that receives the state_dict. state_dict (OrderedDict): Weights. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. logger (:obj:`logging.Logger`, optional): Logger to log the error message. If not specified, print function will be used. """ unexpected_keys = [] all_missing_keys = [] err_msg = [] metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # use _load_from_state_dict to enable checkpoint version control def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, all_missing_keys, unexpected_keys, err_msg) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(module) load = None # break load->load reference cycle # ignore "num_batches_tracked" of BN layers missing_keys = [ key for key in all_missing_keys if 'num_batches_tracked' not in key ] if unexpected_keys: err_msg.append('unexpected key in source state_dict: {}\n'.format( ', '.join(unexpected_keys))) if missing_keys: err_msg.append('missing keys in source state_dict: {}\n'.format( ', '.join(missing_keys))) rank, _ = get_dist_info() if len(err_msg) > 0 and rank == 0: err_msg.insert( 0, 'The model and loaded state dict do not match exactly\n') err_msg = '\n'.join(err_msg) if strict: raise RuntimeError(err_msg) elif logger is not None: logger.warning(err_msg) else: print(err_msg) def load_url_dist(url): """ In distributed setting, this function only download checkpoint at local rank 0 """ rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) if rank == 0: checkpoint = model_zoo.load_url(url) if world_size > 1: torch.distributed.barrier() if rank > 0: checkpoint = model_zoo.load_url(url) return checkpoint def get_torchvision_models(): model_urls = dict() for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): if ispkg: continue _zoo = import_module('torchvision.models.{}'.format(name)) if hasattr(_zoo, 'model_urls'): _urls = getattr(_zoo, 'model_urls') model_urls.update(_urls) return model_urls def load_checkpoint(model, filename, map_location=None, strict=False, logger=None): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Either a filepath or URL or modelzoo://xxxxxxx. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. Returns: dict or OrderedDict: The loaded checkpoint. """ # load checkpoint from modelzoo or file or url if filename.startswith('modelzoo://'): warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' 'use "torchvision://" instead') model_urls = get_torchvision_models() model_name = filename[11:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('torchvision://'): model_urls = get_torchvision_models() model_name = filename[14:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('open-mmlab://'): model_name = filename[13:] checkpoint = load_url_dist(open_mmlab_model_urls[model_name]) elif filename.startswith(('http://', 'https://')): checkpoint = load_url_dist(filename) else: if not osp.isfile(filename): raise IOError('{} is not a checkpoint file'.format(filename)) checkpoint = torch.load(filename, map_location=map_location) # get state_dict from checkpoint if isinstance(checkpoint, OrderedDict): state_dict = checkpoint elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: raise RuntimeError( 'No state_dict found in checkpoint file {}'.format(filename)) # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} # load state_dict if hasattr(model, 'module'): load_state_dict(model.module, state_dict, strict, logger) else: load_state_dict(model, state_dict, strict, logger) return checkpoint def weights_to_cpu(state_dict): """Copy a model state_dict to cpu. Args: state_dict (OrderedDict): Model weights on GPU. Returns: OrderedDict: Model weights on GPU. """ state_dict_cpu = OrderedDict() for key, val in state_dict.items(): state_dict_cpu[key] = val.cpu() return state_dict_cpu def save_checkpoint(model, filename, optimizer=None, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ #if meta is None: #meta = {} #elif not isinstance(meta, dict): #raise TypeError('meta must be a dict or None, but got {}'.format( #type(meta))) #meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) mmcv.mkdir_or_exist(osp.dirname(filename)) if hasattr(model, 'module'): model = model.module #checkpoint = { #'meta': meta, #'state_dict': weights_to_cpu(model.state_dict()) #} #if optimizer is not None: #checkpoint['optimizer'] = optimizer.state_dict() #torch.save(checkpoint, filename) torch.save(model.state_dict(), filename)
Then re-train you model with your dataset (with the mmskeleton project) and then you can use the generated models in realtime recognition code of the ST-GCN project.
I revised it as you said, but the following problems arise. Do you know how to solve it? Traceback (most recent call last): File "/home/haige/anaconda3/envs/pytracking/lib/python3.6/site-packages/torchlight-1.0-py3.6.egg/torchlight/io.py", line 82, in load_weights doc = _io._TextIOBase.doc File "/home/haige/anaconda3/envs/pytracking/lib/python3.6/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Model: size mismatch for A: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for data_bn.weight: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.bias: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.running_mean: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.running_var: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for edge_importance.0: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.1: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.2: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.3: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.4: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.5: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.6: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.7: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.8: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.9: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for fcn.weight: copying a param with shape torch.Size([3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([400, 256, 1, 1]). size mismatch for fcn.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([400]).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "main.py", line 34, in
File "/home/haige/anaconda3/envs/pytracking/lib/python3.6/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Model: size mismatch for A: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for data_bn.weight: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.bias: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.running_mean: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.running_var: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for edge_importance.0: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.1: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.2: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.3: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.4: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.5: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.6: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.7: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.8: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.9: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for fcn.weight: copying a param with shape torch.Size([3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([400, 256, 1, 1]). size mismatch for fcn.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([400]).
@yosagaf can you show us modification details on mmcv to run on real time? Did you use the st-gcn demo file trun on real time ?
Use this command to locate your mmcv package :
python -c "import mmcv as _; print(_.__path__)"
Then search for mmcv /runner/checkpoint.py and replace it with this :# Copyright (c) Open-MMLab. All rights reserved. import os import os.path as osp import pkgutil import time import warnings from collections import OrderedDict from importlib import import_module import torch import torchvision from torch.utils import model_zoo import mmcv from .dist_utils import get_dist_info open_mmlab_model_urls = { 'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501 'resnet50_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth', # noqa: E501 'resnet101_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth', # noqa: E501 'resnext50_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50-32x4d-0ab1a123.pth', # noqa: E501 'resnext101_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth', # noqa: E501 'resnext101_64x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth', # noqa: E501 'contrib/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth', # noqa: E501 'detectron/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn-9186a21c.pth', # noqa: E501 'detectron/resnet101_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn-cac0ab98.pth', # noqa: E501 'jhu/resnet50_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_ws-15beedd8.pth', # noqa: E501 'jhu/resnet101_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth', # noqa: E501 'jhu/resnext50_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth', # noqa: E501 'jhu/resnext101_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth', # noqa: E501 'jhu/resnext50_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth', # noqa: E501 'jhu/resnext101_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth', # noqa: E501 'msra/hrnetv2_w18': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w18-00eb2006.pth', # noqa: E501 'msra/hrnetv2_w32': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth', # noqa: E501 'msra/hrnetv2_w40': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w40-ed0b031c.pth', # noqa: E501 'bninception_caffe': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth', # noqa: E501 'kin400/i3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth', # noqa: E501 'kin400/nl3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth', # noqa: E501 } # yapf: disable def load_state_dict(module, state_dict, strict=False, logger=None): """Load state_dict to a module. This method is modified from :meth:`torch.nn.Module.load_state_dict`. Default value for ``strict`` is set to ``False`` and the message for param mismatch will be shown even if strict is False. Args: module (Module): Module that receives the state_dict. state_dict (OrderedDict): Weights. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. logger (:obj:`logging.Logger`, optional): Logger to log the error message. If not specified, print function will be used. """ unexpected_keys = [] all_missing_keys = [] err_msg = [] metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # use _load_from_state_dict to enable checkpoint version control def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, all_missing_keys, unexpected_keys, err_msg) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(module) load = None # break load->load reference cycle # ignore "num_batches_tracked" of BN layers missing_keys = [ key for key in all_missing_keys if 'num_batches_tracked' not in key ] if unexpected_keys: err_msg.append('unexpected key in source state_dict: {}\n'.format( ', '.join(unexpected_keys))) if missing_keys: err_msg.append('missing keys in source state_dict: {}\n'.format( ', '.join(missing_keys))) rank, _ = get_dist_info() if len(err_msg) > 0 and rank == 0: err_msg.insert( 0, 'The model and loaded state dict do not match exactly\n') err_msg = '\n'.join(err_msg) if strict: raise RuntimeError(err_msg) elif logger is not None: logger.warning(err_msg) else: print(err_msg) def load_url_dist(url): """ In distributed setting, this function only download checkpoint at local rank 0 """ rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) if rank == 0: checkpoint = model_zoo.load_url(url) if world_size > 1: torch.distributed.barrier() if rank > 0: checkpoint = model_zoo.load_url(url) return checkpoint def get_torchvision_models(): model_urls = dict() for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): if ispkg: continue _zoo = import_module('torchvision.models.{}'.format(name)) if hasattr(_zoo, 'model_urls'): _urls = getattr(_zoo, 'model_urls') model_urls.update(_urls) return model_urls def load_checkpoint(model, filename, map_location=None, strict=False, logger=None): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Either a filepath or URL or modelzoo://xxxxxxx. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. Returns: dict or OrderedDict: The loaded checkpoint. """ # load checkpoint from modelzoo or file or url if filename.startswith('modelzoo://'): warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' 'use "torchvision://" instead') model_urls = get_torchvision_models() model_name = filename[11:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('torchvision://'): model_urls = get_torchvision_models() model_name = filename[14:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('open-mmlab://'): model_name = filename[13:] checkpoint = load_url_dist(open_mmlab_model_urls[model_name]) elif filename.startswith(('http://', 'https://')): checkpoint = load_url_dist(filename) else: if not osp.isfile(filename): raise IOError('{} is not a checkpoint file'.format(filename)) checkpoint = torch.load(filename, map_location=map_location) # get state_dict from checkpoint if isinstance(checkpoint, OrderedDict): state_dict = checkpoint elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: raise RuntimeError( 'No state_dict found in checkpoint file {}'.format(filename)) # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} # load state_dict if hasattr(model, 'module'): load_state_dict(model.module, state_dict, strict, logger) else: load_state_dict(model, state_dict, strict, logger) return checkpoint def weights_to_cpu(state_dict): """Copy a model state_dict to cpu. Args: state_dict (OrderedDict): Model weights on GPU. Returns: OrderedDict: Model weights on GPU. """ state_dict_cpu = OrderedDict() for key, val in state_dict.items(): state_dict_cpu[key] = val.cpu() return state_dict_cpu def save_checkpoint(model, filename, optimizer=None, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ #if meta is None: #meta = {} #elif not isinstance(meta, dict): #raise TypeError('meta must be a dict or None, but got {}'.format( #type(meta))) #meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) mmcv.mkdir_or_exist(osp.dirname(filename)) if hasattr(model, 'module'): model = model.module #checkpoint = { #'meta': meta, #'state_dict': weights_to_cpu(model.state_dict()) #} #if optimizer is not None: #checkpoint['optimizer'] = optimizer.state_dict() #torch.save(checkpoint, filename) torch.save(model.state_dict(), filename)
Then re-train you model with your dataset (with the mmskeleton project) and then you can use the generated models in realtime recognition code of the ST-GCN project.
I revised it as you said, but the following problems arise. Do you know how to solve it? Traceback (most recent call last): File "/home/haige/anaconda3/envs/pytracking/lib/python3.6/site-packages/torchlight-1.0-py3.6.egg/torchlight/io.py", line 82, in load_weights doc = _io._TextIOBase.doc File "/home/haige/anaconda3/envs/pytracking/lib/python3.6/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Model: size mismatch for A: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for data_bn.weight: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.bias: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.running_mean: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.running_var: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for edge_importance.0: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.1: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.2: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.3: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.4: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.5: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.6: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.7: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.8: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.9: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for fcn.weight: copying a param with shape torch.Size([3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([400, 256, 1, 1]). size mismatch for fcn.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([400]).
During handling of the above exception, another exception occurred:
Traceback (most recent call last): File "main.py", line 34, in p = Processor(sys.argv[2:]) File "/media/haige/14482B9E482B7E1C/project/st-gcn-master/processor/io.py", line 28, in init self.load_weights() File "/media/haige/14482B9E482B7E1C/project/st-gcn-master/processor/io.py", line 75, in load_weights self.arg.ignore_weights) File "/home/haige/anaconda3/envs/pytracking/lib/python3.6/site-packages/torchlight-1.0-py3.6.egg/torchlight/io.py", line 89, in load_weights
File "/home/haige/anaconda3/envs/pytracking/lib/python3.6/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Model: size mismatch for A: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for data_bn.weight: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.bias: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.running_mean: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for data_bn.running_var: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([54]). size mismatch for edge_importance.0: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.1: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.2: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.3: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.4: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.5: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.6: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.7: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.8: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for edge_importance.9: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for fcn.weight: copying a param with shape torch.Size([3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([400, 256, 1, 1]). size mismatch for fcn.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([400]).
You are loading a model with 17 joints into a model with 18 joints. Check your models shape.
back (most recent call last): File "/home/haige/anaconda3/envs/pytracking/lib/python3.6/site-packages/torchlight-1.0-py3.6.egg/torchlight/io.py", line 82, in load_weights doc = _io._TextIOBase.doc File "/home/haige/anaconda3/envs/pytracking/lib/python3.6/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Model: size mismatch for A: copying a param with shape torch.Size([3, 17, 17]) from checkpoint, the shape in current model is torch.Size([3, 18, 18]). size mismatch for data_bn.weight: copying a param
I have the same problem too Did you solve it? Thanks
我也是这个问题,需要用openpose来生成数据集吗,那样又回到老版本去了
Hi, I have used the mmskeleton code to train the net with my own dataset. Now I wanna test the realtime demo recongition using my dataset. I have managed to make it work on the Kinetics model and dataset with the ST-CGN code. Now I wanna test it with my own model. It seems that the saved model generated with mmskeleton is not compatible with the old st-gcn code. I'm having this error when trying to load model weights :
AttributeError: 'dict' object has no attribute 'cpu'
Traceback (most recent call last): File "main.py", line 31, in <module> p = Processor(sys.argv[2:]) File "/home/deep01/st-gcn/processor/io.py", line 28, in __init__ self.load_weights() File "/home/deep01/st-gcn/processor/io.py", line 75, in load_weights self.arg.ignore_weights) File "/home/deep01/anaconda3/lib/python3.7/site-packages/torchlight-1.0-py3.7.egg/torchlight/io.py", line 66, in load_weights SEEK_CUR = 1 File "/home/deep01/anaconda3/lib/python3.7/site-packages/torchlight-1.0-py3.7.egg/torchlight/io.py", line 66, in <listcomp> SEEK_CUR = 1 AttributeError: 'dict' object has no attribute 'cpu'
Any suggestions ?
Hi,Those days I am also interesting in mmskeleton algorithm, and I have installed st-gcn and mmskeleton on my server and both of them can run the demo code.
Now I want to train my own dataset.
My data format is that different types of action videos are placed in different folders.
So what should I do next?