Nanoflow icon indicating copy to clipboard operation
Nanoflow copied to clipboard

Regarding GEMV.AG and O.AG

Open rajagond opened this issue 8 months ago • 9 comments

Hi,

In standard tensor parallelism, we typically have: Attention → Output Projection → All-Reduce → LayerNorm → FFN → All-Reduce → LayerNorm.

However, in the paper, you use GEMV.AG and O.AG. I didn't fully understand this part. Could you please briefly explain the math behind it and why you use 2 AG + 1 AR instead of the conventional 2 AR?

Thanks

rajagond avatar Apr 30 '25 03:04 rajagond

In A100, the all reduce is often implemented as Reduce Scatter and All Gather. In terms of total execution time, AR = 2*AG. However, use GEMV.AG and O.AG allow us to move GEMV.AG to a place where the network is underutilized and effectively "start early", thus can reduce pipeline bubbles.

serendipity-zk avatar Apr 30 '25 03:04 serendipity-zk

In my experiment, I found that AG is approximately 55–58% the cost of AR. However, I still don't fully understand the math behind it.

For example, in LLaMA-70B on 8×A100, the attention output is of shape [num_tokens, 1024]. When we multiply this with the o_projection weights of shape [1024, 8192], the result is [num_tokens, 8192], followed by an all-reduce operation.

How does this process look with GEMV.AG and O.AG? Have you replicated some parts of the Attention block (e.g., QKV or o_projection)? If so, what was the motivation behind this design?

rajagond avatar Apr 30 '25 04:04 rajagond

2AG: Aggregate attention output [num_tokens, 1024] to [num_tokens, 8192], multiply with O [1024, 8192] (N dim partitioned) to get [num_tokens, 1024], then AG to get [num_tokens, 8192].

AR: attention output [num_tokens, 1024], multiply with O [8192, 1024] (K dim partitioned) to get [num_tokens, 8192], then AR to [num_tokens, 8192].

serendipity-zk avatar Apr 30 '25 04:04 serendipity-zk

Thanks, this was really helpful. Also, you might want to update the snapshot ID of the Hugging Face models. The snapshot ID is hardcoded in the source code, and I had to update it to run the model on my machine. It’s possible that earlier snapshots are no longer available.

rajagond avatar Apr 30 '25 09:04 rajagond

Could you also clarify what you mean by batch_size?
Does batch_size refer to the shape (B, seq_length)?

global_batch_size = 2048
decode_batch_size = 1280
prefill_batch_size = 768

Above seems to refer B x L whereas below seems to refer only B.

def getGemvTime(df_gemv, batch_size, seq_len):
    # 1. Sort and extract the two axes we care about
    df_sorted = df_gemv.sort_values(by=['batch_size', 'seqlen']).copy()
    batch_sizes = df_sorted['batch_size'].unique()
    seqlens     = df_sorted['seqlen'].unique()
    
    # 2. Reshape the measured times into a 2-D grid
    #    (assumes one sample per (batch, seq) pair in df_gemv)
    times = df_sorted['gemv_time'].values.reshape(
        len(batch_sizes),
        len(seqlens)
    )
    
    # 3. Build a 2-D interpolator
    interpolator = RegularGridInterpolator(
        (batch_sizes, seqlens),  # the grid points
        times,
        method='linear',
        bounds_error=False,      # allow extrapolation
        fill_value=None
    )
    
    # 4. Query at the new (batch_size, seq_len)
    pt = np.array([batch_size, seq_len])
    est = interpolator(pt)

rajagond avatar May 01 '25 21:05 rajagond

How do I run it on other GPUs? It seems like I have to manually profile each operation for different sizes and SMs before I can use autosearch.

rajagond avatar May 01 '25 22:05 rajagond

@serendipity-zk

rajagond avatar May 02 '25 09:05 rajagond

Here batchsize means the token number per round, which is not the number of request.

Wazrrr avatar Aug 11 '25 20:08 Wazrrr

To run in multi-gpu, we have built an engine here Multi-GPU engine script, which would be helpful for you to build your own multi-gpu model.

Wazrrr avatar Aug 11 '25 20:08 Wazrrr