DeepSurv
DeepSurv copied to clipboard
TypeError: nesterov_momentum() got an unexpected keyword argument 'logger'
When I get to this row, I get an error. Why is that? "metrics = model.train(train_data, n_epochs=n_epochs, logger=logger, update_fn=update_fn)" [INFO] Training CoxMLP
TypeError Traceback (most recent call last)
D:\anaconda\lib\site-packages\deepsurv\deep_surv.py in train(self, train_data, valid_data, n_epochs, validation_frequency, patience, improvement_threshold, patience_increase, verbose, update_fn, **kwargs) 366 reached, looks at validation improvement to increase patience or 367 early stop. --> 368 improvement_threshold: percentage of improvement needed to increase 369 patience. 370 patience_increase: multiplier to patience if threshold is reached.
D:\anaconda\lib\site-packages\deepsurv\deep_surv.py in _get_train_valid_fn(self, L1_reg, L2_reg, learning_rate, **kwargs) 208 updates = update_fn( 209 scaled_grads, self.params, **kwargs --> 210 ) 211 else: 212 updates = update_fn(
D:\anaconda\lib\site-packages\deepsurv\deep_surv.py in _get_loss_updates(self, L1_reg, L2_reg, update_fn, max_norm, deterministic, **kwargs) 179 Returns Theano expressions for the network's loss function and parameter 180 updates. --> 181 182 Parameters: 183 L1_reg: float for L1 weight regularization coefficient.
TypeError: nesterov_momentum() got an unexpected keyword argument 'logger'
Are you still having this issue? It looks like the logger is being passed from the .train() function to the _get_train_valid_fn() which is then passing it up update_fn through the **kwargs.
Adding a logger=None parameter to the function signature of _get_train_valid_fn might fix the problem.
For example:
_get_train_valid_fn(self, L1_reg, L2_reg, learning_rate, **kwargs)
Would become:
_get_train_valid_fn(self, L1_reg, L2_reg, learning_rate, logger=None, **kwargs)