RxInfer.jl icon indicating copy to clipboard operation
RxInfer.jl copied to clipboard

implemented backend-aware node alias conversion for initilization - similar to @model macro

Open skoghoern opened this issue 1 month ago • 9 comments

This should be an attempt to solve github issue #507 I built it with the assumption that we only want to convert the distribution aliases for which we have a factor_alias/interface_aliases() defined (as mentioned by wouterwln). (so first part is only filtering to the right types we want to convert.) For the actual conversion i "copied" the strategy from the @model macro with make_node! to convert the aliases (reusing the factor_alias, and interface_aliases). There were a few other changes i needed to include:

  1. I needed to change the include order in Rxinfer.jl file to correctly precompile the ReacitveMPGraphPPLBackend struct before using it in the initialization_plugin.
  2. I also added and adapted existing tests. when running the test file though i ran into an error in "init_macro_interior". It was due to the usage of the unwanted Normal(0,1) definition

Error in testset "init_macro_interior" on worker 2344: Got exception outside of a @test Normal cannot be constructed without keyword arguments. Use Normal(mean = ..., var = ...) or Normal(mean = ..., precision = ...).

⇒ so i changed it to Normal(m=0, v=1). Then all tests passed in my setup.

  1. one problem remained. Gamma(shape=1, scale=1) converts to Gamma{Float64}(α=1.0, θ=1.0) with aliased_fform: Gamma. it actually correctly sets it in GraphPPL.factor_alias to ExponentialFamily.GammaShapeScale, but apparently internally Julia immediately resolves it to Distributions.Gamma. This happens since it was defined as alias in ExponentialFamilyPackage (as bvdimitri wrote). when running code it is actually not a problem, only in my new tests it caused problems (see test 5 in "convert_init_node_aliases".) for now i used the dirty version to adapt the tests, . ⇒ just saw issue #468 and probably it is connected. can try to look into it the next time.

In general i guess the long term goal might be to consider unifying the logic of make_node! and the transformation here into one shared helper stored in GraphPPL?

skoghoern avatar Oct 06 '25 18:10 skoghoern

Thanks for your PR and the impressive attempt at solving this issue! I think there might be an easier way to go about this. The way in which it is done in GraphPPL is the following: Instead of trying to find out what the factor alias would be using eval and injecting this into the code, we generate code that calls factor_alias for us (so no need to evaluate code during the macro transformation of code and no calls to eval!). I think this transformation differs in method from how you implemented it, but it would be more robust and more desirable (as exposing eval may lead to undesired behavior). The only clunky thing is that we would have to separate the functional form (Normal) from a NamedTuple with the arguments (something like (mean=1, var=1)). If you can split these in the macro, then passing them into a "guard function" (like we did here https://github.com/ReactiveBayes/GraphPPL.jl/blob/2f7cd70a58be2e47ea7d1932cc811ad602fb8f28/src/graph_engine.jl#L2071-L2091 with make_node!) allows you to have easy access to factor_alias without having to interpolate this code yourself! I think this is a cleaner approach to fix the issue, and would require less esoteric (yet very impressive) Julia magic. Let me know what you think!

wouterwln avatar Oct 07 '25 19:10 wouterwln

hi wouter, thanks for your good feedback. yeah i totally agree, its a rather ugly piece of code^^ trying to get it work quickly (fast and furious ;). my main problems were:

  1. to filter the expression tree for the functional form calls vs. other expressions (thats the part before the # split keyword (NamedTuple) vs positional (Tuple) user calls) -> maybe you have a better idea how to filter them?
  2. once identified, using @capture we get typeof(fform)=symbol, whereas when writing the make_node! call in convert_tilde_expression in @model_macro, we insert it with $fform into the quoted block. apparently julias compiler treats it then as if it represents a type variable and converts it (in make_node! function we get typeof(fform)=UnionAll). this makes the dispatch work for the the defined functions factor_alias(:ReactiveMPGraphPPLBackend, :Type{Normal}, :GraphPPL.StaticInterfaces{(:μ, :v)}), which doesnt work when fform is a Symbol.
  3. then once obtained the aliased_fform we actually need to transform it back into the an expression/symbol. since the macro should return an expression.

the separation of functional form vs. arguments worked actually easy using @capture fform_(args__) - or did you mean something else? do you then mean to just write the "correct" calls of the guard function as expression in the init macro, which then get executed during the infer() function, similar as with model macro?

my next steps would be to look again into model_macro.jl to understand how you did the filtering for the functional forms. and then build an expression that then calls the "guard function" with interface_aliases, factor_alias and default_parametrization during inference (as in make_node!). this would solve the problem of transforming back and forth between expressions/symbols and types.

let me know what you think. (i hope to get time to further look into it by tomorrow or friday)

skoghoern avatar Oct 08 '25 17:10 skoghoern

I think if you interpolate fform into an expression, you'll still get the actual type! The mechanism that does this in the @model macro is here:

https://github.com/ReactiveBayes/GraphPPL.jl/blob/2f7cd70a58be2e47ea7d1932cc811ad602fb8f28/src/model_macro.jl#L640-L657

You'd really only need the first half of this function, as the second half deals with where clauses. The idea is to convert the arguments of whatever you put in the arguments to a NamedTuple. Then you're free to manipulate the arguments, and it's surprisingly easy to then evaluate with the new transformed fform and arguments (https://github.com/ReactiveBayes/GraphPPL.jl/blob/2f7cd70a58be2e47ea7d1932cc811ad602fb8f28/src/graph_engine.jl#L1931). Does this point you in the right direction?

wouterwln avatar Oct 09 '25 07:10 wouterwln

yeah, thanks for the good hints. actually i realized that you provided already the main step for "filtering" the functional_forms (with if @capture(e, (fform_(var_) = init_obj_)) and my last attempt was quite some bulls**t^^. If you are fine with it, i would now simply extend the convert_init_variables by processing also the init_obj part, splitting it into functional form vs. (kw)args, and wrapping it into a new function that will be called when esc()-ing the quoted block. that function would differentiate between kwargs vs. args, and reuse factor_alias and default_parametrization (just like make_node!).

skoghoern avatar Oct 09 '25 08:10 skoghoern

hi wouter, i've been trying to proceed more like your previous code now.

  1. extract the init_obj, split it into functional_form vs. (kw)args.
  2. then build expression with inserting the variables into a new build convert_fform function, that basically follows make_node! for kwargs and uses default_parametrization for args.
  3. when the expression gets executed, the convert_fform transforms them automatically into their respective aliases/or throws an error if not possible. the only "hard-coded" pieces are the filtering of vague(),huge() etc. and previous have a look at it, and let me know what you think (e.g. if its still to complicated). i also adapted the tests accordingly and tried to run them all ⇒ which works fine for me. i wasnt sure how to correctly build the new tests, and just followed my gut feeling. if you have any further ideas||feedback, i am happy to know and learn:)

skoghoern avatar Oct 09 '25 16:10 skoghoern

I think this is a lot more like what I meant indeed! I will do a more thorough review tomorrow, but I think this is 90% along the way in the right direction! I like your solution where you first transform into RxInfer.convert_fform form and afterwards into actual init objects, it's a lot cleaner than what I did in the @model macro. I think most of my comments will be style-wise rather than functionality wise (I dont like isa checks in Julia as I'd rather dispatch on types, and I don't like try-catch constructions as they are computationally expensive) so I think you did an amazing job on the functionality and getting Julias metaprogramming functionality down.

wouterwln avatar Oct 09 '25 18:10 wouterwln

totally fine for me and thx for the extra explanations. i used the try-catch just to be able to indicate the user that the error is coming from the init macro. but maybe you have a better workaround :D

skoghoern avatar Oct 10 '25 15:10 skoghoern

@wouterwln what do you think? Shall we merge?

bvdmitri avatar Oct 20 '25 15:10 bvdmitri

Coming back to this PR, I see some of the tests are failing? There is a merge conflict I will solve but apart from that I think we still need to brush up the implementation to make sure the tests are passing.

wouterwln avatar Nov 04 '25 13:11 wouterwln

hi wouter, are you already working on it or shall i get back on it? if i see it correctly the most errors are just due to the use of either Gamma without args or the previous Dirichlet implementation with probvec. the remaining error comes from the Multi-Agent Example where it seems like it doesnt catch the repeat([PointMass(1)], nr_steps) definition correctly

skoghoern avatar Nov 05 '25 18:11 skoghoern

i changed the parsing to correctly handle repeat() and vector constructs. on my setup the tests are passing. not sure how to handle the the failed "gamma cases" here.

skoghoern avatar Nov 05 '25 21:11 skoghoern

sorry that was my bad. i hadent pulled the last version of the updated tests ⇒ there was an end argument lost during the merge-process ⇒ and correct handling of PointMass(int) wasnt fixed. i made sure everything is updated and hope it also works on the online version now

skoghoern avatar Nov 06 '25 22:11 skoghoern

Codecov Report

:x: Patch coverage is 90.74074% with 5 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 80.05%. Comparing base (cf35888) to head (dcbf222). :warning: Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
src/model/plugins/initialization_plugin.jl 90.74% 5 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #525      +/-   ##
==========================================
- Coverage   80.18%   80.05%   -0.14%     
==========================================
  Files          25       25              
  Lines        2019     2071      +52     
==========================================
+ Hits         1619     1658      +39     
- Misses        400      413      +13     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Nov 14 '25 16:11 codecov[bot]