torchfold
torchfold copied to clipboard
Inefficient chaining of get(index) and cat
In,
def _batch_args(self, arg_lists, values):
although, the following code reduces unnecessary chunking and cat for computations that involve sequential operations of the same kind,
if isinstance(arg[0], Fold.Node):
if arg[0].batch:
batched_arg = values[arg[0].step][arg[0].op].try_get_batched(arg)
if batched_arg is not None:
res.append(batched_arg)
continue
for other cases I have the feeling that there is still room for improvements.
for x in arg:
r.append(x.get(values))
res.append(torch.cat(r, 0))
Do you see any obstacle to convert the sequence of get()s and cat into a unique index_select? Would it improve the performances (by batching the get()s)? Especially for the backward pass?