TorchSharp icon indicating copy to clipboard operation
TorchSharp copied to clipboard

GELU does not appear to support approximate tanh

Open travisjj opened this issue 1 year ago • 4 comments

The optional algorithm for GELU is to internally use tanh

See more here: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html#torch.nn.GELU

I was expecting this to just work:

var gelu = nn.GELU(approximate: "tanh");

When the approximate argument is ‘tanh’, GELU is estimated differently. The default is rather different.

Is it possible, since this is supported natively, to include the "approximate" property for TorchSharp's GELU?

Is there a way for me to do it without requiring the difficulty of pushing new versions of the library?

travisjj avatar Aug 08 '24 22:08 travisjj

I'm guessing perhaps this could be an option

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_GELU_ctor(string approximate, out IntPtr pBoxedModule);

and then perhaps replace the current GELU calling function, or add an overload (either way seems similar)

public static GELU GELU(string approximate = "none")
{
    IntPtr boxedHandle;
    IntPtr intPtr = NativeMethods.THSNN_GELU_ctor(approximate, out boxedHandle);
    if (intPtr == IntPtr.Zero)
    {
        torch.CheckForErrors();
    }
    return new GELU(intPtr, boxedHandle);
}

travisjj avatar Aug 08 '24 22:08 travisjj

Two options:

  1. Fix the code and send us a much-appreciated PR. The approximate argument should be an enumeration instead of a string.
  2. Implement your own GELU module using the available mathematical primitives in TorchSharp.

NiklasGustafsson avatar Aug 09 '24 19:08 NiklasGustafsson

Sorry if this seems obvious, just trying to make sure it's right.

I'm definitely willing to try the PR approach for this (and anything else I could help with).

  • I am unsure what naming conventions for enums are used within TorchSharp, and what the appropriate namespace or scope would be.
  • I forked the repo, but I am new to issuing PR's so any guidance would be appreciated (or if the CONTRIBUTING.md explanation fully applies I will just try that approach)

Would then enum reside within the same GELU.cs file? Perhaps the changes could look like:

PInvoke change:

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_GELU_ctor(TorchSharp.Modules.ApproxType approximate, out IntPtr pBoxedModule);

GELU.cs change:

(within the Modules namespace)

    public enum ApproxType
    {
            none,
            tanh
    }

the updated constructor:

    public static GELU GELU(ApproxType approximate = ApproxType.none)
    {
            var handle = THSNN_GELU_ctor(approximate, out var boxedHandle);
            if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
            return new GELU(handle, boxedHandle);
    }

travisjj avatar Aug 10 '24 01:08 travisjj

I tried the previous code, but it causes an exception when calling the ctor. If I use string instead of the enum it works, so perhaps the implicit conversion of ApproxType.tanh to 1 is causing the problem. Unsure how or where the enum would be brought back to a string to satisfy the approximate parameter.

Perhaps a blend of the two?

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_GELU_ctor(string approximate, out IntPtr pBoxedModule);

public enum ApproxType
{
        none,
        tanh
}

public static GELU GELU(ApproxType approximate = ApproxType.none)
{
        var handle = THSNN_GELU_ctor(approximate.ToString("f"), out var boxedHandle);
        if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
        return new GELU(handle, boxedHandle);
}

travisjj avatar Aug 10 '24 02:08 travisjj