LLaMA-Factory icon indicating copy to clipboard operation
LLaMA-Factory copied to clipboard

指定使用2,3号两张卡,但是真实使用0,1,2,3四张卡

Open bingwork opened this issue 10 months ago • 0 comments

Reminder

  • [X] I have read the README and searched the existing issues.

Reproduction

def _train_(rank, world_size, allocated_gpus, parameters, callbacks):
    """assume the task use gpu 6, 7
    so world_size is 2,
    local_rank is 6 or 7
    rank is 0 or 1
    """
    import torch

    local_rank = allocated_gpus[rank]
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    os.environ["LOCAL_RANK"] = str(local_rank)
    torch.distributed.init_process_group(
        backend="nccl", rank=rank, world_size=world_size
    )
    try:
        from llmtuner import run_exp

        run_exp(args=parameters, callbacks=callbacks)
    except Exception as e:
        raise e
    finally:
        torch.distributed.destroy_process_group()


scheduler = GPUScheduler()
allocated_gpus = scheduler.allocate_gpus()  # 比如返回[2,3]
from torch.multiprocessing import spawn
  world_size = len(allocated_gpus)
  spawn(
      _train_,
      args=(
          world_size,
          allocated_gpus,
          parameters,
          [PrinterCallback(handler=handler, task_log=self.task_log)]
      ),
      nprocs=world_size,
      join=True,
  )

这里我们指定使用2,3号卡,但是真实会使用0,1,2,3四张卡。 企业微信截图_20240424191014

Expected behavior

期望只使用2,3两张卡。

System Info

transformers-cli env

/root/anaconda3/lib/python3.9/site-packages/scipy/init.py:155: UserWarning: A NumPy version >=1.18.5 and <1.25.0 is required for this version of SciPy (detected version 1.26.4 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 4.39.3
  • Platform: Linux-6.2.0-1018-aws-x86_64-with-glibc2.35
  • Python version: 3.9.13
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.3
  • Accelerate version: 0.28.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Others

No response

bingwork avatar Apr 24 '24 12:04 bingwork