whisperX
whisperX copied to clipboard
fix(alignment-filling-start-end): fix the bug that may generate non-monotonic timecode when filling Nan
Problem
Get non-monotonic timecode when a segment get both invalid start
and end
.
See the following example.
Raw data from alignment model:
seg start end
seg1 Nan Nan
seg2 1 3
seg3 4 9
Filled result:
seg start end
seg1 1 3
seg2 1 3
seg3 4 9
We can see the end of seg1
(3) >= start of seg2
(1), it's non-monotonic
Why
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
Current code process the start and end independently.
Fix
Let's interpolate them jointly, it can generate monotonic result:
seq_timecode_vals = aligned_subsegments[["start", "end"]].values.ravel("C")
filled_seq_timecodes = interpolate_nans(pd.Series(seq_timecode_vals), method=interpolate_method)
aligned_subsegments["start"] = filled_seq_timecodes.iloc[::2].values
aligned_subsegments["end"] = filled_seq_timecodes.iloc[1::2].values
Test Code
import pandas as pd
datas = [
[None, None],
[1, 3],
[4, 9]
]
records = [
{
"start": p[0],
"end": p[1],
"words": f"{p}",
}
for p in datas
]
# copy from utils for quickly run
def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
else:
return x.ffill().bfill()
# ORIGINAL RESULT
interpolate_method = "nearest"
aligned_subsegments = pd.DataFrame.from_records(records)
## -> old logic
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
print(aligned_subsegments)
#> start end words
#> 0 1.0 3.0 [None, None] # not monotonic
#> 1 1.0 3.0 [1, 3]
#> 2 4.0 9.0 [4, 9]
## New result
interpolate_method = "nearest"
aligned_subsegments = pd.DataFrame.from_records(records)
## new logic
seq_timecode_vals = aligned_subsegments[["start", "end"]].values.ravel("C")
filled_seq_timecodes = interpolate_nans(pd.Series(seq_timecode_vals), method=interpolate_method)
aligned_subsegments["start"] = filled_seq_timecodes.iloc[::2].values
aligned_subsegments["end"] = filled_seq_timecodes.iloc[1::2].values
print(aligned_subsegments)
#> start end words
#> 0 1.0 1.0 [None, None] # fixed
#> 1 1.0 3.0 [1, 3]
#> 2 4.0 9.0 [4, 9]
Final Notes: Why we get Nan for whole segments?
See this sentence example:
1. Quickly put the water to the table
It will be split to 2 sentences badly by the punc model:
1. => Oooops, this leads to start, end all Nan
Quickly put the water to the table