la
la copied to clipboard
Add rollaxis method to larry?
I have come across a situation where I would like to roll an axis on a larry, much like how one could do so with a regular numpy array. Could that be possible?
There is currently a swapaxes. But a rollaxis sounds like a good addition.
la includes the numpy license so we can copy and paste np.rollaxis and then add label support to it.
I'll take a look when I get a chance.
Oh, wait, sorry. We'd of course just code the label support in larry.rollaxis and then just call np.rollaxis on the data part of the larry. We could copy the docstring however.
Here's a quick hack that borrows heavily from np.rollaxis. Does it behave the way you want?
import numpy as np
def rollaxis(lar, axis, start=0):
lar.x = np.rollaxis(lar.x, axis, start)
n = lar.ndim
if axis < 0:
axis += n
if start < 0:
start += n
msg = 'rollaxis: %s (%d) must be >=0 and < %d'
if not (0 <= axis < n):
raise ValueError, msg % ('axis', axis, n)
if not (0 <= start < n+1):
raise ValueError, msg % ('start', start, n+1)
if (axis < start): # it's been removed
start -= 1
if axis==start:
return lar
axes = range(0,n)
axes.remove(axis)
axes.insert(start, axis)
label = [lar.label[i] for i in axes]
lar.label = label
return lar
Demo:
>> lar = la.rand(3,4,5,6)
>> lar.shape
(3, 4, 5, 6)
>> np.rollaxis(lar.x, 3, 1).shape
(3, 6, 4, 5)
>> rollaxis(lar, 3, 1).shape
(3, 6, 4, 5)
>> np.rollaxis(lar.x, 2).shape
(4, 3, 6, 5)
>> rollaxis(lar, 2).shape
(4, 3, 6, 5)
>> np.rollaxis(lar.x, 1, 4).shape
(4, 6, 5, 3)
>> rollaxis(lar, 1, 4).shape
(4, 6, 5, 3)
BTW, np.rollaxis (np 1.6.0) does not return what the docstring examples says it returns.
I guess my hack doesn't behave the same way as np.rollaxis:
>> a = np.ones((3,4,5,6))
>> b = np.rollaxis(a, 1, 3)
>> a.shape
(3, 4, 5, 6)
>> b.shape
(3, 5, 4, 6)
>>
>> lar = la.larry(a)
>> lar.shape
(3, 4, 5, 6)
>> lar2 = rollaxis(lar, 1, 3)
>> lar.shape
(3, 5, 4, 6)
>> lar2.shape
(3, 5, 4, 6)
Sorry for the rapid-fire comments and corresponding bugs and mistakes. I'm trying to crank something out quickly.
Second attempt:
import numpy as np
import la
def rollaxis(lar, axis, start=0):
x = np.rollaxis(lar.x, axis, start)
n = lar.ndim
if axis < 0:
axis += n
if start < 0:
start += n
msg = 'rollaxis: %s (%d) must be >=0 and < %d'
if not (0 <= axis < n):
raise ValueError, msg % ('axis', axis, n)
if not (0 <= start < n+1):
raise ValueError, msg % ('start', start, n+1)
if (axis < start): # it's been removed
start -= 1
if axis==start:
return lar
axes = range(0,n)
axes.remove(axis)
axes.insert(start, axis)
label = [lar.label[i] for i in axes]
return la.larry(x, label, integrity=False)