semg_repro icon indicating copy to clipboard operation
semg_repro copied to clipboard

pytorch version

Open Yihang6688 opened this issue 2 years ago • 3 comments

May I ask is there a pytorch version?

Yihang6688 avatar Dec 15 '22 17:12 Yihang6688

Hmm no, but it should not be too hard to make! You need:

  1. focal loss (I think this is baked in to torchvision)
  2. Learning rate scheduler as described in the paper
  3. Mish activation (in pytorch)
  4. layer normalization (this is in pytorch)
  5. Naive attention mechanism (which should be pretty easy to implement, as it is pretty much matrix multiplication and a sum). I think you could also immediately get better results by replacing the attention mechanism, which in tensorflow looks like:
# attention mechanism
def attention_simple(inputs, timesteps):
    input_dim = int(inputs.shape[-1])
    a = Permute((2, 1), name='transpose')(inputs)
    a = Dense(timesteps, activation='softmax',  name='attention_probs')(a)
    a_probs = Permute((2, 1), name='attention_vec')(a)
    output_attention_mul = Multiply(name='focused_attention')([inputs, a_probs])
    output_flat = Lambda(lambda x: K.sum(x, axis=1), name='temporal_average')(output_attention_mul)
    return output_flat, a_probs

With something that looks like

def attention_simple(inputs, timesteps):
    input_dim = int(inputs.shape[-1])
    lhs = Dense(input_dim)(inputs) # linear layer
    rhs = Dense(input_dim)(inputs) # linear layer
    a = Permute((2, 1), name='transpose')(rhs)
    a = Dense(timesteps, activation='softmax',  name='attention_probs')(a) # alternatively here we could "just" softmax
    a_probs = Permute((2, 1), name='attention_vec')(a)
    output_attention_mul = Multiply(name='focused_attention')([lhs, rhs])
    output_flat = Lambda(lambda x: K.sum(x, axis=1), name='temporal_average')(output_attention_mul)
    return output_flat, a_probs

To convert to pytorch, you would replace Dense(..., activation=...) with Linear(...), [ACTIVATION FUNCTION HERE], and then you could compute the sum normally because pytorch is great :) Would also be interesting to experiment with not summing across time and treating it like a "real" attention model

josephsdavid avatar Dec 15 '22 19:12 josephsdavid

Thx!!! I will try it!!!

Yihang6688 avatar Dec 19 '22 09:12 Yihang6688

Yeah! ping me if you run into trouble, i would be happy to give review and/or thoughts, but no longer really have the time to write code for this unfortunately :)

josephsdavid avatar Dec 19 '22 19:12 josephsdavid