Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
VectorInterfaceChainRulesCoreExt = "ChainRulesCore"
VectorInterfaceEnzymeExt = "Enzyme"
VectorInterfaceMooncakeExt = "Mooncake"
VectorInterfaceStaticArraysExt = "StaticArrays"

[compat]
Aqua = "0.6, 0.7, 0.8"
Expand All @@ -25,6 +27,7 @@ EnzymeTestUtils = "0.2.6"
LinearAlgebra = "1"
Mooncake = "0.5"
Random = "1"
StaticArrays = "1"
Test = "1"
TestExtras = "0.2,0.3"
julia = "1"
Expand All @@ -37,8 +40,9 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[targets]
test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "Mooncake", "Enzyme", "EnzymeTestUtils", "Random"]
test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "Mooncake", "Enzyme", "EnzymeTestUtils", "Random", "StaticArrays"]
12 changes: 12 additions & 0 deletions ext/VectorInterfaceStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module VectorInterfaceStaticArraysExt

using VectorInterface
using StaticArrays: SArray

# `SArray` is immutable so make sure !! methods route to non-inplace methods
VectorInterface.zerovector!!(x::SArray) = zerovector(x)
VectorInterface.scale!!(x::SArray, α::Number) = scale(x, α)
VectorInterface.scale!!(y::SArray, x::SArray, α::Number) = scale(x, α * one(scalartype(y)))
VectorInterface.add!!(y::SArray, x::SArray, α::Number, β::Number) = add(y, x, α, β)

end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ end

Base.collect(x::MinimalVec) = x.vec
@static if isdefined(Base, :get_extension) && isempty(VERSION.prerelease)
println("Testing StaticArrays extension")
println("==============================")
include("staticsvec.jl")

println("Testing AD rules")
println("================")
println("Testing ChainRules")
Expand Down
108 changes: 108 additions & 0 deletions test/staticsvec.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
module StaticSVec
using VectorInterface
using StaticArrays
using Test
using TestExtras

deepcollect(x::StaticArray) = collect(x)
deepcollect(x::Number) = x

x = SVector{3}(randn(3))
y = SVector{3}(randn(3))

@testset "scalartype" begin
s = @constinferred scalartype(x)
@test s == Float64
end

@testset "zerovector" begin
z = @constinferred zerovector(x)
@test z isa SVector{3, Float64}
@test all(iszero, deepcollect(z))
@test all(deepcollect(z) .=== zero(scalartype(x)))
z1 = @constinferred zerovector!!(x)
@test z1 isa SVector{3, Float64}
@test all(deepcollect(z1) .=== zero(scalartype(x)))

z3 = @constinferred zerovector(x, ComplexF64)
@test z3 isa SVector{3, ComplexF64}
@test all(deepcollect(z3) .=== zero(ComplexF64))
z4 = @constinferred zerovector!!(x, ComplexF64)
@test z4 isa SVector{3, ComplexF64}
@test all(deepcollect(z4) .=== zero(ComplexF64))
end

@testset "scale" begin
α = randn()
z = @constinferred scale(x, α)
@test z isa SVector{3, Float64}
@test all(deepcollect(z) .== α .* deepcollect(x))

z2 = @constinferred scale!!(x, α)
@test z2 isa SVector{3, Float64}
@test deepcollect(z2) ≈ (α .* deepcollect(x))
z2 = @constinferred scale!!(y, x, α)
@test z2 isa SVector{3, Float64}
@test deepcollect(z2) ≈ (α .* deepcollect(x))

α = randn(ComplexF64)
z4 = @constinferred scale(x, α)
@test z4 isa SVector{3, ComplexF64}
@test deepcollect(z4) ≈ (α .* deepcollect(x))
z5 = @constinferred scale!!(x, α)
@test z5 isa SVector{3, ComplexF64}
@test deepcollect(z5) ≈ (α .* deepcollect(x))

z6 = @constinferred scale!!(zerovector(x), x, α)
@test z6 isa SVector{3, ComplexF64}
@test deepcollect(z6) ≈ (α .* deepcollect(x))

ycomplex = zerovector(y, ComplexF64)
α = randn(Float64)
z8 = @constinferred scale!!(ycomplex, x, α)
@test scalartype(z8) == ComplexF64
@test all(deepcollect(z8) .== α .* deepcollect(x))
end

@testset "add" begin
α, β = randn(2)
z = @constinferred add(y, x)
@test z isa SVector{3, Float64}
@test all(deepcollect(z) .== deepcollect(x) .+ deepcollect(y))
z = @constinferred add(y, x, α)
@test deepcollect(z) ≈ muladd.(deepcollect(x), α, deepcollect(y))
z = @constinferred add(y, x, α, β)
@test deepcollect(z) ≈ muladd.(deepcollect(x), α, deepcollect(y) .* β)

z2 = @constinferred add!!(y, x)
@test z2 isa SVector{3, Float64}
@test deepcollect(z2) ≈ (deepcollect(x) .+ deepcollect(y))
z2 = @constinferred add!!(y, x, α)
@test deepcollect(z2) ≈ (muladd.(deepcollect(x), α, deepcollect(y)))
z2 = @constinferred add!!(y, x, α, β)
@test deepcollect(z2) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β))

α, β = randn(ComplexF64, 2)
z4 = @constinferred add(y, x, α)
@test z4 isa SVector{3, ComplexF64}
@test deepcollect(z4) ≈ (muladd.(deepcollect(x), α, deepcollect(y)))
z4 = @constinferred add(y, x, α, β)
@test deepcollect(z4) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β))

z5 = @constinferred add!!(y, x, α)
@test z5 isa SVector{3, ComplexF64}
@test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y)))
z5 = @constinferred add!!(y, x, α, β)
@test deepcollect(z5) ≈ (muladd.(deepcollect(x), α, deepcollect(y) .* β))
end

@testset "inner" begin
s = @constinferred inner(x, y)
@test s ≈ inner(deepcollect(x), deepcollect(y))

α, β = randn(ComplexF64, 2)
s2 = @constinferred inner(scale(x, α), scale(y, β))
@test s2 ≈ inner(α * deepcollect(x), β * deepcollect(y))
end

end
Loading