basenji icon indicating copy to clipboard operation
basenji copied to clipboard

How to use basenji_train.py to train multi-genome?

Open xfchen0912 opened this issue 3 years ago • 5 comments

Hi, David! Thanks for creating this incredible work! I would like to use your scripts to train a model for plant genomes, but I can't find the tutorials for multi-genome training. So I've made some attempts, but now I'm running into some problems.

First, I used basenji_data_align.py and basenji_data.py to build a dataset,

Then I change the params_human.json in manuscripts/cross2020 for my job, I add a head like this

        "head_rice": {
            "name": "final",
            "units": 12,
            "activation": "softplus"
        },
        "head_maize": {
            "name": "final",
            "units": 8,
            "activation": "softplus"
        }

Finally, I run this script, but there is no output for a long time, I want to know which step went wrong.

basenji_train.py -o models/cross ./params/params_cross.json ./dataset_cross/rice_maize_cross/rice/ ./dataset_cross/rice_maize_cross/maize/

Can you give me some advise for this, or provide a standard pipline for multi-genome training.

I would be very grateful if you could help!

Thank you very much!

xfchen0912 avatar Oct 21 '21 14:10 xfchen0912

That looks correct to me. Are you using a fast GPU for training? It's possible that you haven't seen output because it's still working on the first epoch. Do you see activity from your CPU and GPU? You might recreate a sample dataset using -d 0.1 to use 10% of the sequences to make sure everything else is working properly. You might also send the exact commands that you used to generate the dataset, and I can check them for correctness.

davek44 avatar Oct 21 '21 20:10 davek44

Thanks for your reply!

I checked the Volatile GPU-Util after the model is started for a long time, but it is 0%, when I test using a dataset of single genome, GPU can run normally.I am confused.

There are my commands to generate the dataset, please check it for me.

basenji_data_align.py -a rice,maize -c 8192 -d 1.0 -g IRGSP_1.0.gaps.bed,B73_RefGen_v4.gaps.bed -l 131072 -o dataset_cross/rice_maize_cross -t 0.1 -w 128 -v 0.1 IRGSP_1.0vsB73_RefGen_v4.syn.net public/home/xwli/xfchen/reference/rice/nip7chr.fa,/public/home/xwli/xfchen/reference/B73/B73.fa

basenji_data.py --restart --local -c 8192 -d 1.0 -g IRGSP_1.0.gaps.bed -l 131072 -o dataset_cross/rice_maize_cross/rice -t 0.1 -w 128 -v 0.1 /public/home/xwli/xfchen/reference/rice/nip7chr.fa ./target/targets_riceNip.txt

basenji_data.py --restart --local -c 8192 -d 1.0 -g B73_RefGen_v4.gaps.bed -l 131072 -o dataset_cross/rice_maize_cross/maize -t 0.1 -w 128 -v 0.1 /public/home/xwli/xfchen/reference/B73/B73.fa ./target/targets_maizeB73.txt

Thank you very much!

xfchen0912 avatar Oct 22 '21 02:10 xfchen0912

Hi, David!The above problem is because I use the newest Basenji script with the old version of Module, the dataset is empty. Now I resolved the problem,but I have encountered a new problem:

Traceback (most recent call last):
  File "/public/home/xwli/xfchen/basenji/bin/basenji_train.py", line 165, in <module>
    main()
  File "/public/home/xwli/xfchen/basenji/bin/basenji_train.py", line 159, in main
    seqnn_trainer.fit2(seqnn_model)
  File "/public/home/xwli/xfchen/basenji/basenji/trainer.py", line 305, in fit2
    train_step1(x, y)
  File "/public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 871, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 725, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3196, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 977, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    /public/home/xwli/xfchen/basenji/basenji/trainer.py:184 train_step1  *
        loss = self.loss_fn(y, pred) + sum(seqnn_model.models[1].losses)
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/keras/losses.py:152 __call__  **
        losses = call_fn(y_true, y_pred)
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/keras/losses.py:256 call  **
        return ag_fn(y_true, y_pred, **self._fn_kwargs)
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/keras/losses.py:1686 poisson
        return K.mean(y_pred - y_true * math_ops.log(y_pred + K.epsilon()), axis=-1)
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py:1180 binary_op_wrapper
        raise e
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py:1164 binary_op_wrapper
        return func(x, y, name=name)
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py:1496 _mul_dispatch
        return multiply(x, y, name=name)
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py:518 multiply
        return gen_math_ops.mul(x, y, name)
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/ops/gen_math_ops.py:6077 mul
        _, _, _op, _outputs = _op_def_library._apply_op_helper(
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py:748 _apply_op_helper
        op = g._create_op_internal(op_type_name, inputs, dtypes=None,
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:590 _create_op_internal
        return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3528 _create_op_internal
        ret = Operation(
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:2015 __init__
        self._c_op = _create_c_op(self._graph, node_def, inputs,
    /public/home/xwli/anaconda3/envs/basenji/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1856 _create_c_op
        raise ValueError(str(e))

    ValueError: Dimensions must be equal, but are 8 and 12 for '{{node poisson/mul}} = Mul[T=DT_FLOAT](y, poisson/Log)' with input shapes: [4,896,8], [4,896,12].

I haven't changed the commmds, can you give me some advise? Thank you.

xfchen0912 avatar Oct 22 '21 05:10 xfchen0912

Is it possible that the heads specified in your parameters file are in a different order than the datasets given to the training script? The 8 vs 12 mismatch could be due to that.

davek44 avatar Oct 25 '21 18:10 davek44

I encountered a similar error. The reason is that the final heads are sorted first. So to make it work, you can rename your two final heads to "head0_rice" and "head1_maize" to make them in the correct order. Hope this helps after almost 3 years.

nzhang89 avatar Apr 17 '24 22:04 nzhang89