glow
glow copied to clipboard
[GraphOptz] Constant folding required before quantization
When lowering for example a 1x1 Conv to FullyConnected the weights will have Reshape and Transpose attached to them. If this FullyConnected is desired to be further quantized rowwise the quantization will not succeed because the weights are not seen as constant (having Reshape and Transpose attached to them). Therefore for the purpose of channelwise/rowwise quantization an additional constant folding must be run right before the quantization to make sure the weights are seen as constant.
@jfix71 Can I add an additional optimization right after the lowering to make sure all constants are folded?
You can try it out, but I think it will cause problems. We need the quantization profile names to match the Node names in the graph when we do quantization. Doing the constant folding will result in a new Constant with a different name. Perhaps one way this could be fixed is by writing a targeted optimization for folding Reshape/Transpose into Constants without changing the names of the Constants?
@jfix71 But if I add an extra optimization right before the transformForPrecisionMode()
such that both the quantization and the profiling path benefit from the same optimizations then it should be fine right?
@mciprian13 I think there would also be a problem involving lowering differences between profiling and quantization. I think an example makes this more clear.
Let's say there's a FullyConnected that has Constant input and weights but Placeholder bias (not really realistic but it's easier to reason about as compared to other nodes like e.g. LayerNorm which has many sub Nodes when lowered). Let's call this FullyConnected(Constant1, Constant2, Placeholder)
.
Assume we do ConstantFolding after lowering but before transformForPrecisionMode()
as you're suggesting. When we lower it fully for profiling, we would end up with BatchedAdd(MatMul(Constant1, Constant2), Placeholder)
. As you can see, MatMul would have all Constant inputs and be constant folded, so we'd end up with BatchedAdd(Constant3, Placeholder)
to profile. So then the profile would then have quantization info only for Constant3
and not Constant1
or Constant2
.
Then when we go to use that profile for quantization, the backend we're targeting may prevent lowering for the FullyConnected, and so we wouldn't constant fold. Then the quantizer would look for quantization info for Constant1
and Constant2
and it wouldn't find it.