how to access `stateDict`?
@sbrunk Thank you for this excellent library! I have been trying to re-implement this cart-pole deep q learning using storch and cats-effect. In that article, there is the following python code which initializes the two networks:
self.policy_net = self.build_network(layer_sizes)
self.target_net = self.build_network(layer_sizes)
self.target_net.load_state_dict(self.policy_net.state_dict())
Every so often, as the policy_net gets trained, the target_net needs to be updated.
However, I cannot seem to find a way to access stateDict in storch. Is there a more scala/storch-recommended way to take one network (the target_net in the example above) and load it up so that it initially is equivalent to a different network (the policy_net)?
I noticed that in storch, the loadStateDict method is available, but just cannot figure out what to feed into it. Any help is much appreciated. Thanks!
This does not seem right, but here is what I have come up with so far:
class NeuralNetwork extends nn.Module:
val flatten = nn.Flatten[Float64]()
// fixed architecture for now of fully connected tanh inner layers
val linearTanhStack = List.range(0,layerSizes.size - 1)
.foldLeft(List.empty[TensorModule[Float64]]){
case (layers, index) =>
val linear = nn.Linear[Float64](layerSizes(index), layerSizes(index+1))
val act = if index < layerSizes.size - 2 then nn.Tanh[Float64]() else nn.Identity[Float64]()
layers :+ linear :+ act
}.pipe(nn.Sequential[Float64](_*)).pipe(register(_))
def apply(x: Tensor[Float64]) =
val flattened = flatten(x)
val logits = linearTanhStack(flattened)
logits
val policyNet = NeuralNetwork()
val targetNet = NeuralNetwork()
def syncTargetNetToPolicyNet: IO[Unit] =
for
stateDict <- IO(policyNet.namedParameters(true).map((s,t) => (s,Tensor.fromNative[DType](t.native))))
yield targetNet.loadStateDict(stateDict)
Without doing the Tensor.fromNative[DType] step, I could not get it to compile due to Tensor[?] placeholder type in the signature for namedParameters whereas loadStateDict requires Map[String,Tensor[DType]
@VzxPLnHqr you're right these should be more consistent. Given that Tensor is invariant currently, Tensor[DType] is a bit awkward in general.
Looking at the implementation of loadStateDict, I think we should be able to change the argument to stateDict: Map[String, Tensor[?]].
https://github.com/sbrunk/storch/blob/2dfa3884b9f0f2d1e2566aad791f44535b48bb09/core/src/main/scala/torch/nn/modules/Module.scala#L54-L60
Could you perhaps try to use something like this loadStateDict function as a workaround?
def loadStateDict(m: Module, stateDict: Map[String, Tensor[?]]): Unit =
val tensorsToLoad = m.namedParameters() ++ m.namedBuffers()
for ((key, param) <- tensorsToLoad if stateDict.contains(key))
noGrad {
param.copy_(stateDict(key))
}
If it works for you, I think we could also change the signature of the actual loadStateDict method in Module.
@sbrunk Thank you for your reply. I added the workaround loadStateDict method you provided as an extension method to nn.Module (wrapped in IO since I am using cats-effect):
extension(m: nn.Module)
def loadStateDict(stateDict: Map[String, Tensor[?]]): IO[Unit] = IO {
val tensorsToLoad = m.namedParameters() ++ m.namedBuffers()
for ((key, param) <- tensorsToLoad if stateDict.contains(key))
noGrad {
param.copy_(stateDict(key))
}
}
I can then use the new method like so:
targetNet.loadStateDict(policyNet.namedParameters(true).toMap)
Unfortunately the above gives a bloop/compile error:
Cannot prove that (String, torch.Tensor[?]) <:< (String, V2).
where: V2 is a type variable with constraint <: torch.Tensor[torch.DType]
update: turns out the .toMap I had to append there was causing the compilation issue. Changing the signature for loadStateDict(stateDict: Map[String,Tensor[?]]: IO[Unit] to loadStateDict(stateDict: SeqMap[String,Tensor[?]]):IO[Unit] at least fixed that.
Have not had a chance yet to test the actual functionality though. Will do that next and let you know if it works as expected. If so, then yes, this seems like it would be a good usability improvement.