Denis Vieriu
Denis Vieriu
Summary of changes: - support for scatter slice - enable index put With whole model delegation, I am seeing following crash in llama2: ``` in _verify_exported_program_signature raise SpecViolationError( torch._export.verifier.SpecViolationError: Buffer...
Fixes https://github.com/pytorch/pytorch/issues/124850 Replace previous MPSGraph nonzero construction with native nonzero op. For older OSes, fallback to CPU (previous implementation was not reliable and was comparable to CPU in speed). cc...
- Add support for cumprod. - Extend cumsum testcase to test cumprod.
Since **NHWC** is represented as a view operation in PyTorch, we can execute the convolution ops directly in NCHW if the **suggested memory format** is NHWC but the **actual memory...
Add support for MPS Int4 per channel group-wise quantization through MPSGraph. --- Testing: **AOT export** ``` python -m examples.models.llama2.export_llama --checkpoint /Volumes/Source/weights/llama2/llama2-7b/llama-2-7b/consolidated.00.pth --params /Volumes/Source/weights/llama2/llama2-7b/llama-2-7b/params.json -kv --use_sdpa_with_kv_cache --mps -d fp32 --disable_dynamic_shape -qmode...