lattice
lattice copied to clipboard
Many-batches predictions
Hi,
When trying to get predictions of Lattice Models on more than one batch of data at once, Errors are raised. This is a nice feature to efficiently get predictions, and is present in basic Neural Network Keras models; find some examples in this colab.
As far as I can tell from looking at API docs + source code, this should be related to the inputs admitted by PWC layers, but I wonder if there is an easy way around.
In particular, this piece of code captures what I would like to get (and retrieves an error when calling on batched_inputs):
class LatticeModel(tf.keras.Model):
def __init__(self, nodes=[2,2], nkeypoints=100):
super(LatticeModel,self).__init__()
self.combined_calibrators = tfl.layers.ParallelCombination()
for ind,i in enumerate(range(2)):
calibration_layer = tfl.layers.PWLCalibration(input_keypoints=np.linspace(0,1,nkeypoints),output_min=0.0, output_max=nodes[ind])
self.combined_calibrators.append(calibration_layer)
self.lattice = tfl.layers.Lattice(lattice_sizes=nodes,interpolation="simplex")
def call(self, x):
rescaled = self.combined_calibrators(x)
feat = self.lattice(rescaled)
return feat
#we define some input data
x1 = np.random.randn(100,1).astype(np.float32)
x2 = np.random.randn(100,1).astype(np.float32)
inputs = tf.concat([x1,x2], axis=-1)
#we initialize out model, and feed it with a batch of size 100
model = LatticeModel()
model(inputs)
### now we would like to efficiently predict the output of the lattice model on many batches of data at once (in this case 2)
batched_inputs = np.random.randn(2,100,1)
model(batched_inputs)
Thanks a lot! Matías.
The model you have constructed here expects an input of shape (B, 2)
. The input you are passing in your last call is (2, 100, 1)
. Maybe you meant to pass (2, 100, 2)
, which would be 2 batches, each of shape (100, 2)
. A general way of approaching this is to use tf.reshape
batched_inputs = np.random.randn(num_batches, batch_size, input_dim)
reshaped_batched_inputs = tf.reshape([-1, input_dim]) # would be of shape (num_batches * batch_size, input_dim)
flat_preds = model(batched_inputs) # would be of shape (num_batches * batch_size, 1)
preds = tf.reshape([num_batches, batch_size])
You can do the reshaping inside the model before and after the call to the layers.
cool, many thanks!