seamless_communication
seamless_communication copied to clipboard
Transcriber class
[work in progress]
Hi @gonzalpi!
Thank you for your pull request and welcome to our community.
Action Required
In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.
Process
In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.
Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed
. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.
If you have received this in error or have any questions, please contact us at [email protected]. Thanks!
@gonzalpi here is the example code for the timestamp extraction wrapper:
from fairseq2.generation.sequence_generator import (
Seq2SeqGenerator,
SequenceGeneratorOptions,
SequenceGeneratorOutput,
VocabularyInfo,
)
from fairseq2.nn.transformer.multihead_attention import AttentionWeightHook
from seamless_communication.models.unity import UnitYModel, UnitYX2TModel
from fairseq2.nn.transformer.norm_order import TransformerNormOrder
from seamless_communication.models.unity.builder import (
NllbConfig,
UnitYConfig,
UnitYT2UConfig,
Wav2Vec2EncoderConfig,
create_unity_model,
)
_log = logging.getLogger(__name__)
class EncDecAttentionsCollect(AttentionWeightHook):
def __init__(self):
super().__init__()
self.attn_scores = []
def __call__(self, m, attn_weights) -> None:
val = torch.clone(attn_weights).detach().sum(dim=0).squeeze(0).tolist()
self.attn_scores.append(val)
def reset(self):
self.attn_scores = []
def lis(arr):
n = len(arr)
lis = [1] * n
prev = [0] * n
for i in range(0, n):
prev[i] = i
for i in range(1, n):
for j in range(0, i):
if arr[i] > arr[j] and lis[i] < lis[j] + 1:
lis[i] = lis[j] + 1
prev[i] = j
maximum = 0
idx = 0
for i in range(n):
if maximum < lis[i]:
maximum = lis[i]
idx = i
seq = [arr[idx]]
while idx != prev[idx]:
idx = prev[idx]
seq.append(arr[idx])
return (maximum, reversed(seq))
class M4TTiny:
def __init__(
self,
checkpoint_path: str,
tokenizer_path: str,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
encoder_layers: int = 6,
decoder_layers: int = 3,
embed_dim: int = 512,
depthwise_conv_kernel_size: int = 31,
):
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
self.device = device
self.dtype = dtype
self.embed_dim = embed_dim
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.depthwise_conv_kernel_size = depthwise_conv_kernel_size
self.tokenizer = spm.SentencePieceProcessor()
self.tokenizer.LoadFromFile(self.tokenizer_path)
self.decoder_vocab_info = VocabularyInfo(
20010, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
)
self.langs = ["English", "Hindi", "Portuguese", "Russian", "Spanish", "Bengali", "Urdu"]
model = self._load_model()
self.s2t = UnitYX2TModel(
encoder_frontend=model.speech_encoder_frontend,
encoder=model.speech_encoder,
decoder_frontend=model.text_decoder_frontend,
decoder=model.text_decoder,
final_proj=model.final_proj,
pad_idx=model.pad_idx,
)
self.enc_dec_attn_collector = EncDecAttentionsCollect()
self.s2t.decoder.layers[-1].encoder_decoder_attn.register_attn_weight_hook(
self.enc_dec_attn_collector
)
self.gen_opts = SequenceGeneratorOptions(beam_size=1)
prefix = [self.decoder_vocab_info.eos_idx]
self.prefix_len = len(prefix) + 1 # generated langtok also goes as a prefix
self.generator = Seq2SeqGenerator(
decoder=self.s2t,
vocab_info=self.decoder_vocab_info,
prefix_seq=torch.LongTensor(prefix, device=self.device),
opts=self.gen_opts,
)
def _load_model(
self,
) -> UnitYModel:
unity_config = self._get_config()
model = create_unity_model(unity_config, device=self.device, dtype=self.dtype)
state_dict = torch.load(self.checkpoint_path)
updated = OrderedDict()
for key, value in state_dict.items():
for prefix in ["module.", "model."]:
if key.startswith(prefix):
key = key[len(prefix) :]
updated[key] = value
model.load_state_dict(updated)
model.eval()
return model
def _get_config(
self,
) -> UnitYConfig:
encoder_config = Wav2Vec2EncoderConfig(
model_dim=self.embed_dim,
max_seq_len=4096,
feature_dim=160,
use_fbank=True,
first_pass_dropout_p=0.0,
layer_norm_features=False,
feature_extractor_layer_descs=[],
feature_extractor_bias=False,
feature_extractor_layer_norm_convs=False,
feature_grad_scale=0,
num_fbank_channels=80,
fbank_stride=2,
sample_fbank_every_k=1,
pos_encoder_type="relative",
pos_encoder_depth=0,
pos_conv_kernel_size=0,
num_pos_conv_groups=0,
use_conformer=True,
num_encoder_layers=self.encoder_layers, # changed
num_encoder_attn_heads=16,
ffn_inner_dim=self.embed_dim * 4, # changed
dropout_p=0.0,
attn_dropout_p=0.0,
layer_drop_p=0.0,
norm_order=TransformerNormOrder.POST,
depthwise_conv_kernel_size=self.depthwise_conv_kernel_size,
)
nllb_config = NllbConfig(
model_dim=self.embed_dim, # changed
max_seq_len=1024, # changed
vocabulary_size=20008, # changed
pad_idx=1,
num_encoder_layers=1, # changed
num_decoder_layers=self.decoder_layers, # changed
num_encoder_attn_heads=16,
num_decoder_attn_heads=16,
ffn_inner_dim=self.embed_dim * 4, # changed
dropout_p=0.1,
)
t2u_config = UnitYT2UConfig(
model_dim=self.embed_dim, # changed
unit_max_seq_len=2048, # changed
unit_vocabulary_size=10015,
unit_pad_idx=1,
num_encoder_layers=1, # changed
num_decoder_layers=1, # changed
num_encoder_attn_heads=16,
num_decoder_attn_heads=16,
ffn_inner_dim=self.embed_dim * 8, # changed
dropout_p=0.1,
)
return UnitYConfig(
model_dim=self.embed_dim, # changed
w2v2_encoder_config=encoder_config,
nllb_config=nllb_config,
t2u_config=t2u_config,
use_text_encoder=False, # changed
use_conformer_adaptor=True,
num_adaptor_layers=1,
adaptor_kernel_size=8,
adaptor_stride=8,
adaptor_layer_norm=True,
adaptor_dropout_p=0.1,
)
@classmethod
def _word_stats_to_string(cls, lang_name: str, word_stats: List[Any]) -> str:
words = [
f"{stat['text']}[t{stat['time_s']:0.2f}, p{stat['prob']:0.2f}]"
for stat in word_stats
]
return f"[[{lang_name}]]: " + " ".join(words)
@classmethod
def _extract_timestamps(cls, attn_weights, audio_len):
num_out_tokens = len(attn_weights)
num_encoder_steps = len(attn_weights[0])
attn_weights = np.array(attn_weights)
col_maxes = np.argmax(attn_weights, axis=0)
lis_input = [
(out_tok_idx, -enc_bin_idx) for enc_bin_idx, out_tok_idx in enumerate(col_maxes)
]
tok_idx_to_start_enc_bin_idx = {
out_tok_idx: -enc_bin_idx for out_tok_idx, enc_bin_idx in lis(lis_input)[1]
}
prev_start = 0
starts = []
for tok_idx in range(num_out_tokens):
start_enc_bin_idx = tok_idx_to_start_enc_bin_idx.get(tok_idx, prev_start)
starts.append(start_enc_bin_idx)
prev_start = start_enc_bin_idx
seconds_per_enc_pos = audio_len / num_encoder_steps
start_times = [seconds_per_enc_pos * start_pos for start_pos in starts]
return start_times
@classmethod
def _collect_word_level_stats(
cls,
pieces: List[str], token_timestamps: List[float], step_scores: List[float]
):
assert len(pieces) == len(token_timestamps) and len(token_timestamps) == len(
step_scores
)
word_stats: List[List[Any]] = []
for (
time_s,
token,
score,
) in zip(token_timestamps, pieces, step_scores):
if not word_stats or token.startswith("▁") and time_s > word_stats[-1][1]:
word_stats.append(
[token.replace("▁", " ").strip(), time_s, [np.exp(score)]]
)
else:
word_stats[-1][0] += token.replace("▁", " ")
word_stats[-1][2].append(np.exp(score))
word_stats = [
{"text": word, "time_s": start, "prob": np.mean(probs)}
for word, start, probs in word_stats
]
return word_stats
def run_inference(self, fbanks: torch.Tensor, length_seconds: float) -> str:
encoder_output, encoder_padding_mask = self.s2t.encode(fbanks.unsqueeze(0), None)
self.enc_dec_attn_collector.reset()
output: SequenceGeneratorOutput = self.generator(
encoder_output=encoder_output, encoder_padding_mask=encoder_padding_mask
)
lang_token = output.results[0][0].seq.squeeze(0)[1].item()
lang_name = self.langs[lang_token - 20000]
token_ids = output.results[0][0].seq.squeeze(0)[self.prefix_len:].tolist()
step_scores = output.results[0][0].step_scores[self.prefix_len:].tolist()
enc_dec_attn_scores = self.enc_dec_attn_collector.attn_scores[self.prefix_len - 1 :]
token_timestamps = self._extract_timestamps(enc_dec_attn_scores, length_seconds)
pieces = [self.tokenizer.IdToPiece(token_id) for token_id in token_ids]
word_stats = self._collect_word_level_stats(
pieces=pieces, token_timestamps=token_timestamps, step_scores=step_scores
)
augmented_text = self._word_stats_to_string(lang_name, word_stats)
return augmented_text
Few comments on the snippet above:
- _load_model - this is a custom thing that is not needed for your work. You can load default M4T models using load_unity_model
- self.tokenizer -- should be default tokenizer that comes with the model
- self.decoder_vocab_info -- should come from the tokenizer
- self.langs -- should also come from the tokenizer
Code to generate prefix tokens:
from seamless_communication.models.unity import load_unity_text_tokenizer
prefix_tokens = load_unity_text_tokenizer("seamlessM4T_medium").create_encoder(mode="target", lang="eng").prefix_indices.tolist()
I have a question regarding the following section of the project scope document:
The expected API would be similar to the API of Translator, except that Transcriber will only allow Speech-to-Text transformation (thus lighter in terms of input parameter), and instead of the general
predict
method it will exposetranscribe
method that takes: (1) path to input audio (2) dialect of the input audio (3) (optional) audio sample rate (4) (optional) sequence generation params. The output of the method will be a stream of transcribed words with time codes.
Given that the dialect and seq gen params are used to generate the model object, would it not make more sense for them to go in the __init__
method? Edit: On the same note, do these "params" refer to those already in __init__
or should they be a separate object passed for the self.gen_opts: SequenceGeneratorOptions
object?
Also, it is my understanding that audio sample rate is necessary only when the input audio is a tensor and not a path, is that correct?
@mavlyutovr
@gonzalpi
- VocabularyInfo and SequenceGenerator should be instantiated at inference time (these are lightweight). That will make instances of Transcriber flexible in terms of language, and even accept batches of inputs with variable languages. We also have models that do language identification, so this interface will be useful for them as well.
- Language should come as an argument of the inference method. vocabulary_info will be assembled based on this input and parameters of the model and tokenizer (eos_idx, pad_idx, etc)
- Sequence generation params (e.g beam size) may go into the constructor inputs