Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Allowing for function replacement in LSTM and GRU layers

Open NTimmons opened this issue 5 years ago • 17 comments

This makes the behaviour of the LSTM and GRU recurrent layers consistent with other layers.

NTimmons avatar Dec 09 '19 16:12 NTimmons

This seems like a good thing to have, although perhaps with more informative names than sigmRepl etc; is there a better name for these gates?

Also, the CUDA bindings will have to be updated, since only layers with the default setup will be able to use CUDNN.

MikeInnes avatar Dec 09 '19 16:12 MikeInnes

Ah, I hadnt seen the CUDA bindings. I'll get on that

I'll take a look to see if the papers call the gates something particular. I was going to go with σ and tanh for the names to stay in sync with the other layers which just use σ but thought that would be extra confusing hah

NTimmons avatar Dec 09 '19 16:12 NTimmons

For the CUDA bindings, without going into CUDNN, I don't think there is much I can do except throw a message saying that the selected functions aren't GPU compatible? I looked at the CUDNN docs (https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html) and it looks like it has a hardcoded implementation? I am not sure...

Probably would restrict submission of a custom function LSTM/GRU to the GPU here though?

curnn.jl:10
function CUDNN.RNNDesc(m::CuRNNs{T}) where T
  h, i = length(m.h), size(m.Wi, 2)
  mode = m isa CuRNN ?
    (m.σ == tanh ? CUDNN.CUDNN_RNN_TANH : CUDNN.CUDNN_RNN_RELU) :
    m isa CuGRU ? CUDNN.CUDNN_GRU : CUDNN.CUDNN_LSTM
  r = CUDNN.RNNDesc{T}(mode, i, h)
  return r
end

For the replacement names - we would need to call them the input, output and forget gates. But the input and output gate functions would be a pair (sigm, tanh). As seen here...

forgetgate

I would rather keep the current setup with just the two functions because I want to use it for replacing sigmoid with better sigmoids and tanh with better tanh.

But, I know that others are interested in controlling the actual model where they may want to vary the input and output gates independently.

I am happy either way though, so let me know what you think would be best.

NTimmons avatar Dec 09 '19 16:12 NTimmons

RE cuda, you can look here. The situation is the same for vanilla RNNs, we allow the activation function to be changed but only two are supported by CUDNN. Changing the CuRNN signature prevents incompatible RNNs from ever dispatching to CUDNN, they just fall through to the Julia implementation instead; the same should be possible for LSTMs.

MikeInnes avatar Dec 10 '19 15:12 MikeInnes

Hi, I have done the changes required for CUDA. Is there any performance benchmarks/tests I could run to ensure that my changes haven' t caused any performance regression?

NTimmons avatar Dec 12 '19 15:12 NTimmons

Updated the repo with the CUDA support. Added tests to the Cuda test files so that we can check that explicitly stating the sigmoid and tanh functions works as well as a test with a random function.

I am limited to a laptop at the moment with a limited NVIDIA chip, so couldn't check the performance for any regressions but can when I return after Christmas.

NTimmons avatar Dec 15 '19 11:12 NTimmons

Is anything else needed for this to be merged?

NTimmons avatar Jan 14 '20 13:01 NTimmons

bors try

MikeInnes avatar Jan 14 '20 13:01 MikeInnes

try

Build succeeded

bors[bot] avatar Jan 14 '20 14:01 bors[bot]

bors try

NTimmons avatar Jan 17 '20 11:01 NTimmons

:lock: Permission denied

Existing reviewers: click here to make NTimmons a reviewer

bors[bot] avatar Jan 17 '20 11:01 bors[bot]

Damn I think the normal checking is broken. Seems to be failing on downloading data... Local testing all passes.

NTimmons avatar Jan 17 '20 11:01 NTimmons

This looks like a network error. Can we rerun the test and merge this?

NTimmons avatar Mar 25 '20 14:03 NTimmons

@NTimmons could you rebase? Then we can merge if tests pass

CarloLucibello avatar Jul 02 '20 17:07 CarloLucibello

The docs currently say: LSTM(in::Integer, out::Integer, σ = tanh). But this fails:

julia> LSTM(1,1,σ=tanh)
ERROR: MethodError: no method matching Flux.LSTMCell(::Int64, ::Int64; σ=tanh)

I presume docs were updated prematurely before this was merged..?

tbenst avatar Nov 15 '20 20:11 tbenst

I don't believe this has been merged. I need to check it and retest it. Been busy on other projects and missed the update. I will get on this ASAP

NTimmons avatar Nov 15 '20 20:11 NTimmons

Those docs are also from v0.5. The latest reference looks correct.

ToucheSir avatar Nov 15 '20 20:11 ToucheSir