java
java copied to clipboard
How to update the weight about BlockLSTM ?
trafficstars
HI :
now I use BlockLSTM for build lstm layer ,but I don't know how to update the lstm weight parameters , if need use blockLSTMGrad or something to do ,the coda is paste here:
object LstmExample {
def initializeTruncatedNormalTensor(shape: Operand[TInt32], scope: Scope): TFloat32 = {
TruncatedNormal.seed(1000L)
// TruncatedNormal<TFloat32> truncatedNormal = TruncatedNormal.create(scope, shape, TFloat32.DTYPE);
// DataType<TFloat32> DTYPE = DataType.create("FLOAT", 1, 4, TFloat32Impl::mapTensor);
// DataType DTYPE = DataType.valueOf("FLOAT");
val truncatedNormal: TruncatedNormal[TFloat32] = TruncatedNormal.create(scope, shape, classOf[TFloat32])
return truncatedNormal.asTensor
}
private def getWeightMatrix(shape: Operand[TInt32], scope: Scope) = { // Tensor<TFloat32> tensorWeight = TensorValues.initializeTruncatedNormalTensor(shape, scope);
val tensorWeight = TensorValues.initializeTruncatedNormalTensor(shape, scope)
Constant.create(scope, tensorWeight)
}
def printTensor(tensor: Operand[TFloat32],name:String): Unit ={
val data: Array[Float] = TensorResources.extractFloats(tensor.asTensor())
println(s"data:${name}, ${data.mkString(" | ")}")
}
def main(args: Array[String]): Unit = {
val libraryPath = System.getProperty("java.library.path")
System.out.println(libraryPath)
implicit val session = TestSession.createTestSession(TestSession.Mode.EAGER) // EagerSession.create()
implicit val tf = session.getTF // Ops.create(session).withName("test")
implicit val scope = tf.scope()
// val session = EagerSession.create
// val tf = Ops.create(session)
// Scope scope = new Scope(session);
// val scope = session.baseScope()
val rawInputSequence = Array(Array(Array(0.1f, 0.2f)), Array(Array(0.3f, 0.4f))) //shape (timelen, batch_size, num_inputs).
val inputSequence = tf.constant(rawInputSequence)
val inputSize = 2
val cellSize = 5
val maximumTimeLength = 2
val cellShape = Array(1, cellSize)
val cellDims = Constant.vectorOf(scope, cellShape)
val seqLenMax = tf.array(maximumTimeLength)
// Operand<TFloat32> initialCellState = Zeros.create(scope, cellDims, TFloat32.DTYPE);
// Operand<TFloat32> initialHiddenState = Zeros.create(scope, cellDims, TFloat32.DTYPE);
val initialCellState = Zeros.create(scope, cellDims, classOf[TFloat32])
val initialHiddenState = Zeros.create(scope, cellDims, classOf[TFloat32])
val weightShape = Array(inputSize + cellSize, cellSize * 4)
val weightMatrixDims = Constant.vectorOf(scope, weightShape)
val weightMatrix = getWeightMatrix(weightMatrixDims, scope)
// session.print(weightMatrix)
val weightGatesShape = Array(cellSize)
val weightGatesDims = Constant.vectorOf(scope, weightGatesShape)
val weightInputGate = getWeightMatrix(weightGatesDims, scope)
printTensor(weightInputGate,"weightInputGate")
val weightForgetGate = getWeightMatrix(weightGatesDims, scope)
printTensor(weightForgetGate,"weightForgetGate")
val weightOutputGate = getWeightMatrix(weightGatesDims, scope)
printTensor(weightOutputGate ,"weightOutputGate")
val biasShape = Array(cellSize * 4)
val biasDim = Constant.vectorOf(scope, biasShape)
//classOf[TFloat32]
// Operand<TFloat32> bias = Zeros.create(scope, biasDim, TFloat32.DTYPE);
val bias = Zeros.create(scope, biasDim, classOf[TFloat32])
val blockLSTM = BlockLSTM.create(scope, tf.dtypes.cast(seqLenMax, classOf[TInt64]), inputSequence, initialCellState, initialHiddenState, weightMatrix, weightInputGate, weightForgetGate, weightOutputGate, bias)
// session.print(blockLSTM.i)
// println("&&&cs")
// session.print(blockLSTM.cs)
// println("&&&f")
// session.print(blockLSTM.f)
// println("&&&o")
// session.print(blockLSTM.o)
// println("&&&ci")
// session.print(blockLSTM.ci)
// println("&&&co")
// session.print(blockLSTM.co)
// println("&&&h")
// session.print(blockLSTM.h)
thanks for your help
relate question for BlockLSTMGrad input parameter , it need cs_grad, h_grad, but I don't know from blockLSTM.cs(),blockLSTM.h() how to get cs_grad, h_grad ?
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/raw_ops/BlockLSTMGrad https://runebook.dev/zh-CN/docs/tensorflow/raw_ops/blocklstmgrad
// seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, h,cs_grad, h_grad, use_peephole, name=None
val blockLSTMGrad = BlockLSTMGrad.create(scope, tf.dtypes.cast(seqLenMax, classOf[TInt64]), inputSequence, initialCellState, initialHiddenState, weightMatrix,weightInputGate, weightForgetGate, weightOutputGate,bias, blockLSTM.i(),blockLSTM.cs(), blockLSTM.f(),blockLSTM.o(),blockLSTM.ci(), blockLSTM.co(),blockLSTM.h(), blockLSTM.cs_grad()?,blockLSTM.h_grad()? ,false)
blockLSTMGrad.bGrad()
blockLSTMGrad.wGrad()
blockLSTMGrad.xGrad()
blockLSTMGrad.wcfGrad()
blockLSTMGrad.csPrevGrad()
blockLSTMGrad.hPrevGrad()
blockLSTMGrad.wciGrad()
blockLSTMGrad.wcoGrad()
do we need like this Gradients grads = Gradients.create(tf.scope(), Arrays.asList(y0, y1), Arrays.asList(x)); ???