java icon indicating copy to clipboard operation
java copied to clipboard

How to update the weight about BlockLSTM ?

Open mullerhai opened this issue 3 years ago • 1 comments
trafficstars

HI :

now I use BlockLSTM for build lstm layer ,but I don't know how to update the lstm weight parameters , if need use blockLSTMGrad or something to do ,the coda is paste here:

object LstmExample {

  def initializeTruncatedNormalTensor(shape: Operand[TInt32], scope: Scope): TFloat32 = {
    TruncatedNormal.seed(1000L)
    //        TruncatedNormal<TFloat32> truncatedNormal = TruncatedNormal.create(scope, shape, TFloat32.DTYPE);
    //        DataType<TFloat32> DTYPE = DataType.create("FLOAT", 1, 4, TFloat32Impl::mapTensor);
    //        DataType DTYPE = DataType.valueOf("FLOAT");
    val truncatedNormal: TruncatedNormal[TFloat32] = TruncatedNormal.create(scope, shape, classOf[TFloat32])
    return truncatedNormal.asTensor
  }
  private def getWeightMatrix(shape: Operand[TInt32], scope: Scope) = { //        Tensor<TFloat32> tensorWeight = TensorValues.initializeTruncatedNormalTensor(shape, scope);
    val tensorWeight = TensorValues.initializeTruncatedNormalTensor(shape, scope)
    Constant.create(scope, tensorWeight)
  }

  def printTensor(tensor: Operand[TFloat32],name:String): Unit ={
    val data: Array[Float] = TensorResources.extractFloats(tensor.asTensor())
    println(s"data:${name},  ${data.mkString(" | ")}")
  }
  def main(args: Array[String]): Unit = {
    val libraryPath = System.getProperty("java.library.path")
    System.out.println(libraryPath)
    implicit val session = TestSession.createTestSession(TestSession.Mode.EAGER) // EagerSession.create()
    implicit val tf = session.getTF // Ops.create(session).withName("test")
    implicit val scope = tf.scope()
    //    val session = EagerSession.create
    //    val tf = Ops.create(session)
    //        Scope scope = new Scope(session);
    //    val scope = session.baseScope()
    val rawInputSequence = Array(Array(Array(0.1f, 0.2f)), Array(Array(0.3f, 0.4f))) //shape (timelen, batch_size, num_inputs).
    val inputSequence = tf.constant(rawInputSequence)
    val inputSize = 2
    val cellSize =  5
    val maximumTimeLength = 2
    val cellShape = Array(1, cellSize)
    val cellDims = Constant.vectorOf(scope, cellShape)
    val seqLenMax = tf.array(maximumTimeLength)
    //        Operand<TFloat32> initialCellState = Zeros.create(scope, cellDims, TFloat32.DTYPE);
    //        Operand<TFloat32> initialHiddenState = Zeros.create(scope, cellDims, TFloat32.DTYPE);
    val initialCellState = Zeros.create(scope, cellDims, classOf[TFloat32])
    val initialHiddenState = Zeros.create(scope, cellDims, classOf[TFloat32])
    val weightShape = Array(inputSize + cellSize, cellSize * 4)
    val weightMatrixDims = Constant.vectorOf(scope, weightShape)
    val weightMatrix = getWeightMatrix(weightMatrixDims, scope)
    //    session.print(weightMatrix)
    val weightGatesShape = Array(cellSize)
    val weightGatesDims = Constant.vectorOf(scope, weightGatesShape)
    val weightInputGate = getWeightMatrix(weightGatesDims, scope)
    printTensor(weightInputGate,"weightInputGate")
    val weightForgetGate = getWeightMatrix(weightGatesDims, scope)
    printTensor(weightForgetGate,"weightForgetGate")
    val weightOutputGate = getWeightMatrix(weightGatesDims, scope)
    printTensor(weightOutputGate ,"weightOutputGate")
    val biasShape = Array(cellSize * 4)
    val biasDim = Constant.vectorOf(scope, biasShape)
    //classOf[TFloat32]
    //        Operand<TFloat32> bias = Zeros.create(scope, biasDim, TFloat32.DTYPE);
    val bias = Zeros.create(scope, biasDim, classOf[TFloat32])
    val blockLSTM = BlockLSTM.create(scope, tf.dtypes.cast(seqLenMax, classOf[TInt64]), inputSequence, initialCellState, initialHiddenState, weightMatrix, weightInputGate, weightForgetGate, weightOutputGate, bias)

//    session.print(blockLSTM.i)
//    println("&&&cs")
//    session.print(blockLSTM.cs)
//    println("&&&f")
//    session.print(blockLSTM.f)
//    println("&&&o")
//    session.print(blockLSTM.o)
//    println("&&&ci")
//    session.print(blockLSTM.ci)
//    println("&&&co")
//    session.print(blockLSTM.co)
//    println("&&&h")
//    session.print(blockLSTM.h)

thanks for your help

mullerhai avatar May 25 '22 15:05 mullerhai

relate question for BlockLSTMGrad input parameter , it need cs_grad, h_grad, but I don't know from blockLSTM.cs(),blockLSTM.h() how to get cs_grad, h_grad ?

https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/raw_ops/BlockLSTMGrad https://runebook.dev/zh-CN/docs/tensorflow/raw_ops/blocklstmgrad

   //  seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, h,cs_grad, h_grad, use_peephole, name=None
  
    val blockLSTMGrad = BlockLSTMGrad.create(scope, tf.dtypes.cast(seqLenMax, classOf[TInt64]), inputSequence, initialCellState, initialHiddenState, weightMatrix,weightInputGate, weightForgetGate, weightOutputGate,bias, blockLSTM.i(),blockLSTM.cs(), blockLSTM.f(),blockLSTM.o(),blockLSTM.ci(), blockLSTM.co(),blockLSTM.h(), blockLSTM.cs_grad()?,blockLSTM.h_grad()? ,false)
    blockLSTMGrad.bGrad()
    blockLSTMGrad.wGrad()
    blockLSTMGrad.xGrad()
    blockLSTMGrad.wcfGrad()
    blockLSTMGrad.csPrevGrad()
    blockLSTMGrad.hPrevGrad()
    blockLSTMGrad.wciGrad()
    blockLSTMGrad.wcoGrad()

do we need like this Gradients grads = Gradients.create(tf.scope(), Arrays.asList(y0, y1), Arrays.asList(x)); ???

mullerhai avatar May 25 '22 16:05 mullerhai