serket icon indicating copy to clipboard operation
serket copied to clipboard

`at` + regex based matching for sharding

Open ASEM000 opened this issue 1 year ago • 0 comments

at can match paths using regex pattern, use this feature and write a sample example on how set sharding in a similar sense how the new keras API implements it 1.

Something along the following lines,

Note that this should work with arbitrary pytrees (e.g. flax params dict)


import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import jax
from jax.sharding import Mesh, NamedSharding as N, PartitionSpec as P
import numpy as np
import serket as sk
import re

class FeedForward(sk.TreeClass):
    def __init__(self, d_model: int, *, key: jax.Array):
        k1, k2 = jax.random.split(key)
        self.linear1 = sk.nn.Linear(d_model, d_model * 4, key=k1)
        self.linear2 = sk.nn.Linear(4 * d_model, d_model, key=k2)

    def __call__(self, input: jax.Array) -> jax.Array:
        return self.linear2(jax.nn.relu(self.linear1(input)))


ff = FeedForward(d_model=128, key=jax.random.PRNGKey(0))
ff = sk.tree_mask(ff)  # hide non in-exact types
mesh = Mesh(np.array(jax.devices()).reshape(2, 4), axis_names=["data", "model"])
# select layers start with `linear`
sharding = sk.at(ff)[re.compile("linear.*")]["weight"].set(N(mesh, P("model", None)))
sharding = sk.at(sharding)[re.compile("linear.*")]["bias"].set(N(mesh, P("model")))
ff = jax.device_put(ff, sharding)


def vis_sharding(path, leaf):
    print(jax.tree_util.keystr(path))
    jax.debug.visualize_array_sharding(leaf)

jax.tree_util.tree_map_with_path(vis_sharding, ff)

image

ASEM000 avatar Jan 30 '24 05:01 ASEM000