Generating-Devanagari-Using-DRAW
Generating-Devanagari-Using-DRAW copied to clipboard
Filter Locations
This is a great implementation! One small note:
On line 151:
grid_i = torch.arange(start=0.0, end=N, device=self.device, requires_grad=True,).view(1, -1)
I believe you should set the arange
to be from 1 -> N
, rather than from 0 -> (N-1)
:
grid_i = torch.arange(start=1.0, end=N+1, device=self.device, requires_grad=True,).view(1, -1)
Doing this will ensure that the filter locations are around the center, gx
and gy
. Here's a figure from the paper for reference:
Currently the grid is shifted more towards the top-left by 1 filter.
Below are some diagrams showing the differences between the current grid and the correct grid:
The red dot is the center gx
and gy
. The green rectangle is the current grid layout/filter locations, and the red rectangle is the correct grid layout/filter locations, using the recommended code change mentioned above.
Example 1 (N=5):
Notice how the filter grid contains gx
and gy
, but isn't centered around it!
Example 2 (N=2):
Notice how the filter locations don't even contain gx
and gy
! One of DRAW's optimal configurations for MNIST was using a read_N
of 2
Details about the above figure:
Image size: 25 x 25 gx = 10 gy = 12 N = 5 (for the first example) N = 2 (for the second example) (delta) stride = 3
To calculate the NxN filter locations: We can use Equations 19 and 20 from the paper.
In your code it's:
# Equation 19.
mu_x = gx + (grid_i - N / 2 - 0.5) * delta
# Equation 20.
mu_y = gy + (grid_i - N / 2 - 0.5) * delta
To draw the rectangle, we need the top-left and bottom-right filter locations. We can simple set the starting and ending grid_i values as so:
With the current arange
, we would set grid_i = 0
and grid_i = N-1
With the proposed arange
, we would set grid_i = 1
and grid_i = N
Here's the pull request: #2