DeepMatch icon indicating copy to clipboard operation
DeepMatch copied to clipboard

当对DSSM模型设置sample_weight会报错

Open zhuchenxi opened this issue 1 year ago • 0 comments

Describe the bug(问题描述) 当对DSSM模型设置sample_weight会报错,其中sample_weight是按照格式,和label一样大小的一个numpy的一维数组

To Reproduce(复现步骤) 运行代码: history = model.fit(train_model_input, train_label, batch_size=256, epochs=4, verbose=1, validation_split=0.0, sample_weight = sample_weights)

Operating environment(运行环境):

  • python version [3.7]
  • tensorflow version [1.15.0]
  • deepmatch version [0.3.1]

Additional context 对应结果: Traceback (most recent call last): File "dssm_qt_train_noid.py", line 93, in sample_weight = sample_weights) File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 727, in fit use_multiprocessing=use_multiprocessing) File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 675, in fit steps_name='steps_per_epoch') File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 394, in model_iteration batch_outs = f(ins_batch) File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 3476, in call run_metadata=self.run_metadata) File "/DATA/jupyter/personal/guanggao_dssm/py37_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1472, in call run_metadata_ptr) tensorflow.python.framework.errors_impl.InvalidArgumentError: Can not squeeze dim[0], expected a dimension of 1, got 256 [[{{node loss_1/in_batch_softmax_layer_loss/weighted_loss/Squeeze}}]]

zhuchenxi avatar Oct 15 '23 14:10 zhuchenxi