JGAN
JGAN copied to clipboard
fix(cluster-gan): `valid` shape error
reproduct
python ./JGAN/models/cluster_gan/clustergan.py -b 64
error
Traceback (most recent call last):
File "./cluster_gan.py", line 396, in <module>
real_loss = bce_loss(D_real, valid)
File "/usr/local/lib/python3.8/dist-packages/jittor/__init__.py", line 1172, in __call__
return self.execute(*args, **kw)
File "/usr/local/lib/python3.8/dist-packages/jittor/nn.py", line 487, in execute
return bce_loss(output, target, self.weight, self.size_average)
File "/usr/local/lib/python3.8/dist-packages/jittor/nn.py", line 409, in bce_loss
loss = - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20)))
RuntimeError: Wrong inputs arguments, Please refer to examples(help(jt.__mul__)).
Types of your inputs are:
self = Var,
b = Var,
The function declarations are:
VarHolder* multiply(VarHolder* x, VarHolder* y)
Failed reason:[f 0422 19:43:25.634921 80 binary_op.cc:432] Check failed xshape(64) == yshape(32) Shape not match, x:float32[64,1,] y:float32[32,1,]