SparkNet icon indicating copy to clipboard operation
SparkNet copied to clipboard

Machine Learning Provider Abstraction Layer

Open javadba opened this issue 9 years ago • 5 comments

I am a coder for a team looking to consider using SparkNet with another ML library besides caffe. The intent of this Issue is to capture discussions on a ML Provider Abstraction Layer (MLPAL?) that would permit pluggable use of Caffe vs SomeOtherMLLibrary.

To the core committers: do you already have thoughts and/or a Roadmap for this? In any case our thoughts will start appearing here.

javadba avatar Feb 16 '16 22:02 javadba

Thanks for bringing this up, we are very interested in this question. Providing unified APIs and data loading procedures is one of the areas where we can add value compared to what is already out there in terms of deep learning libraries. Data loading/processing is on of Spark's main strength.

Let us know about your suggestions, our current plan is to provide interfaces that can be implemented by various backends.

For the network, the interface would look like this:

trait NetInterface {
  def forward(rowIt: Iterator[Row]): Array[Row]
  def forwardBackward(rowIt: Iterator[Row])
  def getWeights(): WeightCollection
  def setWeights(weights: WeightCollection)
  def outputSchema(): StructType
}

For the Solver:

trait Solver {
  def step(rowIt: Iterator[Row])
}

Data would be loaded in a unified way from Spark DataFrames. We are working on this in the javacpp+dataframes branch, see for example this file.

pcmoritz avatar Feb 17 '16 02:02 pcmoritz

Awesome! This may be a fair bit less complicated than anticipated from our perspective. I am interested in trying out that trait with another ml library.

Please suggest which test(s) to run that would best validate the usability of your NetInterface with the OtherMlLibrary framework.

javadba avatar Feb 17 '16 02:02 javadba

Thanks, the least complicated approaches are often the best.

I can sketch how we plan to implement the interface for TensorFlow.

Assume you have a TensorFlow graph definition like this (in Python):

import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
tf.initialize_all_variables().run()

You can then serialize the graph and the weights in the following way:

saver = tf.train.Saver()
saver.save(sess, "model.bin")
g = tf.Graph()
g.as_graph_def()

SparkNet would provide a TensorFlowNet class which implements the Net trait and as a constructor takes the protocol buffer definition that is generated by g.as_graph_def(). Furthermore, there would be a procedure for loading the weights saved by tf.train.Saver into a WeightCollection object and an implementation of setWeights that loads the weights into the Network.

If you are interested in pursuing this, you can start from the JavaCPP TensorFlow implementation and implement the TensorFlowNet as well as the TensorFlowSolver class. This is high priority for us, but before we get to it we would like to improve a few other things first.

pcmoritz avatar Feb 17 '16 05:02 pcmoritz

OK i will first dig a bit into the javacpp-presets as a background and then at the JavaCpp-presets for TensorFlow. ETA for an update is late Thursday 2/18

2016-02-16 21:52 GMT-08:00 Philipp Moritz [email protected]:

I can sketch how we plan to implement the interface for TensorFlow.

Assume you have a TensorFlow graph definition like this (in Python):

import tensorflow as tf sess = tf.InteractiveSession() x = tf.placeholder(tf.float32, [None, 784])W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b) tf.initialize_all_variables().run()

You can then serialize the graph and the weights in the following way:

saver = tf.train.Saver() saver.save(sess, "model.bin") g = tf.Graph() g.as_graph_def()

SparkNet would provide a TensorFlowNet class which implements the Net trait and as a constructor takes the protocol buffer definition that is generated by g.as_graph_def(). Furthermore, there would be a procedure for loading the weights saved by tf.train.Saver into a WeightCollection object and an implementation of setWeights that loads the weights into the Network.

If you are interested in pursuing this, you can start from the JavaCPP TensorFlow implementation https://github.com/bytedeco/javacpp-presets/tree/master/tensorflow and implement the TensorFlowNet as well as TensorFlowSolver class. This is high priority for us, but before we get to it we would like to improve a few other things first.

— Reply to this email directly or view it on GitHub https://github.com/amplab/SparkNet/issues/70#issuecomment-185040009.

javadba avatar Feb 17 '16 05:02 javadba

Great, any progress on this will be very helpful for the project, and don't hesitate to ask questions if you run into problems. We have a bunch of experience with JavaCPP by now and might be able to help you.

To get started, you can both try to run the ExampleTrainer.java from the TensorFlow preset, and also our Cifar training app in the SparkNet javacpp+dataframes branch. It is almost ready to merge now, we just haven't gotten around to create the AMI yet.

pcmoritz avatar Feb 17 '16 06:02 pcmoritz