TensorFlowSharp
TensorFlowSharp copied to clipboard
Add Run(Expression)
In Python, this code is possible:
session.run(a+b)
This both produces the Add operation with the two parameters, but also the result is automatically fetched.
Generally, this requires a "context" to properly work, as in the more common scenario, you would not only be doing Add operations, but likely invoking other methods, and those methods currently live in the TFGraph class.
So either we do what the Python bindings do which is to have a global variable for a default context, and surface an API that operates on this global context or we end up with an API that is not as pretty looking as it could be.
Interesting.
What do you think about the following API:
C#:
float n = 5; // specifies input type and value
float p = 10;
// run multiply expression which will be translated Graph.Mul()
// n, p will be translated to placeholders
var mul = session.Run(() => n * p);
F#:
let result = session.Run(translate <@@ fun () -> n* p @@>) // where 'translate' function transpates F# quatation to LINQ expression
The TFSession class has a Graph property which can be used as a "context". Here's proof of concept implementation of such Run(Expression) API:
public TFTensor[] Run<TResult>(Expression<Func<TResult>> lambdaExpression)
{
// this is very simple example of expression tree parsing
// it considers only BinaryExpression
var binaryExpression = lambdaExpression.Body as BinaryExpression;
var leftMember = (MemberExpression)binaryExpression.Left;
var leftValue = GetValue(leftMember);
var rightMember = (MemberExpression)binaryExpression.Right;
var rightValue = GetValue(rightMember);
switch (binaryExpression.NodeType)
{
// translate expression to appropriate tensorflow graph
case ExpressionType.Multiply:
var left = Graph.Placeholder(ToTFDataType(binaryExpression.Left.Type));
var right = Graph.Placeholder(ToTFDataType(binaryExpression.Right.Type));
var y = Graph.Mul(left, right);
return Run(
new[] { left, right },
new []{ new TFTensor((float)leftValue), new TFTensor((float)rightValue) },
new[] { y });
break;
}
...
}
Mhm, I do like this idea. The next step is: how do you "reference" the various methods in TFGraph in the nested expression though?
I discovered a little the python implementation. The function get_default_graph
is used for getting the default graph. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/ops.py#L4771
It's based on the local thread stack.
If we will have it, then it would be possible to add overrides of operators in TFOutput. For example:
public static TFOutput operator + (TFOutput o1, TFOutput o2)
{
var desc = new TFOperationDesc (DefaultGraph.Instance, "Add", MakeName ("Add", operName));
desc.AddInput (x);
desc.AddInput (y);
var op = desc.FinishOperation ();
int _idx = 0;
var z = new TFOutput (op, _idx++);
return z;
}
I could implement it, if idea is OK.
Just as an aside, the operator idea here could also work if my musings regarding issue #3 is acceptable. Each output would then have access to a TFGraph reference.