blurr icon indicating copy to clipboard operation
blurr copied to clipboard

How to use BatchLossFilter callback with blurr

Open maxmatical opened this issue 4 years ago • 2 comments

I've been trying to experiment with using tsai's BatchLossFilter callback. If I try to run the training with this callback

model = HF_BaseModelWrapper(hf_model)

learn = Learner(dls, 
                model,
                loss_func=LabelSmoothingCrossEntropyFlat(),
                metrics=[accuracy],
                cbs=[HF_BaseModelCallback],
                splitter=hf_splitter).to_fp16()

learn.unfreeze()

cbs = [BatchLossFilter(loss_perc=0.4)]

learn.fit_one_cycle(
    3,
    lr_max=3e-5,
    cbs = cbs
)

I get the following error, which is due to the SequenceClassifierOutput object from huggingface

---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

<ipython-input-13-ce9f3d20977f> in <module>()
      2     3,
      3     lr_max=3e-5,
----> 4     cbs = cbs
      5 )

18 frames

/usr/local/lib/python3.7/dist-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
     32         if self.floatify and targ.dtype!=torch.float16: targ = targ.float()
     33         if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
---> 34         if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
     35         return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
     36 

AttributeError: 'SequenceClassifierOutput' object has no attribute 'view'

Is there any way I can adapt BatchLossFilter to be functional with blurr? I haven't had any issues using other callbacks with blurr

maxmatical avatar Sep 01 '21 17:09 maxmatical

First, cool library ... wasn't aware of that one!

Second, it would be helpful you can post a gist I can run so I can see full stack trace here. The inp is expecting a tensor, but in the case of Blurr (and HuggingFace) what is output at this point is an object with a bunch of info such as loss, etc.... With callbacks, I'm sure this can be altered to work with Blurr/HF.

Btw, any particular reason you're using BatchLoss? Just curious :)

ohmeow avatar Sep 01 '21 19:09 ohmeow

Here is a gist using the BatchLossFilter callback with the standard blurr training script. I'm currently experimenting with BatchLossFilter since it had some traction on twitter a while back, plus intuitively, focusing on the harder examples could potentially improve performance, so extra tools in the toolbox never hurts 😄

i was able using the example implementation with other forms of data (images, timeseries etc.), so it looks to be an issue specific to huggingface if that helps

maxmatical avatar Sep 01 '21 19:09 maxmatical