mixtral-offloading icon indicating copy to clipboard operation
mixtral-offloading copied to clipboard

Run without quantization

Open freQuensy23-coder opened this issue 1 year ago • 9 comments

QuantConfig is mandatory of make model function


model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)

Can I run mixtral with layer offloading, but WITHOUT quntization using this library?

freQuensy23-coder avatar Jan 22 '24 16:01 freQuensy23-coder

What hardware do you plan running the model on? It would require quite the amount of combined RAM + VRAM to run the model without quantization.

dvmazur avatar Jan 22 '24 18:01 dvmazur

I'll use Tesla A100 with 80 gb vram + 512 ram

freQuensy23-coder avatar Jan 22 '24 18:01 freQuensy23-coder

Yeah, sound like it'll fit :D

The current codebase doesn't support running the model without quantization, but you could try rewriting the expert wrapper class.

This class moves the expert's parameters to a single storage, so it later can be efficiently moved between GPU and CPU memory. Here's a snippet that does this for the original expert class:

def replace_layer_storage(layer, device):
    state_dict = layer.state_dict()

    storage_size = 0
    offsets = [0]

    for x in nested_flatten(state_dict):
        if not isinstance(x, torch.Tensor):
            continue
        storage_size += x.nbytes
        offsets.append(storage_size)

    storage = torch.UntypedStorage(storage_size, device=device) 

    i = 0
    new_flattened_states = list()
    for x in nested_flatten(state_dict):
        if not isinstance(x, torch.Tensor):
            new_flattened_states.append(x)
            continue

        start = offsets[i]
        end = offsets[i + 1]
        a_view = torch.as_tensor(storage[start:end], dtype=x.dtype, device=device).view(x.shape)
        a_view[...] = x
        assert a_view.data_ptr() == storage.data_ptr() + start
        i += 1
        new_flattened_states.append(a_view)

    state_dict = nested_pack(new_flattened_states, state_dict)

    for name, param in layer.named_parameters():
        param.data = state_dict[name]

    return layer, storage

The rest of the codebase is still quite HQQ-specific and offloading the unquantized model will require rewriting some code in the build_model.py file. Most of it boils down to replacing HQQ layers with default pytorch ones, though.

If you decide to go down that path, I can help you out a bit in this issue :)

dvmazur avatar Jan 22 '24 19:01 dvmazur

Seems like you'll be a little bit short on VRAM. Full fp16 model requires ~87GB. The table is taken from our tech report.

image

lavawolfiee avatar Jan 22 '24 20:01 lavawolfiee

Seems like you'll be a little bit short on VRAM. Full fp16 model requires ~87GB. The table is taken from our tech report.

image

I'll unload some of experts to RAM during inference, and it will use less gpu vram. It's the main idea of this lib. @dvmazur am i right

freQuensy23-coder avatar Jan 22 '24 20:01 freQuensy23-coder

If you decide to go down that path, I can help you out a bit in this issue :)

Thanks, I’d appreciate your help with this. Also i 'll try to do it myself today's evening.

freQuensy23-coder avatar Jan 22 '24 21:01 freQuensy23-coder

@freQuensy23-coder, yes, you are right - @lavawolfiee must have misunderstood you.

dvmazur avatar Jan 22 '24 21:01 dvmazur

I've tried to rewrite your code to add a fp16 support using your tips, but i faced some difficulties: i don't understand where exactly in replace_layer_storage we use quantization? As i think it will work with 16bits layers to? Can you help me with it?

freQuensy23-coder avatar Jan 27 '24 09:01 freQuensy23-coder

I've tried to rewrite your code to add a fp16 support using your tips, but i faced some difficulties: i don't understand where exactly in replace_layer_storage we use quantization? As i think it will work with 16bits layers to? Can you help me with it?

The snippet I sent you doesn't use quantization. It simply puts a given layer to one single storage.

dvmazur avatar Jan 27 '24 14:01 dvmazur