aihwkit
aihwkit copied to clipboard
Tool/tutorial for computing the symmetry points of a given device and setting up a TransferCompound
Description and motivation
TransferCompound
used for the Tiki-taka learning rule requires the symmetry point of each device to be set near zero. We have a ReferenceUnitCell
l to support arbitrary subtraction, however, it is not immediately obvious how to do that as it requires setting of hidden parameter settings and the symmetry points are also not computed numerically automatically.
Proposed solution
Make a tool/script that computes and sets symmetry point to the reference devices for arbitrary device model settings.
Hi, @maljoras
Following #288 , with @chaeunl, we analytically calculated symmetry point and tried to implement it with hidden parameters, yet numerical issue ( that the denominator & numerator is too small) still makes the symmetry point little off from zero. We formed the equation from the definition of symmetry point, that is symmetry point is the value of the weight where the slope of potentiation and that of depression are the same.
Here is proto-type version of code:
(var_up2 is defined as torch.normal(mean=zeros, std=ones*var2)
to manually implement variance using hidden parameters.)
params = layer.analog_tile.get_hidden_parameters()
dwminup, dwmindown = dw_min*(1+var_up2), dw_min*(1+var_down2)
dwmin_up, dwmin_down = dw_min*(1+var_up2), dw_min*(1+var_down2)
dwmin_up += (-params["slope_up"] / (params["slope_up"]- params["slope_down"]) * (dwminup- dwmindown)
dwmin_down += (-params["slope_down"] / (params["slope_up"]- params["slope_down"]) * (dwminup - dwmindown)
params["dwmin_up"], params["dwmin_down"] = torch.abs(dwmin_up), torch.abs(dwmin_down)
params["max_bound"], params["min_bound"] = 1000.*ones, -1000.*ones
layer.analog_tile.set_hidden_parameters(params)
We hope this helps.
Hi @nkyungmi,
many thanks for your input! This is great and very helpful in general. Note, however, that most device models will respect the symmetry-point with its parameter settings already. For instance, for SoftBoundsDevice
if up_down_dtod=0.0, the devices will be sampled in such a way that they all have a symmetry point at zero.
For this issue, I had thought more of a tool to "simulate" the symmetry point numerically in a way one would do it in reality (by giving the up/down pulse sequence). This will induce an additional error in the estimation of the symmetry point, which need to be taken into account as well. The estimated symmetry point would then be coded onto a ReferenceUnitCell
.