coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

support for torch.cdist

Open SpiraMira opened this issue 10 months ago β€’ 5 comments

🌱 Describe your Feature Request

  • support an optimized torch.cdist implementation
  • a good implementation would support both CPU and GPU (c++ and/or kernel shaders)

How can this feature be used?

  • in my case, will facilitate seamless conversion of some hugging face models (without refactoring)

Describe alternatives you've considered

       # https://stackoverflow.com/questions/27948363/numpy-broadcast-to-perform-euclidean-distance-vectorized
       # this version uses a dot product and is optimized to consume less resources

        A = torch.ones(576,4) #, dtype=inputs_dtype32)
        B = torch.ones(8192,4) #, dtype=inputs_dtype32)
        summand1= torch.sum(torch.square(A)[:, None, :], axis=2)
        # summand2= 2 * A @ B.T
        summand2= 2 * A @ torch.t(B)
        # summand2= 2 * A @ torch.transpose(B,0,1)
        summand3= torch.sum(torch.square(B), axis=1)
        threeSums = summand1 - summand2 + summand3
        euclid_distance_dot = torch.sqrt(threeSums)

OR

        # this is more concise but β€œis_close” to torch.cdist only.  It also uses more resources
        euclid_distance_no_dot=  torch.sqrt(torch.sum(torch.square(A[:, None ,:] - B), axis=2))     

Additional context

  • writing my own op is an alternative, but am not an expert.

SpiraMira avatar Apr 16 '24 01:04 SpiraMira

Can you give us a minimal code to reproduce this issue?

TobyRoseman avatar Apr 22 '24 18:04 TobyRoseman

Can you give us a minimal code to reproduce this issue?

sure here goes ...

macos14.4.1 Xcode 15.3 coremltools 7.1 python 3.8... pytorch: -version:2.4.0.dev20240421 -build:py3.8_0 -channel:pytorch-nightly

import torch
from torch import nn
import coremltools as ct

class BugModule(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, z):      
        A = torch.ones(576,4)
        B = torch.ones(8192,4)
        _,min_encoding_indices = torch.min(torch.cdist(A, B), dim=1)
        return min_encoding_indices
        
bug_model = BugModule().eval()
any_shape = (1,4,24,24)
traced_model = torch.jit.trace(bug_model, torch.rand(*any_shape)) 
coreml_model = ct.convert(traced_model, inputs=[ct.TensorType(shape=any_shape)], 
                          debug=True,
                          convert_to="mlprogram",
                          minimum_deployment_target=ct.target.macOS13)

output

Converting PyTorch Frontend ==> MIL Ops:  82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–           | 18[/22](http://localhost:8888/22) [00:00<00:00, 8612.53 ops[/s](http://localhost:8888/s)]
the following model ops are IMPLEMENTED:
  constant
  listconstruct
  min
  ones
the following model ops are MISSING:
  cdist

SpiraMira avatar Apr 22 '24 22:04 SpiraMira

added note : also breaks for me under a (released) pytorch: 2.1.2

SpiraMira avatar Apr 23 '24 05:04 SpiraMira

@SpiraMira please try this:


import torch
from torch import nn
import coremltools as ct

def my_cdist(x1: torch.Tensor, x2: torch.Tensor, p=2.0):
    assert p == 2.0

    x1_norm = x1.pow(2).sum(-1, keepdim=True)
    x1_pad = torch.ones_like(x1_norm)
    x2_norm = x2.pow(2).sum(-1, keepdim=True)
    x2_pad = torch.ones_like(x2_norm)

    x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], dim=-1)
    x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)

    result = x1_.matmul(x2_.transpose(0, 1)) 
    result = result.clamp_min_(0.0).sqrt_()
    return result

# test cdist implementation with random values 
def test_cdist():
    A = torch.rand(576, 4)
    B = torch.rand(8192, 4)

    result1 = torch.cdist(A, B)
    result2 = my_cdist(A, B)

    print(result1)
    print(result2)

    assert result1.shape == result2.shape 
    assert torch.eq(result1, result2).all() == True 

    print("test_cdist: OK")

test_cdist()

class BugModule(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, z):      
        A = torch.ones(576,4)
        B = torch.ones(8192,4)

        # _,min_encoding_indices = torch.min(torch.cdist(A, B), dim=1)
        _,min_encoding_indices = torch.min(my_cdist(A, B), dim=1)
        return min_encoding_indices
        
bug_model = BugModule().eval()
any_shape = (1,4,24,24)
traced_model = torch.jit.trace(bug_model, torch.rand(*any_shape)) 
coreml_model = ct.convert(traced_model, inputs=[ct.TensorType(shape=any_shape)], 
                          debug=True,
                          convert_to="mlprogram",
                          minimum_deployment_target=ct.target.macOS13)

If you use default parameters (p=2) it should work well

https://pytorch.org/docs/stable/_modules/torch/functional.html#cdist


def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):

dneprDroid avatar May 20 '24 17:05 dneprDroid

thank you ! works well as a drop in replacement.

SpiraMira avatar May 26 '24 21:05 SpiraMira