gpt-neox icon indicating copy to clipboard operation
gpt-neox copied to clipboard

Integrate distilling

Open StellaAthena opened this issue 4 years ago • 5 comments

@preethamgali wrote a model distilling framework here which we should aim to integrate into GPT-NeoX

StellaAthena avatar May 04 '21 14:05 StellaAthena

Just moving the same is enough or should there be any changes to it?

preethamgali avatar May 09 '21 02:05 preethamgali

Just moving that same is enough or should there be any changes to it?

Much of the code can be copy and pasted from distiller.py, it just needs to be modified to work with the GPT-NeoX framework.

StellaAthena avatar May 09 '21 23:05 StellaAthena

The distillation code follows only data-parallelism. You are aware right? or should we need to use the model from the framework.

preethamgali avatar May 10 '21 04:05 preethamgali

@preethamgali and I discussed this on discord. We do want to use the GPT-NeoX modeling framework and to capture as much of the optimizations that our code provides.

What I had in mind was to make the student model the last stage(s) in a pipeline, so that instead of having T1 -> T2 -> T3 -> T4 and S1 -> S2 you have a single model T1 -> T2 -> T3 -> T4 -> S1 -> S2. Then when you do backprop you just stop after finishing the student model. The teacher and the student models likely will have different widths though, and I’m not sure if that’ll do anything wonky.

@sdtblck suggested running something like

    teacher = GPT2ModelPipe(**kwargs)
    student = GPT2ModelPipe(**student_kwargs)
    ...
    
    teacher, _, _, _ = deepspeed.initialize(
            model=teacher,
            optimizer=optimizer,
            ...)
    student, optimizer, _, lr_scheduler = deepspeed.initialize(
            model=student,
            optimizer=optimizer,
            ...)

However I worry that DS will not play nicely with multiple models. This is purely conjecture though, and Sid’s suggestion is absolutely worth trying. Or more realistically, asking DS people about.

StellaAthena avatar May 13 '21 12:05 StellaAthena

The cross_entropy loss function used in the framework is mpu.vocab_parallel_cross_entropy, which has implemented to work on 3D parallelism. But for distillation, we need KLDivLoss, MSELoss, CosineEmbedding loss functions. So we also need to implement the same for these losses as well. @StellaAthena please raise the feature request.

preethamgali avatar May 14 '21 13:05 preethamgali

We are abandoning this effort as unsuccessful

StellaAthena avatar Sep 18 '22 15:09 StellaAthena