brainstorm
brainstorm copied to clipboard
bugfix in PyCudaHandler merge and split operations
The attribute call to gpudata is done by pycuda inside the function, no need to do it outside the function.
Before this change, using a merge layer with pycuda 2016.1 was giving the following errors in forward and backward pass:
File "test.py", line 214, in main
trainer.train(network, getter_tr, valid_getter=getter_va)
File "brainstorm/training/trainer.py", line 99, in train
self.stepper.run()
File "brainstorm/training/steppers.py", line 103, in run
self.net.forward_pass(training_pass=True)
File "brainstorm/structure/network.py", line 430, in forward_pass
layer.forward_pass(self.buffer[layer_name], training_pass)
File "brainstorm/layers/merge_layer.py", line 52, in forward_pass
buffers.outputs.default)
File "brainstorm/handlers/pycuda_handler.py", line 323, in merge_tt
block=block, grid=grid)
File "/lib/python2.7/site-packages/pycuda/driver.py", line 383, in function_call
handlers, arg_buf = _build_arg_buf(args)
File "/lib/python2.7/site-packages/pycuda/driver.py", line 158, in _build_arg_buf
raise TypeError("invalid type on parameter #%d (0-based)" % i)
TypeError: invalid type on parameter #0 (0-based)
File "test.py", line 214, in main
trainer.train(network, getter_tr, valid_getter=getter_va)
File "brainstorm/training/trainer.py", line 99, in train
self.stepper.run()
File "brainstorm/training/steppers.py", line 104, in run
self.net.backward_pass()
File "brainstorm/structure/network.py", line 444, in backward_pass
layer.backward_pass(self.buffer[layer_name])
File "brainstorm/layers/merge_layer.py", line 59, in backward_pass
buffers.input_deltas.inputs_2)
File "brainstorm/handlers/pycuda_handler.py", line 364, in split_add_tt
block=block, grid=grid)
File "/lib/python2.7/site-packages/pycuda/driver.py", line 383, in function_call
handlers, arg_buf = _build_arg_buf(args)
File "/lib/python2.7/site-packages/pycuda/driver.py", line 158, in _build_arg_buf
raise TypeError("invalid type on parameter #%d (0-based)" % i)
TypeError: invalid type on parameter #0 (0-based)