DESC
DESC copied to clipboard
Excessive recompilation in `ProximalProjection`
Running with JAX_LOG_COMPILES="1" I see stuff like this between every iteration of the outer loop:
Finished tracing + transforming jac_scaled for pjit in 2.1464920043945312 sec
Compiling jac_scaled with global shapes and types []. Argument mapping: [].
Finished jaxpr to MLIR module conversion jit(jac_scaled) in 0.13467741012573242 sec
Finished XLA compilation of jit(jac_scaled) in 0.7804932594299316 sec
Finished tracing + transforming compute_scaled_error for pjit in 0.4098942279815674 sec
Compiling compute_scaled_error with global shapes and types [ShapedArray(float64[298])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(compute_scaled_error) in 0.09883785247802734 sec
Finished XLA compilation of jit(compute_scaled_error) in 0.16284894943237305 sec
Finished tracing + transforming project for pjit in 0.0073757171630859375 sec
Compiling project with global shapes and types [ShapedArray(float64[298])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(project) in 0.014415740966796875 sec
Finished XLA compilation of jit(project) in 0.07303285598754883 sec
Finished tracing + transforming recover for pjit in 0.01389932632446289 sec
Compiling recover with global shapes and types [ShapedArray(float64[120])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(recover) in 0.0190582275390625 sec
Finished XLA compilation of jit(recover) in 0.1190042495727539 sec
Finished tracing + transforming jac_scaled for pjit in 1.3376684188842773 sec
Compiling jac_scaled with global shapes and types []. Argument mapping: [].
Finished jaxpr to MLIR module conversion jit(jac_scaled) in 0.13845014572143555 sec
Finished XLA compilation of jit(jac_scaled) in 0.7922420501708984 sec
Finished tracing + transforming compute_scaled_error for pjit in 0.38459253311157227 sec
Compiling compute_scaled_error with global shapes and types [ShapedArray(float64[298])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(compute_scaled_error) in 0.10509967803955078 sec
Finished XLA compilation of jit(compute_scaled_error) in 0.15252304077148438 sec
Finished tracing + transforming recover for pjit in 0.007174253463745117 sec
Compiling recover with global shapes and types [ShapedArray(float64[120])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(recover) in 0.01448822021484375 sec
Finished XLA compilation of jit(recover) in 0.06177163124084473 sec
Finished tracing + transforming jac_scaled for pjit in 1.1975443363189697 sec
Compiling jac_scaled with global shapes and types []. Argument mapping: [].
Finished jaxpr to MLIR module conversion jit(jac_scaled) in 0.14225172996520996 sec
Finished XLA compilation of jit(jac_scaled) in 0.7997219562530518 sec
Finished tracing + transforming compute_scaled_error for pjit in 0.4294285774230957 sec
Compiling compute_scaled_error with global shapes and types [ShapedArray(float64[298])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(compute_scaled_error) in 0.13470077514648438 sec
Finished XLA compilation of jit(compute_scaled_error) in 0.20583796501159668 sec
Finished tracing + transforming project for pjit in 0.007157325744628906 sec
Compiling project with global shapes and types [ShapedArray(float64[298])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(project) in 0.012563467025756836 sec
Finished XLA compilation of jit(project) in 0.06066417694091797 sec
Finished tracing + transforming recover for pjit in 0.012331247329711914 sec
Compiling recover with global shapes and types [ShapedArray(float64[120])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(recover) in 0.01821756362915039 sec
Finished XLA compilation of jit(recover) in 0.05025196075439453 sec
We expect things to compile on the first iteration and then be reused but here it seems that a bunch of stuff is getting recompiled on each step. Possibly related to #957 but not sure.
On #957 would be interested to see how the JAX log output changes once you change the jitting
After the fixes in #1043 this doesn't compile anything after the first iteration.