semg_repro
semg_repro copied to clipboard
pytorch version
May I ask is there a pytorch version?
Hmm no, but it should not be too hard to make! You need:
- focal loss (I think this is baked in to torchvision)
- Learning rate scheduler as described in the paper
- Mish activation (in pytorch)
- layer normalization (this is in pytorch)
- Naive attention mechanism (which should be pretty easy to implement, as it is pretty much matrix multiplication and a sum). I think you could also immediately get better results by replacing the attention mechanism, which in tensorflow looks like:
# attention mechanism
def attention_simple(inputs, timesteps):
input_dim = int(inputs.shape[-1])
a = Permute((2, 1), name='transpose')(inputs)
a = Dense(timesteps, activation='softmax', name='attention_probs')(a)
a_probs = Permute((2, 1), name='attention_vec')(a)
output_attention_mul = Multiply(name='focused_attention')([inputs, a_probs])
output_flat = Lambda(lambda x: K.sum(x, axis=1), name='temporal_average')(output_attention_mul)
return output_flat, a_probs
With something that looks like
def attention_simple(inputs, timesteps):
input_dim = int(inputs.shape[-1])
lhs = Dense(input_dim)(inputs) # linear layer
rhs = Dense(input_dim)(inputs) # linear layer
a = Permute((2, 1), name='transpose')(rhs)
a = Dense(timesteps, activation='softmax', name='attention_probs')(a) # alternatively here we could "just" softmax
a_probs = Permute((2, 1), name='attention_vec')(a)
output_attention_mul = Multiply(name='focused_attention')([lhs, rhs])
output_flat = Lambda(lambda x: K.sum(x, axis=1), name='temporal_average')(output_attention_mul)
return output_flat, a_probs
To convert to pytorch, you would replace Dense(..., activation=...) with Linear(...), [ACTIVATION FUNCTION HERE], and then you could compute the sum normally because pytorch is great :) Would also be interesting to experiment with not summing across time and treating it like a "real" attention model
Thx!!! I will try it!!!
Yeah! ping me if you run into trouble, i would be happy to give review and/or thoughts, but no longer really have the time to write code for this unfortunately :)