ParlAI icon indicating copy to clipboard operation
ParlAI copied to clipboard

BlenderBot 3 Chat Service Multiple GPUs

Open ryanshea10 opened this issue 3 years ago • 0 comments

Hi, I'm trying to host a chat service with blenderbot 3 using multiple gpus. Here is my config file

tasks:
  default:
    onboard_world: MessengerBotChatOnboardWorld
    task_world: MessengerBotChatTaskWorld
    timeout: 1800
    agents_required: 1
task_name: chatbot
world_module: parlai.chat_service.tasks.chatbot.worlds
overworld: MessengerOverworld
max_workers: 100
opt:
  debug: True
  models:
    blenderbot3_3B:
      model: projects.seeker.agents.seeker:ComboFidGoldDocumentAgent
      model_file: zoo:bb3/bb3_3B/model
      interactive_mode: True
      no_cuda: False
      override:
        search_server: https://www.google.com
        model_parallel: True
additional_args:
  page_id: 1 # Configure Your Own Page

My problem I am able to successfully load the model onto multiple gpus by using the model_parallel option and using CUDA_VISIBLE_DEVICES but whenever the model tries to respond to a user it gives CUDA error: device-side assert triggered. This issue only occurs when I try to use multiple gpus, everything works fine when I only use one. Here is the stacktrace of the error using CUDA_LAUNCH_BLOCKING=1:

World default had error RuntimeError('CUDA error: device-side assert triggered')
World default had error RuntimeError('CUDA error: device-side assert triggered')
Traceback (most recent call last):
  File "/home/rs4235/miniconda3/envs/parlai_blender3/lib/python3.9/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/local2/data/rs4235/ParlAI/parlai/chat_service/core/world_runner.py", line 128, in _world_fn
    return self._run_world(task, world_name, agents)
  File "/local2/data/rs4235/ParlAI/parlai/chat_service/core/world_runner.py", line 99, in _run_world
    ret_val = world.parley()
  File "/local2/data/rs4235/ParlAI/parlai/chat_service/tasks/chatbot/worlds.py", line 106, in parley
    response = self.model.act()
  File "/local2/data/rs4235/ParlAI/parlai/core/torch_agent.py", line 2148, in act
    response = self.batch_act([self.observation])[0]
  File "/local2/data/rs4235/ParlAI/parlai/agents/fid/fid.py", line 389, in batch_act
    batch_reply = super().batch_act(observations)
  File "/local2/data/rs4235/ParlAI/parlai/core/torch_agent.py", line 2244, in batch_act
    output = self.eval_step(batch)
  File "/local2/data/rs4235/ParlAI/projects/seeker/agents/seeker.py", line 161, in eval_step
    output = TorchGeneratorAgent.eval_step(self, batch)
  File "/local2/data/rs4235/ParlAI/parlai/core/torch_generator_agent.py", line 901, in eval_step
    beam_preds_scores, beams = self._generate(
  File "/local2/data/rs4235/ParlAI/parlai/agents/rag/rag.py", line 684, in _generate
    gen_outs = self._rag_generate(batch, beam_size, max_ts, prefix_tokens)
  File "/local2/data/rs4235/ParlAI/parlai/agents/rag/rag.py", line 727, in _rag_generate
    return self._generation_agent._generate(
  File "/local2/data/rs4235/ParlAI/parlai/core/torch_generator_agent.py", line 1236, in _generate
    incr_state = model.reorder_decoder_incremental_state(
  File "/local2/data/rs4235/ParlAI/parlai/agents/fid/fid.py", line 110, in reorder_decoder_incremental_state
    return {
  File "/local2/data/rs4235/ParlAI/parlai/agents/fid/fid.py", line 111, in <dictcomp>
    idx: layer.reorder_incremental_state(incremental_state[idx], inds)
  File "/local2/data/rs4235/ParlAI/parlai/agents/transformer/modules/decoder.py", line 588, in reorder_incremental_state
    return {
  File "/local2/data/rs4235/ParlAI/parlai/agents/transformer/modules/decoder.py", line 589, in <dictcomp>
    attn_type: attn.reorder_incremental_state(
  File "/local2/data/rs4235/ParlAI/parlai/agents/transformer/modules/attention.py", line 275, in reorder_incremental_state
    return {
  File "/local2/data/rs4235/ParlAI/parlai/agents/transformer/modules/attention.py", line 276, in <dictcomp>
    key: torch.index_select(val, 0, inds.to(val.device)).contiguous()
RuntimeError: CUDA error: device-side assert triggered

ryanshea10 avatar Aug 12 '22 16:08 ryanshea10