toma icon indicating copy to clipboard operation
toma copied to clipboard

Variable batchsize decrease factor

Open ChielWH opened this issue 4 years ago • 0 comments

First of all, many thanks for this handy utility package! My use-case is to detect the largest batchsize possible for translation during inference, which is rather low. I see that the batchsize decrease factor is fixed at 2:

@dataclassclass Batchsize:    
    ...
    def decrease_batchsize(self):        
        self.value //= 2        
        assert self.value

For my use-case, this quite an aggressive decrease factor. Might it be possible to make this variable? So that it can be used in such a way:

@toma.batch(initial_batchsize=32, cache_type=GlobalBatchsizeCache, decrease_factor=1.2)
def infer(batchsize, *args, **kwargs):
    ...

I'm happy to help, so if you would like me to make a PR, please let me know.

ChielWH avatar Jun 10 '21 06:06 ChielWH