swift-apis icon indicating copy to clipboard operation
swift-apis copied to clipboard

Recursive Neural Networks (structured data/trees)

Open tanmayb123 opened this issue 6 years ago • 3 comments

I can't wait for AutoDiff to support control flow. On that topic, I had a quick question: how easy/difficult would it be to implement recursive (not recurrent) neural networks? (https://stackoverflow.com/questions/26022866/how-can-a-tree-be-encoded-as-input-to-a-neural-network)

Could I do something like this (this is a hastily written sample only): https://gist.github.com/tanmayb123/c57c956120d3733ed0691c6406a74284

Is there anything specific in that implementation that would prevent it from working? Is there any other/a better way to do BPTS (Backprop through Structure)?

tanmayb123 avatar Mar 31 '19 09:03 tanmayb123

To make the structure a bit more clear, here's a picture: https://imgur.com/a/AF9llFX

Each circle is a node - every node that has 2 connections has a non-nil value in the connections variable, but every node with 1 connection is a "leaf node" and has a non-nil value in the encoding variable. The values are represented with words at the moment, but will be word vectors in the implementation.

tanmayb123 avatar Mar 31 '19 09:03 tanmayb123

Could I do something like this (this is a hastily written sample only): https://gist.github.com/tanmayb123/c57c956120d3733ed0691c6406a74284

This style can work, but it is not recommended. We recommend using recursive algebraic data types (indirect enums) to represent trees.

indirect enum Tree {
    case leaf(Tensor<Float>)
    case node(Tree, Tensor<Float>, Tree)
}

You can make the whole thing conform to Differentiable. The compiler currently doesn't synthesize the conformance yet so you need to do it by hand.

rxwei avatar Mar 31 '19 09:03 rxwei

Backpropagation doesn't necessarily require control flow differentiation once you are comfortable using pullbacks. Here's an RNN example: https://gist.github.com/rxwei/ce6644efad8f229651050e096c05ccbb.

rxwei avatar Mar 31 '19 09:03 rxwei