diff --git a/cranelift/codegen/meta/src/isa/arm64.rs b/cranelift/codegen/meta/src/isa/arm64.rs index e43b74bc2e3a..1f5f0de23707 100644 --- a/cranelift/codegen/meta/src/isa/arm64.rs +++ b/cranelift/codegen/meta/src/isa/arm64.rs @@ -24,6 +24,14 @@ pub(crate) fn define() -> TargetIsa { "", false, ); + settings.add_bool( + "has_dotprod", + "Has Dot Product (FEAT_DotProd) support; enables lowering i8 dot \ + products (e.g. the relaxed-SIMD i8x16 dot) to SDOT/UDOT instead of \ + a smull/saddlp widening fallback.", + "", + false, + ); settings.add_bool( "sign_return_address_all", "If function return address signing is enabled, then apply it to all \ diff --git a/cranelift/codegen/src/isa/aarch64/inst.isle b/cranelift/codegen/src/isa/aarch64/inst.isle index 6e3fbcd2f683..04f91685d887 100644 --- a/cranelift/codegen/src/isa/aarch64/inst.isle +++ b/cranelift/codegen/src/isa/aarch64/inst.isle @@ -1600,6 +1600,8 @@ (Fmla) ;; Floating-point fused multiply-subtract vectors (Fmls) + ;; Signed integer dot product (FEAT_DotProd) + (Sdot) )) ;; A Vector miscellaneous operation with two registers. @@ -1819,6 +1821,10 @@ (decl use_lse () Inst) (extern extractor use_lse use_lse) +;; Matches any instruction when the `has_dotprod` (FEAT_DotProd) setting is on. +(decl use_dotprod () Inst) +(extern extractor use_dotprod use_dotprod) + (decl pure use_fp16 () bool) (extern constructor use_fp16 use_fp16) @@ -3417,6 +3423,12 @@ (rule (bsl ty c x y) (vec_rrr_mod (VecALUModOp.Bsl) c x y (vector_size ty))) +;; SDOT Vd.4S, Vn.16B, Vm.16B -- accumulates the 4-way i8 dot product into +;; `acc` (which is both source and destination). `a`/`b` are i8x16. +(decl sdot (Reg Reg Reg) Reg) +(rule (sdot acc a b) + (vec_rrr_mod (VecALUModOp.Sdot) acc a b (VectorSize.Size32x4))) + ;; Helper for generating a `udf` instruction. (decl udf (TrapCode) SideEffectNoResult) diff --git a/cranelift/codegen/src/isa/aarch64/inst/emit.rs b/cranelift/codegen/src/isa/aarch64/inst/emit.rs index 008ee230bfe9..0c261c267e56 100644 --- a/cranelift/codegen/src/isa/aarch64/inst/emit.rs +++ b/cranelift/codegen/src/isa/aarch64/inst/emit.rs @@ -2764,6 +2764,11 @@ impl MachInstEmit for Inst { VecALUModOp::Fmls => { (0b000_01110_10_1 | (size.enc_float_size() << 1), 0b110011) } + // SDOT Vd.4S, Vn.16B, Vm.16B (FEAT_DotProd). The size/element + // field (bits 23:22 = 0b10) is part of the dot-product opcode, + // so it is baked into top11; only Q (from `size`) is variable. + // top11 (Q=0) | q<<9 with bit15_10 yields 0x4E809400 for .4S/.16B. + VecALUModOp::Sdot => (0b000_01110_10_0, 0b100101), }; sink.put4(enc_vec_rrr(top11 | q << 9, rm, bit15_10, rn, rd)); } diff --git a/cranelift/codegen/src/isa/aarch64/inst/mod.rs b/cranelift/codegen/src/isa/aarch64/inst/mod.rs index c550b6dc053c..46ac59086779 100644 --- a/cranelift/codegen/src/isa/aarch64/inst/mod.rs +++ b/cranelift/codegen/src/isa/aarch64/inst/mod.rs @@ -2287,6 +2287,9 @@ impl Inst { VecALUModOp::Bsl => ("bsl", VectorSize::Size8x16), VecALUModOp::Fmla => ("fmla", size), VecALUModOp::Fmls => ("fmls", size), + // Note: the real operand arrangement is .4s, .16b, .16b; + // this debug print renders all lanes as .4s. + VecALUModOp::Sdot => ("sdot", VectorSize::Size32x4), }; let rd = pretty_print_vreg_vector(rd.to_reg(), size); let ri = pretty_print_vreg_vector(ri, size); diff --git a/cranelift/codegen/src/isa/aarch64/lower.isle b/cranelift/codegen/src/isa/aarch64/lower.isle index 42659d4c3272..a81f40bf4678 100644 --- a/cranelift/codegen/src/isa/aarch64/lower.isle +++ b/cranelift/codegen/src/isa/aarch64/lower.isle @@ -359,6 +359,23 @@ (rule -1 (lower (has_type ty (iadd_pairwise _ x y))) (addp x y (vector_size ty))) +;; With FEAT_DotProd, fold the `swiden`/`imul`/`iadd_pairwise`/`iadd` tree that +;; `i32x4.relaxed_dot_i8x16_i7x16_add_s` decomposes into (there is no dot CLIF +;; opcode) back into a single `sdot`. This is bit-identical to the otherwise +;; emitted smull/saddlp fallback over the in-i7 input range the op guarantees. +;; Priority 8 stays above the scalar `iadd` rules whose opaque +;; `ty_int_ref_scalar_64` guard the overlap checker can't prove disjoint here. +(rule 8 (lower (and (use_dotprod) + (has_type $I32X4 + (iadd _ + (iadd_pairwise _ + (swiden_low _ dot @ (iadd_pairwise _ + (imul _ (swiden_low _ a) (swiden_low _ b)) + (imul _ (swiden_high _ a) (swiden_high _ b)))) + (swiden_high _ dot)) + c)))) + (sdot c a b)) + ;;;; Rules for `iabs` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; (rule -1 (lower (has_type ty @ (multi_lane _ _) (iabs _ x))) diff --git a/cranelift/codegen/src/isa/aarch64/lower/isle.rs b/cranelift/codegen/src/isa/aarch64/lower/isle.rs index 6e707d695240..720b36413e1b 100644 --- a/cranelift/codegen/src/isa/aarch64/lower/isle.rs +++ b/cranelift/codegen/src/isa/aarch64/lower/isle.rs @@ -179,6 +179,14 @@ impl Context for IsleContext<'_, '_, MInst, AArch64Backend> { } } + fn use_dotprod(&mut self, _: Inst) -> Option<()> { + if self.backend.isa_flags.has_dotprod() { + Some(()) + } else { + None + } + } + fn use_fp16(&mut self) -> bool { self.backend.isa_flags.has_fp16() } diff --git a/cranelift/filetests/filetests/isa/aarch64/simd-sdot.clif b/cranelift/filetests/filetests/isa/aarch64/simd-sdot.clif new file mode 100644 index 000000000000..10fff0680c2d --- /dev/null +++ b/cranelift/filetests/filetests/isa/aarch64/simd-sdot.clif @@ -0,0 +1,35 @@ +test compile precise-output +set unwind_info=false +target aarch64 has_dotprod + +;; Tests the aarch64 `sdot` lowering: the i8 dot-product tree contracts to one `sdot`. +function %sdot_i8x16(i8x16, i8x16, i32x4) -> i32x4 { +block0(v0: i8x16, v1: i8x16, v2: i32x4): + v3 = swiden_low v0 + v4 = swiden_low v1 + v5 = imul v3, v4 + v6 = swiden_high v0 + v7 = swiden_high v1 + v8 = imul v6, v7 + v9 = iadd_pairwise v5, v8 + v10 = swiden_low v9 + v11 = swiden_high v9 + v12 = iadd_pairwise v10, v11 + v13 = iadd v12, v2 + return v13 +} + +; VCode: +; block0: +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; sdot v0.4s, v0.4s, v5.4s, v1.4s +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; sdot v0.4s, v5.16b, v1.16b +; ret + diff --git a/cranelift/filetests/filetests/runtests/simd-sdot.clif b/cranelift/filetests/filetests/runtests/simd-sdot.clif new file mode 100644 index 000000000000..fb73ac2e3197 --- /dev/null +++ b/cranelift/filetests/filetests/runtests/simd-sdot.clif @@ -0,0 +1,27 @@ +test interpret +test run +target aarch64 +target aarch64 has_dotprod +target x86_64 has_sse3 has_ssse3 has_sse41 +target s390x + +;; Tests the aarch64 `sdot` lowering of the i8 4-way dot product. +function %sdot_i8x16(i8x16, i8x16, i32x4) -> i32x4 { +block0(v0: i8x16, v1: i8x16, v2: i32x4): + v3 = swiden_low v0 + v4 = swiden_low v1 + v5 = imul v3, v4 + v6 = swiden_high v0 + v7 = swiden_high v1 + v8 = imul v6, v7 + v9 = iadd_pairwise v5, v8 + v10 = swiden_low v9 + v11 = swiden_high v9 + v12 = iadd_pairwise v10, v11 + v13 = iadd v12, v2 + return v13 +} +; each i32x4 lane i = c[i] + sum_{j=4i..4i+3} a[j]*b[j] +; run: %sdot_i8x16([1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16], [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1], [0 0 0 0]) == [10 26 42 58] +; run: %sdot_i8x16([-1 -2 -3 -4 1 2 3 4 5 5 5 5 -5 -5 -5 -5], [2 2 2 2 3 3 3 3 1 1 1 1 4 4 4 4], [100 200 300 400]) == [80 230 320 320] +; run: %sdot_i8x16([127 127 127 127 -128 -128 -128 -128 100 -100 100 -100 0 0 0 0], [127 127 127 127 127 127 127 127 63 63 63 63 1 2 3 4], [0 0 0 0]) == [64516 -65024 0 0] diff --git a/cranelift/native/src/lib.rs b/cranelift/native/src/lib.rs index 495913c2e17b..4469df829b58 100644 --- a/cranelift/native/src/lib.rs +++ b/cranelift/native/src/lib.rs @@ -116,6 +116,10 @@ pub fn infer_native_flags(isa_builder: &mut dyn Configurable) -> Result<(), &'st isa_builder.enable("has_fp16").unwrap(); } + if std::arch::is_aarch64_feature_detected!("dotprod") { + isa_builder.enable("has_dotprod").unwrap(); + } + if cfg!(target_os = "macos") { // Pointer authentication is always available on Apple Silicon. isa_builder.enable("sign_return_address").unwrap(); diff --git a/crates/wasmtime/src/config.rs b/crates/wasmtime/src/config.rs index 0eb4d2e16a76..454831091280 100644 --- a/crates/wasmtime/src/config.rs +++ b/crates/wasmtime/src/config.rs @@ -4297,6 +4297,7 @@ fn detect_host_feature(feature: &str) -> Option { "lse" => Some(std::arch::is_aarch64_feature_detected!("lse")), "paca" => Some(std::arch::is_aarch64_feature_detected!("paca")), "fp16" => Some(std::arch::is_aarch64_feature_detected!("fp16")), + "dotprod" => Some(std::arch::is_aarch64_feature_detected!("dotprod")), _ => None, }; diff --git a/crates/wasmtime/src/engine.rs b/crates/wasmtime/src/engine.rs index 0c4c087a4438..20394981581c 100644 --- a/crates/wasmtime/src/engine.rs +++ b/crates/wasmtime/src/engine.rs @@ -571,6 +571,7 @@ information about this check\ "has_lse" => "lse", "has_pauth" => "paca", "has_fp16" => "fp16", + "has_dotprod" => "dotprod", // aarch64 features which don't need detection // No effect on its own. diff --git a/tests/disas/aarch64-relaxed-simd-dotprod.wat b/tests/disas/aarch64-relaxed-simd-dotprod.wat new file mode 100644 index 000000000000..f0bc44e2d7a7 --- /dev/null +++ b/tests/disas/aarch64-relaxed-simd-dotprod.wat @@ -0,0 +1,22 @@ +;;! target = "aarch64" +;;! test = "compile" +;;! flags = "-C cranelift-has_dotprod=true" + +;; `i32x4.relaxed_dot_i8x16_i7x16_add_s` with FEAT_DotProd: the dot-product tree +;; lowers to a single `sdot`. +(module + (func (param v128 v128 v128) (result v128) + local.get 0 + local.get 1 + local.get 2 + i32x4.relaxed_dot_i8x16_i7x16_add_s + ) +) +;; wasm[0]::function[0]: +;; stp x29, x30, [sp, #-0x10]! +;; mov x29, sp +;; mov v6.16b, v0.16b +;; mov v0.16b, v2.16b +;; sdot v0.4s, v6.16b, v1.16b +;; ldp x29, x30, [sp], #0x10 +;; ret