From b533b97a83bff1a917c953e4277ff9501694d07d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 27 May 2026 14:10:39 +0200 Subject: [PATCH 1/2] Force usage of working Enzyme --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 29f0e2bc8..75998207e 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ Aqua = "0.6, 0.7, 0.8" CUDA = "6" ChainRulesCore = "1" ChainRulesTestUtils = "1" -Enzyme = "0.13.131" +Enzyme = "<0.13.148" EnzymeTestUtils = "0.2.5" GenericLinearAlgebra = "0.3.19, 0.4" GenericSchur = "0.5.6" From 74f19d8b758b17b798f32bd624fac431e89ec381 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 27 May 2026 15:35:15 +0200 Subject: [PATCH 2/2] Fix check for cache_arg --- Project.toml | 2 +- ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 75998207e..1b464db86 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ Aqua = "0.6, 0.7, 0.8" CUDA = "6" ChainRulesCore = "1" ChainRulesTestUtils = "1" -Enzyme = "<0.13.148" +Enzyme = "0.13.148" EnzymeTestUtils = "0.2.5" GenericLinearAlgebra = "0.3.19, 0.4" GenericSchur = "0.5.6" diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 57bc92a53..a7014a463 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -67,13 +67,13 @@ for (f, pb) in ( # so we do not need to cache it. This may change if future pullbacks # depend directly on A! ret = func.val(A.val, arg.val, alg.val) - # if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed + # if arg.val === ret, the annotation must be Duplicated or DuplicatedNoNeed # if arg isa Const, ret may still be modified further down the call graph so we should # copy it to protect ourselves A_is_arg1 = !isa(A, Const) && A.val === arg.val[1] A_is_arg2 = !isa(A, Const) && A.val === arg.val[2] A_is_arg = A_is_arg1 || A_is_arg2 - cache_arg = (arg.val !== ret && !A_is_arg) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing + cache_arg = arg.val !== ret || A_is_arg || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing dret = if EnzymeRules.needs_shadow(config) && ((TA == Nothing && TB == Nothing) || isa(arg, Const)) make_zero.(ret) elseif EnzymeRules.needs_shadow(config)