Char-RNN-TensorFlow
Char-RNN-TensorFlow copied to clipboard
支持断点继续训练,支持TensorFlow2.0,增加predict功能。
支持断点继续训练,若未达到目标次数会按照最后一次保存的模型继续训练;若已经到达目标次数,会直接停止。
支持TF2.0,并未改动大多数代码,只是启用了TF2.0已经弃用的1.0功能,并且关闭了2.0的功能。
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
增加了predict功能,给出start_string。
python predict.py \
--converter_path model/torch_gen/converter.pkl \
--checkpoint_path model/torch_gen \
--max_length 1500 \
--start_string " raise "
会输出如下结果:
raise -> utized_inpu probability: 0.6539345979690552
raise -> es()\r\n probability: 0.1654084473848343
raise -> pistent_and probability: 0.07784435153007507
raise -> al_module_t probability: 0.0615621916949749
raise -> Porgex(self probability: 0.04125040024518967
另外加入了预处理好的pytorch的代码,在data/torch_code.txt中,去除了#注释,把所有字符串都替换成了"msg" 'msg' """msg""" '''msg'''的形式。