How to input paired audio and video into the pipeline efficiently?
Describe the question.
My use case is very easy to describe: paired audio fbank features and video frames. But when the dataset comes to a huge number of samples, it becomes a problem that the loading phase takes to long.
I have tried many ways to load them from the HDD as the limited resources, in the following cases the audio fbank feature saved as a .npy file:
- use DALI externel source, load both
.mp4and fbank.npywith numpy, and then decode the video on gpu
code
class DALIDatasetCallable:
def __init__(
self,
ann_file,
batch_size=1,
shuffled=True,
media_type='audio_video',
shard_id=0,
num_shards=1,
**kwargs,
):
self.media_type = media_type
self.batch_size = batch_size
with open(ann_file.anno_path, 'r') as f:
self.label_file = json.load(f)
self.data_root = ann_file.data_root
self.indices = np.arange(len(self.label_file))
if shuffled:
np.random.shuffle(self.indices)
self.label_file = [self.label_file[i] for i in self.indices]
self.filenames = list(zip(self.label_file, self.indices))
self.epoch = 0
self.shard_id = shard_id
self.num_shards = num_shards
self.shard_size = len(self.label_file) // num_shards
self.shard_offset = self.shard_size * shard_id
self.full_iterations = self.shard_size // batch_size
self.perm = None
self.last_seen_epoch = (
None
)
def __len__(self):
return self.full_iterations
def reset(self):
self.perm = None
self.last_seen_epoch = None
def __call__(self, sample_info):
# print(sample_info.epoch_idx, sample_info.iteration,sample_info.idx_in_epoch)
if sample_info.iteration >= self.full_iterations:
raise StopIteration
if self.last_seen_epoch != sample_info.epoch_idx:
self.last_seen_epoch = sample_info.epoch_idx
self.perm = np.random.default_rng(seed=42 + sample_info.epoch_idx)
self.perm = self.perm.permutation(len(self.filenames))
sample_idx = self.perm[sample_info.idx_in_epoch + self.shard_offset]
sample, index = self.filenames[sample_idx]
vfilename, afilename, frame_count = os.path.join(self.data_root,sample['video']), \
os.path.join(self.data_root,sample['audio']), \
sample['num_frames']
video = np.fromfile(vfilename, dtype=np.uint8)
fbank = np.load(afilename)
frame_idxs = np.array(get_frame_indices(16,frame_count),dtype=np.int32)
index = np.array([np.int32(index)])
return video, fbank, frame_idxs, index
@pipeline_def(py_num_workers=4, py_start_method="spawn")
def AudioVideoPipeline(eii, parallel):
vid, audio, frame_idxs, index = fn.external_source(device="cpu", batch=False,
num_outputs=4, source=eii, parallel=parallel,
prefetch_queue_depth=4)
video = fn.experimental.decoders.video(vid, device="mixed", frames=frame_idxs)
resize_v = fn.resize(video,resize_shorter=224, device='gpu')
crop_v = fn.crop_mirror_normalize(resize_v,crop=(224,224),device='gpu',
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
output_layout="FCHW")
return audio, crop_v, index
- tar all the video
.mp4files and fbank.npyfiles, use DALI webdataset to load them,
code
def get_indices(num_frames, vlen):
indices = np.array(get_frame_indices(num_frames, vlen),dtype=np.int32)
return indices
@pipeline_def(batch_size=4, num_threads=4)
def WebDatasetAVPipeline(wds_data, index_paths, shard_id, num_shards):
vid, fbank, vlen = fn.readers.webdataset(paths=wds_data,
index_paths=index_paths,
dtypes=[types.UINT8, types.FLOAT, types.INT16],
ext=["mp4","fbank","npy"],
num_shards=num_shards,
shard_id=shard_id,
random_shuffle=False,
prefetch_queue_depth=4,
name="reader",
missing_component_behavior="error")
indices = fn.python_function(16, vlen[0], function=get_indices, num_outputs=1)
fbank = fn.reshape(fbank, shape=(1024, 128))
video = fn.experimental.decoders.video(vid, device="mixed", frames=indices)
resize_v = fn.resize(video,resize_shorter=224, device='gpu')
crop_v = fn.crop_mirror_normalize(resize_v,crop=(224,224),device='gpu',
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
output_layout="FCHW")
index = types.Constant(0,dtype=types.UINT8)
return fbank, crop_v, index
- extract the video from mp4 file into 16 frames as
.jpgfiles, then use DALI externel source to load paired frames and fbank features which are loaded by numpy
code
class DALIFrameDatasetCallable:
def __init__(
self,
ann_file,
batch_size=1,
shuffled=True,
media_type='audio_video',
shard_id=0,
num_shards=1,
**kwargs,
):
self.media_type = media_type
self.batch_size = batch_size
with open(ann_file.anno_path, 'r') as f:
self.label_file = json.load(f)
self.data_root = ann_file.data_root
assert "frames" in self.label_file[0], "Please make sure the label file has frames"
self.indices = np.arange(len(self.label_file))
if shuffled:
np.random.shuffle(self.indices)
self.label_file = [self.label_file[i] for i in self.indices]
self.filenames = list(zip(self.label_file, self.indices))
self.epoch = 0
self.shard_id = shard_id
self.num_shards = num_shards
self.shard_size = len(self.label_file) // num_shards
self.shard_offset = self.shard_size * shard_id
self.full_iterations = self.shard_size // batch_size
self.perm = None
self.last_seen_epoch = (
None
)
def __len__(self):
return self.full_iterations
def reset(self):
self.perm = None
self.last_seen_epoch = None
@staticmethod
def load_video_frames(path, frame_count):
jpg_paths = sorted(glob(os.path.join(path, "*.jpg")))
# load with cv2
frames = [cv2.imread(jpg_path)[:, :, ::-1] for jpg_path in jpg_paths[:frame_count]]
return np.stack(frames)
def __call__(self, sample_info):
# print(sample_info.epoch_idx, sample_info.iteration,sample_info.idx_in_epoch)
if sample_info.iteration >= self.full_iterations:
# Indicate end of the epoch
raise StopIteration
if self.last_seen_epoch != sample_info.epoch_idx:
self.last_seen_epoch = sample_info.epoch_idx
self.perm = np.random.default_rng(seed=42 + sample_info.epoch_idx)
self.perm = self.perm.permutation(len(self.filenames))
sample_idx = self.perm[sample_info.idx_in_epoch + self.shard_offset]
sample, index = self.filenames[sample_idx]
vfilename, afilename = os.path.join(self.data_root,sample['frames']), \
os.path.join(self.data_root,sample['audio'])
video = self.load_video_frames(vfilename, 16)
fbank = np.load(afilename,allow_pickle=True)
index = np.array([np.int32(index)])
return video, fbank, index
@pipeline_def(py_num_workers=4, py_start_method="spawn")
def AudioFramesPipeline(eii, parallel):
vid, audio, index = fn.external_source(device="cpu", batch=False,
num_outputs=3, source=eii, parallel=parallel,
prefetch_queue_depth=4)
resize_v = fn.resize(vid.gpu(), resize_shorter=224, device='gpu')
crop_v = fn.crop_mirror_normalize(resize_v,crop=(224,224),device='gpu',
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
output_layout="FCHW")
return audio, crop_v, index
There is another case I have not tested that is input the filelists to the pipeline, but I do not know the mechanism of DALI. When using 2 filelists in the same order (for example, one file contains the paths of frames, and another file contains the paths of npy, with the paths at the same index position in both files being paired video and audio), does the pipeline for reading videos output samples at the same positions as the pipeline for reading npy?
But all of above methods showed poor performance... I have no idea now for how to accelerate the loading phase in training😮💨 Now only less than 10% time is spending on gpu training...
I am unable to upgrade the server's hard drive to SSD, so I want to know how to input paired audio and video into the pipeline more efficiently?
Looking forward to your reply, thanks!
Check for duplicates
- [x] I have searched the open bugs/issues and have found no duplicates for this bug report
Hi @Ash-one,
Thank you for reaching out.
Regarding access, I think using tar would be beneficial here as it minimizes the number of file access operations.
Since you want to minimize the amount of data read from the drive, using video files should work well because they offer better compression ratios than separate frames. For audio, have you considered using audio files and computing the filter bank (fbank) on the fly, as demonstrated here?
Additionally, to ensure that I/O is the bottleneck, you can run some diagnostics.