machine_learning_POC icon indicating copy to clipboard operation
machine_learning_POC copied to clipboard

sparsemax_plot issue.

Open pangjac opened this issue 11 months ago • 0 comments

Hi! I am a reader of your wonderful blog: https://towardsdatascience.com/what-is-sparsemax-f84c136624e4

So I am curious on how the sparsemax function looks like at 2-dimension space, and I tried to plot, but I find the plot does not allign with the plot from the paper (https://miro.medium.com/v2/resize:fit:640/format:webp/1*EK55_sEqmQDSXcIV2CGLKg.png)

My code and plot as below:

import numpy as np
import matplotlib.pyplot as plt

# this is the code from the repo, not from blog
def sparsemax(z):
    sum_all_z = sum(z)
    z_sorted = sorted(z, reverse=True)
    k = np.arange(len(z))
    k_array = 1 + k * z_sorted
    z_cumsum = np.cumsum(z_sorted) - z_sorted
    k_selected = k_array > z_cumsum
    k_max = np.where(k_selected)[0].max() + 1
    threshold = (z_cumsum[k_max-1] - 1) / k_max
    return np.maximum(z-threshold, 0)

# plot 
z = np.linspace(-3, 3, 50)
sparsemax_values = np.array([sparsemax(np.array([zi])) for zi in z])


plt.figure(figsize=(8, 6))
plt.plot(z, sparsemax_values, label='Sparsemax Output')
plt.title('Sparsemax Activation Function')
plt.xlabel('Input values (z)')
plt.ylabel('Sparsemax Output')
plt.legend()
plt.grid(True)
plt.show()

The resulted plot is below download

pangjac avatar Mar 01 '24 20:03 pangjac