GaitGraph
GaitGraph copied to clipboard
Making an Inference Script
I'm attempting to adapt evaluate.py
into my own infer.py
. Currently I'm just trying to flush a dummy tensor through but I'm getting an error. Any help would be appreciated!!!
Here's my infer.py
:
import numpy as np
import torch
from common import get_model_resgcn
from datasets.graph import Graph
# define configuration options with custom class instead of using CLI & argument parser
class opt():
weights_path = "/app/Gait/GaitGraph/src/models/gaitgraph_resgcn-n39-r8_coco_seq_60.pth"
network_name = "resgcn-n39-r8"
embedding_layer_size = 128
temporal_kernel_size = 9
dropout = 0.4
use_multi_branch = True
# Config for dataset
graph = Graph("coco")
# Init model
model, model_args = get_model_resgcn(graph, opt)
if torch.cuda.is_available():
model.cuda()
# Load weights
checkpoint = torch.load(opt.weights_path)
model.load_state_dict(checkpoint["model"])
model.eval()
# Load data
data_arr = np.ones((51,60,1))
data_ten = torch.from_numpy(data_arr)
# Calculate embeddings
with torch.no_grad():
bsz = data_ten.shape[0]
data_ten_flipped = torch.flip(data_ten, dims=[1])
data_ten = torch.cat([data_ten, data_ten_flipped], dim=0)
if torch.cuda.is_available():
data_ten = data_ten.cuda(non_blocking=True)
output = model(data_ten)
f1, f2 = torch.split(output, [bsz, bsz], dim=0)
output = torch.mean(torch.stack([f1, f2]), dim=0)
Here's the error message I get:
Traceback (most recent call last):
File "/app/Gait/GaitGraph/src/infer.py", line 27, in <module>
model.load_state_dict(checkpoint["model"])
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ResGCN:
Missing key(s) in state_dict: "input_branches.1.A", "input_branches.1.bn.weight", "input_branches.1.bn.bias", "input_branches.1.bn.running_mean", "input_branches.1.bn.running_var", "input_branches.1.layers.0.edge", "input_branches.1.layers.0.scn.conv.gcn.weight", "input_branches.1.layers.0.scn.conv.gcn.bias", "input_branches.1.layers.0.scn.bn.weight", "input_branches.1.layers.0.scn.bn.bias", "input_branches.1.layers.0.scn.bn.running_mean", "input_branches.1.layers.0.scn.bn.running_var", "input_branches.1.layers.0.tcn.conv.weight", "input_branches.1.layers.0.tcn.conv.bias", "input_branches.1.layers.0.tcn.bn.weight", "input_branches.1.layers.0.tcn.bn.bias", "input_branches.1.layers.0.tcn.bn.running_mean", "input_branches.1.layers.0.tcn.bn.running_var", "input_branches.1.layers.1.edge", "input_branches.1.layers.1.scn.conv_down.weight", "input_branches.1.layers.1.scn.conv_down.bias", "input_branches.1.layers.1.scn.bn_down.weight", "input_branches.1.layers.1.scn.bn_down.bias", "input_branches.1.layers.1.scn.bn_down.running_mean", "input_branches.1.layers.1.scn.bn_down.running_var", "input_branches.1.layers.1.scn.conv.gcn.weight", "input_branches.1.layers.1.scn.conv.gcn.bias", "input_branches.1.layers.1.scn.bn.weight", "input_branches.1.layers.1.scn.bn.bias", "input_branches.1.layers.1.scn.bn.running_mean", "input_branches.1.layers.1.scn.bn.running_var", "input_branches.1.layers.1.scn.conv_up.weight", "input_branches.1.layers.1.scn.conv_up.bias", "input_branches.1.layers.1.scn.bn_up.weight", "input_branches.1.layers.1.scn.bn_up.bias", "input_branches.1.layers.1.scn.bn_up.running_mean", "input_branches.1.layers.1.scn.bn_up.running_var", "input_branches.1.layers.1.tcn.conv_down.weight", "input_branches.1.layers.1.tcn.conv_down.bias", "input_branches.1.layers.1.tcn.bn_down.weight", "input_branches.1.layers.1.tcn.bn_down.bias", "input_branches.1.layers.1.tcn.bn_down.running_mean", "input_branches.1.layers.1.tcn.bn_down.running_var", "input_branches.1.layers.1.tcn.conv.weight", "input_branches.1.layers.1.tcn.conv.bias", "input_branches.1.layers.1.tcn.bn.weight", "input_branches.1.layers.1.tcn.bn.bias", "input_branches.1.layers.1.tcn.bn.running_mean", "input_branches.1.layers.1.tcn.bn.running_var", "input_branches.1.layers.1.tcn.conv_up.weight", "input_branches.1.layers.1.tcn.conv_up.bias", "input_branches.1.layers.1.tcn.bn_up.weight", "input_branches.1.layers.1.tcn.bn_up.bias", "input_branches.1.layers.1.tcn.bn_up.running_mean", "input_branches.1.layers.1.tcn.bn_up.running_var", "input_branches.1.layers.2.edge", "input_branches.1.layers.2.scn.residual.0.weight", "input_branches.1.layers.2.scn.residual.0.bias", "input_branches.1.layers.2.scn.residual.1.weight", "input_branches.1.layers.2.scn.residual.1.bias", "input_branches.1.layers.2.scn.residual.1.running_mean", "input_branches.1.layers.2.scn.residual.1.running_var", "input_branches.1.layers.2.scn.conv_down.weight", "input_branches.1.layers.2.scn.conv_down.bias", "input_branches.1.layers.2.scn.bn_down.weight", "input_branches.1.layers.2.scn.bn_down.bias", "input_branches.1.layers.2.scn.bn_down.running_mean", "input_branches.1.layers.2.scn.bn_down.running_var", "input_branches.1.layers.2.scn.conv.gcn.weight", "input_branches.1.layers.2.scn.conv.gcn.bias", "input_branches.1.layers.2.scn.bn.weight", "input_branches.1.layers.2.scn.bn.bias", "input_branches.1.layers.2.scn.bn.running_mean", "input_branches.1.layers.2.scn.bn.running_var", "input_branches.1.layers.2.scn.conv_up.weight", "input_branches.1.layers.2.scn.conv_up.bias", "input_branches.1.layers.2.scn.bn_up.weight", "input_branches.1.layers.2.scn.bn_up.bias", "input_branches.1.layers.2.scn.bn_up.running_mean", "input_branches.1.layers.2.scn.bn_up.running_var", "input_branches.1.layers.2.tcn.conv_down.weight", "input_branches.1.layers.2.tcn.conv_down.bias", "input_branches.1.layers.2.tcn.bn_down.weight", "input_branches.1.layers.2.tcn.bn_down.bias", "input_branches.1.layers.2.tcn.bn_down.running_mean", "input_branches.1.layers.2.tcn.bn_down.running_var", "input_branches.1.layers.2.tcn.conv.weight", "input_branches.1.layers.2.tcn.conv.bias", "input_branches.1.layers.2.tcn.bn.weight", "input_branches.1.layers.2.tcn.bn.bias", "input_branches.1.layers.2.tcn.bn.running_mean", "input_branches.1.layers.2.tcn.bn.running_var", "input_branches.1.layers.2.tcn.conv_up.weight", "input_branches.1.layers.2.tcn.conv_up.bias", "input_branches.1.layers.2.tcn.bn_up.weight", "input_branches.1.layers.2.tcn.bn_up.bias", "input_branches.1.layers.2.tcn.bn_up.running_mean", "input_branches.1.layers.2.tcn.bn_up.running_var", "input_branches.2.A", "input_branches.2.bn.weight", "input_branches.2.bn.bias", "input_branches.2.bn.running_mean", "input_branches.2.bn.running_var", "input_branches.2.layers.0.edge", "input_branches.2.layers.0.scn.conv.gcn.weight", "input_branches.2.layers.0.scn.conv.gcn.bias", "input_branches.2.layers.0.scn.bn.weight", "input_branches.2.layers.0.scn.bn.bias", "input_branches.2.layers.0.scn.bn.running_mean", "input_branches.2.layers.0.scn.bn.running_var", "input_branches.2.layers.0.tcn.conv.weight", "input_branches.2.layers.0.tcn.conv.bias", "input_branches.2.layers.0.tcn.bn.weight", "input_branches.2.layers.0.tcn.bn.bias", "input_branches.2.layers.0.tcn.bn.running_mean", "input_branches.2.layers.0.tcn.bn.running_var", "input_branches.2.layers.1.edge", "input_branches.2.layers.1.scn.conv_down.weight", "input_branches.2.layers.1.scn.conv_down.bias", "input_branches.2.layers.1.scn.bn_down.weight", "input_branches.2.layers.1.scn.bn_down.bias", "input_branches.2.layers.1.scn.bn_down.running_mean", "input_branches.2.layers.1.scn.bn_down.running_var", "input_branches.2.layers.1.scn.conv.gcn.weight", "input_branches.2.layers.1.scn.conv.gcn.bias", "input_branches.2.layers.1.scn.bn.weight", "input_branches.2.layers.1.scn.bn.bias", "input_branches.2.layers.1.scn.bn.running_mean", "input_branches.2.layers.1.scn.bn.running_var", "input_branches.2.layers.1.scn.conv_up.weight", "input_branches.2.layers.1.scn.conv_up.bias", "input_branches.2.layers.1.scn.bn_up.weight", "input_branches.2.layers.1.scn.bn_up.bias", "input_branches.2.layers.1.scn.bn_up.running_mean", "input_branches.2.layers.1.scn.bn_up.running_var", "input_branches.2.layers.1.tcn.conv_down.weight", "input_branches.2.layers.1.tcn.conv_down.bias", "input_branches.2.layers.1.tcn.bn_down.weight", "input_branches.2.layers.1.tcn.bn_down.bias", "input_branches.2.layers.1.tcn.bn_down.running_mean", "input_branches.2.layers.1.tcn.bn_down.running_var", "input_branches.2.layers.1.tcn.conv.weight", "input_branches.2.layers.1.tcn.conv.bias", "input_branches.2.layers.1.tcn.bn.weight", "input_branches.2.layers.1.tcn.bn.bias", "input_branches.2.layers.1.tcn.bn.running_mean", "input_branches.2.layers.1.tcn.bn.running_var", "input_branches.2.layers.1.tcn.conv_up.weight", "input_branches.2.layers.1.tcn.conv_up.bias", "input_branches.2.layers.1.tcn.bn_up.weight", "input_branches.2.layers.1.tcn.bn_up.bias", "input_branches.2.layers.1.tcn.bn_up.running_mean", "input_branches.2.layers.1.tcn.bn_up.running_var", "input_branches.2.layers.2.edge", "input_branches.2.layers.2.scn.residual.0.weight", "input_branches.2.layers.2.scn.residual.0.bias", "input_branches.2.layers.2.scn.residual.1.weight", "input_branches.2.layers.2.scn.residual.1.bias", "input_branches.2.layers.2.scn.residual.1.running_mean", "input_branches.2.layers.2.scn.residual.1.running_var", "input_branches.2.layers.2.scn.conv_down.weight", "input_branches.2.layers.2.scn.conv_down.bias", "input_branches.2.layers.2.scn.bn_down.weight", "input_branches.2.layers.2.scn.bn_down.bias", "input_branches.2.layers.2.scn.bn_down.running_mean", "input_branches.2.layers.2.scn.bn_down.running_var", "input_branches.2.layers.2.scn.conv.gcn.weight", "input_branches.2.layers.2.scn.conv.gcn.bias", "input_branches.2.layers.2.scn.bn.weight", "input_branches.2.layers.2.scn.bn.bias", "input_branches.2.layers.2.scn.bn.running_mean", "input_branches.2.layers.2.scn.bn.running_var", "input_branches.2.layers.2.scn.conv_up.weight", "input_branches.2.layers.2.scn.conv_up.bias", "input_branches.2.layers.2.scn.bn_up.weight", "input_branches.2.layers.2.scn.bn_up.bias", "input_branches.2.layers.2.scn.bn_up.running_mean", "input_branches.2.layers.2.scn.bn_up.running_var", "input_branches.2.layers.2.tcn.conv_down.weight", "input_branches.2.layers.2.tcn.conv_down.bias", "input_branches.2.layers.2.tcn.bn_down.weight", "input_branches.2.layers.2.tcn.bn_down.bias", "input_branches.2.layers.2.tcn.bn_down.running_mean", "input_branches.2.layers.2.tcn.bn_down.running_var", "input_branches.2.layers.2.tcn.conv.weight", "input_branches.2.layers.2.tcn.conv.bias", "input_branches.2.layers.2.tcn.bn.weight", "input_branches.2.layers.2.tcn.bn.bias", "input_branches.2.layers.2.tcn.bn.running_mean", "input_branches.2.layers.2.tcn.bn.running_var", "input_branches.2.layers.2.tcn.conv_up.weight", "input_branches.2.layers.2.tcn.conv_up.bias", "input_branches.2.layers.2.tcn.bn_up.weight", "input_branches.2.layers.2.tcn.bn_up.bias", "input_branches.2.layers.2.tcn.bn_up.running_mean", "input_branches.2.layers.2.tcn.bn_up.running_var".
size mismatch for input_branches.0.bn.weight: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).
size mismatch for input_branches.0.bn.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).
size mismatch for input_branches.0.bn.running_mean: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).
size mismatch for input_branches.0.bn.running_var: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).
size mismatch for input_branches.0.layers.0.scn.conv.gcn.weight: copying a param with shape torch.Size([192, 3, 1, 1]) from checkpoint, the shape in current model is torch.Size([192, 6, 1, 1]).
size mismatch for main_stream.0.scn.residual.0.weight: copying a param with shape torch.Size([128, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 96, 1, 1]).
size mismatch for main_stream.0.scn.conv_down.weight: copying a param with shape torch.Size([16, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([16, 96, 1, 1]).
I eventually figured out my issues. Here's my final script to flush a dummy tensor through GaitGraph for inference:
import numpy as np
import torch
from common import get_model_resgcn
from datasets.graph import Graph
# define configuration options with custom class instead of using CLI & argument parser
class opt():
weights_path = "/app/Gait/GaitGraph/models/gaitgraph_resgcn-n39-r8_coco_seq_60.pth"
network_name = "resgcn-n39-r8"
embedding_layer_size = 128
temporal_kernel_size = 9
dropout = 0.4
use_multi_branch = False
# Config for dataset
graph = Graph("coco")
# Init model
model, model_args = get_model_resgcn(graph, opt)
if torch.cuda.is_available():
model.cuda()
# Load weights
checkpoint = torch.load(opt.weights_path)
model.load_state_dict(checkpoint["model"], strict=False)
model.eval()
# Load data
# Dim 1: batch components
# Dim 2: unknown
# Dim 3: coordinates (x, y, conf)
# Dim 4: sequence of 60 frames
# Dim 5: coordinate list (nose, left_eye, right_eye, etc)
data_arr = np.ones((1,1,3,60,17)).astype(float)
data_ten = torch.from_numpy(data_arr)
data_ten = data_ten.type(torch.cuda.FloatTensor)
# Calculate embeddings
with torch.no_grad():
bsz = data_ten.shape[0]
data_ten_flipped = torch.flip(data_ten, dims=[1])
data_ten = torch.cat([data_ten, data_ten_flipped], dim=0)
if torch.cuda.is_available():
data_ten = data_ten.cuda(non_blocking=True)
output = model(data_ten)
f1, f2 = torch.split(output, [bsz, bsz], dim=0)
output = torch.mean(torch.stack([f1, f2]), dim=0)
请问一下,你最终是否顺利运行出了代码。我目前也遇到了一些问题,想知道代码是否可行。希望能得到回答,不甚感激