RL4LMs
RL4LMs copied to clipboard
passing extra variable to the forward function
Hey, I am currently using your repo to finetune a Longformer model. The problem is this model requires to pre-define a global attention mask (in addition to the regular attention mask), which defines which of the tokens get an extra "global attention head". So my question is - is there an easy way to pass this variable, that does not require to skim through the code and locate every calling of the forward functions? I other words- is there an easy way to pass extra model_kwargs? Thanks!