Low GPU Utilization during training
Hi, I have been trying to train a StyleTTS2 model from scratch on the LibriTTS 460 dataset, currently going through the first stage via train_first.py
The GPU utilisation of the training is very low ~30%. I am using a single H100 with batch_size = 8 and max_len = 300 to fit it on a single GPU.
Such low util means that the script is not using the GPU effeciently and there are potential bottlenecks to be addressed which can make the training faster.
Has anyone observed similar issues while training the model from scratch or has any ideas for improving the GPU util.
cc @yl4579
Yes, the same here, it seems there is a bottleneck, but using accelerate seems to help a little. Are you using accelerate? Try to set the num_processes.
Yes @lucasgris I am using accelerate and have played around with num_workers. Even in the graph you shared, the util hits very low points (<25% GPU util) consistently, any luck with improving that?
Not yet, but I think it is worth trying to identify where the code is slow, if I have any updates I will share here.
Confirming the problem of low GPU utilization:
It seems that some sort of computing on a single CPU core is a bottle neck:
Also having this problem with train_finetune_accelerate.py. I haven't dug too deep but the accelerator.backward() calls seemed to be taking a very long time, specifically this code block https://github.com/yl4579/StyleTTS2/blob/5cedc71c333f8d8b8551ca59378bdcc7af4c9529/train_finetune_accelerate.py#L449-L464
I tried the following options one by one:
- Without accelerator and with accelerator
- Increase the number of num_processes from 1 to 2
- Decrease max_len from 600 to 290
- Switch decoder from hifigan to istftnet Unsuccessfully.
Also showing low GPU utilization and high single core CPU utilization
It also seems like the issue goes away after the first epoch is finished, my GPU will start being utilized and the CPU load becomes more distributed
@borrero-c thanks for looking into this, I didn't seem to observe anything changing after 1 epoch, it stays low for me. Also accelerate.backward() call might be taking time since its doing the backward pass, that might be expected
I did a little research and launched the profiler. Pay attention to the % of time
MAIN LOOP
Line # Hits Time Per Hit % Time Line Contents
==============================================================
162 2 8.8 4.4 0.0 for epoch in range(start_epoch, 5):
163 1 0.3 0.3 0.0 running_loss = 0
164 1 3.8 3.8 0.0 start_time = time.time()
165
166 1 7624.1 7624.1 0.0 _ = [model[key].train() for key in model]
167
168 2 2430.4 1215.2 0.0 pgbar = tqdm(desc=f"Epoch {epoch + 1}/{epochs}", unit='Step', total=len(train_list) // batch_size, smoothing=0,
169 1 0.1 0.1 0.0 initial=1)
170 102 525418.3 5151.2 0.3 for i, batch in enumerate(train_dataloader):
171 102 73.2 0.7 0.0 if i > 100:
172 1 265667.4 265667.4 0.1 break
173 101 36354.2 359.9 0.0 pgbar.update(1)
174 101 917.4 9.1 0.0 waves = batch[0]
175 101 5605.0 55.5 0.0 batch = [b.to(device) for b in batch[1:]]
176 101 2789.0 27.6 0.0 texts, input_lengths, _, _, mels, mel_input_length, _ = batch
177
178 202 1903.5 9.4 0.0 with torch.no_grad():
179 101 77350.0 765.8 0.0 mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
180 101 12146.6 120.3 0.0 text_mask = length_to_mask(input_lengths).to(texts.device)
181
182 101 20453215.4 202507.1 10.2 ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
183
184 101 911.2 9.0 0.0 s2s_attn = s2s_attn.transpose(-1, -2)
185 101 1402.6 13.9 0.0 s2s_attn = s2s_attn[..., 1:]
186 101 334.4 3.3 0.0 s2s_attn = s2s_attn.transpose(-1, -2)
187
188 202 2396.6 11.9 0.0 with torch.no_grad():
189 101 33850.8 335.2 0.0 attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
190 101 16570.0 164.1 0.0 attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
191 101 5279.7 52.3 0.0 attn_mask = (attn_mask < 1)
192
193 101 3047.2 30.2 0.0 s2s_attn.masked_fill_(attn_mask, 0.0)
194
195 202 1703.6 8.4 0.0 with torch.no_grad():
196 101 48330.6 478.5 0.0 mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
197 101 416141.6 4120.2 0.2 s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
198
199 # encode
200 101 1624539.8 16084.6 0.8 t_en = model.text_encoder(texts, input_lengths, text_mask)
201
202 # 50% of chance of using monotonic version
203 101 416.9 4.1 0.0 if bool(random.getrandbits(1)):
204 43 4864.5 113.1 0.0 asr = (t_en @ s2s_attn)
205 else:
206 58 12170.7 209.8 0.0 asr = (t_en @ s2s_attn_mono)
207
208 # get clips
209 101 5637.2 55.8 0.0 mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
210 101 14759.9 146.1 0.0 mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
211 101 2895.5 28.7 0.0 mel_len_st = int(mel_input_length.min().item() / 2 - 1)
212
213 101 499.9 4.9 0.0 en = []
214 101 521.1 5.2 0.0 gt = []
215 101 432.4 4.3 0.0 wav = []
216 101 427.5 4.2 0.0 st = []
217
218 909 1352.1 1.5 0.0 for bib in range(len(mel_input_length)):
219 808 17282.6 21.4 0.0 mel_length = int(mel_input_length[bib].item() / 2)
220
221 808 6093.9 7.5 0.0 random_start = np.random.randint(0, mel_length - mel_len)
222 808 12116.3 15.0 0.0 en.append(asr[bib, :, random_start:random_start+mel_len])
223 808 6309.5 7.8 0.0 gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
224
225 808 1675.8 2.1 0.0 y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
226 808 69490.9 86.0 0.0 wav.append(torch.from_numpy(y).to(device))
227
228 # style reference (better to be different from the GT)
229 808 5077.0 6.3 0.0 random_start = np.random.randint(0, mel_length - mel_len_st)
230 808 8738.8 10.8 0.0 st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
231
232 101 5238.2 51.9 0.0 en = torch.stack(en)
233 101 2928.9 29.0 0.0 gt = torch.stack(gt).detach()
234 101 2246.8 22.2 0.0 st = torch.stack(st).detach()
235
236 101 7146.6 70.8 0.0 wav = torch.stack(wav).float().detach()
237
238 # clip too short to be used by the style encoder
239 101 202.3 2.0 0.0 if gt.shape[-1] < 80:
240 continue
241
242 202 2124.9 10.5 0.0 with torch.no_grad():
243 101 44210.4 437.7 0.0 real_norm = log_norm(gt.unsqueeze(1)).squeeze(1).detach()
244 101 2671261.3 26448.1 1.3 F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
245
246 101 2978410.2 29489.2 1.5 s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
247
248 101 17613113.7 174387.3 8.8 y_rec = model.decoder(en, F0_real, real_norm, s)
249
250 # discriminator loss
251
252 101 70.8 0.7 0.0 if epoch >= TMA_epoch:
253 101 565364.9 5597.7 0.3 optimizer.zero_grad()
254 101 11707820.2 115919.0 5.8 d_loss = dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean()
255 101 18847437.9 186608.3 9.4 accelerator.backward(d_loss)
256 101 313779.6 3106.7 0.2 optimizer.step('msd')
257 101 294492.0 2915.8 0.1 optimizer.step('mpd')
258 else:
259 d_loss = 0
260
261 # generator loss
262 101 237334.6 2349.8 0.1 optimizer.zero_grad()
263 101 282369.5 2795.7 0.1 loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
264
265 101 51.6 0.5 0.0 if epoch >= TMA_epoch: # start TMA training
266 101 419.4 4.2 0.0 loss_s2s = 0
267 909 10903.2 12.0 0.0 for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
268 808 89627.5 110.9 0.0 loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
269 101 2396.6 23.7 0.0 loss_s2s /= texts.size(0)
270
271 101 11985.0 118.7 0.0 loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
272
273 101 6983523.9 69143.8 3.5 loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean()
274 101 4046595.1 40065.3 2.0 loss_slm = wl(wav.detach(), y_rec).mean()
275
276 505 812033.9 1608.0 0.4 g_loss = loss_params.lambda_mel * loss_mel + \
277 101 1428.5 14.1 0.0 loss_params.lambda_mono * loss_mono + \
278 101 1285.5 12.7 0.0 loss_params.lambda_s2s * loss_s2s + \
279 101 1268.2 12.6 0.0 loss_params.lambda_gen * loss_gen_all + \
280 101 1230.7 12.2 0.0 loss_params.lambda_slm * loss_slm
281
282 else:
283 loss_s2s = 0
284 loss_mono = 0
285 loss_gen_all = 0
286 loss_slm = 0
287 g_loss = loss_mel
288
289 101 14339.2 142.0 0.0 running_loss += accelerator.gather(loss_mel).mean().item()
290
291 101 99737870.0 987503.7 49.6 accelerator.backward(g_loss)
292
293 101 199636.4 1976.6 0.1 optimizer.step('text_encoder')
294 101 290944.4 2880.6 0.1 optimizer.step('style_encoder')
295 101 2382230.7 23586.4 1.2 optimizer.step('decoder')
296
297 101 72.7 0.7 0.0 if epoch >= TMA_epoch:
298 101 430973.2 4267.1 0.2 optimizer.step('text_aligner')
299 # optimizer.step('pitch_extractor')
300
301 101 82.0 0.8 0.0 iters = iters + 1
302
303 101 386.2 3.8 0.0 if (i+1)%log_interval == 0 and accelerator.is_main_process:
304 20 1296.7 64.8 0.0 status = 'Epoch [%d/%d], Step [%d/%d], Mel Loss: %.5f, Gen Loss: %.5f, Disc Loss: %.5f, Mono Loss: %.5f, S2S Loss: %.5f, SLM Loss: %.5f' % (
305 10 17.5 1.7 0.0 epoch + 1, epochs, i + 1, len(train_list) // batch_size, running_loss / log_interval, loss_gen_all,
306 10 2.6 0.3 0.0 d_loss, loss_mono, loss_s2s, loss_slm)
307 # log_print (status, logger)
308 10 2629.4 262.9 0.0 pgbar.set_postfix_str(status)
309 10 1553.5 155.4 0.0 writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
310 10 2915.1 291.5 0.0 writer.add_scalar('train/gen_loss', loss_gen_all, iters)
311 10 1903.1 190.3 0.0 writer.add_scalar('train/d_loss', d_loss, iters)
312 10 2026.3 202.6 0.0 writer.add_scalar('train/mono_loss', loss_mono, iters)
313 10 1550.2 155.0 0.0 writer.add_scalar('train/s2s_loss', loss_s2s, iters)
314 10 1451.7 145.2 0.0 writer.add_scalar('train/slm_loss', loss_slm, iters)
315
316 10 11.9 1.2 0.0 running_loss = 0
317
318 # print('Time elasped:', time.time()-start_time)
319
320 1 0.4 0.4 0.0 loss_test = 0
321
322 1 6907.6 6907.6 0.0 _ = [model[key].eval() for key in model]
If we exclude it as expected:
accelerator.backward()
This increases GPU utilization by about 20% but utilization remains uneven.
If I additionally exclude line 182:
ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
This makes GPU utilization more uniform
Additional performance details about text_aligner
ASRCNN
File: /app/Utils/ASR/models.py
Function: forward at line 37
Line # Hits Time Per Hit % Time Line Contents
==============================================================
37 @profile
38 def forward(self, x, src_key_padding_mask=None, text_input=None):
39 101 116166.6 1150.2 0.5 x = self.to_mfcc(x)
40 101 742848.0 7354.9 3.2 x = self.init_cnn(x)
41 101 4113638.1 40729.1 17.6 x = self.cnns(x)
42 101 576767.4 5710.6 2.5 x = self.projection(x)
43 101 1441.8 14.3 0.0 x = x.transpose(1, 2)
44 101 102451.0 1014.4 0.4 ctc_logit = self.ctc_linear(x)
45 101 51.5 0.5 0.0 if text_input is not None:
46 101 17664905.7 174900.1 75.8 _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
47 101 66.9 0.7 0.0 return ctc_logit, s2s_logit, s2s_attn
48 else:
49 return ctc_logit
ASRS2S
File: /app/Utils/ASR/models.py
Function: forward at line 118
Line # Hits Time Per Hit % Time Line Contents
==============================================================
118 @profile
119 def forward(self, memory, memory_mask, text_input):
120 """
121 moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
122 moemory_mask.shape = (B, L, )
123 texts_input.shape = (B, T)
124 """
125 101 73718.9 729.9 0.5 self.initialize_decoder_states(memory, memory_mask)
126 # text random mask
127 101 8880.2 87.9 0.1 random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device)
128 101 2321.3 23.0 0.0 _text_input = text_input.clone()
129 101 273189.4 2704.8 1.7 _text_input.masked_fill_(random_mask, self.unk_index)
130 101 8951.4 88.6 0.1 decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel]
131 202 3652.6 18.1 0.0 start_embedding = self.embedding(
132 101 4901.0 48.5 0.0 torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device))
133 101 29957.5 296.6 0.2 decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
134
135 101 50.4 0.5 0.0 hidden_outputs, logit_outputs, alignments = [], [], []
136 12503 27916.0 2.2 0.2 while len(hidden_outputs) < decoder_inputs.size(0):
137
138 12402 124859.6 10.1 0.8 decoder_input = decoder_inputs[len(hidden_outputs)]
139 12402 15663074.1 1262.9 95.9 hidden, logit, attention_weights = self.decode(decoder_input)
140 12402 12834.5 1.0 0.1 hidden_outputs += [hidden]
141 12402 4052.0 0.3 0.0 logit_outputs += [logit]
142 12402 4427.6 0.4 0.0 alignments += [attention_weights]
143
144 101 57422.0 568.5 0.4 hidden_outputs, logit_outputs, alignments = \
145 202 37085.3 183.6 0.2 self.parse_decoder_outputs(
146 101 17.3 0.2 0.0 hidden_outputs, logit_outputs, alignments)
147
148 101 40.2 0.4 0.0 return hidden_outputs, logit_outputs, alignments
File: /app/Utils/ASR/models.py
Function: decode at line 149
Line # Hits Time Per Hit % Time Line Contents
==============================================================
149 @profile
150 def decode(self, decoder_input):
151
152 12077 451589.1 37.4 2.9 cell_input = torch.cat((decoder_input, self.attention_context), -1)
153 24154 1601496.2 66.3 10.2 self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
154 12077 1534.5 0.1 0.0 cell_input,
155 12077 4154.4 0.3 0.0 (self.decoder_hidden, self.decoder_cell))
156
157 24154 395431.2 16.4 2.5 attention_weights_cat = torch.cat(
158 24154 94829.2 3.9 0.6 (self.attention_weights.unsqueeze(1),
159 24154 67172.2 2.8 0.4 self.attention_weights_cum.unsqueeze(1)),dim=1)
160
161 24154 10655347.3 441.1 68.1 self.attention_context, self.attention_weights = self.attention_layer(
162 12077 2301.5 0.2 0.0 self.decoder_hidden,
163 12077 2838.1 0.2 0.0 self.memory,
164 12077 2822.3 0.2 0.0 self.processed_memory,
165 12077 1527.4 0.1 0.0 attention_weights_cat,
166 12077 3143.5 0.3 0.0 self.mask)
167
168 12077 231777.8 19.2 1.5 self.attention_weights_cum += self.attention_weights
169
170 12077 264389.2 21.9 1.7 hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
171 12077 1005236.5 83.2 6.4 hidden = self.project_to_hidden(hidden_and_context)
172
173 # dropout to increasing g
174 12077 860518.1 71.3 5.5 logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
175
176 12077 5210.3 0.4 0.0 return hidden, logit, self.attention_weights
Looked into it some more, my steps are taking 40-20 seconds long and the .backwards() call is taking 20-10 seconds respectively.
When the training starts to pick up after that first epoch (and GPU is being more consistently utilized) the steps are ~4 seconds each and the backwards call takes ~2 seconds.
Also interesting to see that this code block is taking a good amount of time to complete too: https://github.com/yl4579/StyleTTS2/blob/5cedc71c333f8d8b8551ca59378bdcc7af4c9529/train_finetune_accelerate.py#L306-L312
It seems for each step ~25% of time is spent in the loop above and ~50% is spent in the .backwards() call in line 464. Not sure how/if those could be improved, this isnt really my area of expertise