sacred icon indicating copy to clipboard operation
sacred copied to clipboard

Sacred makes pytorch training slow down

Open GoingMyWay opened this issue 3 years ago • 4 comments

I have been using sacred for a while and I found there is an issue. That is, while training with Pytorch on GPUs, the training speed slows down.

OS: Ubuntu; Python: 3.9; tested GPUs: A100 or 3090; Pytorch: 1.9 or newer; Sacred: the newest one.

Reproducible code (no training):

import os
import sys
import copy
import time
import pickle
import logging
import argparse

import tqdm
import numpy as np
import torch as th
import torch.nn as nn
import tensorflow as tf

from types import SimpleNamespace as SN

from tensorboard_logger import configure, log_value

from sacred import Experiment, SETTINGS
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds


logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter('[%(levelname)s %(asctime)s] %(name)s %(message)s', '%H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)

logger.setLevel('INFO')

ex = Experiment("sacred_experiments")
ex.logger = logger
ex.captured_out_filter = apply_backspaces_and_linefeeds

SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console

tf.config.experimental.set_visible_devices([], 'GPU')  # TF does not use GPUs

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # disable GPU's synchronous computing in pytorch

th.backends.cudnn.enabled = True
th.set_flush_denormal(True)


class Agent(nn.Module):
    def __init__(self):
        super(Agent, self).__init__()
        self.model = nn.Sequential(
                nn.Linear(in_features=3*88*88, out_features=512, bias=False, dtype=th.float),
                nn.ReLU(),
                nn.Linear(in_features=512, out_features=512, bias=False, dtype=th.float),
                nn.ReLU(),
                nn.Linear(in_features=512, out_features=64, bias=False, dtype=th.float), #64
                nn.ReLU(),
                nn.Linear(in_features=64, out_features=64, bias=False, dtype=th.float),
                nn.ReLU(),
                nn.Linear(in_features=64, out_features=128, bias=False, dtype=th.float), #128
                nn.ReLU(),
                nn.Linear(in_features=128, out_features=9, bias=False, dtype=th.float)
        )

    def forward(self, img):
        dim = img.shape[0]
        assert tuple(img.shape) == (420, 3, 88, 88), f"img.shape: {img.shape}"
        q = self.model(img.view(dim, -1).float()/255.0)
        return q, None


@ex.main
def test(_run, _config, _log):
    args = SN(**_config)
    model = Agent()  # prediction network
    model.cuda()
    model_2 = copy.deepcopy(model)  # target network

    configure(os.path.join('scred_results', f'{_run._id}'))
    
    for i in tqdm.tqdm(range(10000)):
        time_costs = []
        for _idx in range(10):
            data = th.randint(low=0, high=255, size=(420, 3, 88, 88), dtype=th.uint8, device='cuda')
            st = time.time()
            th.cuda.synchronize()
            # simulate sequence for prediction network, for example, PPO
            for t in range(50):
                out, _ = model_2(data)
            # simulate sequence for target network
            for t in range(50):
                out, _ = model(data)
            th.cuda.synchronize()
            time_costs.append(time.time() - st)
            print(f'{i}, sample: {_idx}, sec: {(time.time()-st):.4f}')
        log_value('per_update_sec', np.mean(time_costs), (i+1)*30000)  # there are 30000 dummy steps


if __name__ == '__main__':
    params = copy.deepcopy(sys.argv)

    config_dict = {}
    config_dict['id'] = 1
    
    file_obs_path = "sacred_results"

    ex.add_config(config_dict)
    ex.observers.append(FileStorageObserver.create(file_obs_path))
    ex.run_commandline(params)

Without sacred, the speed is stable:

image

With sacred, the speed is increasing over time (two figures):

image

image

During training, with sacred, such an issue is even severe and out of control. The following is an example:

image

As you can see, the time cost per update is increasing, causing the ETA out of control.

GoingMyWay avatar Jul 08 '22 11:07 GoingMyWay

Hey @GoingMyWay! Thank you for noticing this and bringing this up. I never noticed this in my own experiments. I checked one of my old experiments and the tendency is there: The time required for one training step increases over training time. It's just not as severe as in your example.

image (Ignore the peak, that's an issue with our filesystem)

I looked into the code of the run and the observer and found two things related to the stdout capturing could be causing this:

  1. The stdout is captured by doing (simplifed) captured_out += newly_captured_lines. Python strings are immutable, so this creates a new string of increasing sizes every 10s. This runs in a background thread, but there is Python's GIL that slows down the main thread when this process takes too long.
  2. The stdout is written to a file. The file is re-opened every 10s to append the new lines to it. I am not sure if opening a file in append mode takes more time the longer the file is. This could also depend on the file system.

Both of these points are related to the stdout capturing. To pinpoint the issue, you could run your benchmark one time without an observer and one time without the print statement in the loop (if that is feasible for you. The number of iterations is quite large). If my assumptions are correct, then the train time should not increase when you don't print anything in the loop but it should not depend that much on the observer.


To support my first point, I plotted the time required to append a string to another string for different string lengths.

import time
times = []
lengths = []
for l in range(1, 3_500_000 * 100, 100000):
    s = '_' * l
    st = time.time()
    s = s + 'a' * 100
    times.append(time.time() - st)
    lengths.append(l)

image

This kind of matches your plot.

thequilo avatar Jul 18 '22 09:07 thequilo

Hey @GoingMyWay! Thank you for noticing this and bringing this up. I never noticed this in my own experiments. I checked one of my old experiments and the tendency is there: The time required for one training step increases over training time. It's just not as severe as in your example.

image (Ignore the peak, that's an issue with our filesystem)

I looked into the code of the run and the observer and found two things related to the stdout capturing could be causing this:

  1. The stdout is captured by doing (simplifed) captured_out += newly_captured_lines. Python strings are immutable, so this creates a new string of increasing sizes every 10s. This runs in a background thread, but there is Python's GIL that slows down the main thread when this process takes too long.
  2. The stdout is written to a file. The file is re-opened every 10s to append the new lines to it. I am not sure if opening a file in append mode takes more time the longer the file is. This could also depend on the file system.

Both of these points are related to the stdout capturing. To pinpoint the issue, you could run your benchmark one time without an observer and one time without the print statement in the loop (if that is feasible for you. The number of iterations is quite large). If my assumptions are correct, then the train time should not increase when you don't print anything in the loop but it should not depend that much on the observer.

To support my first point, I plotted the time required to append a string to another string for different string lengths.

import time
times = []
lengths = []
for l in range(1, 3_500_000 * 100, 100000):
    s = '_' * l
    st = time.time()
    s = s + 'a' * 100
    times.append(time.time() - st)
    lengths.append(l)

image

This kind of matches your plot.

@thequilo, Hi, thanks for pointing out the reason. Will you fix this issue in the next release? I would like to fix it, but I am not an expert on Sacred.

GoingMyWay avatar Jul 19 '22 01:07 GoingMyWay

If I figure out a way to fix this, then yes. But I'm not sure yet how to do it without potentially breaking things (like custom observers).

I would suggest avoiding frequent print or logging statements as a workaround.

thequilo avatar Jul 21 '22 06:07 thequilo

If I figure out a way to fix this, then yes. But I'm not sure yet how to do it without potentially breaking things (like custom observers).

I would suggest avoiding frequent print or logging statements as a workaround.

@thequilo, thanks. I'll reduce printing and logging.

GoingMyWay avatar Jul 21 '22 07:07 GoingMyWay