From d443565928227e07c746eaadfa6cf6a85df8176e Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 21 May 2026 12:34:08 -0400 Subject: [PATCH 1/2] add StaticArrays extension --- Project.toml | 6 +++++- ext/VectorInterfaceStaticArraysExt.jl | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 ext/VectorInterfaceStaticArraysExt.jl diff --git a/Project.toml b/Project.toml index cc14c2d..480fe9c 100644 --- a/Project.toml +++ b/Project.toml @@ -10,11 +10,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] VectorInterfaceChainRulesCoreExt = "ChainRulesCore" VectorInterfaceEnzymeExt = "Enzyme" VectorInterfaceMooncakeExt = "Mooncake" +VectorInterfaceStaticArraysExt = "StaticArrays" [compat] Aqua = "0.6, 0.7, 0.8" @@ -25,6 +27,7 @@ EnzymeTestUtils = "0.2.6" LinearAlgebra = "1" Mooncake = "0.5" Random = "1" +StaticArrays = "1" Test = "1" TestExtras = "0.2,0.3" julia = "1" @@ -37,8 +40,9 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [targets] -test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "Mooncake", "Enzyme", "EnzymeTestUtils", "Random"] +test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "Mooncake", "Enzyme", "EnzymeTestUtils", "Random", "StaticArrays"] diff --git a/ext/VectorInterfaceStaticArraysExt.jl b/ext/VectorInterfaceStaticArraysExt.jl new file mode 100644 index 0000000..9d56f5e --- /dev/null +++ b/ext/VectorInterfaceStaticArraysExt.jl @@ -0,0 +1,12 @@ +module VectorInterfaceStaticArraysExt + +using VectorInterface +using StaticArrays: SArray + +# `SArray` is immutable so make sure !! methods route to non-inplace methods +VectorInterface.zerovector!!(x::SArray) = zerovector(x) +VectorInterface.scale!!(x::SArray, α::Number) = scale(x, α) +VectorInterface.scale!!(y::SArray, x::SArray, α::Number) = scale(x, α * one(scalartype(y))) +VectorInterface.add!!(y::SArray, x::SArray, α::Number, β::Number) = add(y, x, α, β) + +end From 933d46d7edd91646944c60472b4f308535f2bc69 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 21 May 2026 12:37:08 -0400 Subject: [PATCH 2/2] add tests --- test/runtests.jl | 4 ++ test/staticsvec.jl | 108 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 test/staticsvec.jl diff --git a/test/runtests.jl b/test/runtests.jl index 010ce8a..e9fcb97 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,10 @@ end Base.collect(x::MinimalVec) = x.vec @static if isdefined(Base, :get_extension) && isempty(VERSION.prerelease) + println("Testing StaticArrays extension") + println("==============================") + include("staticsvec.jl") + println("Testing AD rules") println("================") println("Testing ChainRules") diff --git a/test/staticsvec.jl b/test/staticsvec.jl new file mode 100644 index 0000000..7d3ce26 --- /dev/null +++ b/test/staticsvec.jl @@ -0,0 +1,108 @@ +module StaticSVec +using VectorInterface +using StaticArrays +using Test +using TestExtras + +deepcollect(x::StaticArray) = collect(x) +deepcollect(x::Number) = x + +x = SVector{3}(randn(3)) +y = SVector{3}(randn(3)) + +@testset "scalartype" begin + s = @constinferred scalartype(x) + @test s == Float64 +end + +@testset "zerovector" begin + z = @constinferred zerovector(x) + @test z isa SVector{3, Float64} + @test all(iszero, deepcollect(z)) + @test all(deepcollect(z) .=== zero(scalartype(x))) + z1 = @constinferred zerovector!!(x) + @test z1 isa SVector{3, Float64} + @test all(deepcollect(z1) .=== zero(scalartype(x))) + + z3 = @constinferred zerovector(x, ComplexF64) + @test z3 isa SVector{3, ComplexF64} + @test all(deepcollect(z3) .=== zero(ComplexF64)) + z4 = @constinferred zerovector!!(x, ComplexF64) + @test z4 isa SVector{3, ComplexF64} + @test all(deepcollect(z4) .=== zero(ComplexF64)) +end + +@testset "scale" begin + α = randn() + z = @constinferred scale(x, α) + @test z isa SVector{3, Float64} + @test all(deepcollect(z) .== α .* deepcollect(x)) + + z2 = @constinferred scale!!(x, α) + @test z2 isa SVector{3, Float64} + @test deepcollect(z2) ≈ (α .* deepcollect(x)) + z2 = @constinferred scale!!(y, x, α) + @test z2 isa SVector{3, Float64} + @test deepcollect(z2) ≈ (α .* deepcollect(x)) + + α = randn(ComplexF64) + z4 = @constinferred scale(x, α) + @test z4 isa SVector{3, ComplexF64} + @test deepcollect(z4) ≈ (α .* deepcollect(x)) + z5 = @constinferred scale!!(x, α) + @test z5 isa SVector{3, ComplexF64} + @test deepcollect(z5) ≈ (α .* deepcollect(x)) + + z6 = @constinferred scale!!(zerovector(x), x, α) + @test z6 isa SVector{3, ComplexF64} + @test deepcollect(z6) ≈ (α .* deepcollect(x)) + + ycomplex = zerovector(y, ComplexF64) + α = randn(Float64) + z8 = @constinferred scale!!(ycomplex, x, α) + @test scalartype(z8) == ComplexF64 + @test all(deepcollect(z8) .== α .* deepcollect(x)) +end + +@testset "add" begin + α, β = randn(2) + z = @constinferred add(y, x) + @test z isa SVector{3, Float64} + @test all(deepcollect(z) .== deepcollect(x) .+ deepcollect(y)) + z = @constinferred add(y, x, α) + @test deepcollect(z) ≈ muladd.(deepcollect(x), α, deepcollect(y)) + z = @constinferred add(y, x, α, β) + @test deepcollect(z) ≈ muladd.(deepcollect(x), α, deepcollect(y) .* β) + + z2 = @constinferred add!!(y, x) + @test z2 isa SVector{3, Float64} + @test deepcollect(z2) ≈ (deepcollect(x) .+ deepcollect(y)) + z2 = @constinferred add!!(y, x, α) + @test deepcollect(z2) ≈ (muladd.(deepcollect(x), α, deepcollect(y))) + z2 = @constinferred add!!(y, x, α, β) + @test deepcollect(z2) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β)) + + α, β = randn(ComplexF64, 2) + z4 = @constinferred add(y, x, α) + @test z4 isa SVector{3, ComplexF64} + @test deepcollect(z4) ≈ (muladd.(deepcollect(x), α, deepcollect(y))) + z4 = @constinferred add(y, x, α, β) + @test deepcollect(z4) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β)) + + z5 = @constinferred add!!(y, x, α) + @test z5 isa SVector{3, ComplexF64} + @test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y))) + z5 = @constinferred add!!(y, x, α, β) + @test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β)) +end + +@testset "inner" begin + s = @constinferred inner(x, y) + @test s ≈ inner(deepcollect(x), deepcollect(y)) + + α, β = randn(ComplexF64, 2) + s2 = @constinferred inner(scale(x, α), scale(y, β)) + @test s2 ≈ inner(α * deepcollect(x), β * deepcollect(y)) +end + +end