ScikitLearn.jl
ScikitLearn.jl copied to clipboard
Bug with weighted DBSCAN
Problem
I'm using DBSCAN to find clusters in a 3D dataset that varies with time. Every now and again (<5% of the time), DBSCAN fails completely to see a very obvious cluster. It's sometimes possible to make it work by simply circshift
ing the array, but not always.
There seems to be no clear reason why it fails, it just sometimes does.
Please see the following example (file is included for replication purposes):
using JLD2
using ScikitLearn
using PyCall
# Wrapper for DBSCAN
DBSCAN = pyimport("sklearn.cluster").DBSCAN
# Load data
f = jldopen("DBSCAN_BUG.jld2")
x, y, z = f["x"], f["y"], f["z"]
# Format data such that each voxel is given as an (x,y,z) coordinate
X = repeat(x',length(y),1,length(z)) .+ 2*maximum(x)
Y = repeat(y,1,length(x),length(z)) .+ 2*maximum(y)
Z = permutedims(repeat(z,1,length(x),length(y)),[3 2 1]) .+ 2*maximum(z)
dat = zeros(length(X[:]),3)
dat[:,1] = X[:]
dat[:,2] = Y[:]
dat[:,3] = Z[:]
# Perform DBSCAN where points are weighted by density array
decomp = DBSCAN(eps=abs(x[1]-x[2]),min_samples=1).fit_predict(dat,sample_weight=f["dens"][:])
dbscan = replace!(reshape(decomp,size(f["dens"])),-1=>0)
dbscan[dbscan.>0] .= 1.0
# Plot DBSCAN results alongside the density
using CairoMakie
fig = Figure()
ax, hm1 = heatmap(fig[1,1], x, y, f["dens"][:,:,72])
ax, hm2 = heatmap(fig[2,1], x, y, dbscan[:,:,72])
fig
Expected result
There should be a yellow blob in the second heatmap, corresponding to the identified (very obvious) cluster
Isn't that a problem with the scikit-learn library? ScikitLearn.jl is just an interface to the python scikit-learn. If so, I would encourage you to translate your example to Python and post it there.
Hmm, yes probably, I'll see if im able to translate it
I'm actually not sure it is an issue with scikit-learn, as it works just fine using it natively in python, (see below):
import numpy as np
import h5py
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt
import matplotlib.colors as colors
# Load data
f = h5py.File("DBSCAN_BUG.jld2", "r")
x, y, z = f["x"][:], f["y"][:], f["z"][:]
# Format data such that each voxel is given as an (x,y,z) coordinate
X = np.repeat(x, len(y) * len(z)).reshape(len(x), len(y), len(z), order='F') + 2 * np.max(x)
Y = np.repeat(y, len(x) * len(z)).reshape(len(y), len(x), len(z), order='C') + 2 * np.max(y)
Z = np.repeat(z, len(x) * len(y)).reshape(len(z), len(x), len(y), order='F').transpose((1, 2, 0)) + 2 * np.max(z)
dat = np.vstack((X.ravel('F'), Y.ravel('F'), Z.ravel('F'))).T
# Perform DBSCAN where points are weighted by density array
decomp = DBSCAN(eps=np.abs(x[0] - x[1]), min_samples=1).fit_predict(dat, sample_weight=f["dens"][:].ravel())
dbscan = np.reshape(np.where(decomp != -1, 1, 0), f["dens"].shape)
# Plot DBSCAN results alongside the density
fig, axs = plt.subplots(2, 1)
hm1 = axs[0].imshow(f["dens"][72, :, :], norm=colors.LogNorm())
hm2 = axs[1].imshow(dbscan[72, :, :], cmap='binary')
plt.show()
I'm not super-familiar with DBScan, I scanned your code and nothing looked obviously wrong. Beware that
dbscan = replace!(reshape(decomp,size(f["dens"])),-1=>0)
reshape
is a view, so this line is also mutating decomp
. But that shouldn't modify the outcome.
Beyond that, I can't offer advice other than: try to figure out what's different in Python and Julia. Ultimately, it's the same library doing the work, so presumably the inputs (or the plotting) is different.