PyTorch-Multi-Style-Transfer icon indicating copy to clipboard operation
PyTorch-Multi-Style-Transfer copied to clipboard

Update for PyTorch 0.4.0

Open mratsim opened this issue 7 years ago • 11 comments

PyTorch 0.4.0 was released on April 24 and unfortunately the pre-trained weights from before are not compatible.

On the notebook I get

style_model = Net(ngf=128)
style_model.load_state_dict(torch.load('21styles.model'), False)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-ce41c62c2272> in <module>()
      1 style_model = Net(ngf=128)
----> 2 style_model.load_state_dict(torch.load('21styles.model'), False)

/usr/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    719         if len(error_msgs) > 0:
    720             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 721                                self.__class__.__name__, "\n\t".join(error_msgs)))
    722 
    723     def parameters(self):

RuntimeError: Error(s) in loading state_dict for Net:
	Unexpected running stats buffer(s) "model1.1.running_mean" and "model1.1.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.3.conv_block.0.running_mean" and "model1.3.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.3.conv_block.3.running_mean" and "model1.3.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.3.conv_block.6.running_mean" and "model1.3.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.4.conv_block.0.running_mean" and "model1.4.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.4.conv_block.3.running_mean" and "model1.4.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.4.conv_block.6.running_mean" and "model1.4.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.1.running_mean" and "model.0.1.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.3.conv_block.0.running_mean" and "model.0.3.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.3.conv_block.3.running_mean" and "model.0.3.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.3.conv_block.6.running_mean" and "model.0.3.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.4.conv_block.0.running_mean" and "model.0.4.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.4.conv_block.3.running_mean" and "model.0.4.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.4.conv_block.6.running_mean" and "model.0.4.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.2.conv_block.0.running_mean" and "model.2.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.2.conv_block.3.running_mean" and "model.2.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.2.conv_block.6.running_mean" and "model.2.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.3.conv_block.0.running_mean" and "model.3.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.3.conv_block.3.running_mean" and "model.3.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.3.conv_block.6.running_mean" and "model.3.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.4.conv_block.0.running_mean" and "model.4.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.4.conv_block.3.running_mean" and "model.4.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.4.conv_block.6.running_mean" and "model.4.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.5.conv_block.0.running_mean" and "model.5.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.5.conv_block.3.running_mean" and "model.5.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.5.conv_block.6.running_mean" and "model.5.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.6.conv_block.0.running_mean" and "model.6.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.6.conv_block.3.running_mean" and "model.6.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.6.conv_block.6.running_mean" and "model.6.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.7.conv_block.0.running_mean" and "model.7.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.7.conv_block.3.running_mean" and "model.7.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.7.conv_block.6.running_mean" and "model.7.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.8.conv_block.0.running_mean" and "model.8.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.8.conv_block.3.running_mean" and "model.8.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.8.conv_block.6.running_mean" and "model.8.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.9.conv_block.0.running_mean" and "model.9.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.9.conv_block.3.running_mean" and "model.9.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.9.conv_block.6.running_mean" and "model.9.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.10.running_mean" and "model.10.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.

mratsim avatar Jun 10 '18 19:06 mratsim

set track_running_stats=True in InstanceNorm2d should be able to fix this

zhanghang1989 avatar Jun 10 '18 19:06 zhanghang1989

track_running_stats = True is buggy and does not work (or I missed something).

I went the other way with the following:

# https://github.com/zhanghang1989/PyTorch-Multi-Style-Transfer/issues/21
# Compatibility shim for PyTorch 0.4

model_dict = torch.load('21styles.model')
model_dict_clone = model_dict.copy() # We can't mutate while iterating

for key, value in model_dict_clone.items():
    if key.endswith(('running_mean', 'running_var')):
        del model_dict[key]

### Next cell

style_model = Net(ngf=128)
style_model.load_state_dict(model_dict, False)

mratsim avatar Jun 10 '18 19:06 mratsim

  1. I had to downgrade PyTorch to get it working.
pip install torch==0.3.0.post4
  1. In the camera_demo.py and main.py files, the above translates into changing
style_model = Net(ngf=args.ngf)
style_model.load_state_dict(torch.load(args.model))

to

model_dict = torch.load(args.model)
model_dict_clone = model_dict.copy() # We can't mutate while iterating
for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]
style_model.load_state_dict(model_dict, False)
  1. Change style_v.data() to style_v.data.

alvinwan avatar Aug 05 '18 00:08 alvinwan

Just got camera_demo.py and main.py working - thanks @alvinwan and @mratsim for the hints above.

For a while I was getting this error:

  File "camera_demo.py", line 105, in <module>
    main()
  File "camera_demo.py", line 102, in main
    run_demo(args, mirror=True)
  File "camera_demo.py", line 75, in run_demo
    simg = simg.transpose(1, 2, 0).astype('uint8')
ValueError: axes don't match array

The quick way to debug is was by replacing my command-line python with python -m pdb and, once it crashed and gave me a prompt, checking the shape of simg. Evidently simg now has 4 dimensions rather than 3, which I fixed with the reshape in step 3 below.

My full fixes were:

1. Downgrade torch:

pip uninstall torch
pip install torch==0.3.0.post4

2. In camera_demo.py and main.py replace

        style_model = Net(ngf=args.ngf)

With

        model_dict = torch.load(args.model) # or args.resume, 
                                           # matching what's in the line with style_model.load_state_dict
        model_dict_clone = model_dict.copy() # We can't mutate while iterating
        for key, value in model_dict_clone.items():
                if key.endswith(('running_mean', 'running_var')):
                        del model_dict[key]

        style_model = Net(ngf=128) # to run with torch-0.3.0.post4
#        style_model = Net(ngf=args.ngf) # to run main.py with torch-0.4.0

Replace

	style_model.load_state_dict(torch.load(args.model)) # or (args.resume) one place

With

        style_model.load_state_dict(model_dict, False)

3. Replace

			simg = style_v.data().numpy()

With

       			simg = style_v.data.numpy().reshape((3,512,512))

QUESTION Instead of downgrading torch, I also tried setting track_running_stats=True for InstanceNorm2d in net.py. I had to do this in a few places: follow norm_layer through the code, including in the Bottleneck and UpBottleneck classes.

(Note that the documentation shows that track_running_stats=True is the default for most normalization layer classes.)

I've gotten main.py working with torch upgraded, but camera_demo gives an all-black image as output. I'm interested in comments, or ideas!

karenerobinson avatar Oct 05 '18 21:10 karenerobinson

@karenerobinson I think the most reasonable way would be to wait for PyTorch 1.0 that should happen within days so that APIs are more stable we don't have to fix something new once again once it hits.

mratsim avatar Oct 06 '18 18:10 mratsim

How do you set track_running_stats = True? I am a beginner sorry if it's too obvious I can't find it for the past hour or so.

Thanks

mertgerdan avatar Mar 03 '19 21:03 mertgerdan

How do you set track_running_stats = True? I am a beginner sorry if it's too obvious I can't find it for the past hour or so.

Thanks

try what @mratsim has mentioned above. model_dict = torch.load('21styles.model') model_dict_clone = model_dict.copy() # We can't mutate while iterating

for key, value in model_dict_clone.items(): if key.endswith(('running_mean', 'running_var')): del model_dict[key]

Next cell

style_model = Net(ngf=128) style_model.load_state_dict(model_dict, False)

nile649 avatar Mar 03 '19 23:03 nile649

How do you set track_running_stats = True? I am a beginner sorry if it's too obvious I can't find it for the past hour or so. Thanks

try what @mratsim has mentioned above. model_dict = torch.load('21styles.model') model_dict_clone = model_dict.copy() # We can't mutate while iterating

for key, value in model_dict_clone.items(): if key.endswith(('running_mean', 'running_var')): del model_dict[key]

Next cell

style_model = Net(ngf=128) style_model.load_state_dict(model_dict, False)

I fixed my issue, I went to NN packages in my python site packages dir and set track_running_stats=True on the instanceNorm file. I didn't know how to do that. After a bit more tweaking, I got it to work. Thanks anyways :)

mertgerdan avatar Mar 04 '19 05:03 mertgerdan

I really appreciate the comments for fixing the compatibility issue for the code. I haven't worked on this project for a while. Could you consider providing a pull request to the master branch? Thanks a lot :)

zhanghang1989 avatar Mar 04 '19 19:03 zhanghang1989

Thanks to @alvinwan for sharing the fixes. I have tried it and it worked for both main.py and camera_demo.py. @zhanghang1989 As this is still not fixed in the master branch, I have created a pull request for it (including another fix for load_lua) here.

jianchao-li avatar Aug 03 '19 04:08 jianchao-li

How do you set track_running_stats = True? I am a beginner sorry if it's too obvious I can't find it for the past hour or so. Thanks

try what @mratsim has mentioned above. model_dict = torch.load('21styles.model') model_dict_clone = model_dict.copy() # We can't mutate while iterating for key, value in model_dict_clone.items(): if key.endswith(('running_mean', 'running_var')): del model_dict[key]

Next cell

style_model = Net(ngf=128) style_model.load_state_dict(model_dict, False)

I fixed my issue, I went to NN packages in my python site packages dir and set track_running_stats=True on the instanceNorm file. I didn't know how to do that. After a bit more tweaking, I got it to work. Thanks anyways :)

I have been looking for this for 50 hours, thanks

IamRafh avatar Mar 27 '21 05:03 IamRafh