jmp
jmp copied to clipboard
Update loss_scale.py
DynamicLossScale.min_loss_scale isn't passed in the DynamicLossScale.tree_flatten outputs which triggered the default factory when calling DynamicLossScale.tree_unflatten. Potentially changing the dtype while doing so (and of course, not keeping the original value).