pytorch_geometric_temporal icon indicating copy to clipboard operation
pytorch_geometric_temporal copied to clipboard

Problem in model

Open KoperSloper opened this issue 1 year ago • 0 comments

Hello everyone,

I’ve been working on training a PyTorch model but I’m running into some issues. Despite following the usual training steps, my model doesn’t seem to be learning properly. I suspect there might be a problem with how I’m updating the model parameters, but I can’t seem to figure out what’s wrong.

Here’s the relevant part of my code:

class RecurrentGCN2(torch.nn.Module): def init(self, node_features): super(RecurrentGCN2, self).init() self.recurrent = LRGCN(node_features, 256, 1, 1) self.layers = nn.Sequential( torch.nn.Linear(256, 64), torch.nn.ReLU(), torch.nn.Linear(64, 1))

def forward(self, inputs, edge_index, edge_weight, h_0, c_0):
    hidden_channels = torch.zeros(inputs.shape[0], 29, 256).to(device)

    for i in range(inputs.shape[0]):
        h_0, c_0 = self.recurrent(inputs[i], edge_index, edge_weight, h_0, c_0)
        hidden_channels[i] = h_0

    h_mean = F.relu(h_0)
    h_mean = hidden_channels.mean(dim=1)
    y_hat = self.layers(h_mean)

    return y_hat

import matplotlib.pyplot as plt

for epoch in range(200): epoch_loss = 0 actual_values = [] predicted_values = []

model2.train()
for batch_x, batch_y in train_loader:
    batch_x = batch_x.to(device)
    batch_y = batch_y.to(device)

    h, c = None, None

    optimizer.zero_grad()

    y_hat = model2(batch_x, edge_index_sectors, edge_weights, h, c)

    loss = criterion(y_hat, batch_y)
    loss.backward()
    optimizer.step()

    epoch_loss += loss.item() / batch_x.shape[0]

    # Store the actual and predicted values for plotting
    actual_values.extend(batch_y.detach().cpu().numpy().flatten())
    predicted_values.extend(y_hat.detach().cpu().numpy().flatten())

epoch_loss /= len(train_loader)
print(f'Epoch {epoch+1}, Loss: {epoch_loss}')

# Plot the graph of predicted vs actual values after each epoch
plt.figure(figsize=(10,5))
plt.plot(actual_values, label='Actual')
plt.plot(predicted_values, label='Predicted')
plt.legend()
plt.show()

KoperSloper avatar Oct 04 '23 18:10 KoperSloper