how can i use tfp.math.psd_kernels.Polynomial, with feature_ndims=6,
with the following code how can i use tfp.math.psd_kernels.Polynomial, with feature_ndims=6, i ussually get an error , please help.
code =
class PolynomialKernelFn(tf_keras.layers.Layer):
def init(self, bias_amplitude=0.0, slope_amplitude=1.0, shift=0.0, exponent=3.0, feature_ndims=1, **kwargs):
super(PolynomialKernelFn, self).init(**kwargs)
dtype = kwargs.get('dtype', None)
self.jitter = 1e-4; # Adding a small jitter for numerical stability
def call(self, inputs): return self.kernel;
@property
def kernel(self):
return tfp.math.psd_kernels.Polynomial( bias_amplitude=0.0, slope_amplitude=1.0, shift=0.0, exponent=3.0, feature_ndims=6, );
//end_class
//Example data
x_train = np.random.uniform(-3., 3., size=(100, 6));
y_train = np.sin(x_train[:, 0:2]); y_train = x_train;
print( 'x_train', x_train.shape, 'y_train', y_train.shape, );
x_range = [np.min(x_train), np.max(x_train)];
num_inducing_points = 18;
induc_idx_points = np.random.uniform(-3.0, 3.0, size=(num_inducing_points, x_train.shape[1])) induc_idx_point_init = tf.constant_initializer(induc_idx_points) print('induc_idx_points shape:', induc_idx_points.shape)
noise0 = ( tf.constant_initializer(np.array(0.54).astype(x_train.dtype)) );
// noise0 = ( tf.constant_initializer(np.array([ [0.54] ]).astype(x_train.dtype)) );
print( 'noise0', noise0.value.shape, );
polynomial_kernel = tfp.math.psd_kernels.Polynomial( bias_amplitude=1.0, slope_amplitude=1.0, shift=0.0, exponent=2, feature_ndims=6, validate_args=True );
// Build model
model = tf_keras.Sequential([
tf_keras.layers.InputLayer( input_shape=(x_train.shape[1],) ),
tf_keras.layers.Dense(1, kernel_initializer='ones', use_bias=False),
tfp.layers.VariationalGaussianProcess(
num_inducing_points= num_inducing_points,
kernel_provider= PolynomialKernelFn(),
event_shape= [1],
inducing_index_points_initializer= induc_idx_point_init,
unconstrained_observation_noise_variance_initializer= noise0,
),
]);
model.summary();
// Compile the model
model.compile(optimizer=tf_keras.optimizers.Adam(), loss=lambda y, p_y: -p_y.log_prob(y), );
model.fit(x_train, y_train, epochs=3, );
// Predict
x_test = np.linspace(-3., 3., 30).reshape(-1, 6);
y_pred = model.predict(x_test);
error = TypeError: Eager execution of tf.constant with unsupported shape. Tensor [[ 2.6181169 -0.63000023 0.88848853 2.8256464 0.7070044 0.10845122]
[ 1.7892019 0.9782654 0.7722825 -1.3020381 -2.2474375 -2.7869945 ]
[-1.7044429 1.3536608 1.6787233 -1.0234025 -0.56131506 -1.3617553 ]
[-0.67438734 -2.781567 -2.8270533 1.1969172 0.6445283 -2.27967 ]
[ 2.360385 -0.04742741 0.59732205 -2.933202 0.7861253 2.8390534 ]
[ 2.8995507 -2.7624567 1.4198784 2.6410117 -1.4379357 -1.7750715 ]
[ 1.152713 -1.1086357 -1.0503062 -1.4047132 2.0060854 -1.1540136 ]
[-2.5986419 0.97180235 1.4190216 -1.3796597 0.29559672 -1.4235702 ]
[ 2.4803004 2.498159 -0.7950366 1.8512287 2.291289 -0.8069504 ]
[ 2.9575715 1.8804133 2.8104806 -1.9198297 -1.3981376 2.7783432 ]
[ 0.5030356 -1.099595 0.3222227 0.6970968 -0.53885055 -0.7117024 ]
[ 2.3374352 -1.5795385 2.6345415 0.8455149 2.956886 0.8089901 ]
[ 2.72245 -0.10884786 -0.5315385 -0.4300995 -1.4305451 0.99511087]
[ 1.7424549 2.3463657 2.444201 0.5814336 -0.8842765 0.9178352 ]
[ 2.8574386 2.663537 1.0705398 0.31068817 0.22017056 1.424388 ]
[ 1.5786017 1.0860821 -2.569758 -0.1683704 -0.87796944 2.7431 ]
[-2.030426 -2.789794 -1.2321734 -0.4800363 1.2280728 2.926302 ]
[ 1.3148888 2.8946493 2.2137816 -1.0343827 1.1032671 -0.5407205 ]] (converted from [[ 2.61811684 -0.63000021 0.88848855 2.82564638 0.70700445 0.10845122]
[ 1.78920188 0.97826541 0.77228246 -1.30203812 -2.24743741 -2.78699451]
[-1.70444284 1.35366083 1.67872335 -1.02340239 -0.56131508 -1.36175523]
[-0.67438734 -2.78156708 -2.82705339 1.19691717 0.64452832 -2.27967009]
[ 2.36038489 -0.04742741 0.59732205 -2.93320213 0.78612532 2.83905331]
[ 2.89955057 -2.76245673 1.41987839 2.64101165 -1.43793566 -1.77507154]
[ 1.15271291 -1.10863564 -1.05030615 -1.40471315 2.00608531 -1.15401369]
[-2.59864192 0.97180237 1.41902165 -1.37965962 0.29559672 -1.4235702 ]
[ 2.48030032 2.49815891 -0.79503662 1.85122868 2.291289 -0.80695038]
[ 2.95757162 1.88041332 2.81048062 -1.91982973 -1.39813751 2.77834323]
[ 0.50303558 -1.09959492 0.3222227 0.69709683 -0.53885057 -0.71170238]
[ 2.33743521 -1.57953841 2.63454156 0.84551489 2.95688615 0.80899011]
[ 2.72245013 -0.10884786 -0.53153849 -0.43009949 -1.43054515 0.99511088]
[ 1.74245491 2.3463657 2.44420105 0.58143359 -0.88427652 0.91783519]
[ 2.85743859 2.66353695 1.0705398 0.31068815 0.22017056 1.42438809]
[ 1.5786017 1.08608213 -2.56975798 -0.16837039 -0.87796944 2.74310004]
[-2.03042604 -2.78979406 -1.23217341 -0.48003628 1.22807274 2.92630186]
[ 1.31488889 2.89464926 2.21378162 -1.03438269 1.10326708 -0.5407205 ]]) has 108 elements, but got shape (1, 18, None, 1) with None elements).