diff --git a/Project.toml b/Project.toml index 29f0e2bc8..1b464db86 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ Aqua = "0.6, 0.7, 0.8" CUDA = "6" ChainRulesCore = "1" ChainRulesTestUtils = "1" -Enzyme = "0.13.131" +Enzyme = "0.13.148" EnzymeTestUtils = "0.2.5" GenericLinearAlgebra = "0.3.19, 0.4" GenericSchur = "0.5.6" diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 57bc92a53..a7014a463 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -67,13 +67,13 @@ for (f, pb) in ( # so we do not need to cache it. This may change if future pullbacks # depend directly on A! ret = func.val(A.val, arg.val, alg.val) - # if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed + # if arg.val === ret, the annotation must be Duplicated or DuplicatedNoNeed # if arg isa Const, ret may still be modified further down the call graph so we should # copy it to protect ourselves A_is_arg1 = !isa(A, Const) && A.val === arg.val[1] A_is_arg2 = !isa(A, Const) && A.val === arg.val[2] A_is_arg = A_is_arg1 || A_is_arg2 - cache_arg = (arg.val !== ret && !A_is_arg) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing + cache_arg = arg.val !== ret || A_is_arg || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing dret = if EnzymeRules.needs_shadow(config) && ((TA == Nothing && TB == Nothing) || isa(arg, Const)) make_zero.(ret) elseif EnzymeRules.needs_shadow(config)