Yeounoh Chung
Yeounoh Chung
## 🚀 Feature We propose `XLAShardedTensor` to represent a sharded tensor that wraps around `torch.Tensor`, and `mark_sharding()` API for tensor sharding annotation. XLAShardedTensor allows annotating tensors with sharding specs and...
This is a follow-up to #3476 and contributes to #3871. The changes include: - `Compile` partitioned HLO computation graph with sharding annotations. - `PjRtComputationClient` integration to support `SPMD` sharded operations....
This helps improve matmul with bias computation and partially addresses the performance regression with ViT model. This replaces `Dot` with`CreateMatMul` op, which actually calls `Dot` internally if it's simpler multiplications...
## 🚀 Feature There are two code paths for composing and executing `torch.gather` in PyTorch/XLA, that relies on a custom heuristic to decide between the two. The heuristic requires a...
This is part of auto-sharding PoC described in #6322 . This PR addresses the following: * Adapt xla::OpSharding::UNKNOWN for implicit replication type. * Refactoring of ShardingUtil::GetOutputSharding, ShardingUtil::CreateShardedData * Remove resolved...
This implemented a PoC prototype on XLA:TPU, as described in #6322 Aside from the auto-sharding feature, `XLA_SPMD_AUTO` or ``` import torch_xla.runtime as xr xr.use_spmd(auto=True) ``` I also adapted `xla::OpSharding::UNKNOWN` to...
This addresses the following to support #6322 - add IsVirtualDeivce() - Un-const `tensors` argument to `XLAGraphExecutor::Compile` method, since they can be updated during auto-sharding paass. - refactor to remove redundant...
This enables PJRT plugin 0.45, and remove quantization openxla patch in our repo. Tested locally with ResNet.
## Why are these changes needed? Gemini model API introduced a new context caching feature that caches the prompt prefix. This PR implements enabled this new feature in GeminiClient to...