ludwig
ludwig copied to clipboard
Add s4 module
This PR begins the work required to add S4 (paper, code) to Ludwig. The implementation work thus far has focused on adapting the high-level API of the S4 class to integrate with the greater Ludwig ecosystem through S4Encoder. There are several next steps before this PR can be considered for review. Namely, the following:
1. Identifying the arguments to expose in HippoSSKernel and SSKernelNPLR
In the current implementation, only the high-level arguments in the S4 class are configurable. One next step is to comb through the classes used by S4 (HippoSSKernel and, subsequently, SSKernelNPLR) and identify the class arguments that should be exposed to Ludwig users via S4Encoder.
2. Confirm intended functionality for masking and determining sequence length
At time of writing, S4Encoder.forward ignores the mask argument. We need to confirm that this is reasonable with respect to the original work. Further, the original repository seems to do some automated work to determine l_max in the S4 class. In the current implementation, l_max is configured using the max_sequence_length argument passed into S4Encoder. We need to confirm (1) whether or not these variables are conceptually referring to the same thing and (2) what the expected behavior should be if max_sequence_length is None.
3. Clean up s4_modules.py
s4_modules.py currently contains a large amount of code pulled directly from the authors' original repository. Due to good abstraction, most of this code can likely remain unchanged. However, there are a few functionalities in the original repository whose interactions with Ludwig remain unknown, one in particular being the use of pykeops to compute the Cauchy kernel.
4. Documentation
At time of writing, most of the functions and classes have minimal documentation. Some of the functions are quite math-heavy and outside of my area of expertise, so it may make sense to ask the authors for their input :)
5. Write tests
Fill in test_s4_modules.py and test_sequence_encoders.py tests to confirm everything works as expected.