wenet
wenet copied to clipboard
[SSL] Config for BESTRQ
Hi, Can you provide config for training bestrq model. Also, the paper reported that this model can streaming, have you implemented it? Thanks you.
Will be available in the next few days:
- [x] init model: https://github.com/wenet-e2e/wenet/pull/2595
- [x] init ssl dataset https://github.com/wenet-e2e/wenet/pulls
- [ ] bestrq script
# network architecture
# network architecture
# encoder related
input_dim: 80
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 15
use_cnn_module: True
use_sdpa: false
gradient_checkpointing: false
pos_enc_layer_type: rope_pos
activation_type: 'gelu'
mlp_type: 'gated'
query_bias: false
key_bias: false
value_bias: false
layer_norm_type: rms_norm
norm_eps: 1.0e-6
mlp_bias: false
selfattention_layer_type: rope_abs_selfattn
conv_bias: true
# conv_bias: false
# cnn_module_norm: "rms_norm"
# n_kv_head: 1
# head_dim: 64
# decoder related
decoder: transformer
decoder_conf:
# linear_bias: False
# tie_word_embedding: false
# n_kv_head: 1
# head_dim: 64
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
use_sdpa: false
gradient_checkpointing: false
activation_type: 'gelu'
mlp_type: 'gated'
query_bias: false
key_bias: false
value_bias: false
mlp_bias: false
layer_norm_type: rms_norm
norm_eps: 1.0e-6
tokenizer: char
tokenizer_conf:
symbol_table_path: null
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2
ctc: ctc
ctc_conf:
ctc_blank_id: 0
cmvn: null # global_cmvn
cmvn_conf:
cmvn_file: null # or cmvn path
is_json_cmvn: true
# hybrid CTC/attention
model: bestrq_model
model_conf:
num_mel_bins: 80
embedding_dim: 16
num_embeddings: 8192
num_codebooks: 1
mask_prob: 0.01
mask_length: 10
min_masks: 2
norm_epsilon: 1.0e-5
features_regularization_weight: 0.00
dataset: ssl
dataset_conf:
filter_conf:
max_length: 40960
min_length: 300
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 0.1
spec_aug: false
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
# batch_type: static
# batch_size: 1
batch_type: 'dynamic'
max_frames_in_batch: 50000
# batch_type: 'bucket'
# bucket_boundaries: [500, 1000, 1500]
# batch_type: 'static'
# batch_size: 10
# batch_type: 'bucket' # static or bucket or dynamic
# bucket_batch_sizes: [128, 64, 32, 16]
# bucket_batch_sizes: [32, 32, 32, 16]
grad_clip: 20
accum_grad: 1
max_epoch: 240
log_interval: 100
save_interval: 2000
optim: adam
optim_conf:
lr: 0.0008
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000