pyannote-audio
pyannote-audio copied to clipboard
Add `balance_weights` to weight balanced batches
The balance
option of the segmentation tasks allows to pass a list of ProtocolFile
fields, e.g. ['database', 'foo']
. Then when batches are sampled, it looks at all existing combinations of values for these fields in the task protocol.
For example if they come from databases aishell
and ami
, and their foo
field is either a
or b
, we compute the cartesian product [('aishell', 'a'), ('aishell', 'b'), ('ami', 'a'), ('ami', 'b')]
, batches are created by randomly selecting one of these tuples and picking a sample from a matching file.
The PR allows to weight the random choice from the cartesian product. For example with
balance_weights = {
('aishell'):2.0,
('ami', 'b'): 4.0,
}
we will sample from the cartesian product using random.choices with these weights:
selected = random.choices(
population=[('aishell', 'a'), ('aishell', 'b'), ('ami', 'a'), ('ami', 'b')],
weights=[2.0, 2.0, 1.0, 4.0],
k=1,
)[0]
e.g. for each tuple of the cartesian product, we find the longest matching (tuple) prefix in balance_weights
and use this weight.
I'm not sure this approach is flexible/clean enough to be PR-ready, and it's hard to make the docstring concise, but i think it could be really useful :)