SiMPL icon indicating copy to clipboard operation
SiMPL copied to clipboard

Transforming spirl weights

Open JACKHAHA363 opened this issue 2 years ago • 3 comments

Hi, I noticed that the link to pretrained kitchen spirl model has a different format from the original spirl format. E.g. the keys in this repo

dict_keys(['horizon', 'z_dim', 'spirl_low_policy', 'spirl_prior_policy'])

The keys in the spirl repo

dict_keys(['epoch', 'global_step', 'state_dict', 'optimizer'])

How can I transform the original weights into the format readable by SiMPL? I need this since I am trying to train spirl on a new env and use the checkpoint in the SiMPL.

JACKHAHA363 avatar Sep 25 '22 23:09 JACKHAHA363

Hi, can you try this for now?

import torch
from collections import OrderedDict

spirl_low_policy = {load CLModelPolicy from spirl repo}

state_dict = spirl_low_policy.net.decoder.state_dict()
spirl_low_policy_state_dict = OrderedDict()
spirl_low_policy_state_dict['log_sigma'] = {needs update here}
for k, v in state_dict.items():
    names = k.split('.')
    if names[0] == 'input':  # input.linear->0
        spirl_low_policy_state_dict[f'net.net.0.{names[-1]}'] = v
    elif names[0].startswith('pyramid'):  # pyramid-i.linear->3*i+1, pyramid-i.linear->3*i+2
        pyramid_n = int(names[0][-1])
        is_norm = (names[1] == 'norm')
        layer_n = 3*pyramid_n + int(is_norm) + 2
        spirl_low_policy_state_dict[f'net.net.{layer_n}.{names[-1]}'] = v
    elif names[0] == 'head':  # input.linear->{last_n+2}
        spirl_low_policy_state_dict[f'net.net.{layer_n+2}.{names[-1]}'] = v
        
state_dict = spirl_low_policy.net.p[0].state_dict()
spirl_prior_policy_state_dict = OrderedDict()
for k, v in state_dict.items():
    names = k.split('.')
    if names[0] == 'input':  # input.linear->0
        spirl_prior_policy_state_dict[f'net.net.0.{names[-1]}'] = v
    elif names[0].startswith('pyramid'):  # pyramid-i.linear->3*i+2, pyramid-i.linear->3*i+3
        pyramid_n = int(names[0][-1])
        is_norm = (names[1] == 'norm')
        layer_n = 3*pyramid_n + int(is_norm) + 2
        spirl_prior_policy_state_dict[f'net.net.{layer_n}.{names[-1]}'] = v
    elif names[0] == 'head':  # input.linear->{last_n+2}
        spirl_prior_policy_state_dict[f'net.net.{layer_n+2}.{names[-1]}'] = v        

torch.save({
    'spirl_low_policy': spirl_low_policy_state_dict,
    'spirl_prior_policy': spirl_prior_policy_state_dict,
    'horizon': 10, 'z_dim': 10
}, 'spirl_pretrained.pt')

namsan96 avatar Sep 26 '22 05:09 namsan96

Hi @namsan96,

Thanks for your reply! but It seems that it still doesn't work

The repo requrires spirl_low_policy and spirl_prior_policy to be simpl.alg.spirl.spirl_policy.SpirlLowPolicy and simpl.alg.spirl.spirl_policy.SpirlPriorPolicy. Can you show me how can I initialize these two objects from the modified state dicts?

Thank you for your help!

JACKHAHA363 avatar Sep 27 '22 21:09 JACKHAHA363

Sorry for the inconvenience, I updated the script to initialize policies with inferred network dimensions. Please let me know if you still have problem!

import torch
from collections import OrderedDict

from simpl.alg.spirl.spirl_policy import SpirlLowPolicy, SpirlPriorPolicy


spirl_low_policy = {load CLModelPolicy from spirl repo}


state_dict = spirl_low_policy.net.decoder.state_dict()
spirl_low_policy_state_dict = OrderedDict()
spirl_low_policy_state_dict['log_sigma'] = {needs update here}
for k, v in state_dict.items():
    names = k.split('.')
    if names[0] == 'input':  # input.linear->0
        spirl_low_policy_state_dict[f'net.net.0.{names[-1]}'] = v
    elif names[0].startswith('pyramid'):  # pyramid-i.linear->3*i+1, pyramid-i.linear->3*i+2
        pyramid_n = int(names[0][-1])
        is_norm = (names[1] == 'norm')
        layer_n = 3*pyramid_n + int(is_norm) + 2
        spirl_low_policy_state_dict[f'net.net.{layer_n}.{names[-1]}'] = v
    elif names[0] == 'head':  # input.linear->{last_n+2}
        spirl_low_policy_state_dict[f'net.net.{layer_n+2}.{names[-1]}'] = v
        
state_dict = spirl_low_policy.net.p[0].state_dict()
spirl_prior_policy_state_dict = OrderedDict()
for k, v in state_dict.items():
    names = k.split('.')
    if names[0] == 'input':  # input.linear->0
        spirl_prior_policy_state_dict[f'net.net.0.{names[-1]}'] = v
    elif names[0].startswith('pyramid'):  # pyramid-i.linear->3*i+2, pyramid-i.linear->3*i+3
        pyramid_n = int(names[0][-1])
        is_norm = (names[1] == 'norm')
        layer_n = 3*pyramid_n + int(is_norm) + 2
        spirl_prior_policy_state_dict[f'net.net.{layer_n}.{names[-1]}'] = v
    elif names[0] == 'head':  # input.linear->{last_n+2}
        spirl_prior_policy_state_dict[f'net.net.{layer_n+2}.{names[-1]}'] = v        


z_dim = spirl_low_policy.net.decoder.input.linear.weight.shape[-1] - spirl_low_policy.net.p[0].input.linear.weight.shape[-1]
state_dim = spirl_low_policy.net.p[0].input.linear.weight.shape[-1]
action_dim = spirl_low_policy.net.decoder.head.linear.weight.shape[0]

low_policy_hidden_dim = spirl_low_policy.net.decoder.input.linear.weight.shape[0]
low_policy_n_hidden = len(spirl_low_policy.net.decoder)-1

prior_policy_hidden_dim = spirl_low_policy.net.p[0].input.linear.weight.shape[0]
prior_policy_n_hidden = len(spirl_low_policy.net.p[0])-1

simpl_spirl_low_policy = SpirlLowPolicy(state_dim, z_dim, action_dim, low_policy_hidden_dim, low_policy_n_hidden)
simpl_spirl_prior_policy = SpirlPriorPolicy(state_dim, z_dim, prior_policy_hidden_dim, prior_policy_n_hidden)

simpl_spirl_low_policy.load_state_dict(spirl_low_policy_state_dict)
simpl_spirl_prior_policy.load_state_dict(spirl_prior_policy_state_dict)

torch.save({
    'spirl_low_policy': simpl_spirl_low_policy,
    'spirl_prior_policy': simpl_spirl_prior_policy,
    'horizon': spirl_low_policy.horizon, 'z_dim': z_dim
}, 'spirl_pretrained.pt')

namsan96 avatar Sep 28 '22 09:09 namsan96