lightweight_mmm icon indicating copy to clipboard operation
lightweight_mmm copied to clipboard

JAX GPU with Apple M2

Open ar-asur opened this issue 1 year ago • 0 comments

Hi team - I am building a MMM model with 8 channels and 2 extra features. It is a geo model at US state level (51 geos). It takes couple of hours for the model fit which is a hurdle for experimentation. I was unsuccessful in getting the JAX to work with M2 GPU. Would like to know any suggestions or method developers have used to get it work? Thanks!

ar-asur avatar Jul 13 '23 18:07 ar-asur