ddsp
ddsp copied to clipboard
Adding a PretrainedCREPEEmbeddingLoss to training
Hello, I've trained a model for a while using the solo_instrument config at 48 kHz, but the audio is still fairly noisy even after 117k steps (spectral loss is ~9 on average).
I'd like to continue training with the PretrainedCREPEEmbeddingLoss() enabled as well to encourage more natural / perceptually realistic synthesis.
I've tried just adding the loss into the ae.gin file, but get the following error which I don't really understand:
Traceback (most recent call last):
File "/home/hans/.conda/envs/ddsp/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/hans/code/maua-ddsp/ddsp/training/ddsp_run.py", line 231, in <module>
console_entry_point()
File "/home/hans/code/maua-ddsp/ddsp/training/ddsp_run.py", line 227, in console_entry_point
app.run(main)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "/home/hans/code/maua-ddsp/ddsp/training/ddsp_run.py", line 205, in main
report_loss_to_hypertune=FLAGS.hypertune,
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/gin/config.py", line 1078, in gin_wrapper
utils.augment_exception_message_and_reraise(e, err_str)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/gin/utils.py", line 49, in augment_exception_message_and_reraise
six.raise_from(proxy.with_traceback(exception.__traceback__), None)
File "<string>", line 3, in raise_from
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/gin/config.py", line 1055, in gin_wrapper
return fn(*new_args, **new_kwargs)
File "/home/hans/code/maua-ddsp/ddsp/training/train_util.py", line 185, in train
trainer.build(next(dataset_iter))
File "/home/hans/code/maua-ddsp/ddsp/training/trainers.py", line 134, in build
_ = self.run(tf.function(self.model.__call__), batch)
File "/home/hans/code/maua-ddsp/ddsp/training/trainers.py", line 129, in run
return self.strategy.run(fn, args=args, kwargs=kwargs)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1211, in run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2585, in call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 585, in _call_for_each_replica
self._container_strategy(), fn, args, kwargs)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/mirrored_run.py", line 78, in call_for_each_replica
return wrapped(args, kwargs)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
result = self._call(*args, **kwds)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 904, in _call
return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2828, in __call__
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
raise e.ag_error_metadata.to_exception(e)
AssertionError: in user code:
/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:896 fn_with_cond *
functools.partial(self._concrete_stateful_fn._filtered_call, # pylint: disable=protected-access
/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201 wrapper **
return target(*args, **kwargs)
/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:507 new_func
return func(*args, **kwargs)
/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py:1180 cond
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/ops/cond_v2.py:92 cond_v2
op_return_value=pred)
/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:986 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1848 _filtered_call
cancellation_manager=cancellation_manager)
/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1877 _call_flat
for v in self._func_graph.variables:
/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:489 variables
return tuple(deref(v) for v in self._weak_variables)
/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:489 <genexpr>
return tuple(deref(v) for v in self._weak_variables)
/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:482 deref
"Called a function referencing variables which have been deleted. "
AssertionError: Called a function referencing variables which have been deleted. This likely means that function-local variables were created and not referenced elsewhere in the program. This is generally a mistake; consider storing variables in an object attribute on first call.
In call to configurable 'train' (<function train at 0x7f995ef00268>)
How can I train with this loss enabled?
Can you give more details on your initial config/run command and the one used for restarting the job? Are you warmstarting from the pretrained checkpoint but adding a new loss?
Yes I want to warmstart with the pretrained checkpoint. Although I get the same error when training from scratch with the crepe embedding loss added in ae.gin.
My original training command:
python -m ddsp.training.ddsp_run \
--mode=train \
--alsologtostderr \
--save_dir="/home/hans/modelzoo/neuro-bass-ddsp-48kHz/" \
--gin_file=models/solo_instrument.gin \
--gin_file=datasets/tfrecord.gin \
--gin_param="TFRecordProvider.file_pattern='/home/hans/datasets/neuro-bass-ddsp/48kHz/train.tfrecord*'" \
--gin_param="batch_size=16" \
--gin_param="train_util.train.num_steps=300000" \
--gin_param="train_util.train.steps_per_save=3000" \
--gin_param="trainers.Trainer.checkpoints_to_keep=10" \
--gin_param="TFRecordProvider.example_secs=4" \
--gin_param="TFRecordProvider.sample_rate=48000" \
--gin_param="TFRecordProvider.frame_rate=250" \
--gin_param="Additive.n_samples=192000" \
--gin_param="Additive.sample_rate=48000" \
--gin_param="FilteredNoise.n_samples=192000"
Then after having trained overnight, I've added PretrainedCREPEEmbeddingLoss() in ae.gin (which solo_instrument.gin inherits from):
Autoencoder.losses = [
@losses.SpectralLoss(),
@losses.PretrainedCREPEEmbeddingLoss(),
]
Then I'm running and getting the error (the error is the same with or without --restore_dir):
python -m ddsp.training.ddsp_run \
--mode=train \
--alsologtostderr \
--save_dir="/home/hans/modelzoo/neuro-bass-ddsp-48kHz-crepe/" \
--restore_dir="/home/hans/modelzoo/neuro-bass-ddsp-48kHz/" \
--gin_file=models/solo_instrument.gin \
--gin_file=datasets/tfrecord.gin \
--gin_param="TFRecordProvider.file_pattern='/home/hans/datasets/neuro-bass-ddsp/48kHz/train.tfrecord*'" \
--gin_param="batch_size=16" \
--gin_param="train_util.train.num_steps=300000" \
--gin_param="train_util.train.steps_per_save=3000" \
--gin_param="trainers.Trainer.checkpoints_to_keep=10" \
--gin_param="TFRecordProvider.example_secs=4" \
--gin_param="TFRecordProvider.sample_rate=48000" \
--gin_param="TFRecordProvider.frame_rate=250" \
--gin_param="Additive.n_samples=192000" \
--gin_param="Additive.sample_rate=48000" \
--gin_param="FilteredNoise.n_samples=192000"
Update: I've found that running with only a single GPU (via CUDA_VISIBLE_DEVICES=0) does work to train with the PretrainedCREPEEmbeddingLoss.
Is there a way to allow the PretrainedCREPEEmbeddingLoss to work with multi-gpu training?