serket
serket copied to clipboard
`at` + regex based matching for sharding
at
can match paths using regex pattern, use this feature and write a sample example on how set sharding in a similar sense how the new keras API implements it 1.
Something along the following lines,
Note that this should work with arbitrary pytrees (e.g. flax params dict)
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import jax
from jax.sharding import Mesh, NamedSharding as N, PartitionSpec as P
import numpy as np
import serket as sk
import re
class FeedForward(sk.TreeClass):
def __init__(self, d_model: int, *, key: jax.Array):
k1, k2 = jax.random.split(key)
self.linear1 = sk.nn.Linear(d_model, d_model * 4, key=k1)
self.linear2 = sk.nn.Linear(4 * d_model, d_model, key=k2)
def __call__(self, input: jax.Array) -> jax.Array:
return self.linear2(jax.nn.relu(self.linear1(input)))
ff = FeedForward(d_model=128, key=jax.random.PRNGKey(0))
ff = sk.tree_mask(ff) # hide non in-exact types
mesh = Mesh(np.array(jax.devices()).reshape(2, 4), axis_names=["data", "model"])
# select layers start with `linear`
sharding = sk.at(ff)[re.compile("linear.*")]["weight"].set(N(mesh, P("model", None)))
sharding = sk.at(sharding)[re.compile("linear.*")]["bias"].set(N(mesh, P("model")))
ff = jax.device_put(ff, sharding)
def vis_sharding(path, leaf):
print(jax.tree_util.keystr(path))
jax.debug.visualize_array_sharding(leaf)
jax.tree_util.tree_map_with_path(vis_sharding, ff)