GANotebooks icon indicating copy to clipboard operation
GANotebooks copied to clipboard

Saving, load and testing for cyclegan_keras

Open daniezest opened this issue 7 years ago • 6 comments

How do I save the model of cyclegan_keras? Also how do i test the trained model?

Thanks!

daniezest avatar May 20 '18 16:05 daniezest

You use the methods mentioned here https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model to save/load keras models and/or weights. The generators are netGB and netGA. The function showG shows how to use functions cycleA_generate and cycleA_generate to transform an image using trained models.

tjwei avatar May 24 '18 07:05 tjwei

Thank you for the help. I noticed that you tried your video swapping on CycleGAN-lasagne-fber.ipynb Is it any different from cyclegan_keras?

I don't seem to be getting good generation loss for video swapping on cyclegan_keras.

Thanks!

daniezest avatar May 29 '18 04:05 daniezest

The keras version uses a unet as generator while the lasagne one uses resnet.

tjwei avatar May 29 '18 08:05 tjwei

I was also trying to save the model, but I don't seem to find a model object in there to do model.save(). Can you please point which line in CGAN-keras must be tweaked to save the model, and where is the model object to save iteratively.

rishab-sharma avatar Jul 16 '18 07:07 rishab-sharma

The generating models are defined in these lines.

netGB = UNET_G(imageSize, nc_in, nc_out, ngf)
netGA = UNET_G(imageSize, nc_out, nc_in, ngf)

You can save them after they are trained.

tjwei avatar Jul 18 '18 08:07 tjwei

Just add this snippet at the end of the code:

def save_model(model, filepath):
    # serialize model to JSON
    model_json = model.to_json()
    with open("{}.json".format(filepath), "w") as json_file:
        json_file.write(model_json)
    # serialize weights to HDF5
    model.save_weights("{}.h5".format(filepath))

save_model(netDA, 'your/output/path/' + 'netDA')
save_model(netDB, 'your/output/path/' + 'netDB')
save_model(netGA, 'your/output/path/' + 'netGA')
save_model(netGB, 'your/output/path/' + 'netGB')

And you can load the model and than generate fake images with the trained model with this code I've made.

viniciusarruda avatar Aug 31 '18 05:08 viniciusarruda