optimum-graphcore
optimum-graphcore copied to clipboard
Support simpler syntax for specifying pipeline splits
What does this PR do?
This PR is a WIP
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
class IPUUNet2DConditionModel(UNet2DConditionModel, PipelineMixin):
pipeline_splits = [
("conv_in", 0),
("down_blocks[2].downsamplers[0]", 1),
("up_blocks[0].resnets[2]", 2),
("up_blocks[1].attentions[2]", 3),
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def parallelize(self):
super().parallelize()
print(self.pipeline_splits)
for layer, ipu_id in self.pipeline_splits:
self.add_block(layer, ipu_id)
return self
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.