tpp-mlir icon indicating copy to clipboard operation
tpp-mlir copied to clipboard

[Transform][Python] Introduce transform.ffi.callback

Open rolfmorel opened this issue 8 months ago • 4 comments

This transform op allows for invoking an external handler routine which can, for example, implement an analysis or transform in whichever way - and programming language - is convenient.

This PR provides full facilities for registering callbacks from Python and has test cases to show how MLIR's Python bindings can be used to easily implement at least some transforms.

rolfmorel avatar Jun 25 '25 21:06 rolfmorel

Failed Tests (4):
  TPP_OPT :: BF16/matmul-vnni.mlir
  TPP_OPT :: Conversion/VectorToXsmm/vector-to-transpose.mlir
  TPP_OPT :: Passes/DefaultPipeline/default-tpp-passes.mlir
  TPP_OPT :: Python/transform_callback.py

Hmm - all fail with segfaults. These tests all succeed for me on the compile server. 😕

Will need to think about it. Might be to do with enabling some Python environment vars in the lit.cfg.py as otherwise tpp-sched (which is implicated in the segfaults) is not touched at all by this PR.

rolfmorel avatar Jun 25 '25 22:06 rolfmorel

Will need to think about it. Might be to do with enabling some Python environment vars in the lit.cfg.py as otherwise tpp-sched (which is implicated in the segfaults) is not touched at all by this PR.

Issue was reference counting of the callback object. This is now fixed.

IMO this PR is ready to go (e.g. to be used for nano-tile selection algorithm prototyping).

For reference purposes, another instance of passing callbacks around: https://reviews.llvm.org/D116568

rolfmorel avatar Jul 21 '25 08:07 rolfmorel

Looks neat, I'll have to take it for a spin 😎

A general question about the whole Python-MLIR interaction. I take that the callback signature is reified into IR. How/Where does the runtime find the actual callback definition to call?

adam-smnk avatar Jul 21 '25 13:07 adam-smnk

I take that the callback signature is reified into IR.

Yes - automatically converted from the handles' types on the invocation of a Python callback. There's a test case with an example: https://github.com/libxsmm/tpp-mlir/pull/1064/files#diff-7aa62724b21b998da9cf032da6e6b77bbdb664258a5dca850f9509a5459646f6R95

How/Where does the runtime find the actual callback definition to call?

A global func pointer gets set when a Python callback handler is provided. When the op executes it just reaches for the global and calls the handler.

The main thing that a proper solution should have is that this global gets moved to hang of the MlirContext. Though in general I think of how the FFI mechanism works as an implementation detail (e.g. the user could not use the Py bindings and use something else to set the global or just swap out the entire mechanism - or even just extremely late bind some C++ code into a transform op).

rolfmorel avatar Jul 23 '25 19:07 rolfmorel