machine_learning_POC
machine_learning_POC copied to clipboard
sparsemax_plot issue.
Hi! I am a reader of your wonderful blog: https://towardsdatascience.com/what-is-sparsemax-f84c136624e4
So I am curious on how the sparsemax function looks like at 2-dimension space, and I tried to plot, but I find the plot does not allign with the plot from the paper (https://miro.medium.com/v2/resize:fit:640/format:webp/1*EK55_sEqmQDSXcIV2CGLKg.png)
My code and plot as below:
import numpy as np
import matplotlib.pyplot as plt
# this is the code from the repo, not from blog
def sparsemax(z):
sum_all_z = sum(z)
z_sorted = sorted(z, reverse=True)
k = np.arange(len(z))
k_array = 1 + k * z_sorted
z_cumsum = np.cumsum(z_sorted) - z_sorted
k_selected = k_array > z_cumsum
k_max = np.where(k_selected)[0].max() + 1
threshold = (z_cumsum[k_max-1] - 1) / k_max
return np.maximum(z-threshold, 0)
# plot
z = np.linspace(-3, 3, 50)
sparsemax_values = np.array([sparsemax(np.array([zi])) for zi in z])
plt.figure(figsize=(8, 6))
plt.plot(z, sparsemax_values, label='Sparsemax Output')
plt.title('Sparsemax Activation Function')
plt.xlabel('Input values (z)')
plt.ylabel('Sparsemax Output')
plt.legend()
plt.grid(True)
plt.show()
The resulted plot is below