GPJax
GPJax copied to clipboard
bug: gpx.Dataset used in the core fit API is invalid?
There might be a work around, but not sure yet.
For example, this kernel requires Int
https://github.com/JaxGaussianProcesses/GPJax/blob/main/gpjax/kernels/non_euclidean/categorical.py#L90
But fit requires Array which only supports one type?
https://github.com/JaxGaussianProcesses/GPJax/blob/main/gpjax/fit.py#L43
A simple example is 2d X, with one float col and one categorical col.
Also asked here.
https://github.com/orgs/JaxGaussianProcesses/discussions/403