bayesnf
bayesnf copied to clipboard
Integer division in MAP ensemble_size cause a crash when ensemble_size < device_count
https://github.com/google/bayesnf/blob/fb59400ab86aa16a548f6df566bc0d5ba6e19eb5/src/bayesnf/inference.py#L445
If ensemble_size < jax.device_count
then 0 particles are fitted.
In terms of the API .fit
silently fails, but .predict
gives an error, since there is a min/max operation over empty arrays.
ValueError: zero-size array to reduction operation min which has no identity
/bayesnf/spatiotemporal.py in predict(self, table, quantiles)
259 def predict(self, table, quantiles=(0.5,)):
260 test_data = self.data_handler.get_test(table)
--> 261 return inference.predict_bnf(
262 test_data,
263 self.observation_model,
/bayesnf/inference.py in predict_bnf(features, observation_model, params, model_args, quantiles, ensemble_dims, approximate_quantiles)
468 (means, scales) = forecast_params
469 forecast_means = means
--> 470 forecast_quantiles = _get_percentile_normal(
471 forecast_means,
472 scales,
/bayesnf/inference.py in _get_percentile_normal(means, scales, quantiles, axis, approximate)
82 for q in quantiles:
83 forecast_quantiles.append(
---> 84 quantile_fn(means, scales[..., jnp.newaxis], q, axis)
85 )
86 return forecast_quantiles
/bayesnf/inference.py in _normal_quantile_via_root(means, scales, q, axis)
31 res = tfp.math.find_root_chandrupatla(
32 lambda x: n.cdf(x).mean(axis) - q,
---> 33 low=jnp.amin(means) - 5 * jnp.amax(scales),
34 high=jnp.amax(means) + 5 * jnp.amax(scales),
35 value_tolerance=1e-5,
jax/_src/numpy/reductions.py in min(a, axis, out, keepdims, initial, where)
276 keepdims: bool = False, initial: ArrayLike | None = None,
277 where: ArrayLike | None = None) -> Array:
--> 278 return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
279 keepdims=keepdims, initial=initial, where=where)
280
jax/_src/numpy/reductions.py in _reduce_min(a, axis, out, keepdims, initial, where)
268 keepdims: bool = False, initial: ArrayLike | None = None,
269 where: ArrayLike | None = None) -> Array:
--> 270 return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
271 axis=axis, out=out, keepdims=keepdims,
272 initial=initial, where_=where, parallel_reduce=lax.pmin)
jax/_src/numpy/reductions.py in _reduction(a, name, np_fun, op, init_val, has_identity, preproc, bool_op, upcast_f16_for_computation, axis, dtype, out, keepdims, initial, where_, parallel_reduce, promote_integers)
99 shape = np.shape(a)
100 if not _all(shape[d] >= 1 for d in pos_dims):
--> 101 raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
102
103 result_dtype = dtype or dtypes.dtype(a)