neuralprocesses icon indicating copy to clipboard operation
neuralprocesses copied to clipboard

Torch vs Tensorflow AGNP is different

Open DrJonnyT opened this issue 3 months ago • 12 comments

I've been converting my code from tensorflow to pytorch and it's much easier to get it training faster. However, the performance after n epochs is worse in torch. After lots of digging, it seems like the model architectures come out different for AGNP? But it doesn't seem to be an issue for a GNP:

import neuralprocesses.torch as nps_torch
import neuralprocesses.tensorflow as nps_tf
import numpy as np

# Some helper functions
def check_length(list1, list2):
    # Check if the lists are the same length
    return len(list1) == len(list2)

def check_shapes(list1, list2):
    # Check if the shapes of the items are the same
    for item1, item2 in zip(list1, list2):
        # Try to transpose the items if they are not the same shape
        if item1.shape != item2.shape:
            item2 = np.transpose(item2)
            if item1.shape != item2.shape:
                return False
    return True



# %%
# Construct GNP models
model_tf = nps_tf.construct_gnp(
    dim_x=17,
    dim_y=9,
    dim_embedding=128,
    num_enc_layers=6,
    num_dec_layers=6,
    )

# Construct the model
model_torch = nps_torch.construct_gnp(
    dim_x=17,
    dim_y=9,
    dim_embedding=128,
    num_enc_layers=6,
    num_dec_layers=6,
    )

tf_weights = model_tf.get_weights()
shapes_tf = [layer.shape for layer in tf_weights]
torch_weights = [param.data.numpy() for param in model_torch.parameters()]
shapes_torch= [layer.shape for layer in torch_weights]

# %%
# Check GNP models
print("\nGNP models:")
if check_length(tf_weights,torch_weights):
    print("Both models have the same number of layers")
else:
    print("Both models do not have the same number of layers")
if check_shapes(tf_weights,torch_weights):
    print("The shapes of all the layers are the same")
else:
    print("The shapes of the layers are not all the same")
    


# %%
# Construct AGNP models
model_tf = nps_tf.construct_agnp(
    dim_x=17,
    dim_y=9,
    dim_embedding=128,
    num_enc_layers=6,
    num_dec_layers=6,
    )

# Construct the model
model_torch = nps_torch.construct_agnp(
    dim_x=17,
    dim_y=9,
    dim_embedding=128,
    num_enc_layers=6,
    num_dec_layers=6,
    )

tf_weights = model_tf.get_weights()
shapes_tf = [layer.shape for layer in tf_weights]
torch_weights = [param.data.numpy() for param in model_torch.parameters()]
shapes_torch= [layer.shape for layer in torch_weights]

# %%
# Check AGNP models
print("\nAGNP models:")
if check_length(tf_weights,torch_weights):
    print("Both models have the same number of layers")
else:
    print("Both models do not have the same number of layers")
if check_shapes(tf_weights,torch_weights):
    print("The shapes of all the layers are the same")
else:
    print("The shapes of the layers are not all the same")

DrJonnyT avatar Mar 28 '24 16:03 DrJonnyT