UMAP.jl
UMAP.jl copied to clipboard
WIP: supervised learning, data views
I've come back to doing some supervised UMAP things, and had the need to use a different number of n_neighbors
when building the fuzzy simplex representation of the "supervising" data vs the usual data, as well as to use non-categorial data. To accommodate these needs, I "unwrapped" the main UMAP_
function into separate functions for each piece (build the KNNs, then build the fuzzy simplex i.e. the graph, then generate the embedding), which I chose to represent with types, like UMAP_
itself did.
I think this is similar to your plans @dillondaudert as discussed in https://github.com/dillondaudert/UMAP.jl/issues/27 and the wiki, although my immediate goal was just to get something working for my use case so I did not adhere super closely to your guidance there. However, I thought I could put this up as a working draft and over time try to bring it to compliance with how you'd like it done.
I also dropped support for transform
but I think that could be brought back without too much work.
Example: MNIST
Like in my previous attempt (https://github.com/dillondaudert/UMAP.jl/pull/29), I've used the MNIST example. I didn't compare to python this time because for some reason I got a segfault when I tried. MNIST is a bit interesting because you can treat the numeric labels as categorical, but you could also treat them as numbers, and say e.g. 5 is closer to 6 than it is to 7. However, since it's highly degenerate data (there are many 5s, for example), you need n_neighbors
greater than the number of examples you have of each piece of data, in order to recover something meaningfully different than the categorical case. (It actually works OK as categorical, even though all the dists are zero, thanks to the left_min
and right_min
stuff in the general fuzzy intersection code that came from the python-umap implementation. Without that, if you do a naive graph1 .* graph2
, you get garbage.)
max_weight = 0.001
max_weight = 0.5
max_weight = 0.999
This all looks great to me, I appreciate the effort to get this functionality into UMAP.jl. I haven't been as focused on the refactor (in the v0.2-dev
branch) lately, but this PR is certainly motivating me.
I want to get the v0.2-dev
branch into a useable state so that PRs like this can build on top of it ASAP. The good news there is all the functionality for supervised UMAP comes before the optimization step, so that's what I can focus on first. I'm adding documentation and examples as I go along (see the Pluto notebook in that branch at docs/examples/advanced/advanced_usage.jl
).
Supervised UMAP (and categorical data in general) does require some special handling, which you've pointed out. As a baseline, I want to handle it identically to the python implementation which I don't recall off the top of my head at the moment. It is already possible to treat numeric labels as continuous features as well on that branch (with separate KNN parameterization, etc).
This all looks great to me, I appreciate the effort to get this functionality into UMAP.jl. I haven't been as focused on the refactor (in the v0.2-dev branch) lately, but this PR is certainly motivating me.
Awesome, glad to hear it!
I want to get the v0.2-dev branch into a useable state so that PRs like this can build on top of it ASAP. The good news there is all the functionality for supervised UMAP comes before the optimization step, so that's what I can focus on first. I'm adding documentation and examples as I go along (see the Pluto notebook in that branch at docs/examples/advanced/advanced_usage.jl).
Cool, thanks for the pointer! Looks good :)
Supervised UMAP (and categorical data in general) does require some special handling, which you've pointed out. As a baseline, I want to handle it identically to the python implementation which I don't recall off the top of my head at the moment. It is already possible to treat numeric labels as continuous features as well on that branch (with separate KNN parameterization, etc).
Yes, makes sense. I see we've both implemented the general fuzzy intersection method; in my new commit https://github.com/dillondaudert/UMAP.jl/pull/36/commits/8b762a512d2c78f2fd8e249c9cacec2cc15c4b60 I made a very literal translation of the python code (I think... I found the code hard to read so maybe I mixed something up, although I get 1e-5 relative error between my implementation and the python one which at least is kinda small); yours in https://github.com/dillondaudert/UMAP.jl/compare/v0.2-dev#diff-47c27891e951c8cd946b850dc2df31082624afdf57446c21cb6992f5f4b74aa2R20-R37 looks much more readable, although I think indexing into sparsearrays in that way won't be the most performant. Hopefully we can converge on something readable and performant :).
All existing functionality has been reimplemented on branch v0.2-dev. The largest outstanding piece is ironing out how to combine multiple dataset views, including when those views are for categorical data (your use case). This latter case is handled specially in the Python code by this function https://github.com/lmcinnes/umap/blob/e077dfd46b2086f865ae8d4e1c2ed8f801bf0656/umap/umap_.py#L711 and I am still reviewing the implementation details there.
Awesome! My version of that is https://github.com/dillondaudert/UMAP.jl/blob/de7a46771cdb3d11cb115fe07e298f1be83c2ae6/src/utils.jl#L157-L173, by the way.
Hey @dillondaudert, hope all is well!
I was wondering what the status of 0.2-dev was and in particular what steps are needed to bring it to a release-- I think if that was written out somewhere outside contributors might be able to help chip away at it.