basepairmodels icon indicating copy to clipboard operation
basepairmodels copied to clipboard

[Suggestion] Read in custom architecture and loss functions from command line

Open mmtrebuchet opened this issue 4 years ago • 1 comments

As I start complicating my interaction with BPNet more, I'd like to add a feature where the user can easily supply a network architecture and loss function to the program. Here's a rough sketch of how I was thinking this could work, and I'd love to get your input. I'm very much not a software engineer, so these may be terrible anti-patterns with a much better way of addressing them.

Here's the issue I'm bumping into. When I want to create a custom model architecture, I have to dive deep into the source of basepairmodels and modify model_archs.py. This is problematic for three reasons. First, the changes are global and I cannot, for example, have my model architecture file in the same directory I'm using for my data without symlinking from basepairmodels into my experiment directory. (and even that doesn't help if I'm working on multiple models in different directories.) Second, if I want to download the latest version of basepairmodels, it will overwrite my changes to model_archs.py, or at least break the symlink to my modified version. Third, it is very difficult to provide additional customization options on the command line, as might be necessary when doing a grid search over hyperparameters.

  1. Model definitions. 1.1 Add a flag --model-src to train; this flag would accept a string naming a python file. Let's call it customModelDef.py. 1.2. Add a flag --model-args to train; this flag would be an arbitrary string. Let's call the string modelArgs for the moment. 1.3. The python file named by --model-src shall contain a function called model(). It will take a single string as an argument, and this is the string given on the command line to --model-args. The function model() shall return a network just as the function in model_archs.py do currently. 1.3.1. The function model() could optionally be required to accept other arguments that are germane to other parts of the program. For example, since the sequence generator needs to know the input length, and the number of bias profiles is required to provide the correct input to the model, then model() could require arguments of the form model(input_seq_len, output_len, num_bias_profiles, modelArgs). 1.3.2. The function could also be designed to accept keyword arguments from the main program, so that model()'s signature would be model(modelArgs, **kwargs). Then the main program would provide additional information that model() could use or discard. For example, the main program could call customModelDef.model(modelArgs, input_seq_len=NNNN, output_len=NNNN, num_bias_profiles=NNNN, ...) and so on. This way, for an architecture where num_bias_profiles is irrelevant, model() could simply ignore that keyword argument. 1.4. The function model() may do with its argument string what it pleases. For example, the modelArgs string could be something like "num_profiles=5:kernel_size=6:allow_opt=false:add_insult_layer=0.825", in which case model() would return the corresponding network. Or, more likely, modelArgs would be something like "/projects/Sebastian/training/config.json", in which case model() would probably open up that file and read in configuration from it. In any event, the main train command would be completely agnostic to how modelArgs is processed or its meaning.

  2. Add a custom loss function. 2.1. Add a flag --loss-src to any cli components that need the loss function. (Since most of the tools that work with the network have to create multinomial_nll before they can load the model, this would probably be most of the cli tools). This flag would take a string naming a python file, call it customLossDef.py 2.2. customLossDef.py should contain a function that returns a loss function, or similar. I'm not familiar enough with how the loss is created to know what the precise architecture of this function should be. 2.2.1. One option would be to have a function getLossFunction(lossArgs), accepting a string like model() would. This function would then return a loss function based on the string lossArgs. But there could be other, better ways of implementing this.

Your thoughts?

mmtrebuchet avatar Feb 24 '21 19:02 mmtrebuchet