neuralprocesses
neuralprocesses copied to clipboard
Torch vs Tensorflow AGNP is different
I've been converting my code from tensorflow to pytorch and it's much easier to get it training faster. However, the performance after n epochs is worse in torch. After lots of digging, it seems like the model architectures come out different for AGNP? But it doesn't seem to be an issue for a GNP:
import neuralprocesses.torch as nps_torch
import neuralprocesses.tensorflow as nps_tf
import numpy as np
# Some helper functions
def check_length(list1, list2):
# Check if the lists are the same length
return len(list1) == len(list2)
def check_shapes(list1, list2):
# Check if the shapes of the items are the same
for item1, item2 in zip(list1, list2):
# Try to transpose the items if they are not the same shape
if item1.shape != item2.shape:
item2 = np.transpose(item2)
if item1.shape != item2.shape:
return False
return True
# %%
# Construct GNP models
model_tf = nps_tf.construct_gnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
# Construct the model
model_torch = nps_torch.construct_gnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
tf_weights = model_tf.get_weights()
shapes_tf = [layer.shape for layer in tf_weights]
torch_weights = [param.data.numpy() for param in model_torch.parameters()]
shapes_torch= [layer.shape for layer in torch_weights]
# %%
# Check GNP models
print("\nGNP models:")
if check_length(tf_weights,torch_weights):
print("Both models have the same number of layers")
else:
print("Both models do not have the same number of layers")
if check_shapes(tf_weights,torch_weights):
print("The shapes of all the layers are the same")
else:
print("The shapes of the layers are not all the same")
# %%
# Construct AGNP models
model_tf = nps_tf.construct_agnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
# Construct the model
model_torch = nps_torch.construct_agnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
tf_weights = model_tf.get_weights()
shapes_tf = [layer.shape for layer in tf_weights]
torch_weights = [param.data.numpy() for param in model_torch.parameters()]
shapes_torch= [layer.shape for layer in torch_weights]
# %%
# Check AGNP models
print("\nAGNP models:")
if check_length(tf_weights,torch_weights):
print("Both models have the same number of layers")
else:
print("Both models do not have the same number of layers")
if check_shapes(tf_weights,torch_weights):
print("The shapes of all the layers are the same")
else:
print("The shapes of the layers are not all the same")