mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] Performance regression when yaml opening many files

Open awni opened this issue 3 months ago • 2 comments

When opening many files with Yaml and using a large MoE there is a very odd performance regression during prompt processing:

Timings done on 192 GB M2 Ultra.

With yaml loading:

Memory GB: 67.658742684 Time sec: 16.188238859176636

Without yaml loading:

Memory GB: 67.658742684 Time sec: 1.250129222869873

import time                                                                    
import os                                                                      
import mlx.core as mx                                                          
from mlx_lm import load                                                        
from mlx_lm.models.switch_layers import SwitchGLU                              
import yaml                                                                    
                                                                               
def load_yaml():                                                               
    yaml.add_constructor("!function", lambda x, y: y)                          
    for root, dirs, file_list in os.walk("/Users/awnihannun/miniconda3/lib/python3.12/site-packages/lm_eval/tasks"):
        for f in file_list:
            if f.endswith(".yaml"):                                            
                yaml_path = os.path.join(root, f)                              
                with open(yaml_path, "rb") as file:                            
                    yaml_config = yaml.full_load(file)                         
                    
dim = 2048                                                                     
idim = 768                                                                     
ne = 128
num_layers = 28
top_k = 8
layers = [SwitchGLU(dim, idim, ne) for _ in range(num_layers)]                 
mx.eval(layers)

# uncomment for ~10x slow down
#load_yaml()
inputs = mx.random.normal(shape=(1, 88, dim))                                  
indices = mx.random.randint(shape=(1, 88), low=0, high=128)

def fn(x):                                                                     
    for l in layers:                                                           
        x = l(x, indices)                                                      
    return x
mx.eval(fn(inputs))                                                            

tic = time.time()                                                              
for _ in range(10):                                                            
    mx.eval(fn(inputs))                                                        
print("Memory GB:", mx.get_peak_memory() / 1e9)
print("Time sec:", time.time() - tic)

awni avatar Sep 10 '25 14:09 awni

It looks like there is also a regression when using regular linear layers.. though it is not nearly as pronounced.

awni avatar Sep 10 '25 14:09 awni

I can’t reproduce the slowdown locally. I think get_peak_memory only gives you memory usage of metal device.

Likely cause is Python GC overhead from many short‑lived YAML objects.

How about try to disable garbage collection for eval and see if the performance improved?

import gc
gc.disable()

CC-Yeh avatar Sep 22 '25 20:09 CC-Yeh