Char-RNN-TensorFlow icon indicating copy to clipboard operation
Char-RNN-TensorFlow copied to clipboard

支持断点继续训练,支持TensorFlow2.0,增加predict功能。

Open LZY2006 opened this issue 5 years ago • 0 comments

支持断点继续训练,若未达到目标次数会按照最后一次保存的模型继续训练;若已经到达目标次数,会直接停止。

支持TF2.0,并未改动大多数代码,只是启用了TF2.0已经弃用的1.0功能,并且关闭了2.0的功能。

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

增加了predict功能,给出start_string。

python predict.py \
  --converter_path model/torch_gen/converter.pkl \
  --checkpoint_path  model/torch_gen \
  --max_length 1500 \
  --start_string "    raise "

会输出如下结果:

    raise  ->  utized_inpu   probability: 0.6539345979690552
    raise  -> es()\r\n       probability: 0.1654084473848343
    raise  ->  pistent_and   probability: 0.07784435153007507
    raise  ->  al_module_t   probability: 0.0615621916949749
    raise  ->  Porgex(self   probability: 0.04125040024518967

另外加入了预处理好的pytorch的代码,在data/torch_code.txt中,去除了#注释,把所有字符串都替换成了"msg" 'msg' """msg""" '''msg'''的形式。

LZY2006 avatar Jan 10 '21 10:01 LZY2006