guanfang12
guanfang12
in load_data.py, the code as follows: MN2 = np.concatenate([missing1[np.newaxis, :], (len(kp2_np)) * np.ones((1, len(missing1)), dtype=np.int64)]) MN3 = np.concatenate([(len(kp1_np)) * np.ones((1, len(missing2)), dtype=np.int64), missing2[np.newaxis, :]]) these means [i,N+1] and [M+1,j]
Because this code use the .double() data type, and it use torch.save(model, model_out_path) rather than torch.save(state_dict, model_out_path). If you change it to float(), it will be 46M.
I will organize and release the training code about half a month after the NeurIPS2023 conference (Dec 16, 2023).