brian2
brian2 copied to clipboard
brian2 units lost in numpy functions
It seems that brian2 units can get lost in numpy and list operations:
import numpy as np
from brian2 import *
a = np.random.rand(5)*ms
b = np.random.rand(5)*ms
np.concatenate((a,b))
array([0.000908919048280581, 0.000242718536606642, 0.00076358258260824, 0.000977874469348032, 0.000778319813877588, 5.21045744641396e-05, 0.000499042622396794, 0.000313784493828753, 0.000355711379832488, 0.000883566239304045])
mean([a,b]) # here, the mean() function should be the brian2 version of np.mean()
0.00057756237605473032
mean(a)
0.734282890144 ms
What would be the correct way of using numpy functions with brian2 units?
Yes, unfortunately numpy does not always handle subclasses of ndarray
(such as our Quantity
class) correctly. We have some documentation on what functions are safe to use and which are not in the documentation (although a bit hidden in the developer's docs...): http://brian2.readthedocs.io/en/stable/developer/units.html
Some functions work automatically for subclasses because they call a method on the object and we have implemented the methods on Quantity
correctly. E.g. np.mean(a)
will just call a.mean()
and therefore work correctly. Unfortunately that breaks if you do np.mean([a, b])
, because the argument now is a list.
We have not implemented any wrappers for functions like concatenate
, hstack
, vstack
, etc., but it could indeed be helpful to have those! Pull requests are welcome (they would go into brian2/units/unitsafefunctions.py
) :smile:
Having said all that, the standard workaround for these kind of problems is to remove the units, do your operation, and re-attach the units:
>>> c = Quantity(np.concatenate([np.asarray(a), np.asarray(b)]), dim=a.dim)
>>> mean(c)
0.44523276 * msecond
(Of course this does not check that a
and b
actually have the same dimensions...)
The np.asarray(...)
is not strictly necessary above, but it makes the code safer for the future when we might provide a unit-safe function for concatenate
.
PS: Just to convince you that this is really fundamentally a numpy issue: you can see the same problem with numpy's builtin subclasses of ndarray
, e.g. MaskedArray
:
>>> m_ar = np.ma.MaskedArray([1, 100, 3], mask=[False, True, False])
>>> m_ar
masked_array(data = [1 -- 3],
mask = [False True False],
fill_value = 999999)
>>> mean(m_ar) # works fine, masked value is ignored
2.0
>>> mean([m_ar, m_ar]) # wrong, does not use the mask information
34.666666666666664