umap icon indicating copy to clipboard operation
umap copied to clipboard

Can't pickle AlignedUMAP

Open agombert opened this issue 4 years ago • 4 comments

Hey,

First thanks for this awesome repo and biblio, I really enjoyed dive in the doc !

I experienced an issue when I wanted to save the AlignedUMAP function. Here some sample code to reproduce:

import umap.aligned_umap
import pickle

aligned_mapper_10D = umap.AlignedUMAP(metric='cosine', n_components=10,  n_neighbors=50,
                                  init='spectral', transform_seed=42, random_state=11)

n = 400
X = np.random.normal(0, 1, (1000, 50))
Y = np.random.normal(0, 1, (500, 50))
Y = np.concatenate((X[-n:], Y))
relations = {600+i: i for i in range(n)}

aligned_mapper_10D.fit([X, Y], relations=[relations])
pickle.dump(aligned_mapper_10D, open(join(PATH_DATA, 'test.pkl'), 'wb'))

And I got the following error:

TypeError: cannot pickle '_nrt_python._MemInfo' object

I saw it could be related to the #273 issue, but I could not manage to solve it. Has anyone succeeded in solving it?

I also add my versions of packages:

numba==0.52
umap==0.5.1
pynndescent==0.5.2

Thanks,

Arnault

agombert avatar May 19 '21 11:05 agombert

Well, I think I may have a solution to save and reload the model. I think the problem comes from numba List that cannot be saved.

If we take back the example:

import umap.aligned_umap

aligned_mapper_10D = umap.AlignedUMAP(metric='cosine', n_components=10,  n_neighbors=50,
                                  init='spectral', transform_seed=42, random_state=11,
                                     low_memory=True)


X = np.random.normal(0, 1, (1000, 50))
Y = np.random.normal(0, 1, (600, 50))
Y = np.concatenate((X[-400:], Y))
Z = np.random.normal(0, 1, (600, 50))
Z = np.concatenate((Y[-400:], Z))
relations = {600+i: i for i in range(400)}
relations_ = {600+i: i for i in range(400)}

aligned_mapper_10D.fit([X, Y], relations=[relations])

Thus with a kind of functions like this yo save the embeddings - you need to quit the numba List type.

params = aligned_mapper_10D.get_params()
attributes_names = [attr for attr in aligned_mapper_10D.__dir__() if attr not in params and attr[0] != '_']
attributes = {key: aligned_mapper_10D.__getattribute__(key) for key in attributes_names}
attributes['embeddings_'] = list(aligned_mapper_10D.embeddings_)
for x in ['fit', 'fit_transform', 'update', 'get_params','set_params']:
    del attributes[x]

all_params = {'umap_params': params, 
              'umap_attributes': {key:value for key, value in attributes.items()}
             }
pickle.dump(all_params, open(join(PATH_REPO, 'data', 'test.pkl'), 'wb'))

Then reload everything with pickle :

from numba.typed import List

params_new  = pickle.load(open(join(PATH_REPO, 'data', 'test.pkl'), 'rb'))
new_aligned_mapper_10D = umap.AlignedUMAP()
new_aligned_mapper_10D.set_params(**params_new.get('umap_params'))
for attr, value in params_new.get('umap_attributes').items():
    new_aligned_mapper_10D.__setattr__(attr, value)
new_aligned_mapper_10D.__setattr__('embeddings_', List(params_new.get('umap_attributes').get('embeddings_')))

And the new_aligned_,apper_10D should work.

new_aligned_mapper_10D.update(Z, relations=relations_)

agombert avatar May 19 '21 16:05 agombert

Yeah; numba has some pickleability issues that can be tricky to track down. In principle I can fix this with an appropriate __get_state__ and __set_state__ which effectively does the kind of thing you've done here. If you are willing to look into that and submit a PR I would be happy to help work through it and get it merged. If you don't have that time then I will probably get to this, but I admit it is not a priority for me right now. Thanks for at least documenting your solution -- at the very least it goes a long way to helping others who encounter the issue and is greatly appreciated.

lmcinnes avatar May 21 '21 18:05 lmcinnes

thanks for the reply. I haven't got the time to do a PR, but I put it in my TODO, I hope to do so quick !

agombert avatar Jul 05 '21 15:07 agombert

Perhaps in later versions, an additional attribute appeared that needs to be removed for the code to work. Here is my version

    for x in [
        "fit",
        "fit_transform",
        "update",
        "get_params",
        "set_params",
        "get_metadata_routing",   # additional attr
    ]:
        del attributes[x]
    # end for

vasja34 avatar Sep 04 '23 15:09 vasja34