RL4LMs icon indicating copy to clipboard operation
RL4LMs copied to clipboard

passing extra variable to the forward function

Open lovodkin93 opened this issue 1 year ago • 1 comments

Hey, I am currently using your repo to finetune a Longformer model. The problem is this model requires to pre-define a global attention mask (in addition to the regular attention mask), which defines which of the tokens get an extra "global attention head". So my question is - is there an easy way to pass this variable, that does not require to skim through the code and locate every calling of the forward functions? I other words- is there an easy way to pass extra model_kwargs? Thanks!

lovodkin93 avatar Dec 28 '22 11:12 lovodkin93