tfdeploy
tfdeploy copied to clipboard
There is no support for tf.strided_slice
Currently, there's no support for tf.strided_slice() so if somewhere in your graph you have x[:,:,:,None] the graph won't execute. It's a complicated function, so it won't be easy to implement.
Although I don't have time to look into this right now ... Challenge accepted ;)
I had to implement a very simple strided_slice implementation
class StridedSlice(td.Operation):
@staticmethod
def func(input_, begins, ends, strides):
output = input_.copy()
for dim, (begin, end, stride) in enumerate(zip(begins,
ends,
strides)):
if dim >= len(input_.shape):
raise ValueError("Dimension mismatch")
end = end if end else None
output = output[begin:end:stride].T
return output,
It is sufficient for a "simple" strided slice, without new dimension add via array[None]
or array[np.newaxis]
It is certainly not the best implementation, but it could be useful?
@celliern: your implementation did not work for me, but this did for simple slices...
@Operation.factory()
def StridedSlice(input_, begins, ends,strides):
dim = len(input_.shape)
slices = [slice(b,e if e else None,s) for b,e,s in zip(begins,ends,strides)]
return np.copy(input_)[slices],