math icon indicating copy to clipboard operation
math copied to clipboard

Add Multinomial GLM

Open jachymb opened this issue 10 months ago • 1 comments

We already have binomial_logit_glm and categorical_logit_glm

Multinomial distribution generalizes both Binomial and Categorical, is still in the exponential family and the logit (softmax) parametrization seem natural - we already have multinomial_logit. So multinomial_logit_glm is a missing piece to be completed.

I'm actually using this in an application (modelling counts of sold products each day), where I simply iterate multinomial_logit, but it would be cool to have an optimized single-call function.

Related issue: https://github.com/stan-dev/math/issues/1964

jachymb avatar Feb 21 '25 11:02 jachymb

This is my STAN implementation, trying to avoid for-loops for better vectorization

real multinomial_logit_glm_lpmf(array[,] int y, matrix x, row_vector alpha, matrix beta) {
    // beta and alpha have 1 less DOF than columns. See: https://mc-stan.org/docs/stan-users-guide/regression.html#identifiability

    int num_categories = cols(beta);
    int num_events = rows(x); // one event = one independent multinomial draw

    vector[num_categories] ones = rep_vector(1, num_categories);  // used for matrix row-sum

    matrix[num_events, num_categories] y_mat = to_matrix(y);
    matrix[num_events, num_categories] predictors = x * beta;
    predictors += rep_matrix(alpha, num_events);  // this actually corresponds to the scalar alpha in univariate GLMs.  Maybe also allow matrix alpha as parameter overload

    real lp = 0;

    lp += sum(lgamma(y_mat * ones + 1));  // log multinomial numerator
    lp -= sum(lgamma(y_mat + 1)); // log multinomial denominator
    // note: we call lgamma over reals for coding convenience. But is it maybe better to call it specialized over ints?

    lp += sum(y_mat .* predictors); // log softmax numerator. == trace(y_mat * predictors'). see [#3161]
    lp -= sum(log(exp(predictors) * ones)' * y_mat);    // log softmax denominator.
    // Can we have log_softmax or at least log_sum_exp or to run row-wise over a matrix? That would feel better

    return lp;
}

jachymb avatar Mar 13 '25 11:03 jachymb