SiMPL
SiMPL copied to clipboard
Transforming spirl weights
Hi, I noticed that the link to pretrained kitchen spirl model has a different format from the original spirl format. E.g. the keys in this repo
dict_keys(['horizon', 'z_dim', 'spirl_low_policy', 'spirl_prior_policy'])
The keys in the spirl repo
dict_keys(['epoch', 'global_step', 'state_dict', 'optimizer'])
How can I transform the original weights into the format readable by SiMPL? I need this since I am trying to train spirl on a new env and use the checkpoint in the SiMPL.
Hi, can you try this for now?
import torch
from collections import OrderedDict
spirl_low_policy = {load CLModelPolicy from spirl repo}
state_dict = spirl_low_policy.net.decoder.state_dict()
spirl_low_policy_state_dict = OrderedDict()
spirl_low_policy_state_dict['log_sigma'] = {needs update here}
for k, v in state_dict.items():
names = k.split('.')
if names[0] == 'input': # input.linear->0
spirl_low_policy_state_dict[f'net.net.0.{names[-1]}'] = v
elif names[0].startswith('pyramid'): # pyramid-i.linear->3*i+1, pyramid-i.linear->3*i+2
pyramid_n = int(names[0][-1])
is_norm = (names[1] == 'norm')
layer_n = 3*pyramid_n + int(is_norm) + 2
spirl_low_policy_state_dict[f'net.net.{layer_n}.{names[-1]}'] = v
elif names[0] == 'head': # input.linear->{last_n+2}
spirl_low_policy_state_dict[f'net.net.{layer_n+2}.{names[-1]}'] = v
state_dict = spirl_low_policy.net.p[0].state_dict()
spirl_prior_policy_state_dict = OrderedDict()
for k, v in state_dict.items():
names = k.split('.')
if names[0] == 'input': # input.linear->0
spirl_prior_policy_state_dict[f'net.net.0.{names[-1]}'] = v
elif names[0].startswith('pyramid'): # pyramid-i.linear->3*i+2, pyramid-i.linear->3*i+3
pyramid_n = int(names[0][-1])
is_norm = (names[1] == 'norm')
layer_n = 3*pyramid_n + int(is_norm) + 2
spirl_prior_policy_state_dict[f'net.net.{layer_n}.{names[-1]}'] = v
elif names[0] == 'head': # input.linear->{last_n+2}
spirl_prior_policy_state_dict[f'net.net.{layer_n+2}.{names[-1]}'] = v
torch.save({
'spirl_low_policy': spirl_low_policy_state_dict,
'spirl_prior_policy': spirl_prior_policy_state_dict,
'horizon': 10, 'z_dim': 10
}, 'spirl_pretrained.pt')
Hi @namsan96,
Thanks for your reply! but It seems that it still doesn't work
The repo requrires spirl_low_policy
and spirl_prior_policy
to be simpl.alg.spirl.spirl_policy.SpirlLowPolicy
and simpl.alg.spirl.spirl_policy.SpirlPriorPolicy
. Can you show me how can I initialize these two objects from the modified state dicts?
Thank you for your help!
Sorry for the inconvenience, I updated the script to initialize policies with inferred network dimensions. Please let me know if you still have problem!
import torch
from collections import OrderedDict
from simpl.alg.spirl.spirl_policy import SpirlLowPolicy, SpirlPriorPolicy
spirl_low_policy = {load CLModelPolicy from spirl repo}
state_dict = spirl_low_policy.net.decoder.state_dict()
spirl_low_policy_state_dict = OrderedDict()
spirl_low_policy_state_dict['log_sigma'] = {needs update here}
for k, v in state_dict.items():
names = k.split('.')
if names[0] == 'input': # input.linear->0
spirl_low_policy_state_dict[f'net.net.0.{names[-1]}'] = v
elif names[0].startswith('pyramid'): # pyramid-i.linear->3*i+1, pyramid-i.linear->3*i+2
pyramid_n = int(names[0][-1])
is_norm = (names[1] == 'norm')
layer_n = 3*pyramid_n + int(is_norm) + 2
spirl_low_policy_state_dict[f'net.net.{layer_n}.{names[-1]}'] = v
elif names[0] == 'head': # input.linear->{last_n+2}
spirl_low_policy_state_dict[f'net.net.{layer_n+2}.{names[-1]}'] = v
state_dict = spirl_low_policy.net.p[0].state_dict()
spirl_prior_policy_state_dict = OrderedDict()
for k, v in state_dict.items():
names = k.split('.')
if names[0] == 'input': # input.linear->0
spirl_prior_policy_state_dict[f'net.net.0.{names[-1]}'] = v
elif names[0].startswith('pyramid'): # pyramid-i.linear->3*i+2, pyramid-i.linear->3*i+3
pyramid_n = int(names[0][-1])
is_norm = (names[1] == 'norm')
layer_n = 3*pyramid_n + int(is_norm) + 2
spirl_prior_policy_state_dict[f'net.net.{layer_n}.{names[-1]}'] = v
elif names[0] == 'head': # input.linear->{last_n+2}
spirl_prior_policy_state_dict[f'net.net.{layer_n+2}.{names[-1]}'] = v
z_dim = spirl_low_policy.net.decoder.input.linear.weight.shape[-1] - spirl_low_policy.net.p[0].input.linear.weight.shape[-1]
state_dim = spirl_low_policy.net.p[0].input.linear.weight.shape[-1]
action_dim = spirl_low_policy.net.decoder.head.linear.weight.shape[0]
low_policy_hidden_dim = spirl_low_policy.net.decoder.input.linear.weight.shape[0]
low_policy_n_hidden = len(spirl_low_policy.net.decoder)-1
prior_policy_hidden_dim = spirl_low_policy.net.p[0].input.linear.weight.shape[0]
prior_policy_n_hidden = len(spirl_low_policy.net.p[0])-1
simpl_spirl_low_policy = SpirlLowPolicy(state_dim, z_dim, action_dim, low_policy_hidden_dim, low_policy_n_hidden)
simpl_spirl_prior_policy = SpirlPriorPolicy(state_dim, z_dim, prior_policy_hidden_dim, prior_policy_n_hidden)
simpl_spirl_low_policy.load_state_dict(spirl_low_policy_state_dict)
simpl_spirl_prior_policy.load_state_dict(spirl_prior_policy_state_dict)
torch.save({
'spirl_low_policy': simpl_spirl_low_policy,
'spirl_prior_policy': simpl_spirl_prior_policy,
'horizon': spirl_low_policy.horizon, 'z_dim': z_dim
}, 'spirl_pretrained.pt')