mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Add chronos model(s) to mlx-lm

Open sugatoray opened this issue 1 year ago • 6 comments

I would like to add the chronos model(s). Looking for feedback or suggestion from the maintainers.

  • paper: https://arxiv.org/abs/2403.07815
  • model: https://huggingface.co/amazon/chronos-t5-large

sugatoray avatar Mar 14 '24 18:03 sugatoray

Please assign this one to me. I have started working on this one already. Will send a draft PR soon to collaborate / get suggestions.

cc: @awni @mzbac

sugatoray avatar Mar 14 '24 18:03 sugatoray

Cool! Although I'm wondering how that will go to do encoder/decoder style models in MLX LM. We have a T5 example you can use as a reference.

If it doesn't add too much complexity I would support allowing T5 style models in MLX LM, but otherwise it might make sense to have an alternative package or repo for such things.

awni avatar Mar 14 '24 19:03 awni

@awni Is there any mlx.nn equivalent of torch.nansum (if np.nansum is to be avoided)?

sugatoray avatar Mar 15 '24 15:03 sugatoray

We don't have such an operation, sorry! You could do something like:

def nansum(x):
  return mx.sum(mx.where(mx.isnan(x), 0, x))

awni avatar Mar 15 '24 16:03 awni

Thanks, I looked at the torch.nansum implementation as well.

https://github.com/pytorch/pytorch/blob/014f91a9d9f94ac9a7f0711600240d7cd7f69844/torch/_decomp/decompositions.py#L4277-L4278

def nansum(x: mx.array, axis: int=-1):
    return mx.sum(mx.where(mx.isnan(x), 0, x), axis=axis)

mx.nansum = nansum

@awni Can we add this as a function to mlx.nn? Would you suggest some edits in that case?

Test

testnp = np.array([1, 2., 0., np.nan, -3.5])
testmx = mx.array(testnp)
testch = torch.tensor(testnp)

torch.nansum(testch) == torch.sum(torch.where(torch.isnan(testch), 0, testch), dim=-1) # True
torch.nansum # tensor(-0.5000, dtype=torch.float64)

mx.nansum(testmx) # array(-0.5, dtype=float32)

sugatoray avatar Mar 15 '24 16:03 sugatoray

@awni What should I use as a substitution for torch.tensor.median()?

  • is there anything equivalent in mx.array?
  • can I instead use np.median(x, axis=1) as a placeholder for torch.median(x, dim=-1)?

sugatoray avatar Mar 15 '24 17:03 sugatoray