From cd83e78dee6c8b0ade4b4650ceabe48af57aebdb Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Sun, 12 Apr 2026 11:13:47 +0300 Subject: [PATCH 1/3] feat(where): Add IL-generated SIMD optimization for np.where(condition, x, y) Add IL-generated kernels for np.where using runtime code generation: - Uses DynamicMethod to generate type-specific kernels at runtime - Vector256/Vector128.ConditionalSelect for SIMD element selection - 4x loop unrolling for better instruction-level parallelism - Full long indexing support for arrays > 2^31 elements - Supports all 12 dtypes (11 via SIMD, Decimal via scalar fallback) - Kernels cached per type for reuse Architecture: - WhereKernel delegate: (bool* cond, T* x, T* y, T* result, long count) - GetWhereKernel(): Returns cached IL-generated kernel - WhereExecute(): Main entry point with automatic fallback IL Generation: - 4x unrolled SIMD loop (processes 4 vectors per iteration) - Remainder SIMD loop (1 vector at a time) - Scalar tail loop for remaining elements - Mask creation methods by element size (1/2/4/8 bytes) - All arithmetic uses long types natively (no int-to-long casts) Falls back to iterator path for: - Non-contiguous/broadcasted arrays (stride=0) - Non-bool conditions (need truthiness conversion) Files: - src/NumSharp.Core/APIs/np.where.cs: Kernel dispatch logic - src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs: IL generation - test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs: 26 tests Closes #604 --- src/NumSharp.Core/APIs/np.where.cs | 199 ++++++ .../Kernels/ILKernelGenerator.Where.cs | 635 ++++++++++++++++++ .../Backends/Kernels/WhereSimdTests.cs | 515 ++++++++++++++ .../Logic/np.where.BattleTest.cs | 346 ++++++++++ test/NumSharp.UnitTest/Logic/np.where.Test.cs | 496 ++++++++++++++ 5 files changed, 2191 insertions(+) create mode 100644 src/NumSharp.Core/APIs/np.where.cs create mode 100644 src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs create mode 100644 test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs create mode 100644 test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs create mode 100644 test/NumSharp.UnitTest/Logic/np.where.Test.cs diff --git a/src/NumSharp.Core/APIs/np.where.cs b/src/NumSharp.Core/APIs/np.where.cs new file mode 100644 index 00000000..a361534a --- /dev/null +++ b/src/NumSharp.Core/APIs/np.where.cs @@ -0,0 +1,199 @@ +using System; +using NumSharp.Backends.Kernels; +using NumSharp.Generic; + +namespace NumSharp +{ + public static partial class np + { + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// + /// Where True, yield `x`, otherwise yield `y`. + /// Tuple of arrays with indices where condition is non-zero (equivalent to np.nonzero). + /// https://numpy.org/doc/stable/reference/generated/numpy.where.html + public static NDArray[] where(NDArray condition) + { + return nonzero(condition); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// + /// Where True, yield `x`, otherwise yield `y`. + /// Values from which to choose where condition is True. + /// Values from which to choose where condition is False. + /// An array with elements from `x` where `condition` is True, and elements from `y` elsewhere. + /// https://numpy.org/doc/stable/reference/generated/numpy.where.html + public static NDArray where(NDArray condition, NDArray x, NDArray y) + { + // Broadcast all three arrays to common shape + var broadcasted = broadcast_arrays(condition, x, y); + var cond = broadcasted[0]; + var xArr = broadcasted[1]; + var yArr = broadcasted[2]; + + // Determine output dtype from x and y (type promotion) + var outType = _FindCommonType(xArr, yArr); + // Use cond.shape (dimensions only) not cond.Shape (which may have broadcast strides) + var result = empty(cond.shape, outType); + + // Handle empty arrays - nothing to iterate + if (result.size == 0) + return result; + + // IL Kernel fast path: all arrays contiguous, bool condition, SIMD enabled + // Broadcasted arrays (stride=0) are NOT contiguous, so they use iterator path. + bool canUseKernel = ILKernelGenerator.Enabled && + cond.typecode == NPTypeCode.Boolean && + cond.Shape.IsContiguous && + xArr.Shape.IsContiguous && + yArr.Shape.IsContiguous; + + if (canUseKernel) + { + WhereKernelDispatch(cond, xArr, yArr, result, outType); + return result; + } + + // Iterator fallback for non-contiguous/broadcasted arrays + switch (outType) + { + case NPTypeCode.Boolean: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Byte: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Int16: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.UInt16: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Int32: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.UInt32: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Int64: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.UInt64: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Char: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Single: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Double: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Decimal: + WhereImpl(cond, xArr, yArr, result); + break; + default: + throw new NotSupportedException($"Type {outType} not supported for np.where"); + } + + return result; + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for x. + /// + public static NDArray where(NDArray condition, object x, NDArray y) + { + return where(condition, asanyarray(x), y); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for y. + /// + public static NDArray where(NDArray condition, NDArray x, object y) + { + return where(condition, x, asanyarray(y)); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for both x and y. + /// + public static NDArray where(NDArray condition, object x, object y) + { + return where(condition, asanyarray(x), asanyarray(y)); + } + + private static void WhereImpl(NDArray cond, NDArray x, NDArray y, NDArray result) where T : unmanaged + { + // Use iterators for proper handling of broadcasted/strided arrays + using var condIter = cond.AsIterator(); + using var xIter = x.AsIterator(); + using var yIter = y.AsIterator(); + using var resultIter = result.AsIterator(); + + while (condIter.HasNext()) + { + var c = condIter.MoveNext(); + var xVal = xIter.MoveNext(); + var yVal = yIter.MoveNext(); + resultIter.MoveNextReference() = c ? xVal : yVal; + } + } + + /// + /// IL Kernel dispatch for contiguous arrays. + /// Uses IL-generated kernels with SIMD optimization. + /// + private static unsafe void WhereKernelDispatch(NDArray cond, NDArray x, NDArray y, NDArray result, NPTypeCode outType) + { + var condPtr = (bool*)cond.Address; + var count = result.size; + + switch (outType) + { + case NPTypeCode.Boolean: + ILKernelGenerator.WhereExecute(condPtr, (bool*)x.Address, (bool*)y.Address, (bool*)result.Address, count); + break; + case NPTypeCode.Byte: + ILKernelGenerator.WhereExecute(condPtr, (byte*)x.Address, (byte*)y.Address, (byte*)result.Address, count); + break; + case NPTypeCode.Int16: + ILKernelGenerator.WhereExecute(condPtr, (short*)x.Address, (short*)y.Address, (short*)result.Address, count); + break; + case NPTypeCode.UInt16: + ILKernelGenerator.WhereExecute(condPtr, (ushort*)x.Address, (ushort*)y.Address, (ushort*)result.Address, count); + break; + case NPTypeCode.Int32: + ILKernelGenerator.WhereExecute(condPtr, (int*)x.Address, (int*)y.Address, (int*)result.Address, count); + break; + case NPTypeCode.UInt32: + ILKernelGenerator.WhereExecute(condPtr, (uint*)x.Address, (uint*)y.Address, (uint*)result.Address, count); + break; + case NPTypeCode.Int64: + ILKernelGenerator.WhereExecute(condPtr, (long*)x.Address, (long*)y.Address, (long*)result.Address, count); + break; + case NPTypeCode.UInt64: + ILKernelGenerator.WhereExecute(condPtr, (ulong*)x.Address, (ulong*)y.Address, (ulong*)result.Address, count); + break; + case NPTypeCode.Char: + ILKernelGenerator.WhereExecute(condPtr, (char*)x.Address, (char*)y.Address, (char*)result.Address, count); + break; + case NPTypeCode.Single: + ILKernelGenerator.WhereExecute(condPtr, (float*)x.Address, (float*)y.Address, (float*)result.Address, count); + break; + case NPTypeCode.Double: + ILKernelGenerator.WhereExecute(condPtr, (double*)x.Address, (double*)y.Address, (double*)result.Address, count); + break; + case NPTypeCode.Decimal: + ILKernelGenerator.WhereExecute(condPtr, (decimal*)x.Address, (decimal*)y.Address, (decimal*)result.Address, count); + break; + } + } + } +} diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs new file mode 100644 index 00000000..e055bd8a --- /dev/null +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs @@ -0,0 +1,635 @@ +using System; +using System.Collections.Concurrent; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; + +// ============================================================================= +// ILKernelGenerator.Where - IL-generated np.where(condition, x, y) kernels +// ============================================================================= +// +// RESPONSIBILITY: +// - Generate optimized kernels for conditional selection +// - result[i] = cond[i] ? x[i] : y[i] +// +// ARCHITECTURE: +// Uses IL emission to generate type-specific kernels at runtime. +// The challenge is bool mask expansion: condition is bool[] (1 byte per element), +// but x/y can be any dtype (1-8 bytes per element). +// +// | Element Size | V256 Elements | Bools to Load | +// |--------------|---------------|---------------| +// | 1 byte | 32 | 32 | +// | 2 bytes | 16 | 16 | +// | 4 bytes | 8 | 8 | +// | 8 bytes | 4 | 4 | +// +// KERNEL TYPES: +// - WhereKernel: Main kernel delegate (cond*, x*, y*, result*, count) +// +// ============================================================================= + +namespace NumSharp.Backends.Kernels +{ + /// + /// Delegate for where operation kernels. + /// + public unsafe delegate void WhereKernel(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged; + + public static partial class ILKernelGenerator + { + /// + /// Cache of IL-generated where kernels. + /// Key: Type + /// + private static readonly ConcurrentDictionary _whereKernelCache = new(); + + #region Public API + + /// + /// Get or generate an IL-based where kernel for the specified type. + /// Returns null if IL generation is disabled or fails. + /// + public static WhereKernel? GetWhereKernel() where T : unmanaged + { + if (!Enabled) + return null; + + var type = typeof(T); + + if (_whereKernelCache.TryGetValue(type, out var cached)) + return (WhereKernel)cached; + + var kernel = TryGenerateWhereKernel(); + if (kernel == null) + return null; + + if (_whereKernelCache.TryAdd(type, kernel)) + return kernel; + + return (WhereKernel)_whereKernelCache[type]; + } + + /// + /// Execute where operation using IL-generated kernel or fallback to static helper. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void WhereExecute(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged + { + if (count == 0) + return; + + var kernel = GetWhereKernel(); + if (kernel != null) + { + kernel(cond, x, y, result, count); + } + else + { + // Fallback to scalar loop + WhereScalar(cond, x, y, result, count); + } + } + + #endregion + + #region Kernel Generation + + private static WhereKernel? TryGenerateWhereKernel() where T : unmanaged + { + try + { + return GenerateWhereKernelIL(); + } + catch (Exception ex) + { + System.Diagnostics.Debug.WriteLine($"[ILKernel] TryGenerateWhereKernel<{typeof(T).Name}>: {ex.GetType().Name}: {ex.Message}"); + return null; + } + } + + private static unsafe WhereKernel GenerateWhereKernelIL() where T : unmanaged + { + int elementSize = Unsafe.SizeOf(); + + // Determine if we can use SIMD + bool canSimd = elementSize <= 8 && IsSimdSupported(); + + var dm = new DynamicMethod( + name: $"IL_Where_{typeof(T).Name}", + returnType: typeof(void), + parameterTypes: new[] { typeof(bool*), typeof(T*), typeof(T*), typeof(T*), typeof(long) }, + owner: typeof(ILKernelGenerator), + skipVisibility: true + ); + + var il = dm.GetILGenerator(); + + // Locals + var locI = il.DeclareLocal(typeof(long)); // loop counter + + // Labels + var lblScalarLoop = il.DefineLabel(); + var lblScalarLoopEnd = il.DefineLabel(); + + // i = 0 + il.Emit(OpCodes.Ldc_I8, 0L); + il.Emit(OpCodes.Stloc, locI); + + if (canSimd && VectorBits >= 128) + { + // Generate SIMD path + EmitWhereSIMDLoop(il, locI); + } + + // Scalar loop for remainder + il.MarkLabel(lblScalarLoop); + + // if (i >= count) goto end + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldarg, 4); // count + il.Emit(OpCodes.Bge, lblScalarLoopEnd); + + // result[i] = cond[i] ? x[i] : y[i] + EmitWhereScalarElement(il, locI); + + // i++ + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, 1L); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Stloc, locI); + + il.Emit(OpCodes.Br, lblScalarLoop); + + il.MarkLabel(lblScalarLoopEnd); + il.Emit(OpCodes.Ret); + + return (WhereKernel)dm.CreateDelegate(typeof(WhereKernel)); + } + + private static void EmitWhereSIMDLoop(ILGenerator il, LocalBuilder locI) where T : unmanaged + { + long elementSize = Unsafe.SizeOf(); + long vectorCount = VectorBits >= 256 ? (32 / elementSize) : (16 / elementSize); + long unrollFactor = 4; + long unrollStep = vectorCount * unrollFactor; + bool useV256 = VectorBits >= 256; + + var locUnrollEnd = il.DeclareLocal(typeof(long)); + var locVectorEnd = il.DeclareLocal(typeof(long)); + + var lblUnrollLoop = il.DefineLabel(); + var lblUnrollLoopEnd = il.DefineLabel(); + var lblVectorLoop = il.DefineLabel(); + var lblVectorLoopEnd = il.DefineLabel(); + + // unrollEnd = count - unrollStep (for 4x unrolled loop) + il.Emit(OpCodes.Ldarg, 4); // count + il.Emit(OpCodes.Ldc_I8, unrollStep); + il.Emit(OpCodes.Sub); + il.Emit(OpCodes.Stloc, locUnrollEnd); + + // vectorEnd = count - vectorCount (for remainder loop) + il.Emit(OpCodes.Ldarg, 4); // count + il.Emit(OpCodes.Ldc_I8, vectorCount); + il.Emit(OpCodes.Sub); + il.Emit(OpCodes.Stloc, locVectorEnd); + + // ========== 4x UNROLLED SIMD LOOP ========== + il.MarkLabel(lblUnrollLoop); + + // if (i > unrollEnd) goto UnrollLoopEnd + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldloc, locUnrollEnd); + il.Emit(OpCodes.Bgt, lblUnrollLoopEnd); + + // Process 4 vectors per iteration + for (long u = 0; u < unrollFactor; u++) + { + long offset = vectorCount * u; + if (useV256) + EmitWhereV256BodyWithOffset(il, locI, elementSize, offset); + else + EmitWhereV128BodyWithOffset(il, locI, elementSize, offset); + } + + // i += unrollStep + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, unrollStep); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Stloc, locI); + + il.Emit(OpCodes.Br, lblUnrollLoop); + + il.MarkLabel(lblUnrollLoopEnd); + + // ========== REMAINDER SIMD LOOP (1 vector at a time) ========== + il.MarkLabel(lblVectorLoop); + + // if (i > vectorEnd) goto VectorLoopEnd + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldloc, locVectorEnd); + il.Emit(OpCodes.Bgt, lblVectorLoopEnd); + + // Process 1 vector + if (useV256) + EmitWhereV256BodyWithOffset(il, locI, elementSize, 0L); + else + EmitWhereV128BodyWithOffset(il, locI, elementSize, 0L); + + // i += vectorCount + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, vectorCount); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Stloc, locI); + + il.Emit(OpCodes.Br, lblVectorLoop); + + il.MarkLabel(lblVectorLoopEnd); + } + + private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged + { + // Get the appropriate mask creation method based on element size + var maskMethod = GetMaskCreationMethod256((int)elementSize); + var loadMethod = typeof(Vector256).GetMethod("Load", new[] { typeof(T*) })!.MakeGenericMethod(typeof(T)); + var storeMethod = typeof(Vector256).GetMethod("Store", new[] { typeof(Vector256<>).MakeGenericType(typeof(T)), typeof(T*) })!; + var selectMethod = typeof(Vector256).GetMethod("ConditionalSelect", new[] { + typeof(Vector256<>).MakeGenericType(typeof(T)), + typeof(Vector256<>).MakeGenericType(typeof(T)), + typeof(Vector256<>).MakeGenericType(typeof(T)) + })!; + + // Load address: cond + (i + offset) + il.Emit(OpCodes.Ldarg_0); // cond + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + + // Call mask creation: returns Vector256 on stack + il.Emit(OpCodes.Call, maskMethod); + + // Load x vector: x + (i + offset) * elementSize + il.Emit(OpCodes.Ldarg_1); // x + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // Load y vector: y + (i + offset) * elementSize + il.Emit(OpCodes.Ldarg_2); // y + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // Stack: mask, xVec, yVec + // ConditionalSelect(mask, x, y) + il.Emit(OpCodes.Call, selectMethod); + + // Store result: result + (i + offset) * elementSize + il.Emit(OpCodes.Ldarg_3); // result + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, storeMethod); + } + + private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged + { + var maskMethod = GetMaskCreationMethod128((int)elementSize); + var loadMethod = typeof(Vector128).GetMethod("Load", new[] { typeof(T*) })!.MakeGenericMethod(typeof(T)); + var storeMethod = typeof(Vector128).GetMethod("Store", new[] { typeof(Vector128<>).MakeGenericType(typeof(T)), typeof(T*) })!; + var selectMethod = typeof(Vector128).GetMethod("ConditionalSelect", new[] { + typeof(Vector128<>).MakeGenericType(typeof(T)), + typeof(Vector128<>).MakeGenericType(typeof(T)), + typeof(Vector128<>).MakeGenericType(typeof(T)) + })!; + + // Load address: cond + (i + offset) + il.Emit(OpCodes.Ldarg_0); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, maskMethod); + + // Load x vector + il.Emit(OpCodes.Ldarg_1); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // Load y vector + il.Emit(OpCodes.Ldarg_2); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // ConditionalSelect + il.Emit(OpCodes.Call, selectMethod); + + // Store + il.Emit(OpCodes.Ldarg_3); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, storeMethod); + } + + private static void EmitWhereScalarElement(ILGenerator il, LocalBuilder locI) where T : unmanaged + { + long elementSize = Unsafe.SizeOf(); + var typeCode = GetNPTypeCode(); + + // result[i] = cond[i] ? x[i] : y[i] + var lblFalse = il.DefineLabel(); + var lblEnd = il.DefineLabel(); + + // Load result address: result + i * elementSize + il.Emit(OpCodes.Ldarg_3); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + + // Load cond[i]: cond + i (bool is 1 byte) + il.Emit(OpCodes.Ldarg_0); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Ldind_U1); // Load bool as byte + + // if (!cond[i]) goto lblFalse + il.Emit(OpCodes.Brfalse, lblFalse); + + // True branch: load x[i] + il.Emit(OpCodes.Ldarg_1); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + EmitLoadIndirect(il, typeCode); + il.Emit(OpCodes.Br, lblEnd); + + // False branch: load y[i] + il.MarkLabel(lblFalse); + il.Emit(OpCodes.Ldarg_2); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + EmitLoadIndirect(il, typeCode); + + il.MarkLabel(lblEnd); + // Stack: result_ptr, value + EmitStoreIndirect(il, typeCode); + } + + private static NPTypeCode GetNPTypeCode() where T : unmanaged + { + if (typeof(T) == typeof(bool)) return NPTypeCode.Boolean; + if (typeof(T) == typeof(byte)) return NPTypeCode.Byte; + if (typeof(T) == typeof(short)) return NPTypeCode.Int16; + if (typeof(T) == typeof(ushort)) return NPTypeCode.UInt16; + if (typeof(T) == typeof(int)) return NPTypeCode.Int32; + if (typeof(T) == typeof(uint)) return NPTypeCode.UInt32; + if (typeof(T) == typeof(long)) return NPTypeCode.Int64; + if (typeof(T) == typeof(ulong)) return NPTypeCode.UInt64; + if (typeof(T) == typeof(char)) return NPTypeCode.Char; + if (typeof(T) == typeof(float)) return NPTypeCode.Single; + if (typeof(T) == typeof(double)) return NPTypeCode.Double; + if (typeof(T) == typeof(decimal)) return NPTypeCode.Decimal; + return NPTypeCode.Empty; + } + + #endregion + + #region Mask Creation Methods + + private static MethodInfo GetMaskCreationMethod256(int elementSize) + { + return elementSize switch + { + 1 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_1Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 2 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_2Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 4 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_4Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 8 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_8Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + _ => throw new NotSupportedException($"Element size {elementSize} not supported for SIMD where") + }; + } + + private static MethodInfo GetMaskCreationMethod128(int elementSize) + { + return elementSize switch + { + 1 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_1Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 2 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_2Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 4 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_4Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + 8 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_8Byte), BindingFlags.NonPublic | BindingFlags.Static)!, + _ => throw new NotSupportedException($"Element size {elementSize} not supported for SIMD where") + }; + } + + /// + /// Create V256 mask from 32 bools for 1-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector256 CreateMaskV256_1Byte(byte* bools) + { + var vec = Vector256.Load(bools); + var zero = Vector256.Zero; + var isZero = Vector256.Equals(vec, zero); + return Vector256.OnesComplement(isZero); + } + + /// + /// Create V256 mask from 16 bools for 2-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector256 CreateMaskV256_2Byte(byte* bools) + { + return Vector256.Create( + bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[2] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[3] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[4] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[5] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[6] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[7] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[8] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[9] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[10] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[11] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[12] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[13] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[14] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[15] != 0 ? (ushort)0xFFFF : (ushort)0 + ); + } + + /// + /// Create V256 mask from 8 bools for 4-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector256 CreateMaskV256_4Byte(byte* bools) + { + return Vector256.Create( + bools[0] != 0 ? 0xFFFFFFFFu : 0u, + bools[1] != 0 ? 0xFFFFFFFFu : 0u, + bools[2] != 0 ? 0xFFFFFFFFu : 0u, + bools[3] != 0 ? 0xFFFFFFFFu : 0u, + bools[4] != 0 ? 0xFFFFFFFFu : 0u, + bools[5] != 0 ? 0xFFFFFFFFu : 0u, + bools[6] != 0 ? 0xFFFFFFFFu : 0u, + bools[7] != 0 ? 0xFFFFFFFFu : 0u + ); + } + + /// + /// Create V256 mask from 4 bools for 8-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector256 CreateMaskV256_8Byte(byte* bools) + { + return Vector256.Create( + bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, + bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, + bools[2] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, + bools[3] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul + ); + } + + /// + /// Create V128 mask from 16 bools for 1-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 CreateMaskV128_1Byte(byte* bools) + { + var vec = Vector128.Load(bools); + var zero = Vector128.Zero; + var isZero = Vector128.Equals(vec, zero); + return Vector128.OnesComplement(isZero); + } + + /// + /// Create V128 mask from 8 bools for 2-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 CreateMaskV128_2Byte(byte* bools) + { + return Vector128.Create( + bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[2] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[3] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[4] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[5] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[6] != 0 ? (ushort)0xFFFF : (ushort)0, + bools[7] != 0 ? (ushort)0xFFFF : (ushort)0 + ); + } + + /// + /// Create V128 mask from 4 bools for 4-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 CreateMaskV128_4Byte(byte* bools) + { + return Vector128.Create( + bools[0] != 0 ? 0xFFFFFFFFu : 0u, + bools[1] != 0 ? 0xFFFFFFFFu : 0u, + bools[2] != 0 ? 0xFFFFFFFFu : 0u, + bools[3] != 0 ? 0xFFFFFFFFu : 0u + ); + } + + /// + /// Create V128 mask from 2 bools for 8-byte elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 CreateMaskV128_8Byte(byte* bools) + { + return Vector128.Create( + bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, + bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul + ); + } + + #endregion + + #region Scalar Fallback + + /// + /// Scalar fallback for where operation. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe void WhereScalar(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged + { + for (long i = 0; i < count; i++) + { + result[i] = cond[i] ? x[i] : y[i]; + } + } + + #endregion + } +} diff --git a/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs b/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs new file mode 100644 index 00000000..3fc30d17 --- /dev/null +++ b/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs @@ -0,0 +1,515 @@ +using System; +using System.Diagnostics; +using NumSharp.Backends.Kernels; +using TUnit.Core; +using NumSharp.UnitTest.Utilities; +using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; + +namespace NumSharp.UnitTest.Backends.Kernels +{ + /// + /// Tests for SIMD-optimized np.where implementation. + /// Verifies correctness of the SIMD path for all supported dtypes. + /// + public class WhereSimdTests + { + #region SIMD Correctness + + [Test] + public void Where_Simd_Float32_Correctness() + { + var rng = np.random.RandomState(42); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.rand(size).astype(NPTypeCode.Single); + var y = rng.rand(size).astype(NPTypeCode.Single); + + var result = np.where(cond, x, y); + + // Verify correctness manually + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (float)x[i] : (float)y[i]; + Assert.AreEqual(expected, (float)result[i], 1e-6f, $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Float64_Correctness() + { + var rng = np.random.RandomState(43); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.rand(size); + var y = rng.rand(size); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (double)x[i] : (double)y[i]; + Assert.AreEqual(expected, (double)result[i], 1e-10, $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Int32_Correctness() + { + var rng = np.random.RandomState(44); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.randint(0, 1000, new[] { size }); + var y = rng.randint(0, 1000, new[] { size }); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (int)x[i] : (int)y[i]; + Assert.AreEqual(expected, (int)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Int64_Correctness() + { + var rng = np.random.RandomState(45); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.Int64); + var y = np.arange(size, size * 2).astype(NPTypeCode.Int64); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (long)x[i] : (long)y[i]; + Assert.AreEqual(expected, (long)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Byte_Correctness() + { + var rng = np.random.RandomState(46); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = (rng.rand(size) * 255).astype(NPTypeCode.Byte); + var y = (rng.rand(size) * 255).astype(NPTypeCode.Byte); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (byte)x[i] : (byte)y[i]; + Assert.AreEqual(expected, (byte)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Int16_Correctness() + { + var rng = np.random.RandomState(47); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.Int16); + var y = np.arange(size, size * 2).astype(NPTypeCode.Int16); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (short)x[i] : (short)y[i]; + Assert.AreEqual(expected, (short)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_UInt16_Correctness() + { + var rng = np.random.RandomState(48); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.UInt16); + var y = np.arange(size, size * 2).astype(NPTypeCode.UInt16); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (ushort)x[i] : (ushort)y[i]; + Assert.AreEqual(expected, (ushort)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_UInt32_Correctness() + { + var rng = np.random.RandomState(49); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.UInt32); + var y = np.arange(size, size * 2).astype(NPTypeCode.UInt32); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (uint)x[i] : (uint)y[i]; + Assert.AreEqual(expected, (uint)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_UInt64_Correctness() + { + var rng = np.random.RandomState(50); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.UInt64); + var y = np.arange(size, size * 2).astype(NPTypeCode.UInt64); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (ulong)x[i] : (ulong)y[i]; + Assert.AreEqual(expected, (ulong)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Boolean_Correctness() + { + var rng = np.random.RandomState(51); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.rand(size) > 0.3; + var y = rng.rand(size) > 0.7; + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (bool)x[i] : (bool)y[i]; + Assert.AreEqual(expected, (bool)result[i], $"Mismatch at index {i}"); + } + } + + [Test] + public void Where_Simd_Char_Correctness() + { + var rng = np.random.RandomState(52); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var xData = new char[size]; + var yData = new char[size]; + for (int i = 0; i < size; i++) + { + xData[i] = (char)('A' + (i % 26)); + yData[i] = (char)('a' + (i % 26)); + } + var x = np.array(xData); + var y = np.array(yData); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (char)x[i] : (char)y[i]; + Assert.AreEqual(expected, (char)result[i], $"Mismatch at index {i}"); + } + } + + #endregion + + #region Path Selection + + [Test] + public void Where_NonContiguous_Works() + { + // Sliced arrays are non-contiguous, should work correctly + var baseArr = np.arange(20); + var cond = (baseArr % 2 == 0)["::2"]; // Sliced: [true, true, true, true, true, true, true, true, true, true] + var x = np.ones(10, NPTypeCode.Int32); + var y = np.zeros(10, NPTypeCode.Int32); + + var result = np.where(cond, x, y); + + Assert.AreEqual(10, result.size); + // All true -> all from x + for (int i = 0; i < 10; i++) + { + Assert.AreEqual(1, (int)result[i]); + } + } + + [Test] + public void Where_Broadcast_Works() + { + // Broadcasted arrays + // cond shape (3,) broadcasts to (3,3): [[T,F,T],[T,F,T],[T,F,T]] + // x shape (3,1) broadcasts to (3,3): [[1,1,1],[2,2,2],[3,3,3]] + // y shape (1,3) broadcasts to (3,3): [[10,20,30],[10,20,30],[10,20,30]] + var cond = np.array(new[] { true, false, true }); + var x = np.array(new int[,] { { 1 }, { 2 }, { 3 } }); + var y = np.array(new int[,] { { 10, 20, 30 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(3, 3); + // Verify values: result[i,j] = cond[j] ? x[i,0] : y[0,j] + Assert.AreEqual(1, (int)result[0, 0]); // cond[0]=true -> x=1 + Assert.AreEqual(20, (int)result[0, 1]); // cond[1]=false -> y=20 + Assert.AreEqual(1, (int)result[0, 2]); // cond[2]=true -> x=1 + Assert.AreEqual(2, (int)result[1, 0]); // cond[0]=true -> x=2 + Assert.AreEqual(20, (int)result[1, 1]); // cond[1]=false -> y=20 + } + + [Test] + public void Where_Decimal_Works() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new decimal[] { 1.1m, 2.2m, 3.3m }); + var y = np.array(new decimal[] { 10.1m, 20.2m, 30.3m }); + + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(decimal), result.dtype); + Assert.AreEqual(1.1m, (decimal)result[0]); + Assert.AreEqual(20.2m, (decimal)result[1]); + Assert.AreEqual(3.3m, (decimal)result[2]); + } + + [Test] + public void Where_NonBoolCondition_Works() + { + // Non-bool condition requires truthiness check + var cond = np.array(new[] { 0, 1, 2, 0 }); // int condition + var result = np.where(cond, 100, -100); + + result.Should().BeOfValues(-100, 100, 100, -100); + } + + #endregion + + #region Edge Cases + + [Test] + public void Where_Simd_SmallArray() + { + // Array smaller than vector width + var cond = np.array(new[] { true, false, true }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + + var result = np.where(cond, x, y); + + result.Should().BeOfValues(1, 20, 3); + } + + [Test] + public void Where_Simd_VectorAlignedSize() + { + var rng = np.random.RandomState(53); + // Size exactly matches vector width (no scalar tail) + var size = 32; // V256 byte count + var cond = rng.rand(size) > 0.5; + var x = np.ones(size, NPTypeCode.Byte); + var y = np.zeros(size, NPTypeCode.Byte); + + var result = np.where(cond, x, y); + + Assert.AreEqual(size, result.size); + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (byte)1 : (byte)0; + Assert.AreEqual(expected, (byte)result[i]); + } + } + + [Test] + public void Where_Simd_WithScalarTail() + { + // Size that requires scalar tail processing + var size = 35; // 32 + 3 tail for bytes + var cond = np.ones(size, NPTypeCode.Boolean); + var x = np.full(size, (byte)255); + var y = np.zeros(size, NPTypeCode.Byte); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual((byte)255, (byte)result[i], $"Mismatch at {i}"); + } + } + + [Test] + public void Where_Simd_AllTrue() + { + var size = 100; + var cond = np.ones(size, NPTypeCode.Boolean); + var x = np.arange(size); + var y = np.full(size, -1L); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual((long)i, (long)result[i]); + } + } + + [Test] + public void Where_Simd_AllFalse() + { + var size = 100; + var cond = np.zeros(size, NPTypeCode.Boolean); + var x = np.arange(size); + var y = np.full(size, -1L); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual(-1L, (long)result[i]); + } + } + + [Test] + public void Where_Simd_Alternating() + { + var size = 100; + var condData = new bool[size]; + for (int i = 0; i < size; i++) + condData[i] = i % 2 == 0; + var cond = np.array(condData); + var x = np.ones(size, NPTypeCode.Int32); + var y = np.zeros(size, NPTypeCode.Int32); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual(i % 2 == 0 ? 1 : 0, (int)result[i], $"Mismatch at {i}"); + } + } + + [Test] + public void Where_Simd_NaN_Propagates() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new[] { double.NaN, 1.0, 2.0 }); + var y = np.array(new[] { 0.0, double.NaN, 0.0 }); + + var result = np.where(cond, x, y); + + Assert.IsTrue(double.IsNaN((double)result[0])); // NaN from x + Assert.IsTrue(double.IsNaN((double)result[1])); // NaN from y + Assert.AreEqual(2.0, (double)result[2], 1e-10); + } + + [Test] + public void Where_Simd_Infinity() + { + var cond = np.array(new[] { true, false, true, false }); + var x = np.array(new[] { double.PositiveInfinity, 0.0, double.NegativeInfinity, 0.0 }); + var y = np.array(new[] { 0.0, double.PositiveInfinity, 0.0, double.NegativeInfinity }); + + var result = np.where(cond, x, y); + + Assert.AreEqual(double.PositiveInfinity, (double)result[0]); + Assert.AreEqual(double.PositiveInfinity, (double)result[1]); + Assert.AreEqual(double.NegativeInfinity, (double)result[2]); + Assert.AreEqual(double.NegativeInfinity, (double)result[3]); + } + + #endregion + + #region Performance Sanity Check + + [Test] + public void Where_Simd_LargeArray_Correctness() + { + var rng = np.random.RandomState(54); + var size = 100_000; + var cond = rng.rand(size) > 0.5; + var x = np.ones(size, NPTypeCode.Double); + var y = np.zeros(size, NPTypeCode.Double); + + var result = np.where(cond, x, y); + + // Spot check + for (int i = 0; i < 100; i++) + { + var expected = (bool)cond[i] ? 1.0 : 0.0; + Assert.AreEqual(expected, (double)result[i], 1e-10); + } + + // Check last few elements (scalar tail) + for (int i = size - 10; i < size; i++) + { + var expected = (bool)cond[i] ? 1.0 : 0.0; + Assert.AreEqual(expected, (double)result[i], 1e-10); + } + } + + #endregion + + #region 2D/Multi-dimensional + + [Test] + public void Where_Simd_2D_Contiguous() + { + var rng = np.random.RandomState(55); + // 2D contiguous array should use SIMD + var shape = new[] { 100, 100 }; + var cond = rng.rand(shape) > 0.5; + var x = np.ones(shape, NPTypeCode.Int32); + var y = np.zeros(shape, NPTypeCode.Int32); + + var result = np.where(cond, x, y); + + result.Should().BeShaped(100, 100); + + // Spot check + for (int i = 0; i < 10; i++) + { + for (int j = 0; j < 10; j++) + { + var expected = (bool)cond[i, j] ? 1 : 0; + Assert.AreEqual(expected, (int)result[i, j]); + } + } + } + + [Test] + public void Where_Simd_3D_Contiguous() + { + var rng = np.random.RandomState(56); + var shape = new[] { 10, 20, 30 }; + var cond = rng.rand(shape) > 0.5; + var x = np.ones(shape, NPTypeCode.Single); + var y = np.zeros(shape, NPTypeCode.Single); + + var result = np.where(cond, x, y); + + result.Should().BeShaped(10, 20, 30); + + // Spot check + for (int i = 0; i < 5; i++) + { + for (int j = 0; j < 5; j++) + { + for (int k = 0; k < 5; k++) + { + var expected = (bool)cond[i, j, k] ? 1.0f : 0.0f; + Assert.AreEqual(expected, (float)result[i, j, k], 1e-6f); + } + } + } + } + + #endregion + } +} diff --git a/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs new file mode 100644 index 00000000..5834d16a --- /dev/null +++ b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs @@ -0,0 +1,346 @@ +using System; +using System.Linq; +using TUnit.Core; +using NumSharp.UnitTest.Utilities; +using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; + +namespace NumSharp.UnitTest.Logic +{ + /// + /// Battle tests for np.where - edge cases, strided arrays, views, etc. + /// + public class np_where_BattleTest + { + #region Strided/Sliced Arrays + + [Test] + public void Where_SlicedCondition() + { + // Sliced condition array + var arr = np.arange(10); + var cond = (arr % 2 == 0)["::2"]; // Every other even check + var x = np.ones(5, NPTypeCode.Int32); + var y = np.zeros(5, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + // Should work with sliced condition + Assert.AreEqual(5, result.size); + } + + [Test] + public void Where_SlicedXY() + { + var cond = np.array(new[] { true, false, true }); + var x = np.arange(6)["::2"]; // [0, 2, 4] + var y = np.arange(6)["1::2"]; // [1, 3, 5] + var result = np.where(cond, x, y); + + result.Should().BeOfValues(0L, 3L, 4L); + } + + [Test] + public void Where_TransposedArrays() + { + var cond = np.array(new bool[,] { { true, false }, { false, true } }).T; + var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } }).T; + var y = np.array(new int[,] { { 10, 20 }, { 30, 40 } }).T; + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 2); + // After transpose: cond[0,0]=T, cond[0,1]=F, cond[1,0]=F, cond[1,1]=T + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(30, (int)result[0, 1]); + Assert.AreEqual(20, (int)result[1, 0]); + Assert.AreEqual(4, (int)result[1, 1]); + } + + [Test] + public void Where_ReversedSlice() + { + var cond = np.array(new[] { true, false, true, false, true }); + var x = np.arange(5)["::-1"]; // [4, 3, 2, 1, 0] + var y = np.zeros(5, NPTypeCode.Int64); + var result = np.where(cond, x, y); + + result.Should().BeOfValues(4L, 0L, 2L, 0L, 0L); + } + + #endregion + + #region Complex Broadcasting + + [Test] + public void Where_3Way_Broadcasting() + { + // cond: (2,1,1), x: (1,3,1), y: (1,1,4) -> result: (2,3,4) + var cond = np.array(new bool[,,] { { { true } }, { { false } } }); + var x = np.arange(3).reshape(1, 3, 1); + var y = (np.arange(4) * 10).reshape(1, 1, 4); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 3, 4); + // First "page" (cond=True): values from x broadcast + Assert.AreEqual(0, (long)result[0, 0, 0]); + Assert.AreEqual(0, (long)result[0, 0, 3]); + Assert.AreEqual(2, (long)result[0, 2, 0]); + // Second "page" (cond=False): values from y broadcast + Assert.AreEqual(0, (long)result[1, 0, 0]); + Assert.AreEqual(30, (long)result[1, 0, 3]); + Assert.AreEqual(30, (long)result[1, 2, 3]); + } + + [Test] + public void Where_RowVector_ColVector_Broadcast() + { + // cond: (1,4), x: (3,1), y: scalar -> result: (3,4) + var cond = np.array(new bool[,] { { true, false, true, false } }); + var x = np.array(new int[,] { { 1 }, { 2 }, { 3 } }); + var result = np.where(cond, x, 0); + + result.Should().BeShaped(3, 4); + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(0, (int)result[0, 1]); + Assert.AreEqual(2, (int)result[1, 0]); + Assert.AreEqual(0, (int)result[1, 1]); + } + + #endregion + + #region Numeric Edge Cases + + [Test] + public void Where_NaN_Values() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new[] { double.NaN, 1.0, double.NaN }); + var y = np.array(new[] { 0.0, double.NaN, 0.0 }); + var result = np.where(cond, x, y); + + Assert.IsTrue(double.IsNaN((double)result[0])); // from x + Assert.IsTrue(double.IsNaN((double)result[1])); // from y + Assert.IsTrue(double.IsNaN((double)result[2])); // from x + } + + [Test] + public void Where_Infinity_Values() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { double.PositiveInfinity, 1.0 }); + var y = np.array(new[] { 0.0, double.NegativeInfinity }); + var result = np.where(cond, x, y); + + Assert.AreEqual(double.PositiveInfinity, (double)result[0]); + Assert.AreEqual(double.NegativeInfinity, (double)result[1]); + } + + [Test] + public void Where_MaxMin_Values() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { long.MaxValue, 0L }); + var y = np.array(new[] { 0L, long.MinValue }); + var result = np.where(cond, x, y); + + Assert.AreEqual(long.MaxValue, (long)result[0]); + Assert.AreEqual(long.MinValue, (long)result[1]); + } + + #endregion + + #region Single Arg Edge Cases + + [Test] + public void Where_SingleArg_Float_Truthy() + { + // 0.0 is falsy, anything else (including -0.0, NaN, Inf) is truthy + var arr = np.array(new[] { 0.0, 1.0, -1.0, 0.5, -0.0 }); + var result = np.where(arr); + + // Note: -0.0 == 0.0 in IEEE 754, so it's falsy + result[0].Should().BeOfValues(1L, 2L, 3L); + } + + [Test] + public void Where_SingleArg_NaN_IsTruthy() + { + // NaN is non-zero, so it's truthy + var arr = np.array(new[] { 0.0, double.NaN, 0.0 }); + var result = np.where(arr); + + result[0].Should().BeOfValues(1L); + } + + [Test] + public void Where_SingleArg_4D() + { + var arr = np.zeros(new[] { 2, 2, 2, 2 }, NPTypeCode.Int32); + arr[0, 1, 0, 1] = 1; + arr[1, 0, 1, 0] = 1; + var result = np.where(arr); + + Assert.AreEqual(4, result.Length); // 4 dimensions + Assert.AreEqual(2, result[0].size); // 2 non-zero elements + } + + #endregion + + #region Performance/Stress Tests + + [Test] + public void Where_LargeArray_Performance() + { + var size = 1_000_000; + var cond = np.random.rand(size) > 0.5; + var x = np.ones(size, NPTypeCode.Double); + var y = np.zeros(size, NPTypeCode.Double); + + var sw = System.Diagnostics.Stopwatch.StartNew(); + var result = np.where(cond, x, y); + sw.Stop(); + + Assert.AreEqual(size, result.size); + // Should complete in reasonable time (< 1 second for 1M elements) + Assert.IsTrue(sw.ElapsedMilliseconds < 1000, $"Took {sw.ElapsedMilliseconds}ms"); + } + + [Test] + public void Where_ManyDimensions() + { + // 6D array + var shape = new[] { 2, 3, 2, 2, 2, 3 }; + var cond = np.ones(shape, NPTypeCode.Boolean); + var x = np.ones(shape, NPTypeCode.Int32); + var y = np.zeros(shape, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 3, 2, 2, 2, 3); + Assert.AreEqual(1, (int)result[0, 0, 0, 0, 0, 0]); + } + + #endregion + + #region Type Conversion Edge Cases + + [Test] + public void Where_UnsignedOverflow() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new byte[] { 255, 0 }); + var y = np.array(new byte[] { 0, 255 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(byte), result.dtype); + Assert.AreEqual((byte)255, (byte)result[0]); + Assert.AreEqual((byte)255, (byte)result[1]); + } + + [Test] + public void Where_Decimal() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new decimal[] { 1.23456789m, 0m }); + var y = np.array(new decimal[] { 0m, 9.87654321m }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(decimal), result.dtype); + Assert.AreEqual(1.23456789m, (decimal)result[0]); + Assert.AreEqual(9.87654321m, (decimal)result[1]); + } + + [Test] + public void Where_Char() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new char[] { 'A', 'B', 'C' }); + var y = np.array(new char[] { 'X', 'Y', 'Z' }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(char), result.dtype); + Assert.AreEqual('A', (char)result[0]); + Assert.AreEqual('Y', (char)result[1]); + Assert.AreEqual('C', (char)result[2]); + } + + #endregion + + #region View Behavior + + [Test] + public void Where_ResultIsNewArray_NotView() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { 1, 2 }); + var y = np.array(new[] { 10, 20 }); + var result = np.where(cond, x, y); + + // Modify original, result should not change + x[0] = 999; + Assert.AreEqual(1, (int)result[0], "Result should be independent of x"); + + y[1] = 999; + Assert.AreEqual(20, (int)result[1], "Result should be independent of y"); + } + + [Test] + public void Where_ModifyResult_DoesNotAffectInputs() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { 1, 2 }); + var y = np.array(new[] { 10, 20 }); + var result = np.where(cond, x, y); + + result[0] = 999; + Assert.AreEqual(1, (int)x[0], "x should not be modified"); + Assert.AreEqual(10, (int)y[0], "y should not be modified"); + } + + #endregion + + #region Alternating Patterns + + [Test] + public void Where_Checkerboard_Pattern() + { + // Create checkerboard condition + var cond = np.zeros(new[] { 4, 4 }, NPTypeCode.Boolean); + for (int i = 0; i < 4; i++) + for (int j = 0; j < 4; j++) + cond[i, j] = (i + j) % 2 == 0; + + var x = np.ones(new[] { 4, 4 }, NPTypeCode.Int32); + var y = np.zeros(new[] { 4, 4 }, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + // Verify checkerboard pattern + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(0, (int)result[0, 1]); + Assert.AreEqual(0, (int)result[1, 0]); + Assert.AreEqual(1, (int)result[1, 1]); + } + + [Test] + public void Where_StripedPattern() + { + // Every row alternates between all True and all False + var cond = np.zeros(new[] { 4, 4 }, NPTypeCode.Boolean); + for (int i = 0; i < 4; i++) + for (int j = 0; j < 4; j++) + cond[i, j] = i % 2 == 0; + + var x = np.full(new[] { 4, 4 }, 1); + var y = np.full(new[] { 4, 4 }, 0); + var result = np.where(cond, x, y); + + // Rows 0, 2 should be 1; rows 1, 3 should be 0 + for (int j = 0; j < 4; j++) + { + Assert.AreEqual(1, (int)result[0, j]); + Assert.AreEqual(0, (int)result[1, j]); + Assert.AreEqual(1, (int)result[2, j]); + Assert.AreEqual(0, (int)result[3, j]); + } + } + + #endregion + } +} diff --git a/test/NumSharp.UnitTest/Logic/np.where.Test.cs b/test/NumSharp.UnitTest/Logic/np.where.Test.cs new file mode 100644 index 00000000..d53ed585 --- /dev/null +++ b/test/NumSharp.UnitTest/Logic/np.where.Test.cs @@ -0,0 +1,496 @@ +using System; +using System.Linq; +using TUnit.Core; +using NumSharp.UnitTest.Utilities; +using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; + +namespace NumSharp.UnitTest.Logic +{ + /// + /// Comprehensive tests for np.where matching NumPy 2.x behavior. + /// + /// NumPy signature: where(condition, x=None, y=None, /) + /// - Single arg: returns np.nonzero(condition) + /// - Three args: element-wise selection with broadcasting + /// + public class np_where_Test + { + #region Single Argument (nonzero equivalent) + + [Test] + public void Where_SingleArg_1D_ReturnsIndices() + { + // np.where([0, 1, 0, 2, 0, 3]) -> (array([1, 3, 5]),) + var arr = np.array(new[] { 0, 1, 0, 2, 0, 3 }); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + result[0].Should().BeOfValues(1L, 3L, 5L); + } + + [Test] + public void Where_SingleArg_2D_ReturnsTupleOfIndices() + { + // np.where([[0, 1, 0], [2, 0, 3]]) -> (array([0, 1, 1]), array([1, 0, 2])) + var arr = np.array(new int[,] { { 0, 1, 0 }, { 2, 0, 3 } }); + var result = np.where(arr); + + Assert.AreEqual(2, result.Length); + result[0].Should().BeOfValues(0L, 1L, 1L); // row indices + result[1].Should().BeOfValues(1L, 0L, 2L); // col indices + } + + [Test] + public void Where_SingleArg_Boolean_ReturnsNonzero() + { + var arr = np.array(new[] { true, false, true, false, true }); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + result[0].Should().BeOfValues(0L, 2L, 4L); + } + + [Test] + public void Where_SingleArg_Empty_ReturnsEmptyIndices() + { + var arr = np.array(new int[0]); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + Assert.AreEqual(0, result[0].size); + } + + [Test] + public void Where_SingleArg_AllFalse_ReturnsEmptyIndices() + { + var arr = np.array(new[] { false, false, false }); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + Assert.AreEqual(0, result[0].size); + } + + [Test] + public void Where_SingleArg_AllTrue_ReturnsAllIndices() + { + var arr = np.array(new[] { true, true, true }); + var result = np.where(arr); + + result[0].Should().BeOfValues(0L, 1L, 2L); + } + + [Test] + public void Where_SingleArg_3D_ReturnsTupleOfThreeArrays() + { + // 2x2x2 array with some non-zero elements + var arr = np.zeros(new[] { 2, 2, 2 }, NPTypeCode.Int32); + arr[0, 0, 1] = 1; + arr[1, 1, 0] = 1; + var result = np.where(arr); + + Assert.AreEqual(3, result.Length); + result[0].Should().BeOfValues(0L, 1L); // dim 0 + result[1].Should().BeOfValues(0L, 1L); // dim 1 + result[2].Should().BeOfValues(1L, 0L); // dim 2 + } + + #endregion + + #region Three Arguments (element-wise selection) + + [Test] + public void Where_ThreeArgs_Basic_SelectsCorrectly() + { + // np.where(a < 5, a, 10*a) for a = arange(10) + var a = np.arange(10); + var result = np.where(a < 5, a, 10 * a); + + result.Should().BeOfValues(0L, 1L, 2L, 3L, 4L, 50L, 60L, 70L, 80L, 90L); + } + + [Test] + public void Where_ThreeArgs_BooleanCondition() + { + var cond = np.array(new[] { true, false, true, false }); + var x = np.array(new[] { 1, 2, 3, 4 }); + var y = np.array(new[] { 10, 20, 30, 40 }); + var result = np.where(cond, x, y); + + result.Should().BeOfValues(1, 20, 3, 40); + } + + [Test] + public void Where_ThreeArgs_2D() + { + // np.where([[True, False], [True, True]], [[1, 2], [3, 4]], [[9, 8], [7, 6]]) + var cond = np.array(new bool[,] { { true, false }, { true, true } }); + var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } }); + var y = np.array(new int[,] { { 9, 8 }, { 7, 6 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 2); + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(8, (int)result[0, 1]); + Assert.AreEqual(3, (int)result[1, 0]); + Assert.AreEqual(4, (int)result[1, 1]); + } + + [Test] + public void Where_ThreeArgs_NonBoolCondition_TreatsAsTruthy() + { + // np.where([0, 1, 2, 0], 100, -100) -> [-100, 100, 100, -100] + var cond = np.array(new[] { 0, 1, 2, 0 }); + var result = np.where(cond, 100, -100); + + result.Should().BeOfValues(-100, 100, 100, -100); + } + + #endregion + + #region Scalar Arguments + + [Test] + public void Where_ScalarX() + { + var cond = np.array(new[] { true, false, true, false }); + var y = np.array(new[] { 10, 20, 30, 40 }); + var result = np.where(cond, 99, y); + + result.Should().BeOfValues(99, 20, 99, 40); + } + + [Test] + public void Where_ScalarY() + { + var cond = np.array(new[] { true, false, true, false }); + var x = np.array(new[] { 1, 2, 3, 4 }); + var result = np.where(cond, x, -1); + + result.Should().BeOfValues(1, -1, 3, -1); + } + + [Test] + public void Where_BothScalars() + { + var cond = np.array(new[] { true, false, true, false }); + var result = np.where(cond, 1, 0); + + result.Should().BeOfValues(1, 0, 1, 0); + } + + [Test] + public void Where_ScalarFloat() + { + var cond = np.array(new[] { true, false }); + var result = np.where(cond, 1.5, 2.5); + + Assert.AreEqual(typeof(double), result.dtype); + Assert.AreEqual(1.5, (double)result[0], 1e-10); + Assert.AreEqual(2.5, (double)result[1], 1e-10); + } + + #endregion + + #region Broadcasting + + [Test] + public void Where_Broadcasting_ScalarY() + { + // np.where(a < 4, a, -1) for 3x3 array + var arr = np.array(new int[,] { { 0, 1, 2 }, { 0, 2, 4 }, { 0, 3, 6 } }); + var result = np.where(arr < 4, arr, -1); + + result.Should().BeShaped(3, 3); + Assert.AreEqual(0, (int)result[0, 0]); + Assert.AreEqual(1, (int)result[0, 1]); + Assert.AreEqual(2, (int)result[0, 2]); + Assert.AreEqual(-1, (int)result[1, 2]); + Assert.AreEqual(-1, (int)result[2, 2]); + } + + [Test] + public void Where_Broadcasting_DifferentShapes() + { + // cond: (2,1), x: (3,), y: (1,3) -> result: (2,3) + var cond = np.array(new bool[,] { { true }, { false } }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new int[,] { { 10, 20, 30 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 3); + // Row 0: cond=True, so x values + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(2, (int)result[0, 1]); + Assert.AreEqual(3, (int)result[0, 2]); + // Row 1: cond=False, so y values + Assert.AreEqual(10, (int)result[1, 0]); + Assert.AreEqual(20, (int)result[1, 1]); + Assert.AreEqual(30, (int)result[1, 2]); + } + + [Test] + public void Where_Broadcasting_ColumnVector() + { + // cond: (3,1), x: scalar, y: (1,4) -> result: (3,4) + var cond = np.array(new bool[,] { { true }, { false }, { true } }); + var x = 1; + var y = np.array(new int[,] { { 10, 20, 30, 40 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(3, 4); + // Row 0: all 1s + for (int j = 0; j < 4; j++) + Assert.AreEqual(1, (int)result[0, j]); + // Row 1: y values + Assert.AreEqual(10, (int)result[1, 0]); + Assert.AreEqual(40, (int)result[1, 3]); + // Row 2: all 1s + for (int j = 0; j < 4; j++) + Assert.AreEqual(1, (int)result[2, j]); + } + + #endregion + + #region Type Promotion + + [Test] + public void Where_TypePromotion_IntFloat_ReturnsFloat64() + { + var cond = np.array(new[] { true, false }); + var result = np.where(cond, 1, 2.5); + + Assert.AreEqual(typeof(double), result.dtype); + Assert.AreEqual(1.0, (double)result[0], 1e-10); + Assert.AreEqual(2.5, (double)result[1], 1e-10); + } + + [Test] + public void Where_TypePromotion_Int32Int64_ReturnsInt64() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new int[] { 1 }); + var y = np.array(new long[] { 2 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(long), result.dtype); + } + + [Test] + public void Where_TypePromotion_FloatDouble_ReturnsDouble() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new float[] { 1.5f }); + var y = np.array(new double[] { 2.5 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(double), result.dtype); + } + + #endregion + + #region Edge Cases + + [Test] + public void Where_EmptyArrays_ThreeArgs() + { + var cond = np.array(new bool[0]); + var x = np.array(new int[0]); + var y = np.array(new int[0]); + var result = np.where(cond, x, y); + + Assert.AreEqual(0, result.size); + } + + [Test] + public void Where_SingleElement() + { + var cond = np.array(new[] { true }); + var result = np.where(cond, 42, 0); + + Assert.AreEqual(1, result.size); + Assert.AreEqual(42, (int)result[0]); + } + + [Test] + public void Where_AllTrue_ReturnsAllX() + { + var cond = np.array(new[] { true, true, true }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + result.Should().BeOfValues(1, 2, 3); + } + + [Test] + public void Where_AllFalse_ReturnsAllY() + { + var cond = np.array(new[] { false, false, false }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + result.Should().BeOfValues(10, 20, 30); + } + + [Test] + public void Where_LargeArray() + { + var size = 100000; + var cond = np.arange(size) % 2 == 0; // alternating True/False + var x = np.ones(size, NPTypeCode.Int32); + var y = np.zeros(size, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + Assert.AreEqual(size, result.size); + // Even indices should be 1, odd should be 0 + Assert.AreEqual(1, (int)result[0]); + Assert.AreEqual(0, (int)result[1]); + Assert.AreEqual(1, (int)result[2]); + } + + #endregion + + #region NumPy Output Verification + + [Test] + public void Where_NumPyExample1() + { + // From NumPy docs: np.where([[True, False], [True, True]], + // [[1, 2], [3, 4]], [[9, 8], [7, 6]]) + // Expected: array([[1, 8], [3, 4]]) + var cond = np.array(new bool[,] { { true, false }, { true, true } }); + var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } }); + var y = np.array(new int[,] { { 9, 8 }, { 7, 6 } }); + var result = np.where(cond, x, y); + + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(8, (int)result[0, 1]); + Assert.AreEqual(3, (int)result[1, 0]); + Assert.AreEqual(4, (int)result[1, 1]); + } + + [Test] + public void Where_NumPyExample2() + { + // From NumPy docs: np.where(a < 5, a, 10*a) for a = arange(10) + // Expected: array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90]) + var a = np.arange(10); + var result = np.where(a < 5, a, 10 * a); + + result.Should().BeOfValues(0L, 1L, 2L, 3L, 4L, 50L, 60L, 70L, 80L, 90L); + } + + [Test] + public void Where_NumPyExample3() + { + // From NumPy docs: np.where(a < 4, a, -1) for specific array + // Expected: array([[ 0, 1, 2], [ 0, 2, -1], [ 0, 3, -1]]) + var a = np.array(new int[,] { { 0, 1, 2 }, { 0, 2, 4 }, { 0, 3, 6 } }); + var result = np.where(a < 4, a, -1); + + Assert.AreEqual(0, (int)result[0, 0]); + Assert.AreEqual(1, (int)result[0, 1]); + Assert.AreEqual(2, (int)result[0, 2]); + Assert.AreEqual(0, (int)result[1, 0]); + Assert.AreEqual(2, (int)result[1, 1]); + Assert.AreEqual(-1, (int)result[1, 2]); + Assert.AreEqual(0, (int)result[2, 0]); + Assert.AreEqual(3, (int)result[2, 1]); + Assert.AreEqual(-1, (int)result[2, 2]); + } + + #endregion + + #region Dtype Coverage + + [Test] + public void Where_Dtype_Byte() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new byte[] { 1, 2 }); + var y = np.array(new byte[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(byte), result.dtype); + result.Should().BeOfValues((byte)1, (byte)20); + } + + [Test] + public void Where_Dtype_Int16() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new short[] { 1, 2 }); + var y = np.array(new short[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(short), result.dtype); + result.Should().BeOfValues((short)1, (short)20); + } + + [Test] + public void Where_Dtype_Int32() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new int[] { 1, 2 }); + var y = np.array(new int[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(int), result.dtype); + result.Should().BeOfValues(1, 20); + } + + [Test] + public void Where_Dtype_Int64() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new long[] { 1, 2 }); + var y = np.array(new long[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(long), result.dtype); + result.Should().BeOfValues(1L, 20L); + } + + [Test] + public void Where_Dtype_Single() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new float[] { 1.5f, 2.5f }); + var y = np.array(new float[] { 10.5f, 20.5f }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(float), result.dtype); + Assert.AreEqual(1.5f, (float)result[0], 1e-6f); + Assert.AreEqual(20.5f, (float)result[1], 1e-6f); + } + + [Test] + public void Where_Dtype_Double() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new double[] { 1.5, 2.5 }); + var y = np.array(new double[] { 10.5, 20.5 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(double), result.dtype); + Assert.AreEqual(1.5, (double)result[0], 1e-10); + Assert.AreEqual(20.5, (double)result[1], 1e-10); + } + + [Test] + public void Where_Dtype_Boolean() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new bool[] { true, true }); + var y = np.array(new bool[] { false, false }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(bool), result.dtype); + Assert.IsTrue((bool)result[0]); + Assert.IsFalse((bool)result[1]); + } + + #endregion + } +} From 312f2233818c956f1e8b20b9c3ce23ed92e5ac91 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Sun, 12 Apr 2026 14:11:23 +0300 Subject: [PATCH 2/3] perf(where): AVX2/SSE4.1 optimize mask expansion in np.where kernel Replace scalar conditional mask creation with SIMD intrinsics: V256 mask creation (for AVX2): - 8-byte elements: Avx2.ConvertToVector256Int64 (vpmovzxbq) - 4-byte elements: Avx2.ConvertToVector256Int32 (vpmovzxbd) - 2-byte elements: Avx2.ConvertToVector256Int16 (vpmovzxbw) V128 mask creation (for SSE4.1): - 8-byte elements: Sse41.ConvertToVector128Int64 (pmovzxbq) - 4-byte elements: Sse41.ConvertToVector128Int32 (pmovzxbd) - 2-byte elements: Sse41.ConvertToVector128Int16 (pmovzxbw) Each intrinsic replaces 4-16 scalar conditionals with a single zero-extend + compare instruction sequence. Also fixes reflection lookups for Vector256/Vector128.Load, Store, and ConditionalSelect methods that were failing because these are generic method definitions requiring special handling. Performance (1M double elements): - Kernel: 2.6ms @ 381 M elements/ms - NumPy baseline: ~1.86ms - Ratio: ~1.4x slower (down from ~3x before optimization) All 12 dtypes supported with fallback for non-AVX2/SSE4.1 systems. --- .../Kernels/ILKernelGenerator.Where.cs | 106 +++++++++++++++--- 1 file changed, 92 insertions(+), 14 deletions(-) diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs index e055bd8a..446755ec 100644 --- a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs @@ -4,6 +4,7 @@ using System.Reflection.Emit; using System.Runtime.CompilerServices; using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; // ============================================================================= // ILKernelGenerator.Where - IL-generated np.where(condition, x, y) kernels @@ -253,13 +254,17 @@ private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder { // Get the appropriate mask creation method based on element size var maskMethod = GetMaskCreationMethod256((int)elementSize); - var loadMethod = typeof(Vector256).GetMethod("Load", new[] { typeof(T*) })!.MakeGenericMethod(typeof(T)); - var storeMethod = typeof(Vector256).GetMethod("Store", new[] { typeof(Vector256<>).MakeGenericType(typeof(T)), typeof(T*) })!; - var selectMethod = typeof(Vector256).GetMethod("ConditionalSelect", new[] { - typeof(Vector256<>).MakeGenericType(typeof(T)), - typeof(Vector256<>).MakeGenericType(typeof(T)), - typeof(Vector256<>).MakeGenericType(typeof(T)) - })!; + + // Get Vector256 methods via reflection - need to find generic method definitions first + var loadMethod = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)! + .MakeGenericMethod(typeof(T)); + var storeMethod = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)! + .MakeGenericMethod(typeof(T)); + var selectMethod = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)! + .MakeGenericMethod(typeof(T)); // Load address: cond + (i + offset) il.Emit(OpCodes.Ldarg_0); // cond @@ -325,13 +330,17 @@ private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged { var maskMethod = GetMaskCreationMethod128((int)elementSize); - var loadMethod = typeof(Vector128).GetMethod("Load", new[] { typeof(T*) })!.MakeGenericMethod(typeof(T)); - var storeMethod = typeof(Vector128).GetMethod("Store", new[] { typeof(Vector128<>).MakeGenericType(typeof(T)), typeof(T*) })!; - var selectMethod = typeof(Vector128).GetMethod("ConditionalSelect", new[] { - typeof(Vector128<>).MakeGenericType(typeof(T)), - typeof(Vector128<>).MakeGenericType(typeof(T)), - typeof(Vector128<>).MakeGenericType(typeof(T)) - })!; + + // Get Vector128 methods via reflection - need to find generic method definitions first + var loadMethod = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)! + .MakeGenericMethod(typeof(T)); + var storeMethod = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)! + .MakeGenericMethod(typeof(T)); + var selectMethod = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)! + .MakeGenericMethod(typeof(T)); // Load address: cond + (i + offset) il.Emit(OpCodes.Ldarg_0); @@ -502,10 +511,22 @@ private static unsafe Vector256 CreateMaskV256_1Byte(byte* bools) /// /// Create V256 mask from 16 bools for 2-byte elements. + /// Uses AVX2 vpmovzxbw instruction for single-instruction expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector256 CreateMaskV256_2Byte(byte* bools) { + if (Avx2.IsSupported) + { + // Load 16 bytes into Vector128, zero-extend each byte to 16-bit + // vpmovzxbw: byte -> word (16 bytes -> 16 words) + var bytes128 = Vector128.Load(bools); + var expanded = Avx2.ConvertToVector256Int16(bytes128).AsUInt16(); + // Compare with zero: non-zero becomes 0xFFFF, zero stays 0 + return Vector256.GreaterThan(expanded, Vector256.Zero); + } + + // Scalar fallback for non-AVX2 systems return Vector256.Create( bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, @@ -528,10 +549,22 @@ private static unsafe Vector256 CreateMaskV256_2Byte(byte* bools) /// /// Create V256 mask from 8 bools for 4-byte elements. + /// Uses AVX2 vpmovzxbd instruction for single-instruction expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector256 CreateMaskV256_4Byte(byte* bools) { + if (Avx2.IsSupported) + { + // Load 8 bytes into low bytes of Vector128, zero-extend each byte to 32-bit + // vpmovzxbd: byte -> dword (8 bytes -> 8 dwords) + var bytes128 = Vector128.CreateScalar(*(ulong*)bools).AsByte(); + var expanded = Avx2.ConvertToVector256Int32(bytes128).AsUInt32(); + // Compare with zero: non-zero becomes 0xFFFF..., zero stays 0 + return Vector256.GreaterThan(expanded, Vector256.Zero); + } + + // Scalar fallback for non-AVX2 systems return Vector256.Create( bools[0] != 0 ? 0xFFFFFFFFu : 0u, bools[1] != 0 ? 0xFFFFFFFFu : 0u, @@ -546,10 +579,22 @@ private static unsafe Vector256 CreateMaskV256_4Byte(byte* bools) /// /// Create V256 mask from 4 bools for 8-byte elements. + /// Uses AVX2 vpmovzxbq instruction for single-instruction expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector256 CreateMaskV256_8Byte(byte* bools) { + if (Avx2.IsSupported) + { + // Load 4 bytes into low bytes of Vector128, zero-extend each byte to 64-bit + // vpmovzxbq: byte -> qword (4 bytes -> 4 qwords) + var bytes128 = Vector128.CreateScalar(*(uint*)bools).AsByte(); + var expanded = Avx2.ConvertToVector256Int64(bytes128).AsUInt64(); + // Compare with zero: non-zero becomes 0xFFFF..., zero stays 0 + return Vector256.GreaterThan(expanded, Vector256.Zero); + } + + // Scalar fallback for non-AVX2 systems return Vector256.Create( bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, @@ -572,10 +617,21 @@ private static unsafe Vector128 CreateMaskV128_1Byte(byte* bools) /// /// Create V128 mask from 8 bools for 2-byte elements. + /// Uses SSE4.1 pmovzxbw instruction for efficient expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector128 CreateMaskV128_2Byte(byte* bools) { + if (Sse41.IsSupported) + { + // Load 8 bytes, zero-extend each to 16-bit + // pmovzxbw: byte -> word (8 bytes -> 8 words) + var bytes128 = Vector128.CreateScalar(*(ulong*)bools).AsByte(); + var expanded = Sse41.ConvertToVector128Int16(bytes128).AsUInt16(); + return Vector128.GreaterThan(expanded, Vector128.Zero); + } + + // Scalar fallback return Vector128.Create( bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, @@ -590,10 +646,21 @@ private static unsafe Vector128 CreateMaskV128_2Byte(byte* bools) /// /// Create V128 mask from 4 bools for 4-byte elements. + /// Uses SSE4.1 pmovzxbd instruction for efficient expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector128 CreateMaskV128_4Byte(byte* bools) { + if (Sse41.IsSupported) + { + // Load 4 bytes, zero-extend each to 32-bit + // pmovzxbd: byte -> dword (4 bytes -> 4 dwords) + var bytes128 = Vector128.CreateScalar(*(uint*)bools).AsByte(); + var expanded = Sse41.ConvertToVector128Int32(bytes128).AsUInt32(); + return Vector128.GreaterThan(expanded, Vector128.Zero); + } + + // Scalar fallback return Vector128.Create( bools[0] != 0 ? 0xFFFFFFFFu : 0u, bools[1] != 0 ? 0xFFFFFFFFu : 0u, @@ -604,10 +671,21 @@ private static unsafe Vector128 CreateMaskV128_4Byte(byte* bools) /// /// Create V128 mask from 2 bools for 8-byte elements. + /// Uses SSE4.1 pmovzxbq instruction for efficient expansion. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe Vector128 CreateMaskV128_8Byte(byte* bools) { + if (Sse41.IsSupported) + { + // Load 2 bytes, zero-extend each to 64-bit + // pmovzxbq: byte -> qword (2 bytes -> 2 qwords) + var bytes128 = Vector128.CreateScalar(*(ushort*)bools).AsByte(); + var expanded = Sse41.ConvertToVector128Int64(bytes128).AsUInt64(); + return Vector128.GreaterThan(expanded, Vector128.Zero); + } + + // Scalar fallback return Vector128.Create( bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul From 10ae98b0353240a248ce87296bdb1babf53a5233 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Sun, 12 Apr 2026 14:48:20 +0300 Subject: [PATCH 3/3] perf(where): inline mask creation in IL - 5.4x faster kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of emitting Call opcodes to mask helper methods, now emit the AVX2/SSE4.1 instructions directly inline in the IL stream. This eliminates: - Method call overhead (~12% per call) - Runtime Avx2.IsSupported checks in hot path - JIT optimization barriers at call boundaries The IL now emits the full mask creation sequence: - 8-byte: ldind.u4 → CreateScalar → AsByte → ConvertToVector256Int64 → AsUInt64 → GreaterThan - 4-byte: ldind.i8 → CreateScalar → AsByte → ConvertToVector256Int32 → AsUInt32 → GreaterThan - 2-byte: Load → ConvertToVector256Int16 → AsUInt16 → GreaterThan - 1-byte: Load → GreaterThan (direct comparison) Performance (1M double elements): - Previous (method call): 2.6 ms - Inlined IL: 0.48 ms (5.4x faster) - NumPy baseline: 1.86 ms (NumSharp is now 3.9x FASTER) Fixed reflection lookups for AsByte/AsUInt* which are extension methods on Vector128/Vector256 static classes, not instance methods. --- .../Kernels/ILKernelGenerator.Where.cs | 237 +++++++++++++++++- 1 file changed, 229 insertions(+), 8 deletions(-) diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs index 446755ec..1b4eb4b0 100644 --- a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs @@ -252,9 +252,6 @@ private static void EmitWhereSIMDLoop(ILGenerator il, LocalBuilder locI) wher private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged { - // Get the appropriate mask creation method based on element size - var maskMethod = GetMaskCreationMethod256((int)elementSize); - // Get Vector256 methods via reflection - need to find generic method definitions first var loadMethod = Array.Find(typeof(Vector256).GetMethods(), m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)! @@ -277,8 +274,8 @@ private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder il.Emit(OpCodes.Conv_I); il.Emit(OpCodes.Add); - // Call mask creation: returns Vector256 on stack - il.Emit(OpCodes.Call, maskMethod); + // Inline mask creation - emit AVX2 instructions directly instead of calling helper + EmitInlineMaskCreationV256(il, (int)elementSize); // Load x vector: x + (i + offset) * elementSize il.Emit(OpCodes.Ldarg_1); // x @@ -329,8 +326,6 @@ private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged { - var maskMethod = GetMaskCreationMethod128((int)elementSize); - // Get Vector128 methods via reflection - need to find generic method definitions first var loadMethod = Array.Find(typeof(Vector128).GetMethods(), m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)! @@ -352,7 +347,9 @@ private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder } il.Emit(OpCodes.Conv_I); il.Emit(OpCodes.Add); - il.Emit(OpCodes.Call, maskMethod); + + // Inline mask creation - emit SSE4.1 instructions directly + EmitInlineMaskCreationV128(il, (int)elementSize); // Load x vector il.Emit(OpCodes.Ldarg_1); @@ -497,6 +494,230 @@ private static MethodInfo GetMaskCreationMethod128(int elementSize) }; } + #endregion + + #region Inline Mask IL Emission + + // Cache reflection lookups for inline emission + private static readonly MethodInfo _v128LoadByte = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "Load" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); + private static readonly MethodInfo _v256LoadByte = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "Load" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); + + private static readonly MethodInfo _v128CreateScalarUInt = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); + private static readonly MethodInfo _v128CreateScalarULong = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); + private static readonly MethodInfo _v128CreateScalarUShort = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); + + // AsByte is an extension method on Vector128 static class, not instance method + private static readonly MethodInfo _v128UIntAsByte = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); + private static readonly MethodInfo _v128ULongAsByte = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); + private static readonly MethodInfo _v128UShortAsByte = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); + + private static readonly MethodInfo _avx2ConvertToV256Int64 = typeof(Avx2).GetMethod("ConvertToVector256Int64", new[] { typeof(Vector128) })!; + private static readonly MethodInfo _avx2ConvertToV256Int32 = typeof(Avx2).GetMethod("ConvertToVector256Int32", new[] { typeof(Vector128) })!; + private static readonly MethodInfo _avx2ConvertToV256Int16 = typeof(Avx2).GetMethod("ConvertToVector256Int16", new[] { typeof(Vector128) })!; + + private static readonly MethodInfo _sse41ConvertToV128Int64 = typeof(Sse41).GetMethod("ConvertToVector128Int64", new[] { typeof(Vector128) })!; + private static readonly MethodInfo _sse41ConvertToV128Int32 = typeof(Sse41).GetMethod("ConvertToVector128Int32", new[] { typeof(Vector128) })!; + private static readonly MethodInfo _sse41ConvertToV128Int16 = typeof(Sse41).GetMethod("ConvertToVector128Int16", new[] { typeof(Vector128) })!; + + // As* methods are extension methods on Vector256/Vector128 static classes + private static readonly MethodInfo _v256LongAsULong = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "AsUInt64" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(long)); + private static readonly MethodInfo _v256IntAsUInt = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "AsUInt32" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(int)); + private static readonly MethodInfo _v256ShortAsUShort = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "AsUInt16" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(short)); + + private static readonly MethodInfo _v128LongAsULong = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsUInt64" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(long)); + private static readonly MethodInfo _v128IntAsUInt = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsUInt32" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(int)); + private static readonly MethodInfo _v128ShortAsUShort = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "AsUInt16" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(short)); + + private static readonly MethodInfo _v256GreaterThanULong = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); + private static readonly MethodInfo _v256GreaterThanUInt = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); + private static readonly MethodInfo _v256GreaterThanUShort = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); + private static readonly MethodInfo _v256GreaterThanByte = Array.Find(typeof(Vector256).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); + + private static readonly MethodInfo _v128GreaterThanULong = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong)); + private static readonly MethodInfo _v128GreaterThanUInt = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint)); + private static readonly MethodInfo _v128GreaterThanUShort = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort)); + private static readonly MethodInfo _v128GreaterThanByte = Array.Find(typeof(Vector128).GetMethods(), + m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); + + private static readonly FieldInfo _v256ZeroULong = typeof(Vector256).GetProperty("Zero")!.GetMethod!.IsStatic + ? null! : null!; // Use GetMethod call instead + private static readonly MethodInfo _v256GetZeroULong = typeof(Vector256).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v256GetZeroUInt = typeof(Vector256).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v256GetZeroUShort = typeof(Vector256).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v256GetZeroByte = typeof(Vector256).GetProperty("Zero")!.GetMethod!; + + private static readonly MethodInfo _v128GetZeroULong = typeof(Vector128).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v128GetZeroUInt = typeof(Vector128).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v128GetZeroUShort = typeof(Vector128).GetProperty("Zero")!.GetMethod!; + private static readonly MethodInfo _v128GetZeroByte = typeof(Vector128).GetProperty("Zero")!.GetMethod!; + + /// + /// Emit inline V256 mask creation. Stack: byte* -> Vector256{T} (as mask) + /// + private static void EmitInlineMaskCreationV256(ILGenerator il, int elementSize) + { + // Stack has: byte* pointing to condition bools + + switch (elementSize) + { + case 8: // double/long: load 4 bytes, expand to 4 qwords + // *(uint*)ptr + il.Emit(OpCodes.Ldind_U4); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, _v128CreateScalarUInt); + // .AsByte() + il.Emit(OpCodes.Call, _v128UIntAsByte); + // Avx2.ConvertToVector256Int64(bytes) + il.Emit(OpCodes.Call, _avx2ConvertToV256Int64); + // .AsUInt64() + il.Emit(OpCodes.Call, _v256LongAsULong); + // Vector256.Zero + il.Emit(OpCodes.Call, _v256GetZeroULong); + // Vector256.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v256GreaterThanULong); + break; + + case 4: // float/int: load 8 bytes, expand to 8 dwords + // *(ulong*)ptr + il.Emit(OpCodes.Ldind_I8); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, _v128CreateScalarULong); + // .AsByte() + il.Emit(OpCodes.Call, _v128ULongAsByte); + // Avx2.ConvertToVector256Int32(bytes) + il.Emit(OpCodes.Call, _avx2ConvertToV256Int32); + // .AsUInt32() + il.Emit(OpCodes.Call, _v256IntAsUInt); + // Vector256.Zero + il.Emit(OpCodes.Call, _v256GetZeroUInt); + // Vector256.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v256GreaterThanUInt); + break; + + case 2: // short/char: load 16 bytes, expand to 16 words + // Vector128.Load(ptr) + il.Emit(OpCodes.Call, _v128LoadByte); + // Avx2.ConvertToVector256Int16(bytes) + il.Emit(OpCodes.Call, _avx2ConvertToV256Int16); + // .AsUInt16() + il.Emit(OpCodes.Call, _v256ShortAsUShort); + // Vector256.Zero + il.Emit(OpCodes.Call, _v256GetZeroUShort); + // Vector256.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v256GreaterThanUShort); + break; + + case 1: // byte/bool: load 32 bytes, compare directly + // Vector256.Load(ptr) + il.Emit(OpCodes.Call, _v256LoadByte); + // Vector256.Zero + il.Emit(OpCodes.Call, _v256GetZeroByte); + // Vector256.GreaterThan(vec, zero) + il.Emit(OpCodes.Call, _v256GreaterThanByte); + break; + + default: + throw new NotSupportedException($"Element size {elementSize} not supported"); + } + } + + /// + /// Emit inline V128 mask creation. Stack: byte* -> Vector128{T} (as mask) + /// + private static void EmitInlineMaskCreationV128(ILGenerator il, int elementSize) + { + switch (elementSize) + { + case 8: // double/long: load 2 bytes, expand to 2 qwords + // *(ushort*)ptr + il.Emit(OpCodes.Ldind_U2); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, _v128CreateScalarUShort); + // .AsByte() + il.Emit(OpCodes.Call, _v128UShortAsByte); + // Sse41.ConvertToVector128Int64(bytes) + il.Emit(OpCodes.Call, _sse41ConvertToV128Int64); + // .AsUInt64() + il.Emit(OpCodes.Call, _v128LongAsULong); + // Vector128.Zero + il.Emit(OpCodes.Call, _v128GetZeroULong); + // Vector128.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v128GreaterThanULong); + break; + + case 4: // float/int: load 4 bytes, expand to 4 dwords + // *(uint*)ptr + il.Emit(OpCodes.Ldind_U4); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, _v128CreateScalarUInt); + // .AsByte() + il.Emit(OpCodes.Call, _v128UIntAsByte); + // Sse41.ConvertToVector128Int32(bytes) + il.Emit(OpCodes.Call, _sse41ConvertToV128Int32); + // .AsUInt32() + il.Emit(OpCodes.Call, _v128IntAsUInt); + // Vector128.Zero + il.Emit(OpCodes.Call, _v128GetZeroUInt); + // Vector128.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v128GreaterThanUInt); + break; + + case 2: // short/char: load 8 bytes, expand to 8 words + // *(ulong*)ptr + il.Emit(OpCodes.Ldind_I8); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, _v128CreateScalarULong); + // .AsByte() + il.Emit(OpCodes.Call, _v128ULongAsByte); + // Sse41.ConvertToVector128Int16(bytes) + il.Emit(OpCodes.Call, _sse41ConvertToV128Int16); + // .AsUInt16() + il.Emit(OpCodes.Call, _v128ShortAsUShort); + // Vector128.Zero + il.Emit(OpCodes.Call, _v128GetZeroUShort); + // Vector128.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, _v128GreaterThanUShort); + break; + + case 1: // byte/bool: load 16 bytes, compare directly + // Vector128.Load(ptr) + il.Emit(OpCodes.Call, _v128LoadByte); + // Vector128.Zero + il.Emit(OpCodes.Call, _v128GetZeroByte); + // Vector128.GreaterThan(vec, zero) + il.Emit(OpCodes.Call, _v128GreaterThanByte); + break; + + default: + throw new NotSupportedException($"Element size {elementSize} not supported"); + } + } + + #endregion + + #region Static Mask Creation Methods (fallback) + /// /// Create V256 mask from 32 bools for 1-byte elements. ///