torchfold icon indicating copy to clipboard operation
torchfold copied to clipboard

Inefficient chaining of get(index) and cat

Open hglaude opened this issue 6 years ago • 0 comments

In,

 def _batch_args(self, arg_lists, values):

although, the following code reduces unnecessary chunking and cat for computations that involve sequential operations of the same kind,

            if isinstance(arg[0], Fold.Node):
                if arg[0].batch:
                    batched_arg = values[arg[0].step][arg[0].op].try_get_batched(arg)
                    if batched_arg is not None:
                        res.append(batched_arg)
                        continue

for other cases I have the feeling that there is still room for improvements.

                    for x in arg:
                        r.append(x.get(values))
                    res.append(torch.cat(r, 0))

Do you see any obstacle to convert the sequence of get()s and cat into a unique index_select? Would it improve the performances (by batching the get()s)? Especially for the backward pass?

hglaude avatar Mar 02 '18 21:03 hglaude