faust icon indicating copy to clipboard operation
faust copied to clipboard

JAX Backend

Open DBraun opened this issue 2 years ago • 10 comments

Rebased version of https://github.com/grame-cncm/faust/pull/799

DBraun avatar Sep 14 '22 15:09 DBraun

I'm still working on something.

DBraun avatar Sep 14 '22 20:09 DBraun

I don't think there are bugs now, but it's missing a feature I'd like. The feature would be to avoid using static tables when possible. This would speed up JAX compilation time immensely and improve the gradient differentiation with JAX.

This code

import("stdfaust.lib");
process = os.osc(440.);

ends up creating this Python (some parts excluded here):

def init_constants(self, state, sample_rate):
	state["fSampleRate"] = sample_rate 
	state["fConst0"] = (4.4e+02 / jnp.minimum(1.92e+05, jnp.maximum(1.0, (state["fSampleRate"])))) 

def classInit(self, state):
	# global declarations:
	state["fRec1"] = jnp.zeros((2,), dtype=jnp.float32)
	state["iVec0"] = jnp.zeros((2,), dtype=jnp.int32)
	state["iRec0"] = jnp.zeros((2,), dtype=jnp.int32)
	state["ftbl0MyDSPSIG0"] = jnp.zeros((65536,), dtype=jnp.float32)
	# inline subcontainers:	
	for i1 in range(0, 65535):
		state["iVec0"] = state["iVec0"].at[0].set(1) 
		state["iRec0"] = state["iRec0"].at[0].set(((state["iVec0"][1] + state["iRec0"][1]) % 65536)) 
		state["ftbl0MyDSPSIG0"]  = state["ftbl0MyDSPSIG0"].at[i1].set(jnp.sin((9.58738e-05 * (state["iRec0"][0])))) 
		state["iVec0"] = jnp.roll(state["iVec0"], 1) 
		state["iRec0"] = jnp.roll(state["iRec0"], 1) 

@staticmethod
def tick(state: dict, inputs: jnp.array):
		
	state["fRec1"] = state["fRec1"].at[0].set((state["fConst0"] + (state["fRec1"][1] - jnp.floor((state["fConst0"] + state["fRec1"][1]))))) 
	_result0 = state["ftbl0MyDSPSIG0"][jnp.int32((65536.0 * state["fRec1"][0]))] 
	state["fRec1"] = jnp.roll(state["fRec1"], 1) 
	return state, jnp.stack([_result0]) 

It turns out that the for i1 in range(0, 65535) takes over a minute to compile because those operations are slow for JAX. It would be better to just compute jnp.sin on the fly in the tick function.

ideal generated python code:

def classInit(self, state):
	# global declarations:
	state["fRec1"] = jnp.zeros((2,), dtype=jnp.float32)

@staticmethod
def tick(state: dict, inputs: jnp.array):
		
	state["fRec1"] = state["fRec1"].at[0].set((state["fConst0"] + (state["fRec1"][1] - jnp.floor((state["fConst0"] + state["fRec1"][1]))))) 
	_result0 = jnp.sin(jnp.pi*2*state["fRec1"][0]) 
	state["fRec1"] = jnp.roll(state["fRec1"], 1) 
	return state, jnp.stack([_result0]) 

I'm looking at instructions_compiler.cpp: ValueInst* InstructionsCompiler::generateStaticTable(Tree sig, Tree tsize, Tree content). We might need to create a different normalized form that avoids creating rdtables so that generateStaticTable isn’t even called.

DBraun avatar Sep 15 '22 05:09 DBraun

What I said above can be put on hold because one can just code an oscillator that doesn't use rdtable. I copied from the Box Tutorial:

import("stdfaust.lib");
process = osc(f1), osc(f2)
with {
    decimalpart(x) = x-int(x);
    phasor(f) = f/ma.SR : (+ : decimalpart) ~ _;
    osc(f) = sin(2 * ma.PI * phasor(f));
    f1 = vslider("Freq1", 300, 100, 2000, 0.01);
    f2 = vslider("Freq2", 500, 100, 2000, 0.01);
};

And that compiles quickly. But if something else in faustlibraries calls os.osc, then I would want a way to use a non-rdtable version.

DBraun avatar Sep 15 '22 15:09 DBraun

Yes, sure changing the way rdtable/rwtable are generated is not trivial so better use another osc version for now.

sletz avatar Sep 15 '22 15:09 sletz

I have some tests here: https://gist.github.com/DBraun/556024cb51ecc7416cc7c9509c7a5259

python -m pytest -v test_jax_backend.py

I would need to add a python continuous integration thing to the faust repo.

DBraun avatar Sep 15 '22 17:09 DBraun

Should I add a Python installer and test to Ubuntu (https://github.com/grame-cncm/faust/blob/master-dev/.github/workflows/ubuntu.yml)? It should be easy to install JAX on Ubuntu.

The test is at https://github.com/DBraun/faust/tree/jax-backend-2/tests/jax-tests

DBraun avatar Sep 16 '22 06:09 DBraun

What are the jax-tests actually testing ?

sletz avatar Sep 17 '22 08:09 sletz

Every function that starts with “test” inside test_jax_backend.py would be run. Each of those functions has some Faust code that gets converted to JAX, built with the JAX minimal.py architecture, and run. So it tests that the JAX backend works and that the JAX code doesn’t throw errors. The audio gets saved to an output folder too.

DBraun avatar Sep 17 '22 11:09 DBraun

Backends are tested in the tests/impulse-tests infrastructure (so not at each commit). How complex would it be to move the JAX backend tests here?

sletz avatar Sep 19 '22 07:09 sletz

It would be easy to change the relative paths in test_jax_backend.py. So would you suggest moving the jax-tests to tests/interp-tests/jax-texts?

DBraun avatar Sep 19 '22 17:09 DBraun

The current progress:

Two files don't work after being converted to JAX:

osc_enable.dsp doesn't work because it results in code with if statements inside its one-sample block.

if (iTemp1 != 0):
    fTemp3 = (fTemp0 * fTemp2) 

Any tips for generating code that doesn't do this? I can imagine getting similar results if the visitor went through visit(Select2Inst* inst)

The other file that doesn't compile is bs.dsp. I could eventually get this working, but it's kind of low priority in my opinion.

If I set the filesCompare precision to 2, ignoring bs.dsp and osc_enable.dsp, all the tests pass. There are some examples which don't pass low precision differences, such as carre_volterra.dsp:

/Applications/Xcode.app/Contents/Developer/usr/bin/make -f Make.jax outdir=jax/double   FAUSTOPTIONS="-I dsp -double"
python3  ir/jax/double/jax_carre_volterra.py > ir/jax/double/carre_volterra.ir
./filesCompare ir/jax/double/carre_volterra.ir reference/carre_volterra.ir 1e-4       
Line : 2632 output : 0 sample1 : 0.090985 different from sample2 : 0.091085 delta : 0.0001
Line : 2633 output : 0 sample1 : 0.056059 different from sample2 : 0.056159 delta : 0.0001
Line : 2681 output : 0 sample1 : -0.125745 different from sample2 : -0.125846 delta : 0.000101
Line : 2682 output : 0 sample1 : -0.091029 different from sample2 : -0.091131 delta : 0.000102
Line : 2683 output : 0 sample1 : -0.056184 different from sample2 : -0.056285 delta : 0.000101
Line : 2684 output : 0 sample1 : -0.021322 different from sample2 : -0.021424 delta : 0.000102
Line : 2685 output : 0 sample1 : 0.013448 different from sample2 : 0.013347 delta : 0.000101
Line : 2686 output : 0 sample1 : 0.048029 different from sample2 : 0.047928 delta : 0.000101
Line : 2687 output : 0 sample1 : 0.082328 different from sample2 : 0.082228 delta : 0.0001
Line : 2731 output : 0 sample1 : 0.125626 different from sample2 : 0.125727 delta : 0.000101
Line : 2732 output : 0 sample1 : 0.090992 different from sample2 : 0.091094 delta : 0.000102
Line : 2733 output : 0 sample1 : 0.056222 different from sample2 : 0.056325 delta : 0.000103

Should I be testing -double? What are your suggestions for debugging small numerical differences?

DBraun avatar Sep 28 '22 07:09 DBraun

Ignoring osc_enable.dsp and bs.dsp should be OK. But yes using -doubleprecision is the way to go ! All tests actually use -double and using default = float can possibly explain the precision issues.

sletz avatar Sep 28 '22 07:09 sletz

Precision tests all pass! I only have to ignore bs.dsp and osc_enable.dsp in Make.jax.

How can I add the JAX installation and tests to appveyor?

DBraun avatar Sep 29 '22 02:09 DBraun

Starting to look at the code; It would be simpler for me to comment the code after 1) rebase on master-dev 2) and prepare a PR in --squash mode (that is a unique single big commit). Thanks.

sletz avatar Oct 03 '22 10:10 sletz

Thanks for the rebase. I will do general comments on the PR here, and specific ones directly in the code. General comments looking at the generated code:

  • I understand that all "init" like functions (init/instanceInit/classInit etc...) are grouped in a single classInit. The usual semantic for classInit in the C++ generated code is to share init code that is somewhat common to all "instances". Here I don't think you really have this "static init" versus "instance init" issue anymore. So I would prefer having classInit simply become init
  • all thoseload_soundfile/add_soundfile/add_nentry/add_button/add_slider are just generic code right ? Then why having the backend generate them? Wouldn't be better to have them part of an architecture file ? (which is the standard way todo...)
  • when compiling freeverb.dsp , I see: state["IOTA0"] = 0 generated in build_interface . This seems like an initialisation right? So better move that in init
  • still an indentation issue in the generated code (probably a non matched "tab" at the end of the file)

sletz avatar Oct 04 '22 10:10 sletz

Looking at the generated code: call state = self.build_interface(state, x, T) seems like somewhat strange and does not really follow "blue/red/green" model of architecture https://faustdoc.grame.fr/manual/architectures/#architecture-files Why not removing build_interface (and have initialize init the DSP state but not more) and have:

def __call__(self, x, T: int) -> jnp.array:
   state = self.initialize(self.sample_rate, x, T)
   state = self.build_interface(self.sample_rate, x, T)
   return jnp.transpose(jax.lax.scan(self.tick, state, jnp.transpose(x, axes=(1, 0)))[1], axes=(1,0))

Ot is is because self.build_interfaceis considered as initialising the DSP state in this JAX backend ?

sletz avatar Oct 06 '22 18:10 sletz

Could you possibly document a bit jax_code_container.cp with the JAX backend particularities as in https://github.com/grame-cncm/faust/blob/master-dev/compiler/generator/julia/julia_code_container.cpp#L32 ?

sletz avatar Oct 06 '22 18:10 sletz

It seems bool fMutateFun; in JAXInstVisitor is not used anymore ?

sletz avatar Oct 06 '22 18:10 sletz

In JAXInstVisitor, please use the opcode name instead if 7 in the following line: if (inst->fOpcode > 7 && !fIsDoingWhile) {

sletz avatar Oct 06 '22 18:10 sletz

The code is now more like

def __call__(self, x, T: int) -> jnp.array:
   state = self.initialize(self.sample_rate, x, T)
   state = self.build_interface(state, x, T)
   state = jax.tree_map(jnp.array, state)
   return jnp.transpose(jax.lax.scan(self.tick, state, jnp.transpose(x, axes=(1, 0)))[1], axes=(1,0))

The state needed to be handed off to build_interface.

I would like to get the add_soundfile add_slider etc functions into an architecture file, but I didn't figure out how to default to a specific architecture file if none is specified.

I removed bool fMutateFun;, added some comments in jax_code_container.cpp and improved if (inst->fOpcode > 7 && !fIsDoingWhile) {

Can you tell me where FAUSTFLOAT(foo) is generated in C++? I had to define a dummy function

def FAUSTFLOAT(x):
    return x

But it would be better if FAUSTFLOAT(foo) was just foo everywhere. I didn't figure it out with breakpoints much earlier.

DBraun avatar Oct 07 '22 06:10 DBraun

"but I didn't figure out how to default to a specific architecture file if none is specified." => well even in C++, the compiler actually does not generate a self-contained class.... It always assumes dsp/UI/meta etc... are defined in an architecture file. This is why I think we could just do the same same with JAX and assume (also..) that the appropriate architecture will be used.

sletz avatar Oct 07 '22 06:10 sletz

Good point. I'll try that tomorrow.

DBraun avatar Oct 07 '22 06:10 DBraun

Can you tell me where FAUSTFLOAT(foo) is generated in C++? => I"ll check that

sletz avatar Oct 07 '22 06:10 sletz

FAUSTFLOAT machinery makes sense in C/C++. With JAX I think you should do as with the LLVM backend where we actually consider that FAUSTFLOAT is the same type as internal REAL type (that is either float or double depending if -single (= default) or -double is used as an option). See https://github.com/grame-cncm/faust/blob/master-dev/compiler/global.hh#L166 and https://github.com/grame-cncm/faust/blob/master-dev/compiler/libcode.cpp#L1396.

sletz avatar Oct 07 '22 08:10 sletz

I moved most of the code into the architecture files now. I think setting gGlobal->gFAUSTFLOAT2Internal = true; helped avoid FAUSTFLOAT being called.

DBraun avatar Oct 08 '22 05:10 DBraun

Great work ! Merged in --squash mode after some final cleanup and reformatting in https://github.com/grame-cncm/faust/commit/44d66aac61b05cb172e101a2b4051e2aa0ea248f

Still 2 thinks to check:

  • Is this FAUSTFLOAT related block still necessary https://github.com/grame-cncm/faust/commit/44d66aac61b05cb172e101a2b4051e2aa0ea248f#diff-5ac628298a99ca6f354d9f3f5a64424743f37be2451f2ab84e6645418f1c02e4R149-R165 ?
  • the JAXCodeContainer::inlineSubcontainersFunCalls is almost the same the the base class CodeContainer::inlineSubcontainersFunCalls (with only the renaming block changing AFAICS) https://github.com/grame-cncm/faust/commit/44d66aac61b05cb172e101a2b4051e2aa0ea248f#diff-5ac628298a99ca6f354d9f3f5a64424743f37be2451f2ab84e6645418f1c02e4R296. Is renaming an issue here? If not then I guess the base class CodeContainer::inlineSubcontainersFunCalls should be used.

sletz avatar Oct 08 '22 08:10 sletz

Now everything is OK, thanks !

sletz avatar Oct 09 '22 06:10 sletz