pytensor
pytensor copied to clipboard
Implement pack/unpack helpers
Description
Adds pt.pack and pt.unpack helpers, roughly conforming to the einops functions of the same name.
These helps are for situations where we have a ragged list of inputs that need to be raveled into a single flat list for some intermediate step. This occurs in places like optimization.
Example usage:
x = pt.tensor("x", shape=shapes[0])
y = pt.tensor("y", shape=shapes[1])
z = pt.tensor("z", shape=shapes[2])
flat_params, packed_shapes = pt.pack(x, y, z)
Unpack simply undoes the computation, although there's norewrite to ensure pt.unpack(*pt.pack(*inputs)) is the identity function:
x, y, z = pt.unpack(flat_params, packed_shapes)
The use-case I forsee is creating replacement for a function of the inputs we're packing, for example:
loss = (x + y.sum() + z.sum()) ** 2
flat_packed, packed_shapes = pack(x, y, z)
new_input = flat_packed.type()
new_outputs = unpack(new_input, packed_shapes)
loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
fn = pytensor.function([new_input], loss)
Note that the final compiled function depends only on new_input, only because the shapes of the 3 packed variables were statically known. This leads to my design choices section:
- I decided to work with the static shapes directly if they are available. This means that
packwill eagerly return a list of integer shapes aspacked_shapesif possible. If not possible, they will be symbolic shapes. This is maybe an anti-pattern -- we might prefer a rewrite to handle this later, but it seemed easy enough to do eagerly. - I didn't add support for batch dims. This is left to the user to do himself using
pt.vectorize. - The
einopsAPI has arguments to support packing/unpacking on arbitrary subsets of dimensions. I didn't do this, because I couldn't think of a use-case that a user couldn't get himself usingDimShuffleandvectorize.
Related Issue
- [ ] Closes #1553
- [ ] Related to #1550
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [x] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pytensor--1578.org.readthedocs.build/en/1578/