Add approximate parameter to GELU activation function#1548
Add approximate parameter to GELU activation function#1548alinpahontu2912 wants to merge 2 commits intodotnet:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds support for PyTorch’s approximate mode to GELU (notably "tanh"), threading the option through the native (C++), P/Invoke, Tensor, functional, and module APIs, and adding a regression test.
Changes:
- Introduces
Modules.GELU.Approximate(none/tanh) and plumbs it throughnn.GELUandnn.functional.gelu. - Extends Tensor
gelu/gelu_to accept an approximation mode and updates the corresponding native/PInvoke signatures. - Adds a unit test validating the tanh approximation path and that it differs from the exact mode.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| test/TorchSharpTest/NN.cs | Adds a test covering GELU tanh approximation behavior. |
| src/TorchSharp/Tensor/Tensor.cs | Adds gelu/gelu_ overloads that pass approximation through to native. |
| src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs | Updates P/Invoke signatures to accept the approximation string. |
| src/TorchSharp/NN/Activation/GELU.cs | Adds approximation enum + overloads in module factory and functional API. |
| src/Native/LibTorchSharp/THSTensor.h | Updates native exports for GELU to accept an approximation parameter. |
| src/Native/LibTorchSharp/THSTensor.cpp | Passes approximation through to torch::gelu / torch::gelu_. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Add support for the 'approximate' parameter in GELU, matching PyTorch's torch.nn.GELU(approximate='tanh') functionality. Changes: - Add GELU.Approximate enum with 'none' and 'tanh' values - Thread approximate parameter through all layers: native C++, PInvoke, Tensor methods, functional API, and module factory - Add new overloads (no breaking changes to existing API) - Add test for tanh approximation mode Fixes dotnet#1368 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Move Approximate enum from GELU module class to neutral TorchSharp namespace as GELUApproximate, removing Tensor/functional layer dependency on Modules layer - Add CharSet, BestFitMapping, ThrowOnUnmappableChar attributes to THSTensor_gelu/gelu_ DllImport declarations to match existing LPStr-based imports pattern - Update all references in Tensor.cs, GELU.cs, and tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
0079b80 to
a098ea4
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| public Tensor gelu(GELUApproximate approximate) | ||
| { | ||
| var res = NativeMethods.THSTensor_gelu(Handle, approximate == GELUApproximate.tanh ? "tanh" : "none"); | ||
| if (res == IntPtr.Zero) | ||
| CheckForErrors(); | ||
| return new Tensor(res); |
There was a problem hiding this comment.
gelu(GELUApproximate approximate) silently maps any unrecognized enum value to "none". That hides invalid inputs (e.g., casts) and can make debugging hard. Consider validating approximate and throwing an ArgumentOutOfRangeException (or similar) for unsupported values instead of defaulting to "none".
| public Tensor gelu_(GELUApproximate approximate) | ||
| { | ||
| var res = NativeMethods.THSTensor_gelu_(Handle, approximate == GELUApproximate.tanh ? "tanh" : "none"); | ||
| if (res == IntPtr.Zero) | ||
| CheckForErrors(); |
There was a problem hiding this comment.
gelu_(GELUApproximate approximate) also falls back to "none" for any enum value other than tanh. For consistency and to avoid silently ignoring invalid values, validate the enum and fail fast when it’s outside the supported set.
| var x = torch.tensor(new float[] { -1.0f, 0.0f, 1.0f, 2.0f }); | ||
| var exact = torch.nn.functional.gelu(x); | ||
| var approx = torch.nn.functional.gelu(x, GELUApproximate.tanh); | ||
| Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5)); |
There was a problem hiding this comment.
This new test exercises the out-of-place approximate GELU path via torch.nn.functional.gelu(x, GELUApproximate.tanh), but it doesn’t cover the newly added in-place overload x.gelu_(GELUApproximate.tanh). Adding an assertion that the in-place path runs and matches the out-of-place approximate result would help catch P/Invoke/native wiring regressions.
| Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5)); | |
| Assert.False(exact.allclose(approx, rtol: 1e-5, atol: 1e-5)); | |
| // Verify that the in-place tanh approximate matches the out-of-place result | |
| var xInPlace = x.clone(); | |
| xInPlace.gelu_(GELUApproximate.tanh); | |
| Assert.True(approx.allclose(xInPlace, rtol: 1e-5, atol: 1e-5)); |
| /// Specifies the approximation method for the GELU activation function. | ||
| /// </summary> | ||
| public enum GELUApproximate | ||
| { | ||
| /// <summary> | ||
| /// Exact GELU computation. | ||
| /// </summary> | ||
| none, | ||
| /// <summary> |
There was a problem hiding this comment.
PR description mentions adding a GELU.Approximate enum, but the implementation introduces a top-level GELUApproximate enum. If the intended public API is the nested name, consider renaming/moving the enum; otherwise, update the PR description to match the shipped API surface.
Fixes #1368
Add support for the 'approximate' parameter in GELU, matching PyTorch's torch.nn.GELU(approximate='tanh') functionality.
Changes: