cs-ranking icon indicating copy to clipboard operation
cs-ranking copied to clipboard

FATENetwork#fit crashes when X is a dict

Open daanvdn opened this issue 4 years ago • 3 comments

hi @kiudee, first off: thanks for sharing this very interesting project with the community. I would very much like to experiment with it, esp. because of its support for learning discrete choices.

While playing around with the api, it seems I have encountered a potential bug pertaining to the csrank.core.fate_network.FATENetwork#fit method.

The data for which I would like to learn a discrete choice model has a variable number of objects (i.e. every instance may have a different value for n_objects). According to the documentation, the csrank.FATENetwork#fit method supports this scenario by allowing X to be a dict that maps n_objects to numpy arrays:

    X : numpy array or dict
        Feature vectors of the objects
        (n_instances, n_objects, n_features) if numpy array or map from n_objects to numpy arrays

I am using the csrank.DiscreteChoiceDatasetGenerator to create some synthetic data. More specifically, I am using the csrank.DiscreteChoiceDatasetGenerator#get_dataset_dictionaries method. However, when I pass the resulting X_train and y_train to the fit method this causes the error below:

Traceback (most recent call last):
  File "C:/Users/Daan_Vandennest/Git/landc-working-dx-ml/src/main/python/models/neural_ranking.py", line 21, in <module>
    fate.fit(X_train, Y_train, verbose=True, epochs=1)
  File "C:\Users\Daan_Vandennest\.virtualenvs\landc-working-dx-ml-eYqgHMIQ\lib\site-packages\csrank\objectranking\fate_object_ranker.py", line 98, in fit
    super().fit(X, Y, **kwd)
  File "C:\Users\Daan_Vandennest\.virtualenvs\landc-working-dx-ml-eYqgHMIQ\lib\site-packages\csrank\core\fate_network.py", line 539, in fit
    _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape
AttributeError: 'dict' object has no attribute 'shape'

It seems that on line 539 in fate_network.py it is attempted to access a shape attribute of X regardless of whether Xis a numpy.array or a dict
This error can be reproduced by running the code below:

from csrank import DiscreteChoiceDatasetGenerator
from csrank import FATEObjectRanker
from csrank.losses import smooth_rank_loss


seed = 123
n_train = 10000
n_test = 10000
n_features = 2
n_objects = 5
gen = DiscreteChoiceDatasetGenerator(dataset_type='medoid', random_state=seed,
                                     n_train_instances=n_train,
                                     n_test_instances=n_test,
                                     n_objects=n_objects,
                                     n_features=n_features)

# X_train, Y_train, X_test, Y_test = gen.get_single_train_test_split()
X_train, Y_train, X_test, Y_test = gen.get_dataset_dictionaries()

fate = FATEObjectRanker(loss_function=smooth_rank_loss)
fate.fit(X_train, Y_train, verbose=True, epochs=1) 

Can you confirm that this is a bug or am I using the api in a wrong way? If it is indeed a bug could you give me some guidance as to how I can fix it? If it's a relatively straightforward fix I can implement it and make a pull request.

Thanks

daanvdn avatar Jun 12 '20 07:06 daanvdn

Thank you for the detailed report. That indeed appears to be a bug in the sense that we originally supported both fixed-length input using NumPy arrays and variable-length input using dicts which was not actively maintained.

Passing variable length inputs is of course desirable, which is why we will work on restoring that interface.

kiudee avatar Jun 12 '20 07:06 kiudee

thanks for the feedback @kiudee. Could I work around this by padding my inputs to make n_objects constant? Or will these dummy objects be too confusing to the FATENetwork causing it to not converge?

daanvdn avatar Jun 12 '20 07:06 daanvdn

Yes, one common workaround would be to pad the inputs with zeros to the maximum length. During prediction time, you would then need to use predict_scores and select the non-dummy object with the highest score.

kiudee avatar Jun 12 '20 08:06 kiudee