mlx icon indicating copy to clipboard operation
mlx copied to clipboard

help debug a slow down with argmax

Open wjessup opened this issue 2 years ago • 5 comments

I'm working on the DQN code but overall it's still 10x slower than torch. I've narrowed the issue down to a line with only an argmax. Can anyone help me fix this?

On an m2 MAX 96gb.

Here's the code for part of the training loop:

for step in count():
        pre_opt_start = default_timer()

        # Explore vs exploit
        eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)

        action = None
        explore = random.random() > eps_threshold
        if explore:
            
            print("current_state = ", current_state)
            x1 = default_timer()
            preds = DQN_net(current_state)
            x2 = default_timer()
            print(f"DQN net timer: {x2 - x1:.6f} seconds.")

            z1 = default_timer()
            print("preds = ", preds)
            action = preds.argmax()
            z2 = default_timer()
            if z2 - z1 > 0.001:
                # ANSI escape code for red text
                print(f"\033[91mZ1 timer: {z2 - z1:.6f} seconds.\033[0m. action = {action}")
            else:
                print(f"Z1 timer: {z2 - z1:.6f} seconds. action = {action}")

        else:
            action = mx.array(env.action_space.sample())
          

        pre_opt_end = default_timer()
        print(f"pre OPT {pre_opt_end - pre_opt_start:.6f} seconds. explore? {explore}, action = {action}")

Notice the 3rd pre OPT loop was 0.000026 seconds, while the slowest was 0.08. You can see that Z1 = 0.085 and the ONLY thing it's doing is the argmax.


current_state =  array([-0.0210095, -0.393987, 0.0341211, 0.625354], dtype=float32)
DQN net timer: 0.000026 seconds.
preds =  array([1.04471, 0.954782], dtype=float32)
Z1 timer: 0.085405 seconds.. action = array(0, dtype=uint32)
**pre OPT 0.086100 seconds.** explore? True, action = array(0, dtype=uint32)

current_state =  array([-0.0288893, -0.589569, 0.0466282, 0.928585], dtype=float32)
DQN net timer: 0.000029 seconds.
preds =  array([1.15783, 1.00868], dtype=float32)
Z1 timer: 0.005387 seconds.. action = array(0, dtype=uint32)
**pre OPT 0.005603 seconds.** explore? True, action = array(0, dtype=uint32)

**pre OPT 0.000026 seconds.** explore? False, action = array(1, dtype=int64)

current_state =  array([-0.0563864, -0.591062, 0.0899109, 0.963984], dtype=float32)
DQN net timer: 0.000029 seconds.
preds =  array([1.15084, 0.99671], dtype=float32)
Z1 timer: 0.010032 seconds.. action = array(0, dtype=uint32)
**pre OPT 0.010220 seconds.** explore? True, action = array(0, dtype=uint32)

wjessup avatar Dec 23 '23 17:12 wjessup

You need to be careful with timing because stuff is asynchronous. So I would use an mx.eval right before each time measurement to make sure everything is done executing.

Once you do that the story will probably change about the bottleneck. Which without knowing more about your model is very hard to say what it could be. We do have some performance cliffs in some of our ops at the moment that we are optimizing. But it would still be really helpful to know more about your case and where the bottleneck is.

awni avatar Dec 23 '23 17:12 awni

Overall i'm trying to get this faster than the equivalent in pytorch, which will do 100 episodes, Total steps = 6358, with Avg time per step: 0.00089. Currently my time per step average is ~ 0.004, ~5x slower, with some of those steps being the really slow .02 seconds or more.

ok I've added the mx.eval() lines and this did move the problem into the the model, not the argmax. The model is simple, but some of the DQN net debug lines come in at .02 seconds.

Here's a gist you can run to see: https://gist.github.com/wjessup/f352251c6504fc17c05ddfa02a99c11a

One other test:

# Numpy
start_time = time.time()
preds = np.random.rand(2,)
max_i = preds.argmax()
val = preds[max_i]
end_time = time.time()
print(f"numpy  time: {end_time - start_time:.5f} ")

# vanilla python
start_time = time.time()
preds = [random.uniform(0, 1) for _ in range(2)]
m = max(preds)
max_i = preds.index(m)
val = preds[max_i]
end_time = time.time()
print(f"python time: {end_time - start_time:.5f} ")

# mlx INTS
start_time = time.time()
preds = mx.random.randint(0,1,(2,))
max_i = preds.argmax()
val = max_i.item()
end_time = time.time()
print(f"MLX int time: {end_time - start_time:.5f} ")


# mlx FLOATS
start_time = time.time()
preds = mx.random.normal((2,))
max_i = preds.argmax()
val = preds[max_i.item()]
end_time = time.time()
print(f"MLX float time: {end_time - start_time:.5f} ")
numpy  time: 0.00011 
python time: 0.00007 
MLX int time: 0.00127 
MLX float time: 0.00040 

if you remove the val = lines, val = max_i.item() or val = preds[max_i.item()] it goes much faster.

wjessup avatar Dec 24 '23 00:12 wjessup

Will take a look at the DQN bit.

But Just FYI in you other test, it makes sense that if you remove the call to .item things will seem faster. That's because there will be no evaluation so you aren't timing the full op. Also just FYI there are a few things you should be careful of when microbenchmarking ops (but especially MLX ops):

  1. You need a warmup. This is because the kernels have to be loaded the first time they are used and possibly some other machinery in the GPU backend.
  2. Do multiple iterations and get the mean / sum to reduce variance.
  3. Make sure there is an eval before and after the op you want to time. (In this case your timing includes the random number generation. Not sure if that's intentional.
  4. For such small sizes it's very likely any CPU backend will be faster than running an op on the GPU which a lot more overhead

awni avatar Dec 24 '23 05:12 awni

Thanks for the comments.

is this the right way to implement the warmup?

Regarding #3, would I put an mx.eval() with no passed argument above the first timer?

# Warmup
for _ in range(10):
    preds = mx.random.randint(0,1,(2,))
    _ = preds.argmax().item()  # Force evaluation

# Timing
times = []
for _ in range(100):
    start_time = time.time()
    preds = mx.random.randint(0,1,(2,))
    max_i = preds.argmax()
    val = max_i.item()  # Force evaluation
    end_time = time.time()
    times.append(end_time - start_time)
print(f"MLX INT time (average): {np.mean(times):.6f} seconds")

wjessup avatar Dec 24 '23 14:12 wjessup

What you have is great! Warmup is exactly right. The .item() is an implicit eval so you are good there. The only caveat is the timings will include the time of the randint. If you wanted to avoid that you could do:

# Timing
times = []
for _ in range(100):
   preds = mx. random.randint(0,1,(2,))
    mx.eval(preds)
    start_time = time.time()    
    max_i = preds.argmax()
    val = max_i.item()  # Force evaluation
    end_time = time.time()
    times.append(end_time - start_time)
print(f"MLX INT time (average): {np.mean(times):.6f} seconds")

awni avatar Dec 24 '23 15:12 awni