Error when using wass
Thanks for sharing your code and excellent work!
I try to run your code on other experimental data. However, I run into an error when using wass. The code works when other imb_funs (mmd2_rbf) are used.
Attached is the error log. Do you have any idea about what might cause this error so that I could fix it?
Traceback (most recent call last): File "cfr_net_train.py", line 427, in main run(outdir) File "cfr_net_train.py", line 374, in run D_exp_test, logfile, i_exp) File "cfr_net_train.py", line 142, in train CFR.r_alpha: FLAGS.p_alpha, CFR.r_lambda: FLAGS.p_lambda, CFR.p_t: p_treated}) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 766, in run run_metadata_ptr) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 964, in _run feed_dict_string, options, run_metadata) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1014, in _do_run target_list, options, run_metadata) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1034, in _do_call raise type(e)(node_def, op, message) InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [0,10] vs. shape[1] = [1,1] [[Node: concat_2 = Concat[N=2, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](concat_2/concat_dim, concat_1, concat)]] [[Node: gradients/Cast_2/_111 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_588_gradients/Cast_2", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"]]
Caused by op u'concat_2', defined at:
File "cfr_net_train.py", line 434, in
InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [0,10] vs. shape[1] = [1,1] [[Node: concat_2 = Concat[N=2, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](concat_2/concat_dim, concat_1, concat)]] [[Node: gradients/Cast_2/_111 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_588_gradients/Cast_2", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"]]
The reported issue might be because of a mismatch in the tf version. I used tf version=1.14 and the API seems to be broken.
I fixed the code by making the following changes in the wasserstein() function:
col = tf.concat([delta*tf.ones(tf.shape(M[:,0:1])),tf.zeros((1,1))], axis=0)
Mt = tf.concat([M,row], axis=0)
Mt = tf.concat([Mt,col], axis=1)
a = tf.concat([p*tf.ones(tf.shape(tf.where(t>0)[:,0:1]))/nt, (1-p)*tf.ones((1,1))], axis=0)
b = tf.concat([(1-p)*tf.ones(tf.shape(tf.where(t<1)[:,0:1]))/nc, p*tf.ones((1,1))], axis=0)