DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] Re-initializing the Engine

Open BiEchi opened this issue 8 months ago • 5 comments

Describe the bug In some use cases, we have to delete the training engine after training and load it again after some operations. What is the correct way to delete the training engine so that the program goes back to its state before training?

BiEchi avatar Apr 05 '25 17:04 BiEchi

@BiEchi, can you clarify what your mean by state before training? It would be best if you could provide some code examples to illustrate.

Since you can create multiple independent engines, why not destroy the old engine and create a new one?

tjruwase avatar Apr 08 '25 13:04 tjruwase

Dear @tjruwase , thanks, I will try using destroy method. I meant exactly what you describe. Basically we launch the .py file with deepspeed, and within this .py file I want to launch an engine, then delete all the old engine and start a new engine, and the new model engine should start with exactly the same CPU/GPU memory (almost all 0) as launching the old model engine.

BiEchi avatar Apr 08 '25 17:04 BiEchi

@BiEchi, sounds good. Do let us know how it goes. ZeRO memory management is work-in-progress, so your feedback would be very helpful. Also, consider this https://deepspeed.readthedocs.io/en/latest/zero3.html#gpu-memory-management

tjruwase avatar Apr 08 '25 17:04 tjruwase

Dear @tjruwase,

It seems like the CPU memory was not freed after destroying the model engine under ZeRO-3 - the GPU memory is freed though. Below is the minimum reproduction code (you can change the 27b model to anything else).

import os
import gc
import time
import torch
import argparse
import deepspeed
import psutil
import types
import transformers
from transformers import Gemma3ForConditionalGeneration, AutoTokenizer
from tqdm import tqdm

# Memory tracking utilities
def get_gpu_memory():
    """Return GPU memory usage in MB"""
    torch.cuda.synchronize()
    return torch.cuda.memory_allocated() / (1024 * 1024)

def get_cpu_memory():
    """Return CPU memory usage in MB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / (1024 * 1024)

def print_memory(prefix=""):
    """Print current memory usage"""
    gpu_mem = get_gpu_memory()
    cpu_mem = get_cpu_memory()
    print(f"{prefix} - GPU: {gpu_mem:.2f} MB, CPU: {cpu_mem:.2f} MB")

# DeepSpeed configuration
def get_ds_config(lr=1e-5, offload=True):
    return {
        "bf16": {
            "enabled": True
        },
        "train_micro_batch_size_per_gpu": 1,
        "gradient_accumulation_steps": 1,
        "train_batch_size": 8,
        
        "zero_optimization": {
            "stage": 3,
            "overlap_comm": True,
            "contiguous_gradients": True,
            "reduce_bucket_size": 5e7,
            "stage3_prefetch_bucket_size": 5e7,
            "stage3_param_persistence_threshold": 1e5,
            
            "offload_optimizer": {
                "device": "cpu" if offload else "none",
                "pin_memory": True
            },
            "offload_param": {
                "device": "cpu" if offload else "none",
                "pin_memory": True
            },
            
            "round_robin_gradients": True,
            "stage3_gather_16bit_weights_on_model_save": True
        },
        
        "activation_checkpointing": {
            "partition_activations": True,
            "cpu_checkpointing": True if offload else False,
            "contiguous_memory_optimization": True,
            "number_checkpoints": 2,
            "synchronize_checkpoint_boundary": True,
            "profile": False
        },
        
        "gradient_clipping": 1.0,
        "steps_per_print": 10,
        
        "optimizer": {
            "type": "AdamW",
            "params": {
                "lr": lr,
                "betas": [0.9, 0.999],
                "eps": 1e-8,
                "weight_decay": 0.01
            }
        },
        
        "wall_clock_breakdown": False
    }

# Initialize model with DeepSpeed
def initialize_model(model_name, local_rank):
    print(f"\n=== Initializing {model_name} with DeepSpeed ===")
    print_memory("Before model load")
    
    # Load model to CPU first with minimal memory usage
    model = Gemma3ForConditionalGeneration.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
        low_cpu_mem_usage=True,
        attn_implementation='eager'
    )
    print_memory("After model load")
    
    # Set training mode and requires_grad
    model.train()
    for param in model.parameters():
        param.requires_grad = True
    
    # Create parameter groups for optimizer
    model_params = [
        {"params": [p for p in model.parameters() if p.requires_grad], 
         "lr": 1e-5}
    ]
    
    # DeepSpeed initialization
    ds_config = get_ds_config()
    model_engine, optimizer, _, _ = deepspeed.initialize(
        model=model,
        model_parameters=model_params,
        config=ds_config
    )
    
    # Monkey patch the step function to avoid flops profiler issues
    original_step = model_engine.step
    def safe_step(self, *args, **kwargs):
        try:
            return original_step(*args, **kwargs)
        except AttributeError as e:
            if '__flops__' in str(e):
                if self.optimizer is not None:
                    self.optimizer.step()
                    
                if hasattr(self, "zero_optimization") and self.zero_optimization():
                    self.optimizer.zero_grad(set_to_none=True)
                else:
                    self.optimizer.zero_grad()
            else:
                raise
    
    # Apply the monkey patch
    model_engine.step = types.MethodType(safe_step, model_engine)
    
    # Enable gradient checkpointing
    model_engine.gradient_checkpointing_enable()
    
    print_memory("After DeepSpeed initialization")
    return model_engine

# Train for a few steps
def train_model(model_engine, tokenizer, steps=10):
    print("\n=== Training Model ===")
    print_memory("Before training")
    
    # Define dummy input data for LM training
    dummy_prompts = [
        "The capital of France is",
        "The speed of light is approximately"
    ]
    
    for i in range(steps):
        # Tokenize inputs
        inputs = tokenizer(
            dummy_prompts, 
            return_tensors="pt", 
            padding=True, 
            truncation=True
        )
        
        # Move inputs to the appropriate device
        inputs = {k: v.to(model_engine.device) for k, v in inputs.items()}
        
        # Add labels for LM task (shifted input_ids)
        inputs["labels"] = inputs["input_ids"].clone()
        
        # Forward pass
        outputs = model_engine(**inputs)
        loss = outputs.loss
        
        # Backward pass and optimize
        model_engine.backward(loss)
        model_engine.step()
        
        # Log memory usage periodically
        if (i+1) % 5 == 0:
            print(f"Step {i+1}/{steps}, Loss: {loss.item():.4f}")
            print_memory(f"Training step {i+1}")
    
    print_memory("After training")
    return model_engine

# Clean up DeepSpeed resources
def cleanup_deepspeed(model_engine):
    print("\n=== Cleaning Up DeepSpeed Resources ===")
    print_memory("Before cleanup")
    
    try:
        # Try to use destroy method first
        if hasattr(model_engine, "destroy"):
            print("Using model_engine.destroy()")
            model_engine.destroy()
        else:
            print("No destroy method available, performing manual cleanup")
            
            # Clean up optimizer references
            if hasattr(model_engine, "optimizer") and model_engine.optimizer is not None:
                print("Cleaning optimizer references")
                try:
                    # Zero out optimizer state
                    if hasattr(model_engine.optimizer, "state"):
                        model_engine.optimizer.state = {}
                    
                    # Clear param groups
                    for param_group in model_engine.optimizer.param_groups:
                        for param in param_group["params"]:
                            if hasattr(param, "ds_tensor"):
                                del param.ds_tensor
                            if hasattr(param, "ds_id"):
                                del param.ds_id
                            if hasattr(param, "grad"):
                                param.grad = None
                    
                    model_engine.optimizer = None
                except Exception as e:
                    print(f"Error cleaning optimizer: {e}")
            
            # Clean up module references
            if hasattr(model_engine, "module"):
                print("Cleaning module references")
                try:
                    # Clear parameter references
                    for param in model_engine.module.parameters():
                        if hasattr(param, "ds_tensor"):
                            del param.ds_tensor
                        if hasattr(param, "ds_id"):
                            del param.ds_id
                        if hasattr(param, "grad"):
                            param.grad = None
                    
                    # Remove module reference
                    del model_engine.module
                except Exception as e:
                    print(f"Error cleaning module: {e}")
            
            # Clean up other DeepSpeed components
            for attr in dir(model_engine):
                if not attr.startswith("__") and not callable(getattr(model_engine, attr)):
                    try:
                        delattr(model_engine, attr)
                    except:
                        pass
        
        # Delete model_engine object
        del model_engine
        
        # Force garbage collection
        print("Running garbage collection")
        gc.collect()
        torch.cuda.empty_cache()
        
        print_memory("After cleanup")
        print("Cleanup completed")
    except Exception as e:
        print(f"Error during cleanup: {e}")
        import traceback
        traceback.print_exc()

def main():
    parser = argparse.ArgumentParser(description="DeepSpeed initialization and cleanup test")
    parser.add_argument("--model", type=str, default="google/gemma-3-27b-it", 
                        help="HuggingFace model ID")
    parser.add_argument("--cycles", type=int, default=3, 
                        help="Number of init-train-cleanup cycles")
    parser.add_argument("--steps", type=int, default=5, 
                        help="Number of training steps per cycle")
    parser.add_argument("--local_rank", type=int, default=0,
                        help="Local rank for distributed training")
    args = parser.parse_args()
    
    # Initialize deepspeed distributed
    deepspeed.init_distributed()
    
    # Set DeepSpeed environment variables
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
    os.environ["DISABLE_DEEPSPEED_FLOPS_PROFILER"] = "1"
    
    print(f"Running {args.cycles} init-train-cleanup cycles with {args.model}")
    print(f"Device count: {torch.cuda.device_count()}")
    print(f"Current device: {torch.cuda.current_device()}")
    print(f"Local rank: {args.local_rank}")
    
    # Initial memory state
    print("\nInitial system state:")
    print_memory("Initial")
    
    # Load tokenizer (shared across cycles)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    
    # Run multiple cycles to check for memory leaks
    for cycle in range(args.cycles):
        print(f"\n{'#'*50}")
        print(f"# CYCLE {cycle+1}/{args.cycles}")
        print(f"{'#'*50}")
        
        try:
            # Step 1: Initialize model with DeepSpeed
            model_engine = initialize_model(args.model, args.local_rank)
            
            # Step 2: Perform training steps
            model_engine = train_model(model_engine, tokenizer, steps=args.steps)
            
            # Step 3: Clean up DeepSpeed resources
            cleanup_deepspeed(model_engine)
            
            # Extra cleanup and pause between cycles
            print("\n=== Extra Cleanup Between Cycles ===")
            gc.collect()
            torch.cuda.empty_cache()
            print_memory("After extra cleanup")
            
            # Sleep to allow background processes to finish
            time.sleep(5)
            
        except Exception as e:
            print(f"Error in cycle {cycle+1}: {e}")
            import traceback
            traceback.print_exc()
            break
    
    print("\nAll cycles completed")
    print_memory("Final")

if __name__ == "__main__":
    main()

BiEchi avatar Apr 08 '25 20:04 BiEchi

For a bit more context - I have to destroy the engine in each epoch because I will need to run vllm after each epoch, which is omitted in the code above.

BiEchi avatar Apr 08 '25 20:04 BiEchi

Are there any update on this issue? Whenever the engine is destroyed and reinitialized, RAM seems to accumulate.

atincuzun avatar Dec 04 '25 20:12 atincuzun