spektral
spektral copied to clipboard
Training a network with heterogeneous data
Hi!
I'm trying to build a regression model which can predict a parameter starting from a dataset in which the X is represented by a couple (Graph, integer parameter) and the output is a plain float value.
I was wondering if Spektral could help me in my task; I've managed to correctly load my graph data into Graph objects but after a couple of weeks I couldn't figure out a way to either:
- Make a composite network which first processes the Graphs obtaining some kind of numerical embedding, and then puts that embedding (+ the integer parameter) into a NN regressor;
- Do all of this with your library only.
Thanks in advance for the help!
Hi,
Spektral layers are just like all other Keras layers so they can be composed however you like (as long as the shapes of the tensors are compatible).
In your case, it sounds like you want something like GeneralGNN or some other GNN with a global pooling layer at the end.
You can then concatenate your integer to the output of the GNN using tf.keras.layers.Concatenate
and finally feed the result to one or more tf.keras.layers.Dense
layers.
To train this architecture you'll need to write a custom training loop (see an example here) because the model.fit
approach won't work with Spektral in this case.
Thanks for your quick answer, I'll try that in the next days and tell you how it went :)
Hi, I got around to try and write some code, and I think I'm stuck:
gnn = GeneralGNN(1, activation='linear')
inputlayer = Input(shape=(1,))
parallel = Dense(128, activation='relu')(inputlayer)
parallel = Dense(128, activation='relu')(parallel)
outpar = Flatten()(parallel)
conc = concatenate([gnn.output(), outpar])
out = Dense(1, activation='linear')(conc)
model = Model([gnn.input(), inputlayer], out)
This is how I'm trying to concatenate the layers. The concatenation gives me this error:
conc = concatenate([gnn.output(), outpar])
AttributeError: Layer general_gnn_26 is not connected, no input to return.
Am I missing something here?
I'm not sure that calling .output()
is the correct way of creating the model, this "static" way of building the computational graph has been superseded by model subclassing now.
Do you still have the issue if you do it with model subclassing?