pyannote-audio icon indicating copy to clipboard operation
pyannote-audio copied to clipboard

Add `balance_weights` to weight balanced batches

Open FrenchKrab opened this issue 6 months ago • 0 comments

The balance option of the segmentation tasks allows to pass a list of ProtocolFile fields, e.g. ['database', 'foo']. Then when batches are sampled, it looks at all existing combinations of values for these fields in the task protocol.

For example if they come from databases aishell and ami, and their foo field is either a or b, we compute the cartesian product [('aishell', 'a'), ('aishell', 'b'), ('ami', 'a'), ('ami', 'b')], batches are created by randomly selecting one of these tuples and picking a sample from a matching file.

The PR allows to weight the random choice from the cartesian product. For example with

balance_weights = {
  ('aishell'):2.0,
  ('ami', 'b'): 4.0,
}

we will sample from the cartesian product using random.choices with these weights:

selected = random.choices(
    population=[('aishell', 'a'), ('aishell', 'b'), ('ami', 'a'), ('ami', 'b')],
    weights=[2.0, 2.0, 1.0, 4.0],
    k=1,
)[0]

e.g. for each tuple of the cartesian product, we find the longest matching (tuple) prefix in balance_weights and use this weight.

I'm not sure this approach is flexible/clean enough to be PR-ready, and it's hard to make the docstring concise, but i think it could be really useful :)

FrenchKrab avatar Dec 14 '23 10:12 FrenchKrab