universal_gemm/gen_instances.py: make parsing more robust
Proposed changes
Why
The current logic is very brittle, and can break down when we parse examples that have new order, or define some keys that were previously only used as default keys
What
- parse based on split and kwargs rather than moving a pointer across a list. This makes the code (arguably) a bit more readable, and more robust to changing parameter lenghts
- use default values explicitly rather than relying on the underlying template. This means every template is fully instantiated, and easier to debug
Test Plan
import logging
import torch
from torch._inductor import config as inductor_config
log = logging.getLogger(__name__)
def main():
log.info(f"M: {M}, N: {N}, K: {K}, dtype: {dtype}")
A = torch.randn(M, K, dtype=dtype).cuda()
B = torch.randn(K, N, dtype=dtype).cuda()
# sample the different backends independently
with inductor_config.patch(
{
"max_autotune_gemm_backends": f"ATEN,CK}",
"search_autotune_cache": False,
"fx_graph_cache": False,
"rocm.n_max_profiling_configs": None,
}
):
# compile the model
compiled_model = torch.compile(SimpleModel(), mode="max-autotune")
log.info("compiled model")
# run the compiled model
_ = compiled_model(A, B)
log.info("ran compiled model")
Checklist
Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.
- [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
- [x] I have added inline documentation which enables the maintainers with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant after this pull request
- [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
- [ ] I have run
clang-formaton all changed files - [ ] Any dependent changes have been merged
@tenpercent if you agree with this general direction, would you be able to make the corresponding changes in batched and conv?
it seems that #1796 ran into this same issue and fixes the value error, and added tests. Sorry for not catching that earlier and thanks for adding the tests, we'll have to port those as well! I do think the rewriting here is beneficial and cleaner to parsing, but open to feedback
Hi @coconutruben , thanks for looking into this! The parsing logic does look more readable this way, except for the regex part
it seems that https://github.com/ROCm/composable_kernel/pull/1796 ran into this same issue and fixes the value error
So do I understand correctly that there's no parsing error when you move the pin past #1796 merge? Meaning there's no urgency to merge
In general, the parsing code does look hairy, and a proper refactor would (1) extract the template arguments as python ints, strings and tuples -- this is template instance agnostic, can be shared between universal batched gemms, universal gemms, and grouped convolutions; and (2) try to instantiate a python dataclass from this list; this is custom for each of the dataclasses. Checking the kwargs could be also done at this step
I was also thinking of possibly using python libclang instead of manual code parsing but that would probably be overkill at this point
If not urgent, it could be a good introductory task
@tenpercent @coconutruben Please update the status for this PR. It has been open for 8 months - please decide whether we can close it or move forward with changes / review. Thank you.
I believe it's not a high priority to merge it