faust
faust copied to clipboard
JAX Backend
Rebased version of https://github.com/grame-cncm/faust/pull/799
I'm still working on something.
I don't think there are bugs now, but it's missing a feature I'd like. The feature would be to avoid using static tables when possible. This would speed up JAX compilation time immensely and improve the gradient differentiation with JAX.
This code
import("stdfaust.lib");
process = os.osc(440.);
ends up creating this Python (some parts excluded here):
def init_constants(self, state, sample_rate):
state["fSampleRate"] = sample_rate
state["fConst0"] = (4.4e+02 / jnp.minimum(1.92e+05, jnp.maximum(1.0, (state["fSampleRate"]))))
def classInit(self, state):
# global declarations:
state["fRec1"] = jnp.zeros((2,), dtype=jnp.float32)
state["iVec0"] = jnp.zeros((2,), dtype=jnp.int32)
state["iRec0"] = jnp.zeros((2,), dtype=jnp.int32)
state["ftbl0MyDSPSIG0"] = jnp.zeros((65536,), dtype=jnp.float32)
# inline subcontainers:
for i1 in range(0, 65535):
state["iVec0"] = state["iVec0"].at[0].set(1)
state["iRec0"] = state["iRec0"].at[0].set(((state["iVec0"][1] + state["iRec0"][1]) % 65536))
state["ftbl0MyDSPSIG0"] = state["ftbl0MyDSPSIG0"].at[i1].set(jnp.sin((9.58738e-05 * (state["iRec0"][0]))))
state["iVec0"] = jnp.roll(state["iVec0"], 1)
state["iRec0"] = jnp.roll(state["iRec0"], 1)
@staticmethod
def tick(state: dict, inputs: jnp.array):
state["fRec1"] = state["fRec1"].at[0].set((state["fConst0"] + (state["fRec1"][1] - jnp.floor((state["fConst0"] + state["fRec1"][1])))))
_result0 = state["ftbl0MyDSPSIG0"][jnp.int32((65536.0 * state["fRec1"][0]))]
state["fRec1"] = jnp.roll(state["fRec1"], 1)
return state, jnp.stack([_result0])
It turns out that the for i1 in range(0, 65535)
takes over a minute to compile because those operations are slow for JAX. It would be better to just compute jnp.sin on the fly in the tick function.
ideal generated python code:
def classInit(self, state):
# global declarations:
state["fRec1"] = jnp.zeros((2,), dtype=jnp.float32)
@staticmethod
def tick(state: dict, inputs: jnp.array):
state["fRec1"] = state["fRec1"].at[0].set((state["fConst0"] + (state["fRec1"][1] - jnp.floor((state["fConst0"] + state["fRec1"][1])))))
_result0 = jnp.sin(jnp.pi*2*state["fRec1"][0])
state["fRec1"] = jnp.roll(state["fRec1"], 1)
return state, jnp.stack([_result0])
I'm looking at instructions_compiler.cpp: ValueInst* InstructionsCompiler::generateStaticTable(Tree sig, Tree tsize, Tree content)
. We might need to create a different normalized form that avoids creating rdtables so that generateStaticTable isn’t even called.
What I said above can be put on hold because one can just code an oscillator that doesn't use rdtable. I copied from the Box Tutorial:
import("stdfaust.lib");
process = osc(f1), osc(f2)
with {
decimalpart(x) = x-int(x);
phasor(f) = f/ma.SR : (+ : decimalpart) ~ _;
osc(f) = sin(2 * ma.PI * phasor(f));
f1 = vslider("Freq1", 300, 100, 2000, 0.01);
f2 = vslider("Freq2", 500, 100, 2000, 0.01);
};
And that compiles quickly. But if something else in faustlibraries calls os.osc, then I would want a way to use a non-rdtable version.
Yes, sure changing the way rdtable/rwtable
are generated is not trivial so better use another osc version for now.
I have some tests here: https://gist.github.com/DBraun/556024cb51ecc7416cc7c9509c7a5259
python -m pytest -v test_jax_backend.py
I would need to add a python continuous integration thing to the faust repo.
Should I add a Python installer and test to Ubuntu (https://github.com/grame-cncm/faust/blob/master-dev/.github/workflows/ubuntu.yml)? It should be easy to install JAX on Ubuntu.
The test is at https://github.com/DBraun/faust/tree/jax-backend-2/tests/jax-tests
What are the jax-tests actually testing ?
Every function that starts with “test” inside test_jax_backend.py would be run. Each of those functions has some Faust code that gets converted to JAX, built with the JAX minimal.py architecture, and run. So it tests that the JAX backend works and that the JAX code doesn’t throw errors. The audio gets saved to an output folder too.
Backends are tested in the tests/impulse-tests infrastructure (so not at each commit). How complex would it be to move the JAX backend tests here?
It would be easy to change the relative paths in test_jax_backend.py. So would you suggest moving the jax-tests to tests/interp-tests/jax-texts?
The current progress:
Two files don't work after being converted to JAX:
osc_enable.dsp
doesn't work because it results in code with if
statements inside its one-sample block.
if (iTemp1 != 0):
fTemp3 = (fTemp0 * fTemp2)
Any tips for generating code that doesn't do this? I can imagine getting similar results if the visitor went through visit(Select2Inst* inst)
The other file that doesn't compile is bs.dsp
. I could eventually get this working, but it's kind of low priority in my opinion.
If I set the filesCompare
precision to 2, ignoring bs.dsp
and osc_enable.dsp
, all the tests pass. There are some examples which don't pass low precision differences, such as carre_volterra.dsp
:
/Applications/Xcode.app/Contents/Developer/usr/bin/make -f Make.jax outdir=jax/double FAUSTOPTIONS="-I dsp -double"
python3 ir/jax/double/jax_carre_volterra.py > ir/jax/double/carre_volterra.ir
./filesCompare ir/jax/double/carre_volterra.ir reference/carre_volterra.ir 1e-4
Line : 2632 output : 0 sample1 : 0.090985 different from sample2 : 0.091085 delta : 0.0001
Line : 2633 output : 0 sample1 : 0.056059 different from sample2 : 0.056159 delta : 0.0001
Line : 2681 output : 0 sample1 : -0.125745 different from sample2 : -0.125846 delta : 0.000101
Line : 2682 output : 0 sample1 : -0.091029 different from sample2 : -0.091131 delta : 0.000102
Line : 2683 output : 0 sample1 : -0.056184 different from sample2 : -0.056285 delta : 0.000101
Line : 2684 output : 0 sample1 : -0.021322 different from sample2 : -0.021424 delta : 0.000102
Line : 2685 output : 0 sample1 : 0.013448 different from sample2 : 0.013347 delta : 0.000101
Line : 2686 output : 0 sample1 : 0.048029 different from sample2 : 0.047928 delta : 0.000101
Line : 2687 output : 0 sample1 : 0.082328 different from sample2 : 0.082228 delta : 0.0001
Line : 2731 output : 0 sample1 : 0.125626 different from sample2 : 0.125727 delta : 0.000101
Line : 2732 output : 0 sample1 : 0.090992 different from sample2 : 0.091094 delta : 0.000102
Line : 2733 output : 0 sample1 : 0.056222 different from sample2 : 0.056325 delta : 0.000103
Should I be testing -double
? What are your suggestions for debugging small numerical differences?
Ignoring osc_enable.dsp
and bs.dsp
should be OK. But yes using -double
precision is the way to go ! All tests actually use -double
and using default = float can possibly explain the precision issues.
Precision tests all pass! I only have to ignore bs.dsp and osc_enable.dsp in Make.jax
.
How can I add the JAX installation and tests to appveyor?
Starting to look at the code; It would be simpler for me to comment the code after 1) rebase on master-dev 2) and prepare a PR in --squash mode (that is a unique single big commit). Thanks.
Thanks for the rebase. I will do general comments on the PR here, and specific ones directly in the code. General comments looking at the generated code:
- I understand that all "init" like functions (
init/instanceInit/classInit
etc...) are grouped in a singleclassInit
. The usual semantic forclassInit
in the C++ generated code is to share init code that is somewhat common to all "instances". Here I don't think you really have this "static init" versus "instance init" issue anymore. So I would prefer havingclassInit
simply becomeinit
- all those
load_soundfile/add_soundfile/add_nentry/add_button/add_slider
are just generic code right ? Then why having the backend generate them? Wouldn't be better to have them part of an architecture file ? (which is the standard way todo...) - when compiling freeverb.dsp , I see:
state["IOTA0"] = 0
generated inbuild_interface
. This seems like an initialisation right? So better move that ininit
- still an indentation issue in the generated code (probably a non matched "tab" at the end of the file)
Looking at the generated code: call state = self.build_interface(state, x, T)
seems like somewhat strange and does not really follow "blue/red/green" model of architecture https://faustdoc.grame.fr/manual/architectures/#architecture-files
Why not removing build_interface
(and have initialize
init the DSP state but not more) and have:
def __call__(self, x, T: int) -> jnp.array:
state = self.initialize(self.sample_rate, x, T)
state = self.build_interface(self.sample_rate, x, T)
return jnp.transpose(jax.lax.scan(self.tick, state, jnp.transpose(x, axes=(1, 0)))[1], axes=(1,0))
Ot is is because self.build_interface
is considered as initialising the DSP state in this JAX backend ?
Could you possibly document a bit jax_code_container.cp
with the JAX backend particularities as in https://github.com/grame-cncm/faust/blob/master-dev/compiler/generator/julia/julia_code_container.cpp#L32 ?
It seems bool fMutateFun;
in JAXInstVisitor is not used anymore ?
In JAXInstVisitor, please use the opcode name instead if 7
in the following line:
if (inst->fOpcode > 7 && !fIsDoingWhile) {
The code is now more like
def __call__(self, x, T: int) -> jnp.array:
state = self.initialize(self.sample_rate, x, T)
state = self.build_interface(state, x, T)
state = jax.tree_map(jnp.array, state)
return jnp.transpose(jax.lax.scan(self.tick, state, jnp.transpose(x, axes=(1, 0)))[1], axes=(1,0))
The state
needed to be handed off to build_interface
.
I would like to get the add_soundfile
add_slider
etc functions into an architecture file, but I didn't figure out how to default to a specific architecture file if none is specified.
I removed bool fMutateFun;
, added some comments in jax_code_container.cpp
and improved if (inst->fOpcode > 7 && !fIsDoingWhile) {
Can you tell me where FAUSTFLOAT(foo)
is generated in C++? I had to define a dummy function
def FAUSTFLOAT(x):
return x
But it would be better if FAUSTFLOAT(foo)
was just foo
everywhere. I didn't figure it out with breakpoints much earlier.
"but I didn't figure out how to default to a specific architecture file if none is specified." => well even in C++, the compiler actually does not generate a self-contained class.... It always assumes dsp/UI/meta
etc... are defined in an architecture file. This is why I think we could just do the same same with JAX and assume (also..) that the appropriate architecture will be used.
Good point. I'll try that tomorrow.
Can you tell me where FAUSTFLOAT(foo) is generated in C++?
=> I"ll check that
FAUSTFLOAT
machinery makes sense in C/C++. With JAX I think you should do as with the LLVM backend where we actually consider that FAUSTFLOAT
is the same type as internal REAL type (that is either float
or double
depending if -single
(= default) or -double
is used as an option). See https://github.com/grame-cncm/faust/blob/master-dev/compiler/global.hh#L166 and https://github.com/grame-cncm/faust/blob/master-dev/compiler/libcode.cpp#L1396.
I moved most of the code into the architecture files now. I think setting gGlobal->gFAUSTFLOAT2Internal = true;
helped avoid FAUSTFLOAT being called.
Great work ! Merged in --squash mode after some final cleanup and reformatting in https://github.com/grame-cncm/faust/commit/44d66aac61b05cb172e101a2b4051e2aa0ea248f
Still 2 thinks to check:
- Is this
FAUSTFLOAT
related block still necessary https://github.com/grame-cncm/faust/commit/44d66aac61b05cb172e101a2b4051e2aa0ea248f#diff-5ac628298a99ca6f354d9f3f5a64424743f37be2451f2ab84e6645418f1c02e4R149-R165 ? - the
JAXCodeContainer::inlineSubcontainersFunCalls
is almost the same the the base classCodeContainer::inlineSubcontainersFunCalls
(with only the renaming block changing AFAICS) https://github.com/grame-cncm/faust/commit/44d66aac61b05cb172e101a2b4051e2aa0ea248f#diff-5ac628298a99ca6f354d9f3f5a64424743f37be2451f2ab84e6645418f1c02e4R296. Is renaming an issue here? If not then I guess the base classCodeContainer::inlineSubcontainersFunCalls
should be used.
Now everything is OK, thanks !