RL4LMs icon indicating copy to clipboard operation
RL4LMs copied to clipboard

passing extra variable to the forward function

Open lovodkin93 opened this issue 2 years ago • 1 comments

Hey, I am currently using your repo to finetune a Longformer model. The problem is this model requires to pre-define a global attention mask (in addition to the regular attention mask), which defines which of the tokens get an extra "global attention head". So my question is - is there an easy way to pass this variable, that does not require to skim through the code and locate every calling of the forward functions? I other words- is there an easy way to pass extra model_kwargs? Thanks!

lovodkin93 avatar Dec 28 '22 11:12 lovodkin93

Hey, there is no straightforward way to do this. Just adapt the policy implementation to pass these extra arguments.

rajcscw avatar Jan 02 '23 09:01 rajcscw