pytorch-crf icon indicating copy to clipboard operation
pytorch-crf copied to clipboard

Torchscripted crf

Open aravindMahadevan opened this issue 3 years ago • 7 comments
trafficstars

Description

Supports Torchscripting the CRF model. Fixes issue #93 without changing the current interface. Added additional tests to verify that Torchscripted CRF model outputs for forward and decode are equivalent to that in non-scripted CRF model.

>>> import torch
>>> from torchcrf import CRF
>>> num_tags = 5  # number of tags is 5
>>> model = CRF(num_tags)
>>> script = torch.jit.script(model)
RecursiveScriptModule(original_name=CRF)

aravindMahadevan avatar Jun 27 '22 15:06 aravindMahadevan

Hi, thanks for the PR. The tests are failing because torch.jit doesn't seem to have attribute export. I suspect this is because the PyTorch version this library uses is too old. Unfortunately, I have limited availability to upgrade to newer PyTorch. Are you perhaps interested in helping out?

kmkurn avatar Jul 09 '22 03:07 kmkurn

Sure, I can help out. What would we need to do to upgrade this library?

aravindMahadevan avatar Jul 09 '22 03:07 aravindMahadevan

Awesome, thank you!

I think we can start simple:

  1. Figure out what the lowest version of PyTorch that has torch.jit.export is
  2. Update the PyTorch version in .github/workflows/run_tests.yml so the tests will be against the new PyTorch version
  3. Make sure all the tests are succeeding

Also, moving the conversation on the upgrade to a separate issue is perhaps better for tracking. Thanks again for agreeing to help out!

kmkurn avatar Jul 09 '22 03:07 kmkurn

I have created issue #104 where work for upgrading this library to Pytorch version 1.2 will take place.

aravindMahadevan avatar Jul 09 '22 17:07 aravindMahadevan

Hey @kmkurn , I tried running this branch with pytorch version 1.2.0 which is where torch.jit.export is introduced but I've been running into some issues. I also have a proposal for a fix at the end of this post.

First issue is that torch.jit doesn't support some basic operations such as not in.

E               torch.jit.frontend.NotSupportedError: unsupported comparison operator: NotIn:
E                               none|sum|mean|token_mean. none: no reduction will be applied.
E                               sum: the output will be summed over batches. mean: the output will be
E                               averaged over batches. token_mean: the output will be averaged over tokens.
E               
E                       Returns:
E                           ~torch.Tensor: The log likelihood. This will have size (batch_size,) if
E                           reduction is none, () otherwise.
E                       """
E                       self._validate(emissions, tags=tags, mask=mask)
E                       if reduction not in ['none', 'sum', 'mean', 'token_mean']:
E                                   ~~~~~~~~ <--- HERE
E                           raise ValueError(f'invalid reduction: {reduction}')
E                       if mask is None:
E                           mask = torch.ones_like(tags, dtype=torch.uint8)
E               
E                       if self.batch_first:
E                           emissions = emissions.transpose(0, 1)
E                           tags = tags.transpose(0, 1)
E                           mask = mask.transpose(0, 1)

Even if we fix this with by replacing this line with

if reduction != 'none' and reduction != 'sum' and reduction !='mean' and reduction != 'token_mean:'

We run into another issue where Torchscripted code returns that it doesn't recognize the type Torch.LongTensor.

E       RuntimeError: 
E       Unknown type name 'torch.LongTensor':
E           def forward(
E                   self,
E                   emissions: torch.Tensor,
E                   tags: torch.LongTensor,
E                         ~~~~~~~~~~~~~~~~ <--- HERE
E                   mask: Optional[torch.ByteTensor] = None,
E                   reduction: str = 'sum',
E           ) -> torch.Tensor:
E               """Compute the conditional log likelihood of a sequence of tags given emission scores.
E       
E               Args:
E                   emissions (`~torch.Tensor`): Emission score tensor of size
E                       ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
E                       ``(batch_size, seq_length, num_tags)`` otherwise.

This above error occurs in all torch versions between 1.2.0 to 1.8.0. The way we can fix it is by changing the type from torch.LongTensor to torch.Tensor but this could mean that we specify FloatTensor as input which is undesirable.

In 1.9.0, the method Tensor.new_ones is not supported by Torchscripting leading to this error:

E       RuntimeError: 
E       'Tensor' object has no attribute or method 'new_ones'.:
E               self._validate(emissions, mask=mask)
E               if mask is None:
E                   mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
E                          ~~~~~~~~~~~~~~~~~~ <--- HERE
E           
E               if self.batch_first:

We can change this to instead use Tensor.ones of the same shape and explicitly specifying that dtype=emissions.dtype and device=emissions.device but it won't be as elegant as using Tensor.new_ones.

Torch 1.10.0 is where the Torchscripted code compiles successfully.

My proposal is that we have a utility function that takes in a CRF module and returns a scripted CRF module . The function will first assert that the Torch version is greater than equal to version 1.10.0, and then return scripted CRF module.

aravindMahadevan avatar Jul 12 '22 19:07 aravindMahadevan

Thanks for the detailed write up! Really appreciate the time you spent investigating this. It's surprising how half-baked the JIT support feels with all these unsupported operations. With your suggestion, does that mean there'd be essentially 2 versions of the code, one is the normal version and the other is the JIT-friendly version?

kmkurn avatar Jul 17 '22 09:07 kmkurn

I was suggesting adding a function inside init after the CRF module class that takes in a CRF module as input and we would first check the torch version and then return a scripted CRF. Something like:

def script_crf(crf):
     assert torch.__version__ >= 1.10.0, "Torchscripting the CRF model requires Pytorch version 1.10.0 and higher" 
     return torch.jit.script(crf)

The issue with having multiple versions is that if there is a fix in one version, then it might not be fixed in the other version. Instead of having a utility function that torchscripts the CRF module, what if we merge these changes in once the Pytorch 1.2 update #104 gets merged into main? I can update the torchscripting tests to only run if the torch version is 1.10.0 or higher and we can also update the documentation to specify that CRF model is torchscriptable with torch version 1.10.0 and higher.

aravindMahadevan avatar Jul 17 '22 16:07 aravindMahadevan

Sorry for responding so late. I like your solution. I've merged #104 so you can implement the solution now.

kmkurn avatar Dec 09 '22 23:12 kmkurn

Hi, what is missing here, I am interested in using this. Any help needed?

marcelbischoff avatar Dec 19 '22 15:12 marcelbischoff

Actually, I think you could add to the README that torchscripting does not work for PyTorch < 1.10.0 and just merge this.

When people try to torchscript something they will not look for a function from the library but will just call torch.jit.script on their final module and get errors anyway and look it up.

Furthermore, I think the script_crf function is not actually usable because you always script your complete module that uses CRF not script every submodule individually.

We can also specify a minimum PyTorch version in the requirements-test.txt to solve the problem with failing tests on older versions.

erksch avatar May 09 '23 12:05 erksch

@kmkurn I recreated a pull request with all the mentioned changes combined :) #113

erksch avatar May 09 '23 13:05 erksch