ao
ao copied to clipboard
[sparsity] Add PartialLinear module for structured sparsity
This PR adds the PartialLinear module to torchao.prototype.sparsity as suggested by @jcaip in pytorch/pytorch#149853.
Implementation Details
- PartialLinear is a linear layer that implements structured sparsity by connecting each output neuron to only the top-k input features by weight magnitude
- Supports dynamic connectivity updates during training with configurable update frequency
- Provides both masked-dense computation and fully sparse computation paths
The implementation is placed in a PartialLinear folder as suggested. I'm open to any feedback or modifications needed to better integrate with the torchao ecosystem and sparsity framework.
This addresses the TODO in pytorch/pytorch's linear.py module that mentions "PartialLinear - maybe in sparse?".
Related issue: pytorch/pytorch#135091
cc @jcaip
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1982
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:heavy_exclamation_mark: 1 Active SEVs
There are 1 currently active SEVs. If your PR is affected, please view them below:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Hi @lakshminarasimmanv!
Thank you for your pull request and welcome to our community.
Action Required
In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.
Process
In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.
Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.
If you have received this in error or have any questions, please contact us at [email protected]. Thanks!
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!
@pytorchbot label "topic: not user facing" "sparsity"
Thank you @jcaip for the feedback and guidance!
I'll take a look at the requested changes and address them.
Regarding standardizing under sparsify_: I agree this would be valuable for long-term integration. Would you like me to attempt integrating with the existing sparsity framework in this PR, or would you prefer to keep this as a standalone implementation for now and handle the standardization in a follow-up effort?
I'll also work on adding memory benchmarks comparing to nn.Linear.
I appreciate your guidance and direction on how best to align this with the torchao ecosystem.
Looking forward to your further thoughts!
cc @lakshminarasimmanv just curious here - are you working on this for fun or do you have a use-case in mind that needs this? If you have a specific use case I would be really interested in learning more.
Regarding standardizing under sparsify_: I agree this would be valuable for long-term integration. Would you like me to attempt integrating with the existing sparsity framework in this PR, or would you prefer to keep this as a standalone implementation for now and handle the standardization in a follow-up effort?
Let's keep this self contained for now, adding to sparsify_ can be done is a subsequent PR. If you're interested in doing that I'm happy to review and accept :)
Feel free to ping me when you're ready for another review or have any questions!
Thanks for asking about my motivation, @jcaip!
I've been working with structured sparsity approaches to optimize transformer models, where I noticed that attention heads often focus strongly on a subset of input features. This seemed like a natural use case for a partial connectivity approach that could reduce computation while preserving model quality.
When exploring the PyTorch codebase to understand implementation patterns, I came across the "TODO: PartialLinear" comment and the related issue, which aligned perfectly with what I was exploring.
Also BTW, I've updated the code to address your requested changes. I've also been running benchmarks comparing Linear and PartialLinear across different model sizes and will share those results in the next day or two.
I'm definitely interested in contributing actively to PyTorch and would be happy to maintain this implementation and help with the sparsity framework going forward.
Looking forward to your thoughts on the changes!