Error when training with local precomputed features
Bug when local training with LocalDataset
Here is my config (without some personal paths), run for mosaicml's diffusion:
algorithms:
low_precision_groupnorm:
attribute: unet
precision: amp_fp16
low_precision_layernorm:
attribute: unet
precision: amp_fp16
model:
_target_: diffusion.models.models.stable_diffusion_2
model_name: runwayml/stable-diffusion-v1-5
pretrained: true
precomputed_latents: true
encode_latents_in_fp16: true
# fsdp: false
fsdp: true
val_metrics:
- _target_: torchmetrics.MeanSquaredError
val_guidance_scales: []
loss_bins: []
dataset:
train_batch_size: 2048 # TODO: explore composer config
eval_batch_size: 16 # Should be 8 per device
train_dataset:
_target_: diffusion.datasets.pixta.pixta.build_custom_dataloader
data_path: ...
feature_dim: 32
num_workers: 8
pin_memory: false
eval_dataset:
_target_: diffusion.datasets.pixta.pixta.build_custom_dataloader
data_path: ...
feature_dim: 32
num_workers: 8
pin_memory: false
optimizer:
_target_: torch.optim.AdamW
lr: 1.0e-5
weight_decay: 0.01
scheduler:
_target_: composer.optim.ConstantWithWarmupScheduler
t_warmup: 10ba
logger:
wandb:
_target_: composer.loggers.wandb_logger.WandBLogger
name: ${name}
project: ${project}
group: ${name}
callbacks:
speed_monitor:
_target_: composer.callbacks.speed_monitor.SpeedMonitor
window_size: 10
lr_monitor:
_target_: composer.callbacks.lr_monitor.LRMonitor
memory_monitor:
_target_: composer.callbacks.memory_monitor.MemoryMonitor
runtime_estimator:
_target_: composer.callbacks.runtime_estimator.RuntimeEstimator
optimizer_monitor:
_target_: composer.callbacks.OptimizerMonitor
trainer:
_target_: composer.Trainer
device: gpu
max_duration: 10ep
eval_interval: 2ep
device_train_microbatch_size: 40
run_name: ${name}
seed: ${seed}
scale_schedule_ratio: ${scale_schedule_ratio}
save_folder: outputs/${project}/${name}
save_interval: 5ep
save_overwrite: true
autoresume: false
fsdp_config:
sharding_strategy: "SHARD_GRAD_OP"
state_dict_type: "full"
mixed_precision: 'PURE'
activation_checkpointing: true
Here is my dataset and dataloader code:
class CustomDataset(Array, Dataset):
def __init__(self,
data_path,
feature_dim=64):
self.feature_dim = feature_dim
index_file = os.path.join(data_path, 'index.json')
data = json.load(open(index_file))
if data['version'] != 2:
raise ValueError(f'Unsupported streaming data version: {data["version"]}. ' +
f'Expected version 2.')
shards = []
for info in data['shards']:
shard = reader_from_json(data_path, None, info)
shards.append(shard)
self.shards = shards
samples_per_shard = np.array([shard.samples for shard in shards], np.int64)
self.length = samples_per_shard.sum()
self.spanner = Spanner(samples_per_shard)
def __len__(self):
return self.length
@property
def size(self) -> int:
"""Get the size of the dataset in samples.
Returns:
int: Number of samples.
"""
return self.length
# def __getitem__(self, index):
def get_item(self, index):
shard_id, shard_sample_id = self.spanner[index]
shard = self.shards[shard_id]
sample = shard[shard_sample_id]
out = {}
if 'caption_latents' in sample:
out['caption_latents'] = torch.from_numpy(
np.frombuffer(sample['caption_latents'], dtype=np.float16).copy()).reshape(77, 768)
if 'image_latents' in sample:
out['image_latents'] = torch.from_numpy(np.frombuffer(sample['image_latents'],
dtype=np.float16).copy()).reshape(4, self.feature_dim, self.feature_dim)
return out
def build_custom_dataloader(
batch_size: int,
data_path: str,
image_root: str = None,
tokenizer_name_or_path: str = 'runwayml/stable-diffusion-v1-5',
caption_drop_prob: float = 0.0,
resize_size: int = 512,
feature_dim: int = 64,
drop_last: bool = True,
shuffle: bool = True, #TODO pass shuffle to dataloader
**dataloader_kwargs,
):
print('Using precomputed features!!!')
dataset = CustomDataset(data_path, feature_dim=feature_dim)
drop_last = False
if isinstance(dataset, IterableDataset):
print('Using IterableDataset!!!')
sampler = None
else:
print('Using Sampler!!!')
sampler = dist.get_sampler(dataset, drop_last=drop_last, shuffle=shuffle)
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=drop_last,
shuffle=shuffle if sampler is None else False,
**dataloader_kwargs,
)
return dataloader
And I got this errors while finish epoch 0 and start epoch 1:
Error executing job with overrides: []
Traceback (most recent call last):
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1120, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/queues.py", line 113, in get
if not self._poll(timeout):
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py", line 262, in poll
return self._poll(timeout)
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py", line 429, in _poll
r = wait([self], timeout)
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py", line 936, in wait
ready = selector.select(timeout)
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/selectors.py", line 416, in select
fd_event_list = self._selector.poll(timeout)
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 3007309) is killed by signal: Aborted.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/tungduongquang/workspace/mosaicml/image-generation/run.py", line 26, in <module>
main()
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/main.py", line 94, in decorated_main
_run_hydra(
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
_run_app(
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/utils.py", line 457, in _run_app
run_and_report(
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
raise ex
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
return func()
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
lambda: hydra.run(
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/hydra.py", line 132, in run
_ = ret.return_value
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/core/utils.py", line 260, in return_value
raise self._return_value
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/core/utils.py", line 186, in run_job
ret.return_value = task_function(task_cfg)
File "/home/tungduongquang/workspace/mosaicml/image-generation/run.py", line 22, in main
return train(config)
File "/home/tungduongquang/workspace/mosaicml/image-generation/diffusion/train.py", line 134, in train
return eval_and_then_train()
File "/home/tungduongquang/workspace/mosaicml/image-generation/diffusion/train.py", line 132, in eval_and_then_train
trainer.fit()
File "/home/tungduongquang/workspace/mosaicml/composer/composer/trainer/trainer.py", line 1796, in fit
self._train_loop()
File "/home/tungduongquang/workspace/mosaicml/composer/composer/trainer/trainer.py", line 1938, in _train_loop
for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)):
File "/home/tungduongquang/workspace/mosaicml/composer/composer/trainer/trainer.py", line 2924, in _iter_dataloader
batch = next(dataloader_iter)
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
data = self._next_data()
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1316, in _next_data
idx, data = self._get_data()
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1282, in _get_data
success, data = self._try_get_data()
File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 3007309) exited unexpectedly
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/datal │
│ oader.py:1120 in _try_get_data │
│ │
│ 1117 │ │ # Returns a 2-tuple: │
│ 1118 │ │ # (bool: whether successfully get data, any: data if successful else None) │
│ 1119 │ │ try: │
│ ❱ 1120 │ │ │ data = self._data_queue.get(timeout=timeout) │
│ 1121 │ │ │ return (True, data) │
│ 1122 │ │ except Exception as e: │
│ 1123 │ │ │ # At timeout and error, we manually check whether any worker has │
│ │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/queues.py:113 in get │
│ │
│ 110 │ │ │ try: │
│ 111 │ │ │ │ if block: │
│ 112 │ │ │ │ │ timeout = deadline - time.monotonic() │
│ ❱ 113 │ │ │ │ │ if not self._poll(timeout): │
│ 114 │ │ │ │ │ │ raise Empty │
│ 115 │ │ │ │ elif not self._poll(): │
│ 116 │ │ │ │ │ raise Empty │
│ │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py:262 in │
│ poll │
│ │
│ 259 │ │ """Whether there is any input available to be read""" │
│ 260 │ │ self._check_closed() │
│ 261 │ │ self._check_readable() │
│ ❱ 262 │ │ return self._poll(timeout) │
│ 263 │ │
│ 264 │ def __enter__(self): │
│ 265 │ │ return self │
│ │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py:429 in │
│ _poll │
│ │
│ 426 │ │ return self._recv(size) │
│ 427 │ │
│ 428 │ def _poll(self, timeout): │
│ ❱ 429 │ │ r = wait([self], timeout) │
│ 430 │ │ return bool(r) │
│ 431 │
│ 432 │
│ │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py:936 in │
│ wait │
│ │
│ 933 │ │ │ │ deadline = time.monotonic() + timeout │
│ 934 │ │ │ │
│ 935 │ │ │ while True: │
│ ❱ 936 │ │ │ │ ready = selector.select(timeout) │
│ 937 │ │ │ │ if ready: │
│ 938 │ │ │ │ │ return [key.fileobj for (key, events) in ready] │
│ 939 │ │ │ │ else: │
│ │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/selectors.py:416 in select │
│ │
│ 413 │ │ │ timeout = math.ceil(timeout * 1e3) │
│ 414 │ │ ready = [] │
│ 415 │ │ try: │
│ ❱ 416 │ │ │ fd_event_list = self._selector.poll(timeout) │
│ 417 │ │ except InterruptedError: │
│ 418 │ │ │ return ready │
│ 419 │ │ for fd, event in fd_event_list: │
│ │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/_util │
│ s/signal_handling.py:66 in handler │
│ │
│ 63 │ def handler(signum, frame): │
│ 64 │ │ # This following call uses `waitid` with WNOHANG from C side. Therefore, │
│ 65 │ │ # Python can still get and update the process status successfully. │
│ ❱ 66 │ │ _error_if_any_worker_fails() │
│ 67 │ │ if previous_handler is not None: │
│ 68 │ │ │ assert callable(previous_handler) │
│ 69 │ │ │ previous_handler(signum, frame) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: DataLoader worker (pid 3007309) is killed by signal: Aborted.
Plz help!!!
Apologies for the delay, is the training dataset you are using in MDS streaming format? If so, you can use StreamingDataset class to load data even if it is stored locally.
@Landanjs I have the same issue when loading data in MDS streaming format. I have built my custom dataset purely using PyTorch's Dataset, not StreamingDataset, is that the problem?
@hieuphung97 is there a reason you aren't using a subclass of StreamingDataset to load your data? A custom dataset purely using PyTorch's Dataset may miss some logic to load MDS shards, so we recommend subclassing the StreamingDataset if possible.