wenet icon indicating copy to clipboard operation
wenet copied to clipboard

[SSL] Config for BESTRQ

Open ncakhoa opened this issue 1 year ago • 2 comments

Hi, Can you provide config for training bestrq model. Also, the paper reported that this model can streaming, have you implemented it? Thanks you.

ncakhoa avatar Jul 25 '24 10:07 ncakhoa

Will be available in the next few days:

  • [x] init model: https://github.com/wenet-e2e/wenet/pull/2595
  • [x] init ssl dataset https://github.com/wenet-e2e/wenet/pulls
  • [ ] bestrq script

Mddct avatar Aug 07 '24 06:08 Mddct

# network architecture
# network architecture
# encoder related
input_dim: 80
encoder: conformer
encoder_conf:
    output_size: 256    # dimension of attention
    attention_heads: 4
    linear_units: 2048  # the number of units of position-wise feed forward
    num_blocks: 12      # the number of encoder blocks
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.0
    input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
    normalize_before: true
    cnn_module_kernel: 15
    use_cnn_module: True
    use_sdpa: false
    gradient_checkpointing: false
    pos_enc_layer_type: rope_pos
    activation_type: 'gelu'
    mlp_type: 'gated'
    query_bias: false
    key_bias: false
    value_bias: false
    layer_norm_type: rms_norm
    norm_eps: 1.0e-6
    mlp_bias: false
    selfattention_layer_type: rope_abs_selfattn
    conv_bias: true
    # conv_bias: false
    # cnn_module_norm: "rms_norm"
    # n_kv_head: 1
    # head_dim: 64


# decoder related
decoder: transformer
decoder_conf:
        # linear_bias: False
        # tie_word_embedding: false
        # n_kv_head: 1
        # head_dim: 64
    attention_heads: 4
    linear_units: 2048
    num_blocks: 6
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.0
    src_attention_dropout_rate: 0.0
    use_sdpa: false
    gradient_checkpointing: false
    activation_type: 'gelu'
    mlp_type: 'gated'
    query_bias: false
    key_bias: false
    value_bias: false
    mlp_bias: false
    layer_norm_type: rms_norm
    norm_eps: 1.0e-6


tokenizer: char
tokenizer_conf:
  symbol_table_path:  null
  split_with_space: false
  bpe_path: null
  non_lang_syms_path: null
  is_multilingual: false
  num_languages: 1
  special_tokens:
    <blank>: 0
    <unk>: 1
    <sos>: 2
    <eos>: 2

ctc: ctc
ctc_conf:
  ctc_blank_id: 0

cmvn: null # global_cmvn
cmvn_conf:
  cmvn_file: null  # or cmvn path
  is_json_cmvn: true

# hybrid CTC/attention
model: bestrq_model
model_conf:
    num_mel_bins:  80
    embedding_dim:  16
    num_embeddings:  8192
    num_codebooks:  1
    mask_prob:  0.01
    mask_length: 10
    min_masks:  2
    norm_epsilon:  1.0e-5
    features_regularization_weight:  0.00

dataset: ssl
dataset_conf:
    filter_conf:
        max_length: 40960
        min_length: 300
        token_max_length: 200
        token_min_length: 1
    resample_conf:
        resample_rate: 16000
    speed_perturb: true
    fbank_conf:
        num_mel_bins: 80
        frame_shift: 10
        frame_length: 25
        dither: 0.1
    spec_aug: false
    spec_aug_conf:
        num_t_mask: 2
        num_f_mask: 2
        max_t: 50
        max_f: 10
    shuffle: true
    shuffle_conf:
        shuffle_size: 1500
    sort: true
    sort_conf:
        sort_size: 500  # sort_size should be less than shuffle_size
    batch_conf:
            # batch_type: static
            # batch_size: 1
            batch_type: 'dynamic'
            max_frames_in_batch: 50000
            # batch_type: 'bucket'
            # bucket_boundaries:  [500, 1000, 1500]
            # batch_type: 'static'
            # batch_size: 10
            # batch_type: 'bucket' # static or bucket or dynamic
            #  bucket_batch_sizes: [128, 64, 32, 16]
            # bucket_batch_sizes: [32, 32, 32, 16]

grad_clip: 20
accum_grad: 1
max_epoch: 240
log_interval: 100
save_interval: 2000

optim: adam
optim_conf:
    lr: 0.0008
scheduler: warmuplr     # pytorch v1.1.0+ required
scheduler_conf:
    warmup_steps: 25000

Mddct avatar Aug 08 '24 06:08 Mddct