Skip to content

Add approximate parameter to GELU activation function#1548

Open
alinpahontu2912 wants to merge 2 commits intodotnet:mainfrom
alinpahontu2912:feature/gelu-approximate-parameter
Open

Add approximate parameter to GELU activation function#1548
alinpahontu2912 wants to merge 2 commits intodotnet:mainfrom
alinpahontu2912:feature/gelu-approximate-parameter

Conversation

@alinpahontu2912
Copy link
Member

Fixes #1368

Add support for the 'approximate' parameter in GELU, matching PyTorch's torch.nn.GELU(approximate='tanh') functionality.

Changes:

  • Add GELU.Approximate enum with 'none' and 'tanh' values
  • Thread approximate parameter through all layers: native C++, PInvoke, Tensor methods, functional API, and module factory
  • Add new overloads (no breaking changes to existing API)
  • Add test for tanh approximation mode

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support for PyTorch’s approximate mode to GELU (notably "tanh"), threading the option through the native (C++), P/Invoke, Tensor, functional, and module APIs, and adding a regression test.

Changes:

  • Introduces Modules.GELU.Approximate (none / tanh) and plumbs it through nn.GELU and nn.functional.gelu.
  • Extends Tensor gelu/gelu_ to accept an approximation mode and updates the corresponding native/PInvoke signatures.
  • Adds a unit test validating the tanh approximation path and that it differs from the exact mode.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
test/TorchSharpTest/NN.cs Adds a test covering GELU tanh approximation behavior.
src/TorchSharp/Tensor/Tensor.cs Adds gelu/gelu_ overloads that pass approximation through to native.
src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs Updates P/Invoke signatures to accept the approximation string.
src/TorchSharp/NN/Activation/GELU.cs Adds approximation enum + overloads in module factory and functional API.
src/Native/LibTorchSharp/THSTensor.h Updates native exports for GELU to accept an approximation parameter.
src/Native/LibTorchSharp/THSTensor.cpp Passes approximation through to torch::gelu / torch::gelu_.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

alinpahontu2912 and others added 2 commits March 11, 2026 16:15
Add support for the 'approximate' parameter in GELU, matching PyTorch's
torch.nn.GELU(approximate='tanh') functionality.

Changes:
- Add GELU.Approximate enum with 'none' and 'tanh' values
- Thread approximate parameter through all layers: native C++, PInvoke,
  Tensor methods, functional API, and module factory
- Add new overloads (no breaking changes to existing API)
- Add test for tanh approximation mode

Fixes dotnet#1368

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Move Approximate enum from GELU module class to neutral
  TorchSharp namespace as GELUApproximate, removing Tensor/functional
  layer dependency on Modules layer
- Add CharSet, BestFitMapping, ThrowOnUnmappableChar attributes to
  THSTensor_gelu/gelu_ DllImport declarations to match existing
  LPStr-based imports pattern
- Update all references in Tensor.cs, GELU.cs, and tests

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +2986 to 2991
public Tensor gelu(GELUApproximate approximate)
{
var res = NativeMethods.THSTensor_gelu(Handle, approximate == GELUApproximate.tanh ? "tanh" : "none");
if (res == IntPtr.Zero)
CheckForErrors();
return new Tensor(res);
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gelu(GELUApproximate approximate) silently maps any unrecognized enum value to "none". That hides invalid inputs (e.g., casts) and can make debugging hard. Consider validating approximate and throwing an ArgumentOutOfRangeException (or similar) for unsupported values instead of defaulting to "none".

Copilot uses AI. Check for mistakes.
Comment on lines +3002 to 3006
public Tensor gelu_(GELUApproximate approximate)
{
var res = NativeMethods.THSTensor_gelu_(Handle, approximate == GELUApproximate.tanh ? "tanh" : "none");
if (res == IntPtr.Zero)
CheckForErrors();
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gelu_(GELUApproximate approximate) also falls back to "none" for any enum value other than tanh. For consistency and to avoid silently ignoring invalid values, validate the enum and fail fast when it’s outside the supported set.

Copilot uses AI. Check for mistakes.
var x = torch.tensor(new float[] { -1.0f, 0.0f, 1.0f, 2.0f });
var exact = torch.nn.functional.gelu(x);
var approx = torch.nn.functional.gelu(x, GELUApproximate.tanh);
Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5));
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new test exercises the out-of-place approximate GELU path via torch.nn.functional.gelu(x, GELUApproximate.tanh), but it doesn’t cover the newly added in-place overload x.gelu_(GELUApproximate.tanh). Adding an assertion that the in-place path runs and matches the out-of-place approximate result would help catch P/Invoke/native wiring regressions.

Suggested change
Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5));
Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5));
// Verify that the in-place tanh approximate matches the out-of-place result
var xInPlace = x.clone();
xInPlace.gelu_(GELUApproximate.tanh);
Assert.True(approx.allclose(xInPlace, rtol: 1e-5, atol: 1e-5));

Copilot uses AI. Check for mistakes.
Comment on lines +5 to +13
/// Specifies the approximation method for the GELU activation function.
/// </summary>
public enum GELUApproximate
{
/// <summary>
/// Exact GELU computation.
/// </summary>
none,
/// <summary>
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description mentions adding a GELU.Approximate enum, but the implementation introduces a top-level GELUApproximate enum. If the intended public API is the nested name, consider renaming/moving the enum; otherwise, update the PR description to match the shipped API surface.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GELU does not appear to support approximate tanh

2 participants