[BUG] Performance regression when yaml opening many files
When opening many files with Yaml and using a large MoE there is a very odd performance regression during prompt processing:
Timings done on 192 GB M2 Ultra.
With yaml loading:
Memory GB: 67.658742684 Time sec: 16.188238859176636
Without yaml loading:
Memory GB: 67.658742684 Time sec: 1.250129222869873
import time
import os
import mlx.core as mx
from mlx_lm import load
from mlx_lm.models.switch_layers import SwitchGLU
import yaml
def load_yaml():
yaml.add_constructor("!function", lambda x, y: y)
for root, dirs, file_list in os.walk("/Users/awnihannun/miniconda3/lib/python3.12/site-packages/lm_eval/tasks"):
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
dim = 2048
idim = 768
ne = 128
num_layers = 28
top_k = 8
layers = [SwitchGLU(dim, idim, ne) for _ in range(num_layers)]
mx.eval(layers)
# uncomment for ~10x slow down
#load_yaml()
inputs = mx.random.normal(shape=(1, 88, dim))
indices = mx.random.randint(shape=(1, 88), low=0, high=128)
def fn(x):
for l in layers:
x = l(x, indices)
return x
mx.eval(fn(inputs))
tic = time.time()
for _ in range(10):
mx.eval(fn(inputs))
print("Memory GB:", mx.get_peak_memory() / 1e9)
print("Time sec:", time.time() - tic)
It looks like there is also a regression when using regular linear layers.. though it is not nearly as pronounced.
I can’t reproduce the slowdown locally.
I think get_peak_memory only gives you memory usage of metal device.
Likely cause is Python GC overhead from many short‑lived YAML objects.
How about try to disable garbage collection for eval and see if the performance improved?
import gc
gc.disable()