recommenders
recommenders copied to clipboard
Basic Recommender Ranking: Big Optimization issue
Hi everyone,
I am a Data Scientist and I start following your examples on Tensorflow recommenders few weeks ago.
I noticed that the Basic Recommender (ranking) is deeply flawed.
Here is the reason:
- Try to plot the ranking results as histogram or kde
- Do the same with the labels from the test dataset
- Plot also a random integer distribution (from 1 to 5).
If you add the following Code to the basic ranking recommender notebook. You should get my same results.
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')
res_data = pd.DataFrame()
# get predictions from cached test
res_data['predictions'] = model.predict(cached_test)[:, 0]
# add random integer distribution
res_data['random'] = np.random.randint(1,6, len(res_data))
# get user rating from test dataset
test_labels = []
for r in cached_test:
test_labels.append((r['user_rating']).numpy())
res_data['test_labels'] = np.concatenate(test_labels)
# plot everythin as kde
plt.figure(figsize=(10,7), dpi=100)
sns.kdeplot(data=res_data, fill=True, bw_adjust=0.9, alpha=0.6, linewidth=0, legend=False)
plt.legend(["Predictions", "Random Monkey", "Test Labels"][::-1], title="Legend", fontsize=12, title_fontsize=16)
plt.title('Predictions vs. test labels', fontsize=20);
Results:
Did you get the issue? Is it normal that our predictions are not able to properly rank the input data since they are distributed in a gaussian way around a mean value of 3.5? Did I miss something?
Thank you in advance! @albertvillanova @maciejkula @MarkDaoust @hojinYang
It looks like the model is undertrained, and just learning the average.
It you train it longer, it learns a similar distribution to the test labels. It's easy to overfit. We should add some validation data. This does okay (would be better with an EarlyStopping callback):
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))
model.fit(cached_train, epochs=75, validation_data=cached_test)
Hi, I tried to do with more epochs (100) but the results is a broader distribution around average (both with and without EarlyStopping). Can there be a issue within the task.Ranking ? Or maybe the batch size is too big?
Yes. Broader is exactly what I'd expect since this is being approached as a regression. MSE loss doesn't encourage sharpness, it tends to give a blurry result. It doesn't meant to predict exactly 1.0, 2.0, 3.0,... because it needs to hedge its bets. It's approximating a discrete distribution as a continuous distribution. The plot would be more fair if the labels were plotted as a bar chart.
If you want integer predictions or train this as a classifier (you could round off the predictions).
@rageSpin could you send a PR to add your plotting code to the notebook (without the random-monkey)?
Yes. Broader is exactly what I'd expect since this is being approached as a regression. MSE loss doesn't encourage sharpness, it tends to give a blurry result. It doesn't meant to predict exactly 1.0, 2.0, 3.0,... because it needs to hedge its bets. It's approximating a discrete distribution as a continuous distribution. The plot would be more fair if the labels were plotted as a bar chart.
If you want integer predictions or train this as a classifier (you could round off the predictions).
@rageSpin could you send a PR to add your plotting code to the notebook (without the random-monkey)?
Yeah, you're perfectly right. Maybe I'll try different loss function or use a classifier (and compare it with XGBoost). I redo the plot with a bar-chart, which version should I send as PR?
Since I really like Tensorflow, if you need some help for documentation or debugging I am available for further tasks.
(sorry for random-monkey term, but in statistics we use it a lot to describe random behavior in an amusing way😂)
Thanks for the PR! LGTM.
Don't forget to follow the CLA instructions the bot posted: https://github.com/tensorflow/recommenders/pull/593#issuecomment-1352734325
We can't merge the PR unless you do.
Thanks for the PR! LGTM.
Don't forget to follow the CLA instructions the bot posted: #593 (comment)
We can't merge the PR unless you do.
Done!