diff --git a/src/NumSharp.Core/APIs/np.where.cs b/src/NumSharp.Core/APIs/np.where.cs
new file mode 100644
index 00000000..a361534a
--- /dev/null
+++ b/src/NumSharp.Core/APIs/np.where.cs
@@ -0,0 +1,199 @@
+using System;
+using NumSharp.Backends.Kernels;
+using NumSharp.Generic;
+
+namespace NumSharp
+{
+ public static partial class np
+ {
+ ///
+ /// Return elements chosen from `x` or `y` depending on `condition`.
+ ///
+ /// Where True, yield `x`, otherwise yield `y`.
+ /// Tuple of arrays with indices where condition is non-zero (equivalent to np.nonzero).
+ /// https://numpy.org/doc/stable/reference/generated/numpy.where.html
+ public static NDArray[] where(NDArray condition)
+ {
+ return nonzero(condition);
+ }
+
+ ///
+ /// Return elements chosen from `x` or `y` depending on `condition`.
+ ///
+ /// Where True, yield `x`, otherwise yield `y`.
+ /// Values from which to choose where condition is True.
+ /// Values from which to choose where condition is False.
+ /// An array with elements from `x` where `condition` is True, and elements from `y` elsewhere.
+ /// https://numpy.org/doc/stable/reference/generated/numpy.where.html
+ public static NDArray where(NDArray condition, NDArray x, NDArray y)
+ {
+ // Broadcast all three arrays to common shape
+ var broadcasted = broadcast_arrays(condition, x, y);
+ var cond = broadcasted[0];
+ var xArr = broadcasted[1];
+ var yArr = broadcasted[2];
+
+ // Determine output dtype from x and y (type promotion)
+ var outType = _FindCommonType(xArr, yArr);
+ // Use cond.shape (dimensions only) not cond.Shape (which may have broadcast strides)
+ var result = empty(cond.shape, outType);
+
+ // Handle empty arrays - nothing to iterate
+ if (result.size == 0)
+ return result;
+
+ // IL Kernel fast path: all arrays contiguous, bool condition, SIMD enabled
+ // Broadcasted arrays (stride=0) are NOT contiguous, so they use iterator path.
+ bool canUseKernel = ILKernelGenerator.Enabled &&
+ cond.typecode == NPTypeCode.Boolean &&
+ cond.Shape.IsContiguous &&
+ xArr.Shape.IsContiguous &&
+ yArr.Shape.IsContiguous;
+
+ if (canUseKernel)
+ {
+ WhereKernelDispatch(cond, xArr, yArr, result, outType);
+ return result;
+ }
+
+ // Iterator fallback for non-contiguous/broadcasted arrays
+ switch (outType)
+ {
+ case NPTypeCode.Boolean:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ case NPTypeCode.Byte:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ case NPTypeCode.Int16:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ case NPTypeCode.UInt16:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ case NPTypeCode.Int32:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ case NPTypeCode.UInt32:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ case NPTypeCode.Int64:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ case NPTypeCode.UInt64:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ case NPTypeCode.Char:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ case NPTypeCode.Single:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ case NPTypeCode.Double:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ case NPTypeCode.Decimal:
+ WhereImpl(cond, xArr, yArr, result);
+ break;
+ default:
+ throw new NotSupportedException($"Type {outType} not supported for np.where");
+ }
+
+ return result;
+ }
+
+ ///
+ /// Return elements chosen from `x` or `y` depending on `condition`.
+ /// Scalar overload for x.
+ ///
+ public static NDArray where(NDArray condition, object x, NDArray y)
+ {
+ return where(condition, asanyarray(x), y);
+ }
+
+ ///
+ /// Return elements chosen from `x` or `y` depending on `condition`.
+ /// Scalar overload for y.
+ ///
+ public static NDArray where(NDArray condition, NDArray x, object y)
+ {
+ return where(condition, x, asanyarray(y));
+ }
+
+ ///
+ /// Return elements chosen from `x` or `y` depending on `condition`.
+ /// Scalar overload for both x and y.
+ ///
+ public static NDArray where(NDArray condition, object x, object y)
+ {
+ return where(condition, asanyarray(x), asanyarray(y));
+ }
+
+ private static void WhereImpl(NDArray cond, NDArray x, NDArray y, NDArray result) where T : unmanaged
+ {
+ // Use iterators for proper handling of broadcasted/strided arrays
+ using var condIter = cond.AsIterator();
+ using var xIter = x.AsIterator();
+ using var yIter = y.AsIterator();
+ using var resultIter = result.AsIterator();
+
+ while (condIter.HasNext())
+ {
+ var c = condIter.MoveNext();
+ var xVal = xIter.MoveNext();
+ var yVal = yIter.MoveNext();
+ resultIter.MoveNextReference() = c ? xVal : yVal;
+ }
+ }
+
+ ///
+ /// IL Kernel dispatch for contiguous arrays.
+ /// Uses IL-generated kernels with SIMD optimization.
+ ///
+ private static unsafe void WhereKernelDispatch(NDArray cond, NDArray x, NDArray y, NDArray result, NPTypeCode outType)
+ {
+ var condPtr = (bool*)cond.Address;
+ var count = result.size;
+
+ switch (outType)
+ {
+ case NPTypeCode.Boolean:
+ ILKernelGenerator.WhereExecute(condPtr, (bool*)x.Address, (bool*)y.Address, (bool*)result.Address, count);
+ break;
+ case NPTypeCode.Byte:
+ ILKernelGenerator.WhereExecute(condPtr, (byte*)x.Address, (byte*)y.Address, (byte*)result.Address, count);
+ break;
+ case NPTypeCode.Int16:
+ ILKernelGenerator.WhereExecute(condPtr, (short*)x.Address, (short*)y.Address, (short*)result.Address, count);
+ break;
+ case NPTypeCode.UInt16:
+ ILKernelGenerator.WhereExecute(condPtr, (ushort*)x.Address, (ushort*)y.Address, (ushort*)result.Address, count);
+ break;
+ case NPTypeCode.Int32:
+ ILKernelGenerator.WhereExecute(condPtr, (int*)x.Address, (int*)y.Address, (int*)result.Address, count);
+ break;
+ case NPTypeCode.UInt32:
+ ILKernelGenerator.WhereExecute(condPtr, (uint*)x.Address, (uint*)y.Address, (uint*)result.Address, count);
+ break;
+ case NPTypeCode.Int64:
+ ILKernelGenerator.WhereExecute(condPtr, (long*)x.Address, (long*)y.Address, (long*)result.Address, count);
+ break;
+ case NPTypeCode.UInt64:
+ ILKernelGenerator.WhereExecute(condPtr, (ulong*)x.Address, (ulong*)y.Address, (ulong*)result.Address, count);
+ break;
+ case NPTypeCode.Char:
+ ILKernelGenerator.WhereExecute(condPtr, (char*)x.Address, (char*)y.Address, (char*)result.Address, count);
+ break;
+ case NPTypeCode.Single:
+ ILKernelGenerator.WhereExecute(condPtr, (float*)x.Address, (float*)y.Address, (float*)result.Address, count);
+ break;
+ case NPTypeCode.Double:
+ ILKernelGenerator.WhereExecute(condPtr, (double*)x.Address, (double*)y.Address, (double*)result.Address, count);
+ break;
+ case NPTypeCode.Decimal:
+ ILKernelGenerator.WhereExecute(condPtr, (decimal*)x.Address, (decimal*)y.Address, (decimal*)result.Address, count);
+ break;
+ }
+ }
+ }
+}
diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs
new file mode 100644
index 00000000..1b4eb4b0
--- /dev/null
+++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs
@@ -0,0 +1,934 @@
+using System;
+using System.Collections.Concurrent;
+using System.Reflection;
+using System.Reflection.Emit;
+using System.Runtime.CompilerServices;
+using System.Runtime.Intrinsics;
+using System.Runtime.Intrinsics.X86;
+
+// =============================================================================
+// ILKernelGenerator.Where - IL-generated np.where(condition, x, y) kernels
+// =============================================================================
+//
+// RESPONSIBILITY:
+// - Generate optimized kernels for conditional selection
+// - result[i] = cond[i] ? x[i] : y[i]
+//
+// ARCHITECTURE:
+// Uses IL emission to generate type-specific kernels at runtime.
+// The challenge is bool mask expansion: condition is bool[] (1 byte per element),
+// but x/y can be any dtype (1-8 bytes per element).
+//
+// | Element Size | V256 Elements | Bools to Load |
+// |--------------|---------------|---------------|
+// | 1 byte | 32 | 32 |
+// | 2 bytes | 16 | 16 |
+// | 4 bytes | 8 | 8 |
+// | 8 bytes | 4 | 4 |
+//
+// KERNEL TYPES:
+// - WhereKernel: Main kernel delegate (cond*, x*, y*, result*, count)
+//
+// =============================================================================
+
+namespace NumSharp.Backends.Kernels
+{
+ ///
+ /// Delegate for where operation kernels.
+ ///
+ public unsafe delegate void WhereKernel(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged;
+
+ public static partial class ILKernelGenerator
+ {
+ ///
+ /// Cache of IL-generated where kernels.
+ /// Key: Type
+ ///
+ private static readonly ConcurrentDictionary _whereKernelCache = new();
+
+ #region Public API
+
+ ///
+ /// Get or generate an IL-based where kernel for the specified type.
+ /// Returns null if IL generation is disabled or fails.
+ ///
+ public static WhereKernel? GetWhereKernel() where T : unmanaged
+ {
+ if (!Enabled)
+ return null;
+
+ var type = typeof(T);
+
+ if (_whereKernelCache.TryGetValue(type, out var cached))
+ return (WhereKernel)cached;
+
+ var kernel = TryGenerateWhereKernel();
+ if (kernel == null)
+ return null;
+
+ if (_whereKernelCache.TryAdd(type, kernel))
+ return kernel;
+
+ return (WhereKernel)_whereKernelCache[type];
+ }
+
+ ///
+ /// Execute where operation using IL-generated kernel or fallback to static helper.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static unsafe void WhereExecute(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged
+ {
+ if (count == 0)
+ return;
+
+ var kernel = GetWhereKernel();
+ if (kernel != null)
+ {
+ kernel(cond, x, y, result, count);
+ }
+ else
+ {
+ // Fallback to scalar loop
+ WhereScalar(cond, x, y, result, count);
+ }
+ }
+
+ #endregion
+
+ #region Kernel Generation
+
+ private static WhereKernel? TryGenerateWhereKernel() where T : unmanaged
+ {
+ try
+ {
+ return GenerateWhereKernelIL();
+ }
+ catch (Exception ex)
+ {
+ System.Diagnostics.Debug.WriteLine($"[ILKernel] TryGenerateWhereKernel<{typeof(T).Name}>: {ex.GetType().Name}: {ex.Message}");
+ return null;
+ }
+ }
+
+ private static unsafe WhereKernel GenerateWhereKernelIL() where T : unmanaged
+ {
+ int elementSize = Unsafe.SizeOf();
+
+ // Determine if we can use SIMD
+ bool canSimd = elementSize <= 8 && IsSimdSupported();
+
+ var dm = new DynamicMethod(
+ name: $"IL_Where_{typeof(T).Name}",
+ returnType: typeof(void),
+ parameterTypes: new[] { typeof(bool*), typeof(T*), typeof(T*), typeof(T*), typeof(long) },
+ owner: typeof(ILKernelGenerator),
+ skipVisibility: true
+ );
+
+ var il = dm.GetILGenerator();
+
+ // Locals
+ var locI = il.DeclareLocal(typeof(long)); // loop counter
+
+ // Labels
+ var lblScalarLoop = il.DefineLabel();
+ var lblScalarLoopEnd = il.DefineLabel();
+
+ // i = 0
+ il.Emit(OpCodes.Ldc_I8, 0L);
+ il.Emit(OpCodes.Stloc, locI);
+
+ if (canSimd && VectorBits >= 128)
+ {
+ // Generate SIMD path
+ EmitWhereSIMDLoop(il, locI);
+ }
+
+ // Scalar loop for remainder
+ il.MarkLabel(lblScalarLoop);
+
+ // if (i >= count) goto end
+ il.Emit(OpCodes.Ldloc, locI);
+ il.Emit(OpCodes.Ldarg, 4); // count
+ il.Emit(OpCodes.Bge, lblScalarLoopEnd);
+
+ // result[i] = cond[i] ? x[i] : y[i]
+ EmitWhereScalarElement(il, locI);
+
+ // i++
+ il.Emit(OpCodes.Ldloc, locI);
+ il.Emit(OpCodes.Ldc_I8, 1L);
+ il.Emit(OpCodes.Add);
+ il.Emit(OpCodes.Stloc, locI);
+
+ il.Emit(OpCodes.Br, lblScalarLoop);
+
+ il.MarkLabel(lblScalarLoopEnd);
+ il.Emit(OpCodes.Ret);
+
+ return (WhereKernel)dm.CreateDelegate(typeof(WhereKernel));
+ }
+
+ private static void EmitWhereSIMDLoop(ILGenerator il, LocalBuilder locI) where T : unmanaged
+ {
+ long elementSize = Unsafe.SizeOf();
+ long vectorCount = VectorBits >= 256 ? (32 / elementSize) : (16 / elementSize);
+ long unrollFactor = 4;
+ long unrollStep = vectorCount * unrollFactor;
+ bool useV256 = VectorBits >= 256;
+
+ var locUnrollEnd = il.DeclareLocal(typeof(long));
+ var locVectorEnd = il.DeclareLocal(typeof(long));
+
+ var lblUnrollLoop = il.DefineLabel();
+ var lblUnrollLoopEnd = il.DefineLabel();
+ var lblVectorLoop = il.DefineLabel();
+ var lblVectorLoopEnd = il.DefineLabel();
+
+ // unrollEnd = count - unrollStep (for 4x unrolled loop)
+ il.Emit(OpCodes.Ldarg, 4); // count
+ il.Emit(OpCodes.Ldc_I8, unrollStep);
+ il.Emit(OpCodes.Sub);
+ il.Emit(OpCodes.Stloc, locUnrollEnd);
+
+ // vectorEnd = count - vectorCount (for remainder loop)
+ il.Emit(OpCodes.Ldarg, 4); // count
+ il.Emit(OpCodes.Ldc_I8, vectorCount);
+ il.Emit(OpCodes.Sub);
+ il.Emit(OpCodes.Stloc, locVectorEnd);
+
+ // ========== 4x UNROLLED SIMD LOOP ==========
+ il.MarkLabel(lblUnrollLoop);
+
+ // if (i > unrollEnd) goto UnrollLoopEnd
+ il.Emit(OpCodes.Ldloc, locI);
+ il.Emit(OpCodes.Ldloc, locUnrollEnd);
+ il.Emit(OpCodes.Bgt, lblUnrollLoopEnd);
+
+ // Process 4 vectors per iteration
+ for (long u = 0; u < unrollFactor; u++)
+ {
+ long offset = vectorCount * u;
+ if (useV256)
+ EmitWhereV256BodyWithOffset(il, locI, elementSize, offset);
+ else
+ EmitWhereV128BodyWithOffset(il, locI, elementSize, offset);
+ }
+
+ // i += unrollStep
+ il.Emit(OpCodes.Ldloc, locI);
+ il.Emit(OpCodes.Ldc_I8, unrollStep);
+ il.Emit(OpCodes.Add);
+ il.Emit(OpCodes.Stloc, locI);
+
+ il.Emit(OpCodes.Br, lblUnrollLoop);
+
+ il.MarkLabel(lblUnrollLoopEnd);
+
+ // ========== REMAINDER SIMD LOOP (1 vector at a time) ==========
+ il.MarkLabel(lblVectorLoop);
+
+ // if (i > vectorEnd) goto VectorLoopEnd
+ il.Emit(OpCodes.Ldloc, locI);
+ il.Emit(OpCodes.Ldloc, locVectorEnd);
+ il.Emit(OpCodes.Bgt, lblVectorLoopEnd);
+
+ // Process 1 vector
+ if (useV256)
+ EmitWhereV256BodyWithOffset(il, locI, elementSize, 0L);
+ else
+ EmitWhereV128BodyWithOffset(il, locI, elementSize, 0L);
+
+ // i += vectorCount
+ il.Emit(OpCodes.Ldloc, locI);
+ il.Emit(OpCodes.Ldc_I8, vectorCount);
+ il.Emit(OpCodes.Add);
+ il.Emit(OpCodes.Stloc, locI);
+
+ il.Emit(OpCodes.Br, lblVectorLoop);
+
+ il.MarkLabel(lblVectorLoopEnd);
+ }
+
+ private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged
+ {
+ // Get Vector256 methods via reflection - need to find generic method definitions first
+ var loadMethod = Array.Find(typeof(Vector256).GetMethods(),
+ m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)!
+ .MakeGenericMethod(typeof(T));
+ var storeMethod = Array.Find(typeof(Vector256).GetMethods(),
+ m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)!
+ .MakeGenericMethod(typeof(T));
+ var selectMethod = Array.Find(typeof(Vector256).GetMethods(),
+ m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)!
+ .MakeGenericMethod(typeof(T));
+
+ // Load address: cond + (i + offset)
+ il.Emit(OpCodes.Ldarg_0); // cond
+ il.Emit(OpCodes.Ldloc, locI);
+ if (offset > 0)
+ {
+ il.Emit(OpCodes.Ldc_I8, offset);
+ il.Emit(OpCodes.Add);
+ }
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+
+ // Inline mask creation - emit AVX2 instructions directly instead of calling helper
+ EmitInlineMaskCreationV256(il, (int)elementSize);
+
+ // Load x vector: x + (i + offset) * elementSize
+ il.Emit(OpCodes.Ldarg_1); // x
+ il.Emit(OpCodes.Ldloc, locI);
+ if (offset > 0)
+ {
+ il.Emit(OpCodes.Ldc_I8, offset);
+ il.Emit(OpCodes.Add);
+ }
+ il.Emit(OpCodes.Ldc_I8, elementSize);
+ il.Emit(OpCodes.Mul);
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+ il.Emit(OpCodes.Call, loadMethod);
+
+ // Load y vector: y + (i + offset) * elementSize
+ il.Emit(OpCodes.Ldarg_2); // y
+ il.Emit(OpCodes.Ldloc, locI);
+ if (offset > 0)
+ {
+ il.Emit(OpCodes.Ldc_I8, offset);
+ il.Emit(OpCodes.Add);
+ }
+ il.Emit(OpCodes.Ldc_I8, elementSize);
+ il.Emit(OpCodes.Mul);
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+ il.Emit(OpCodes.Call, loadMethod);
+
+ // Stack: mask, xVec, yVec
+ // ConditionalSelect(mask, x, y)
+ il.Emit(OpCodes.Call, selectMethod);
+
+ // Store result: result + (i + offset) * elementSize
+ il.Emit(OpCodes.Ldarg_3); // result
+ il.Emit(OpCodes.Ldloc, locI);
+ if (offset > 0)
+ {
+ il.Emit(OpCodes.Ldc_I8, offset);
+ il.Emit(OpCodes.Add);
+ }
+ il.Emit(OpCodes.Ldc_I8, elementSize);
+ il.Emit(OpCodes.Mul);
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+ il.Emit(OpCodes.Call, storeMethod);
+ }
+
+ private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged
+ {
+ // Get Vector128 methods via reflection - need to find generic method definitions first
+ var loadMethod = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)!
+ .MakeGenericMethod(typeof(T));
+ var storeMethod = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)!
+ .MakeGenericMethod(typeof(T));
+ var selectMethod = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)!
+ .MakeGenericMethod(typeof(T));
+
+ // Load address: cond + (i + offset)
+ il.Emit(OpCodes.Ldarg_0);
+ il.Emit(OpCodes.Ldloc, locI);
+ if (offset > 0)
+ {
+ il.Emit(OpCodes.Ldc_I8, offset);
+ il.Emit(OpCodes.Add);
+ }
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+
+ // Inline mask creation - emit SSE4.1 instructions directly
+ EmitInlineMaskCreationV128(il, (int)elementSize);
+
+ // Load x vector
+ il.Emit(OpCodes.Ldarg_1);
+ il.Emit(OpCodes.Ldloc, locI);
+ if (offset > 0)
+ {
+ il.Emit(OpCodes.Ldc_I8, offset);
+ il.Emit(OpCodes.Add);
+ }
+ il.Emit(OpCodes.Ldc_I8, elementSize);
+ il.Emit(OpCodes.Mul);
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+ il.Emit(OpCodes.Call, loadMethod);
+
+ // Load y vector
+ il.Emit(OpCodes.Ldarg_2);
+ il.Emit(OpCodes.Ldloc, locI);
+ if (offset > 0)
+ {
+ il.Emit(OpCodes.Ldc_I8, offset);
+ il.Emit(OpCodes.Add);
+ }
+ il.Emit(OpCodes.Ldc_I8, elementSize);
+ il.Emit(OpCodes.Mul);
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+ il.Emit(OpCodes.Call, loadMethod);
+
+ // ConditionalSelect
+ il.Emit(OpCodes.Call, selectMethod);
+
+ // Store
+ il.Emit(OpCodes.Ldarg_3);
+ il.Emit(OpCodes.Ldloc, locI);
+ if (offset > 0)
+ {
+ il.Emit(OpCodes.Ldc_I8, offset);
+ il.Emit(OpCodes.Add);
+ }
+ il.Emit(OpCodes.Ldc_I8, elementSize);
+ il.Emit(OpCodes.Mul);
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+ il.Emit(OpCodes.Call, storeMethod);
+ }
+
+ private static void EmitWhereScalarElement(ILGenerator il, LocalBuilder locI) where T : unmanaged
+ {
+ long elementSize = Unsafe.SizeOf();
+ var typeCode = GetNPTypeCode();
+
+ // result[i] = cond[i] ? x[i] : y[i]
+ var lblFalse = il.DefineLabel();
+ var lblEnd = il.DefineLabel();
+
+ // Load result address: result + i * elementSize
+ il.Emit(OpCodes.Ldarg_3);
+ il.Emit(OpCodes.Ldloc, locI);
+ il.Emit(OpCodes.Ldc_I8, elementSize);
+ il.Emit(OpCodes.Mul);
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+
+ // Load cond[i]: cond + i (bool is 1 byte)
+ il.Emit(OpCodes.Ldarg_0);
+ il.Emit(OpCodes.Ldloc, locI);
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+ il.Emit(OpCodes.Ldind_U1); // Load bool as byte
+
+ // if (!cond[i]) goto lblFalse
+ il.Emit(OpCodes.Brfalse, lblFalse);
+
+ // True branch: load x[i]
+ il.Emit(OpCodes.Ldarg_1);
+ il.Emit(OpCodes.Ldloc, locI);
+ il.Emit(OpCodes.Ldc_I8, elementSize);
+ il.Emit(OpCodes.Mul);
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+ EmitLoadIndirect(il, typeCode);
+ il.Emit(OpCodes.Br, lblEnd);
+
+ // False branch: load y[i]
+ il.MarkLabel(lblFalse);
+ il.Emit(OpCodes.Ldarg_2);
+ il.Emit(OpCodes.Ldloc, locI);
+ il.Emit(OpCodes.Ldc_I8, elementSize);
+ il.Emit(OpCodes.Mul);
+ il.Emit(OpCodes.Conv_I);
+ il.Emit(OpCodes.Add);
+ EmitLoadIndirect(il, typeCode);
+
+ il.MarkLabel(lblEnd);
+ // Stack: result_ptr, value
+ EmitStoreIndirect(il, typeCode);
+ }
+
+ private static NPTypeCode GetNPTypeCode() where T : unmanaged
+ {
+ if (typeof(T) == typeof(bool)) return NPTypeCode.Boolean;
+ if (typeof(T) == typeof(byte)) return NPTypeCode.Byte;
+ if (typeof(T) == typeof(short)) return NPTypeCode.Int16;
+ if (typeof(T) == typeof(ushort)) return NPTypeCode.UInt16;
+ if (typeof(T) == typeof(int)) return NPTypeCode.Int32;
+ if (typeof(T) == typeof(uint)) return NPTypeCode.UInt32;
+ if (typeof(T) == typeof(long)) return NPTypeCode.Int64;
+ if (typeof(T) == typeof(ulong)) return NPTypeCode.UInt64;
+ if (typeof(T) == typeof(char)) return NPTypeCode.Char;
+ if (typeof(T) == typeof(float)) return NPTypeCode.Single;
+ if (typeof(T) == typeof(double)) return NPTypeCode.Double;
+ if (typeof(T) == typeof(decimal)) return NPTypeCode.Decimal;
+ return NPTypeCode.Empty;
+ }
+
+ #endregion
+
+ #region Mask Creation Methods
+
+ private static MethodInfo GetMaskCreationMethod256(int elementSize)
+ {
+ return elementSize switch
+ {
+ 1 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_1Byte), BindingFlags.NonPublic | BindingFlags.Static)!,
+ 2 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_2Byte), BindingFlags.NonPublic | BindingFlags.Static)!,
+ 4 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_4Byte), BindingFlags.NonPublic | BindingFlags.Static)!,
+ 8 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_8Byte), BindingFlags.NonPublic | BindingFlags.Static)!,
+ _ => throw new NotSupportedException($"Element size {elementSize} not supported for SIMD where")
+ };
+ }
+
+ private static MethodInfo GetMaskCreationMethod128(int elementSize)
+ {
+ return elementSize switch
+ {
+ 1 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_1Byte), BindingFlags.NonPublic | BindingFlags.Static)!,
+ 2 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_2Byte), BindingFlags.NonPublic | BindingFlags.Static)!,
+ 4 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_4Byte), BindingFlags.NonPublic | BindingFlags.Static)!,
+ 8 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_8Byte), BindingFlags.NonPublic | BindingFlags.Static)!,
+ _ => throw new NotSupportedException($"Element size {elementSize} not supported for SIMD where")
+ };
+ }
+
+ #endregion
+
+ #region Inline Mask IL Emission
+
+ // Cache reflection lookups for inline emission
+ private static readonly MethodInfo _v128LoadByte = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "Load" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte));
+ private static readonly MethodInfo _v256LoadByte = Array.Find(typeof(Vector256).GetMethods(),
+ m => m.Name == "Load" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte));
+
+ private static readonly MethodInfo _v128CreateScalarUInt = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint));
+ private static readonly MethodInfo _v128CreateScalarULong = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong));
+ private static readonly MethodInfo _v128CreateScalarUShort = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "CreateScalar" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort));
+
+ // AsByte is an extension method on Vector128 static class, not instance method
+ private static readonly MethodInfo _v128UIntAsByte = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint));
+ private static readonly MethodInfo _v128ULongAsByte = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong));
+ private static readonly MethodInfo _v128UShortAsByte = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "AsByte" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort));
+
+ private static readonly MethodInfo _avx2ConvertToV256Int64 = typeof(Avx2).GetMethod("ConvertToVector256Int64", new[] { typeof(Vector128) })!;
+ private static readonly MethodInfo _avx2ConvertToV256Int32 = typeof(Avx2).GetMethod("ConvertToVector256Int32", new[] { typeof(Vector128) })!;
+ private static readonly MethodInfo _avx2ConvertToV256Int16 = typeof(Avx2).GetMethod("ConvertToVector256Int16", new[] { typeof(Vector128) })!;
+
+ private static readonly MethodInfo _sse41ConvertToV128Int64 = typeof(Sse41).GetMethod("ConvertToVector128Int64", new[] { typeof(Vector128) })!;
+ private static readonly MethodInfo _sse41ConvertToV128Int32 = typeof(Sse41).GetMethod("ConvertToVector128Int32", new[] { typeof(Vector128) })!;
+ private static readonly MethodInfo _sse41ConvertToV128Int16 = typeof(Sse41).GetMethod("ConvertToVector128Int16", new[] { typeof(Vector128) })!;
+
+ // As* methods are extension methods on Vector256/Vector128 static classes
+ private static readonly MethodInfo _v256LongAsULong = Array.Find(typeof(Vector256).GetMethods(),
+ m => m.Name == "AsUInt64" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(long));
+ private static readonly MethodInfo _v256IntAsUInt = Array.Find(typeof(Vector256).GetMethods(),
+ m => m.Name == "AsUInt32" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(int));
+ private static readonly MethodInfo _v256ShortAsUShort = Array.Find(typeof(Vector256).GetMethods(),
+ m => m.Name == "AsUInt16" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(short));
+
+ private static readonly MethodInfo _v128LongAsULong = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "AsUInt64" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(long));
+ private static readonly MethodInfo _v128IntAsUInt = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "AsUInt32" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(int));
+ private static readonly MethodInfo _v128ShortAsUShort = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "AsUInt16" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(short));
+
+ private static readonly MethodInfo _v256GreaterThanULong = Array.Find(typeof(Vector256).GetMethods(),
+ m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong));
+ private static readonly MethodInfo _v256GreaterThanUInt = Array.Find(typeof(Vector256).GetMethods(),
+ m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint));
+ private static readonly MethodInfo _v256GreaterThanUShort = Array.Find(typeof(Vector256).GetMethods(),
+ m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort));
+ private static readonly MethodInfo _v256GreaterThanByte = Array.Find(typeof(Vector256).GetMethods(),
+ m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte));
+
+ private static readonly MethodInfo _v128GreaterThanULong = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ulong));
+ private static readonly MethodInfo _v128GreaterThanUInt = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(uint));
+ private static readonly MethodInfo _v128GreaterThanUShort = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(ushort));
+ private static readonly MethodInfo _v128GreaterThanByte = Array.Find(typeof(Vector128).GetMethods(),
+ m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte));
+
+ private static readonly FieldInfo _v256ZeroULong = typeof(Vector256).GetProperty("Zero")!.GetMethod!.IsStatic
+ ? null! : null!; // Use GetMethod call instead
+ private static readonly MethodInfo _v256GetZeroULong = typeof(Vector256).GetProperty("Zero")!.GetMethod!;
+ private static readonly MethodInfo _v256GetZeroUInt = typeof(Vector256).GetProperty("Zero")!.GetMethod!;
+ private static readonly MethodInfo _v256GetZeroUShort = typeof(Vector256).GetProperty("Zero")!.GetMethod!;
+ private static readonly MethodInfo _v256GetZeroByte = typeof(Vector256).GetProperty("Zero")!.GetMethod!;
+
+ private static readonly MethodInfo _v128GetZeroULong = typeof(Vector128).GetProperty("Zero")!.GetMethod!;
+ private static readonly MethodInfo _v128GetZeroUInt = typeof(Vector128).GetProperty("Zero")!.GetMethod!;
+ private static readonly MethodInfo _v128GetZeroUShort = typeof(Vector128).GetProperty("Zero")!.GetMethod!;
+ private static readonly MethodInfo _v128GetZeroByte = typeof(Vector128).GetProperty("Zero")!.GetMethod!;
+
+ ///
+ /// Emit inline V256 mask creation. Stack: byte* -> Vector256{T} (as mask)
+ ///
+ private static void EmitInlineMaskCreationV256(ILGenerator il, int elementSize)
+ {
+ // Stack has: byte* pointing to condition bools
+
+ switch (elementSize)
+ {
+ case 8: // double/long: load 4 bytes, expand to 4 qwords
+ // *(uint*)ptr
+ il.Emit(OpCodes.Ldind_U4);
+ // Vector128.CreateScalar(value)
+ il.Emit(OpCodes.Call, _v128CreateScalarUInt);
+ // .AsByte()
+ il.Emit(OpCodes.Call, _v128UIntAsByte);
+ // Avx2.ConvertToVector256Int64(bytes)
+ il.Emit(OpCodes.Call, _avx2ConvertToV256Int64);
+ // .AsUInt64()
+ il.Emit(OpCodes.Call, _v256LongAsULong);
+ // Vector256.Zero
+ il.Emit(OpCodes.Call, _v256GetZeroULong);
+ // Vector256.GreaterThan(expanded, zero)
+ il.Emit(OpCodes.Call, _v256GreaterThanULong);
+ break;
+
+ case 4: // float/int: load 8 bytes, expand to 8 dwords
+ // *(ulong*)ptr
+ il.Emit(OpCodes.Ldind_I8);
+ // Vector128.CreateScalar(value)
+ il.Emit(OpCodes.Call, _v128CreateScalarULong);
+ // .AsByte()
+ il.Emit(OpCodes.Call, _v128ULongAsByte);
+ // Avx2.ConvertToVector256Int32(bytes)
+ il.Emit(OpCodes.Call, _avx2ConvertToV256Int32);
+ // .AsUInt32()
+ il.Emit(OpCodes.Call, _v256IntAsUInt);
+ // Vector256.Zero
+ il.Emit(OpCodes.Call, _v256GetZeroUInt);
+ // Vector256.GreaterThan(expanded, zero)
+ il.Emit(OpCodes.Call, _v256GreaterThanUInt);
+ break;
+
+ case 2: // short/char: load 16 bytes, expand to 16 words
+ // Vector128.Load(ptr)
+ il.Emit(OpCodes.Call, _v128LoadByte);
+ // Avx2.ConvertToVector256Int16(bytes)
+ il.Emit(OpCodes.Call, _avx2ConvertToV256Int16);
+ // .AsUInt16()
+ il.Emit(OpCodes.Call, _v256ShortAsUShort);
+ // Vector256.Zero
+ il.Emit(OpCodes.Call, _v256GetZeroUShort);
+ // Vector256.GreaterThan(expanded, zero)
+ il.Emit(OpCodes.Call, _v256GreaterThanUShort);
+ break;
+
+ case 1: // byte/bool: load 32 bytes, compare directly
+ // Vector256.Load(ptr)
+ il.Emit(OpCodes.Call, _v256LoadByte);
+ // Vector256.Zero
+ il.Emit(OpCodes.Call, _v256GetZeroByte);
+ // Vector256.GreaterThan(vec, zero)
+ il.Emit(OpCodes.Call, _v256GreaterThanByte);
+ break;
+
+ default:
+ throw new NotSupportedException($"Element size {elementSize} not supported");
+ }
+ }
+
+ ///
+ /// Emit inline V128 mask creation. Stack: byte* -> Vector128{T} (as mask)
+ ///
+ private static void EmitInlineMaskCreationV128(ILGenerator il, int elementSize)
+ {
+ switch (elementSize)
+ {
+ case 8: // double/long: load 2 bytes, expand to 2 qwords
+ // *(ushort*)ptr
+ il.Emit(OpCodes.Ldind_U2);
+ // Vector128.CreateScalar(value)
+ il.Emit(OpCodes.Call, _v128CreateScalarUShort);
+ // .AsByte()
+ il.Emit(OpCodes.Call, _v128UShortAsByte);
+ // Sse41.ConvertToVector128Int64(bytes)
+ il.Emit(OpCodes.Call, _sse41ConvertToV128Int64);
+ // .AsUInt64()
+ il.Emit(OpCodes.Call, _v128LongAsULong);
+ // Vector128.Zero
+ il.Emit(OpCodes.Call, _v128GetZeroULong);
+ // Vector128.GreaterThan(expanded, zero)
+ il.Emit(OpCodes.Call, _v128GreaterThanULong);
+ break;
+
+ case 4: // float/int: load 4 bytes, expand to 4 dwords
+ // *(uint*)ptr
+ il.Emit(OpCodes.Ldind_U4);
+ // Vector128.CreateScalar(value)
+ il.Emit(OpCodes.Call, _v128CreateScalarUInt);
+ // .AsByte()
+ il.Emit(OpCodes.Call, _v128UIntAsByte);
+ // Sse41.ConvertToVector128Int32(bytes)
+ il.Emit(OpCodes.Call, _sse41ConvertToV128Int32);
+ // .AsUInt32()
+ il.Emit(OpCodes.Call, _v128IntAsUInt);
+ // Vector128.Zero
+ il.Emit(OpCodes.Call, _v128GetZeroUInt);
+ // Vector128.GreaterThan(expanded, zero)
+ il.Emit(OpCodes.Call, _v128GreaterThanUInt);
+ break;
+
+ case 2: // short/char: load 8 bytes, expand to 8 words
+ // *(ulong*)ptr
+ il.Emit(OpCodes.Ldind_I8);
+ // Vector128.CreateScalar(value)
+ il.Emit(OpCodes.Call, _v128CreateScalarULong);
+ // .AsByte()
+ il.Emit(OpCodes.Call, _v128ULongAsByte);
+ // Sse41.ConvertToVector128Int16(bytes)
+ il.Emit(OpCodes.Call, _sse41ConvertToV128Int16);
+ // .AsUInt16()
+ il.Emit(OpCodes.Call, _v128ShortAsUShort);
+ // Vector128.Zero
+ il.Emit(OpCodes.Call, _v128GetZeroUShort);
+ // Vector128.GreaterThan(expanded, zero)
+ il.Emit(OpCodes.Call, _v128GreaterThanUShort);
+ break;
+
+ case 1: // byte/bool: load 16 bytes, compare directly
+ // Vector128.Load(ptr)
+ il.Emit(OpCodes.Call, _v128LoadByte);
+ // Vector128.Zero
+ il.Emit(OpCodes.Call, _v128GetZeroByte);
+ // Vector128.GreaterThan(vec, zero)
+ il.Emit(OpCodes.Call, _v128GreaterThanByte);
+ break;
+
+ default:
+ throw new NotSupportedException($"Element size {elementSize} not supported");
+ }
+ }
+
+ #endregion
+
+ #region Static Mask Creation Methods (fallback)
+
+ ///
+ /// Create V256 mask from 32 bools for 1-byte elements.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static unsafe Vector256 CreateMaskV256_1Byte(byte* bools)
+ {
+ var vec = Vector256.Load(bools);
+ var zero = Vector256.Zero;
+ var isZero = Vector256.Equals(vec, zero);
+ return Vector256.OnesComplement(isZero);
+ }
+
+ ///
+ /// Create V256 mask from 16 bools for 2-byte elements.
+ /// Uses AVX2 vpmovzxbw instruction for single-instruction expansion.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static unsafe Vector256 CreateMaskV256_2Byte(byte* bools)
+ {
+ if (Avx2.IsSupported)
+ {
+ // Load 16 bytes into Vector128, zero-extend each byte to 16-bit
+ // vpmovzxbw: byte -> word (16 bytes -> 16 words)
+ var bytes128 = Vector128.Load(bools);
+ var expanded = Avx2.ConvertToVector256Int16(bytes128).AsUInt16();
+ // Compare with zero: non-zero becomes 0xFFFF, zero stays 0
+ return Vector256.GreaterThan(expanded, Vector256.Zero);
+ }
+
+ // Scalar fallback for non-AVX2 systems
+ return Vector256.Create(
+ bools[0] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[1] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[2] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[3] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[4] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[5] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[6] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[7] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[8] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[9] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[10] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[11] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[12] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[13] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[14] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[15] != 0 ? (ushort)0xFFFF : (ushort)0
+ );
+ }
+
+ ///
+ /// Create V256 mask from 8 bools for 4-byte elements.
+ /// Uses AVX2 vpmovzxbd instruction for single-instruction expansion.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static unsafe Vector256 CreateMaskV256_4Byte(byte* bools)
+ {
+ if (Avx2.IsSupported)
+ {
+ // Load 8 bytes into low bytes of Vector128, zero-extend each byte to 32-bit
+ // vpmovzxbd: byte -> dword (8 bytes -> 8 dwords)
+ var bytes128 = Vector128.CreateScalar(*(ulong*)bools).AsByte();
+ var expanded = Avx2.ConvertToVector256Int32(bytes128).AsUInt32();
+ // Compare with zero: non-zero becomes 0xFFFF..., zero stays 0
+ return Vector256.GreaterThan(expanded, Vector256.Zero);
+ }
+
+ // Scalar fallback for non-AVX2 systems
+ return Vector256.Create(
+ bools[0] != 0 ? 0xFFFFFFFFu : 0u,
+ bools[1] != 0 ? 0xFFFFFFFFu : 0u,
+ bools[2] != 0 ? 0xFFFFFFFFu : 0u,
+ bools[3] != 0 ? 0xFFFFFFFFu : 0u,
+ bools[4] != 0 ? 0xFFFFFFFFu : 0u,
+ bools[5] != 0 ? 0xFFFFFFFFu : 0u,
+ bools[6] != 0 ? 0xFFFFFFFFu : 0u,
+ bools[7] != 0 ? 0xFFFFFFFFu : 0u
+ );
+ }
+
+ ///
+ /// Create V256 mask from 4 bools for 8-byte elements.
+ /// Uses AVX2 vpmovzxbq instruction for single-instruction expansion.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static unsafe Vector256 CreateMaskV256_8Byte(byte* bools)
+ {
+ if (Avx2.IsSupported)
+ {
+ // Load 4 bytes into low bytes of Vector128, zero-extend each byte to 64-bit
+ // vpmovzxbq: byte -> qword (4 bytes -> 4 qwords)
+ var bytes128 = Vector128.CreateScalar(*(uint*)bools).AsByte();
+ var expanded = Avx2.ConvertToVector256Int64(bytes128).AsUInt64();
+ // Compare with zero: non-zero becomes 0xFFFF..., zero stays 0
+ return Vector256.GreaterThan(expanded, Vector256.Zero);
+ }
+
+ // Scalar fallback for non-AVX2 systems
+ return Vector256.Create(
+ bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul,
+ bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul,
+ bools[2] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul,
+ bools[3] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul
+ );
+ }
+
+ ///
+ /// Create V128 mask from 16 bools for 1-byte elements.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static unsafe Vector128 CreateMaskV128_1Byte(byte* bools)
+ {
+ var vec = Vector128.Load(bools);
+ var zero = Vector128.Zero;
+ var isZero = Vector128.Equals(vec, zero);
+ return Vector128.OnesComplement(isZero);
+ }
+
+ ///
+ /// Create V128 mask from 8 bools for 2-byte elements.
+ /// Uses SSE4.1 pmovzxbw instruction for efficient expansion.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static unsafe Vector128 CreateMaskV128_2Byte(byte* bools)
+ {
+ if (Sse41.IsSupported)
+ {
+ // Load 8 bytes, zero-extend each to 16-bit
+ // pmovzxbw: byte -> word (8 bytes -> 8 words)
+ var bytes128 = Vector128.CreateScalar(*(ulong*)bools).AsByte();
+ var expanded = Sse41.ConvertToVector128Int16(bytes128).AsUInt16();
+ return Vector128.GreaterThan(expanded, Vector128.Zero);
+ }
+
+ // Scalar fallback
+ return Vector128.Create(
+ bools[0] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[1] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[2] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[3] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[4] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[5] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[6] != 0 ? (ushort)0xFFFF : (ushort)0,
+ bools[7] != 0 ? (ushort)0xFFFF : (ushort)0
+ );
+ }
+
+ ///
+ /// Create V128 mask from 4 bools for 4-byte elements.
+ /// Uses SSE4.1 pmovzxbd instruction for efficient expansion.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static unsafe Vector128 CreateMaskV128_4Byte(byte* bools)
+ {
+ if (Sse41.IsSupported)
+ {
+ // Load 4 bytes, zero-extend each to 32-bit
+ // pmovzxbd: byte -> dword (4 bytes -> 4 dwords)
+ var bytes128 = Vector128.CreateScalar(*(uint*)bools).AsByte();
+ var expanded = Sse41.ConvertToVector128Int32(bytes128).AsUInt32();
+ return Vector128.GreaterThan(expanded, Vector128.Zero);
+ }
+
+ // Scalar fallback
+ return Vector128.Create(
+ bools[0] != 0 ? 0xFFFFFFFFu : 0u,
+ bools[1] != 0 ? 0xFFFFFFFFu : 0u,
+ bools[2] != 0 ? 0xFFFFFFFFu : 0u,
+ bools[3] != 0 ? 0xFFFFFFFFu : 0u
+ );
+ }
+
+ ///
+ /// Create V128 mask from 2 bools for 8-byte elements.
+ /// Uses SSE4.1 pmovzxbq instruction for efficient expansion.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static unsafe Vector128 CreateMaskV128_8Byte(byte* bools)
+ {
+ if (Sse41.IsSupported)
+ {
+ // Load 2 bytes, zero-extend each to 64-bit
+ // pmovzxbq: byte -> qword (2 bytes -> 2 qwords)
+ var bytes128 = Vector128.CreateScalar(*(ushort*)bools).AsByte();
+ var expanded = Sse41.ConvertToVector128Int64(bytes128).AsUInt64();
+ return Vector128.GreaterThan(expanded, Vector128.Zero);
+ }
+
+ // Scalar fallback
+ return Vector128.Create(
+ bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul,
+ bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul
+ );
+ }
+
+ #endregion
+
+ #region Scalar Fallback
+
+ ///
+ /// Scalar fallback for where operation.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static unsafe void WhereScalar(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged
+ {
+ for (long i = 0; i < count; i++)
+ {
+ result[i] = cond[i] ? x[i] : y[i];
+ }
+ }
+
+ #endregion
+ }
+}
diff --git a/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs b/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs
new file mode 100644
index 00000000..3fc30d17
--- /dev/null
+++ b/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs
@@ -0,0 +1,515 @@
+using System;
+using System.Diagnostics;
+using NumSharp.Backends.Kernels;
+using TUnit.Core;
+using NumSharp.UnitTest.Utilities;
+using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert;
+
+namespace NumSharp.UnitTest.Backends.Kernels
+{
+ ///
+ /// Tests for SIMD-optimized np.where implementation.
+ /// Verifies correctness of the SIMD path for all supported dtypes.
+ ///
+ public class WhereSimdTests
+ {
+ #region SIMD Correctness
+
+ [Test]
+ public void Where_Simd_Float32_Correctness()
+ {
+ var rng = np.random.RandomState(42);
+ var size = 1000;
+ var cond = rng.rand(size) > 0.5;
+ var x = rng.rand(size).astype(NPTypeCode.Single);
+ var y = rng.rand(size).astype(NPTypeCode.Single);
+
+ var result = np.where(cond, x, y);
+
+ // Verify correctness manually
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (float)x[i] : (float)y[i];
+ Assert.AreEqual(expected, (float)result[i], 1e-6f, $"Mismatch at index {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_Float64_Correctness()
+ {
+ var rng = np.random.RandomState(43);
+ var size = 1000;
+ var cond = rng.rand(size) > 0.5;
+ var x = rng.rand(size);
+ var y = rng.rand(size);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (double)x[i] : (double)y[i];
+ Assert.AreEqual(expected, (double)result[i], 1e-10, $"Mismatch at index {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_Int32_Correctness()
+ {
+ var rng = np.random.RandomState(44);
+ var size = 1000;
+ var cond = rng.rand(size) > 0.5;
+ var x = rng.randint(0, 1000, new[] { size });
+ var y = rng.randint(0, 1000, new[] { size });
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (int)x[i] : (int)y[i];
+ Assert.AreEqual(expected, (int)result[i], $"Mismatch at index {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_Int64_Correctness()
+ {
+ var rng = np.random.RandomState(45);
+ var size = 1000;
+ var cond = rng.rand(size) > 0.5;
+ var x = np.arange(size).astype(NPTypeCode.Int64);
+ var y = np.arange(size, size * 2).astype(NPTypeCode.Int64);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (long)x[i] : (long)y[i];
+ Assert.AreEqual(expected, (long)result[i], $"Mismatch at index {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_Byte_Correctness()
+ {
+ var rng = np.random.RandomState(46);
+ var size = 1000;
+ var cond = rng.rand(size) > 0.5;
+ var x = (rng.rand(size) * 255).astype(NPTypeCode.Byte);
+ var y = (rng.rand(size) * 255).astype(NPTypeCode.Byte);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (byte)x[i] : (byte)y[i];
+ Assert.AreEqual(expected, (byte)result[i], $"Mismatch at index {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_Int16_Correctness()
+ {
+ var rng = np.random.RandomState(47);
+ var size = 1000;
+ var cond = rng.rand(size) > 0.5;
+ var x = np.arange(size).astype(NPTypeCode.Int16);
+ var y = np.arange(size, size * 2).astype(NPTypeCode.Int16);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (short)x[i] : (short)y[i];
+ Assert.AreEqual(expected, (short)result[i], $"Mismatch at index {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_UInt16_Correctness()
+ {
+ var rng = np.random.RandomState(48);
+ var size = 1000;
+ var cond = rng.rand(size) > 0.5;
+ var x = np.arange(size).astype(NPTypeCode.UInt16);
+ var y = np.arange(size, size * 2).astype(NPTypeCode.UInt16);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (ushort)x[i] : (ushort)y[i];
+ Assert.AreEqual(expected, (ushort)result[i], $"Mismatch at index {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_UInt32_Correctness()
+ {
+ var rng = np.random.RandomState(49);
+ var size = 1000;
+ var cond = rng.rand(size) > 0.5;
+ var x = np.arange(size).astype(NPTypeCode.UInt32);
+ var y = np.arange(size, size * 2).astype(NPTypeCode.UInt32);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (uint)x[i] : (uint)y[i];
+ Assert.AreEqual(expected, (uint)result[i], $"Mismatch at index {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_UInt64_Correctness()
+ {
+ var rng = np.random.RandomState(50);
+ var size = 1000;
+ var cond = rng.rand(size) > 0.5;
+ var x = np.arange(size).astype(NPTypeCode.UInt64);
+ var y = np.arange(size, size * 2).astype(NPTypeCode.UInt64);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (ulong)x[i] : (ulong)y[i];
+ Assert.AreEqual(expected, (ulong)result[i], $"Mismatch at index {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_Boolean_Correctness()
+ {
+ var rng = np.random.RandomState(51);
+ var size = 1000;
+ var cond = rng.rand(size) > 0.5;
+ var x = rng.rand(size) > 0.3;
+ var y = rng.rand(size) > 0.7;
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (bool)x[i] : (bool)y[i];
+ Assert.AreEqual(expected, (bool)result[i], $"Mismatch at index {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_Char_Correctness()
+ {
+ var rng = np.random.RandomState(52);
+ var size = 1000;
+ var cond = rng.rand(size) > 0.5;
+ var xData = new char[size];
+ var yData = new char[size];
+ for (int i = 0; i < size; i++)
+ {
+ xData[i] = (char)('A' + (i % 26));
+ yData[i] = (char)('a' + (i % 26));
+ }
+ var x = np.array(xData);
+ var y = np.array(yData);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (char)x[i] : (char)y[i];
+ Assert.AreEqual(expected, (char)result[i], $"Mismatch at index {i}");
+ }
+ }
+
+ #endregion
+
+ #region Path Selection
+
+ [Test]
+ public void Where_NonContiguous_Works()
+ {
+ // Sliced arrays are non-contiguous, should work correctly
+ var baseArr = np.arange(20);
+ var cond = (baseArr % 2 == 0)["::2"]; // Sliced: [true, true, true, true, true, true, true, true, true, true]
+ var x = np.ones(10, NPTypeCode.Int32);
+ var y = np.zeros(10, NPTypeCode.Int32);
+
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(10, result.size);
+ // All true -> all from x
+ for (int i = 0; i < 10; i++)
+ {
+ Assert.AreEqual(1, (int)result[i]);
+ }
+ }
+
+ [Test]
+ public void Where_Broadcast_Works()
+ {
+ // Broadcasted arrays
+ // cond shape (3,) broadcasts to (3,3): [[T,F,T],[T,F,T],[T,F,T]]
+ // x shape (3,1) broadcasts to (3,3): [[1,1,1],[2,2,2],[3,3,3]]
+ // y shape (1,3) broadcasts to (3,3): [[10,20,30],[10,20,30],[10,20,30]]
+ var cond = np.array(new[] { true, false, true });
+ var x = np.array(new int[,] { { 1 }, { 2 }, { 3 } });
+ var y = np.array(new int[,] { { 10, 20, 30 } });
+ var result = np.where(cond, x, y);
+
+ result.Should().BeShaped(3, 3);
+ // Verify values: result[i,j] = cond[j] ? x[i,0] : y[0,j]
+ Assert.AreEqual(1, (int)result[0, 0]); // cond[0]=true -> x=1
+ Assert.AreEqual(20, (int)result[0, 1]); // cond[1]=false -> y=20
+ Assert.AreEqual(1, (int)result[0, 2]); // cond[2]=true -> x=1
+ Assert.AreEqual(2, (int)result[1, 0]); // cond[0]=true -> x=2
+ Assert.AreEqual(20, (int)result[1, 1]); // cond[1]=false -> y=20
+ }
+
+ [Test]
+ public void Where_Decimal_Works()
+ {
+ var cond = np.array(new[] { true, false, true });
+ var x = np.array(new decimal[] { 1.1m, 2.2m, 3.3m });
+ var y = np.array(new decimal[] { 10.1m, 20.2m, 30.3m });
+
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(decimal), result.dtype);
+ Assert.AreEqual(1.1m, (decimal)result[0]);
+ Assert.AreEqual(20.2m, (decimal)result[1]);
+ Assert.AreEqual(3.3m, (decimal)result[2]);
+ }
+
+ [Test]
+ public void Where_NonBoolCondition_Works()
+ {
+ // Non-bool condition requires truthiness check
+ var cond = np.array(new[] { 0, 1, 2, 0 }); // int condition
+ var result = np.where(cond, 100, -100);
+
+ result.Should().BeOfValues(-100, 100, 100, -100);
+ }
+
+ #endregion
+
+ #region Edge Cases
+
+ [Test]
+ public void Where_Simd_SmallArray()
+ {
+ // Array smaller than vector width
+ var cond = np.array(new[] { true, false, true });
+ var x = np.array(new[] { 1, 2, 3 });
+ var y = np.array(new[] { 10, 20, 30 });
+
+ var result = np.where(cond, x, y);
+
+ result.Should().BeOfValues(1, 20, 3);
+ }
+
+ [Test]
+ public void Where_Simd_VectorAlignedSize()
+ {
+ var rng = np.random.RandomState(53);
+ // Size exactly matches vector width (no scalar tail)
+ var size = 32; // V256 byte count
+ var cond = rng.rand(size) > 0.5;
+ var x = np.ones(size, NPTypeCode.Byte);
+ var y = np.zeros(size, NPTypeCode.Byte);
+
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(size, result.size);
+ for (int i = 0; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? (byte)1 : (byte)0;
+ Assert.AreEqual(expected, (byte)result[i]);
+ }
+ }
+
+ [Test]
+ public void Where_Simd_WithScalarTail()
+ {
+ // Size that requires scalar tail processing
+ var size = 35; // 32 + 3 tail for bytes
+ var cond = np.ones(size, NPTypeCode.Boolean);
+ var x = np.full(size, (byte)255);
+ var y = np.zeros(size, NPTypeCode.Byte);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ Assert.AreEqual((byte)255, (byte)result[i], $"Mismatch at {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_AllTrue()
+ {
+ var size = 100;
+ var cond = np.ones(size, NPTypeCode.Boolean);
+ var x = np.arange(size);
+ var y = np.full(size, -1L);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ Assert.AreEqual((long)i, (long)result[i]);
+ }
+ }
+
+ [Test]
+ public void Where_Simd_AllFalse()
+ {
+ var size = 100;
+ var cond = np.zeros(size, NPTypeCode.Boolean);
+ var x = np.arange(size);
+ var y = np.full(size, -1L);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ Assert.AreEqual(-1L, (long)result[i]);
+ }
+ }
+
+ [Test]
+ public void Where_Simd_Alternating()
+ {
+ var size = 100;
+ var condData = new bool[size];
+ for (int i = 0; i < size; i++)
+ condData[i] = i % 2 == 0;
+ var cond = np.array(condData);
+ var x = np.ones(size, NPTypeCode.Int32);
+ var y = np.zeros(size, NPTypeCode.Int32);
+
+ var result = np.where(cond, x, y);
+
+ for (int i = 0; i < size; i++)
+ {
+ Assert.AreEqual(i % 2 == 0 ? 1 : 0, (int)result[i], $"Mismatch at {i}");
+ }
+ }
+
+ [Test]
+ public void Where_Simd_NaN_Propagates()
+ {
+ var cond = np.array(new[] { true, false, true });
+ var x = np.array(new[] { double.NaN, 1.0, 2.0 });
+ var y = np.array(new[] { 0.0, double.NaN, 0.0 });
+
+ var result = np.where(cond, x, y);
+
+ Assert.IsTrue(double.IsNaN((double)result[0])); // NaN from x
+ Assert.IsTrue(double.IsNaN((double)result[1])); // NaN from y
+ Assert.AreEqual(2.0, (double)result[2], 1e-10);
+ }
+
+ [Test]
+ public void Where_Simd_Infinity()
+ {
+ var cond = np.array(new[] { true, false, true, false });
+ var x = np.array(new[] { double.PositiveInfinity, 0.0, double.NegativeInfinity, 0.0 });
+ var y = np.array(new[] { 0.0, double.PositiveInfinity, 0.0, double.NegativeInfinity });
+
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(double.PositiveInfinity, (double)result[0]);
+ Assert.AreEqual(double.PositiveInfinity, (double)result[1]);
+ Assert.AreEqual(double.NegativeInfinity, (double)result[2]);
+ Assert.AreEqual(double.NegativeInfinity, (double)result[3]);
+ }
+
+ #endregion
+
+ #region Performance Sanity Check
+
+ [Test]
+ public void Where_Simd_LargeArray_Correctness()
+ {
+ var rng = np.random.RandomState(54);
+ var size = 100_000;
+ var cond = rng.rand(size) > 0.5;
+ var x = np.ones(size, NPTypeCode.Double);
+ var y = np.zeros(size, NPTypeCode.Double);
+
+ var result = np.where(cond, x, y);
+
+ // Spot check
+ for (int i = 0; i < 100; i++)
+ {
+ var expected = (bool)cond[i] ? 1.0 : 0.0;
+ Assert.AreEqual(expected, (double)result[i], 1e-10);
+ }
+
+ // Check last few elements (scalar tail)
+ for (int i = size - 10; i < size; i++)
+ {
+ var expected = (bool)cond[i] ? 1.0 : 0.0;
+ Assert.AreEqual(expected, (double)result[i], 1e-10);
+ }
+ }
+
+ #endregion
+
+ #region 2D/Multi-dimensional
+
+ [Test]
+ public void Where_Simd_2D_Contiguous()
+ {
+ var rng = np.random.RandomState(55);
+ // 2D contiguous array should use SIMD
+ var shape = new[] { 100, 100 };
+ var cond = rng.rand(shape) > 0.5;
+ var x = np.ones(shape, NPTypeCode.Int32);
+ var y = np.zeros(shape, NPTypeCode.Int32);
+
+ var result = np.where(cond, x, y);
+
+ result.Should().BeShaped(100, 100);
+
+ // Spot check
+ for (int i = 0; i < 10; i++)
+ {
+ for (int j = 0; j < 10; j++)
+ {
+ var expected = (bool)cond[i, j] ? 1 : 0;
+ Assert.AreEqual(expected, (int)result[i, j]);
+ }
+ }
+ }
+
+ [Test]
+ public void Where_Simd_3D_Contiguous()
+ {
+ var rng = np.random.RandomState(56);
+ var shape = new[] { 10, 20, 30 };
+ var cond = rng.rand(shape) > 0.5;
+ var x = np.ones(shape, NPTypeCode.Single);
+ var y = np.zeros(shape, NPTypeCode.Single);
+
+ var result = np.where(cond, x, y);
+
+ result.Should().BeShaped(10, 20, 30);
+
+ // Spot check
+ for (int i = 0; i < 5; i++)
+ {
+ for (int j = 0; j < 5; j++)
+ {
+ for (int k = 0; k < 5; k++)
+ {
+ var expected = (bool)cond[i, j, k] ? 1.0f : 0.0f;
+ Assert.AreEqual(expected, (float)result[i, j, k], 1e-6f);
+ }
+ }
+ }
+ }
+
+ #endregion
+ }
+}
diff --git a/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs
new file mode 100644
index 00000000..5834d16a
--- /dev/null
+++ b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs
@@ -0,0 +1,346 @@
+using System;
+using System.Linq;
+using TUnit.Core;
+using NumSharp.UnitTest.Utilities;
+using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert;
+
+namespace NumSharp.UnitTest.Logic
+{
+ ///
+ /// Battle tests for np.where - edge cases, strided arrays, views, etc.
+ ///
+ public class np_where_BattleTest
+ {
+ #region Strided/Sliced Arrays
+
+ [Test]
+ public void Where_SlicedCondition()
+ {
+ // Sliced condition array
+ var arr = np.arange(10);
+ var cond = (arr % 2 == 0)["::2"]; // Every other even check
+ var x = np.ones(5, NPTypeCode.Int32);
+ var y = np.zeros(5, NPTypeCode.Int32);
+ var result = np.where(cond, x, y);
+
+ // Should work with sliced condition
+ Assert.AreEqual(5, result.size);
+ }
+
+ [Test]
+ public void Where_SlicedXY()
+ {
+ var cond = np.array(new[] { true, false, true });
+ var x = np.arange(6)["::2"]; // [0, 2, 4]
+ var y = np.arange(6)["1::2"]; // [1, 3, 5]
+ var result = np.where(cond, x, y);
+
+ result.Should().BeOfValues(0L, 3L, 4L);
+ }
+
+ [Test]
+ public void Where_TransposedArrays()
+ {
+ var cond = np.array(new bool[,] { { true, false }, { false, true } }).T;
+ var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } }).T;
+ var y = np.array(new int[,] { { 10, 20 }, { 30, 40 } }).T;
+ var result = np.where(cond, x, y);
+
+ result.Should().BeShaped(2, 2);
+ // After transpose: cond[0,0]=T, cond[0,1]=F, cond[1,0]=F, cond[1,1]=T
+ Assert.AreEqual(1, (int)result[0, 0]);
+ Assert.AreEqual(30, (int)result[0, 1]);
+ Assert.AreEqual(20, (int)result[1, 0]);
+ Assert.AreEqual(4, (int)result[1, 1]);
+ }
+
+ [Test]
+ public void Where_ReversedSlice()
+ {
+ var cond = np.array(new[] { true, false, true, false, true });
+ var x = np.arange(5)["::-1"]; // [4, 3, 2, 1, 0]
+ var y = np.zeros(5, NPTypeCode.Int64);
+ var result = np.where(cond, x, y);
+
+ result.Should().BeOfValues(4L, 0L, 2L, 0L, 0L);
+ }
+
+ #endregion
+
+ #region Complex Broadcasting
+
+ [Test]
+ public void Where_3Way_Broadcasting()
+ {
+ // cond: (2,1,1), x: (1,3,1), y: (1,1,4) -> result: (2,3,4)
+ var cond = np.array(new bool[,,] { { { true } }, { { false } } });
+ var x = np.arange(3).reshape(1, 3, 1);
+ var y = (np.arange(4) * 10).reshape(1, 1, 4);
+ var result = np.where(cond, x, y);
+
+ result.Should().BeShaped(2, 3, 4);
+ // First "page" (cond=True): values from x broadcast
+ Assert.AreEqual(0, (long)result[0, 0, 0]);
+ Assert.AreEqual(0, (long)result[0, 0, 3]);
+ Assert.AreEqual(2, (long)result[0, 2, 0]);
+ // Second "page" (cond=False): values from y broadcast
+ Assert.AreEqual(0, (long)result[1, 0, 0]);
+ Assert.AreEqual(30, (long)result[1, 0, 3]);
+ Assert.AreEqual(30, (long)result[1, 2, 3]);
+ }
+
+ [Test]
+ public void Where_RowVector_ColVector_Broadcast()
+ {
+ // cond: (1,4), x: (3,1), y: scalar -> result: (3,4)
+ var cond = np.array(new bool[,] { { true, false, true, false } });
+ var x = np.array(new int[,] { { 1 }, { 2 }, { 3 } });
+ var result = np.where(cond, x, 0);
+
+ result.Should().BeShaped(3, 4);
+ Assert.AreEqual(1, (int)result[0, 0]);
+ Assert.AreEqual(0, (int)result[0, 1]);
+ Assert.AreEqual(2, (int)result[1, 0]);
+ Assert.AreEqual(0, (int)result[1, 1]);
+ }
+
+ #endregion
+
+ #region Numeric Edge Cases
+
+ [Test]
+ public void Where_NaN_Values()
+ {
+ var cond = np.array(new[] { true, false, true });
+ var x = np.array(new[] { double.NaN, 1.0, double.NaN });
+ var y = np.array(new[] { 0.0, double.NaN, 0.0 });
+ var result = np.where(cond, x, y);
+
+ Assert.IsTrue(double.IsNaN((double)result[0])); // from x
+ Assert.IsTrue(double.IsNaN((double)result[1])); // from y
+ Assert.IsTrue(double.IsNaN((double)result[2])); // from x
+ }
+
+ [Test]
+ public void Where_Infinity_Values()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new[] { double.PositiveInfinity, 1.0 });
+ var y = np.array(new[] { 0.0, double.NegativeInfinity });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(double.PositiveInfinity, (double)result[0]);
+ Assert.AreEqual(double.NegativeInfinity, (double)result[1]);
+ }
+
+ [Test]
+ public void Where_MaxMin_Values()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new[] { long.MaxValue, 0L });
+ var y = np.array(new[] { 0L, long.MinValue });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(long.MaxValue, (long)result[0]);
+ Assert.AreEqual(long.MinValue, (long)result[1]);
+ }
+
+ #endregion
+
+ #region Single Arg Edge Cases
+
+ [Test]
+ public void Where_SingleArg_Float_Truthy()
+ {
+ // 0.0 is falsy, anything else (including -0.0, NaN, Inf) is truthy
+ var arr = np.array(new[] { 0.0, 1.0, -1.0, 0.5, -0.0 });
+ var result = np.where(arr);
+
+ // Note: -0.0 == 0.0 in IEEE 754, so it's falsy
+ result[0].Should().BeOfValues(1L, 2L, 3L);
+ }
+
+ [Test]
+ public void Where_SingleArg_NaN_IsTruthy()
+ {
+ // NaN is non-zero, so it's truthy
+ var arr = np.array(new[] { 0.0, double.NaN, 0.0 });
+ var result = np.where(arr);
+
+ result[0].Should().BeOfValues(1L);
+ }
+
+ [Test]
+ public void Where_SingleArg_4D()
+ {
+ var arr = np.zeros(new[] { 2, 2, 2, 2 }, NPTypeCode.Int32);
+ arr[0, 1, 0, 1] = 1;
+ arr[1, 0, 1, 0] = 1;
+ var result = np.where(arr);
+
+ Assert.AreEqual(4, result.Length); // 4 dimensions
+ Assert.AreEqual(2, result[0].size); // 2 non-zero elements
+ }
+
+ #endregion
+
+ #region Performance/Stress Tests
+
+ [Test]
+ public void Where_LargeArray_Performance()
+ {
+ var size = 1_000_000;
+ var cond = np.random.rand(size) > 0.5;
+ var x = np.ones(size, NPTypeCode.Double);
+ var y = np.zeros(size, NPTypeCode.Double);
+
+ var sw = System.Diagnostics.Stopwatch.StartNew();
+ var result = np.where(cond, x, y);
+ sw.Stop();
+
+ Assert.AreEqual(size, result.size);
+ // Should complete in reasonable time (< 1 second for 1M elements)
+ Assert.IsTrue(sw.ElapsedMilliseconds < 1000, $"Took {sw.ElapsedMilliseconds}ms");
+ }
+
+ [Test]
+ public void Where_ManyDimensions()
+ {
+ // 6D array
+ var shape = new[] { 2, 3, 2, 2, 2, 3 };
+ var cond = np.ones(shape, NPTypeCode.Boolean);
+ var x = np.ones(shape, NPTypeCode.Int32);
+ var y = np.zeros(shape, NPTypeCode.Int32);
+ var result = np.where(cond, x, y);
+
+ result.Should().BeShaped(2, 3, 2, 2, 2, 3);
+ Assert.AreEqual(1, (int)result[0, 0, 0, 0, 0, 0]);
+ }
+
+ #endregion
+
+ #region Type Conversion Edge Cases
+
+ [Test]
+ public void Where_UnsignedOverflow()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new byte[] { 255, 0 });
+ var y = np.array(new byte[] { 0, 255 });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(byte), result.dtype);
+ Assert.AreEqual((byte)255, (byte)result[0]);
+ Assert.AreEqual((byte)255, (byte)result[1]);
+ }
+
+ [Test]
+ public void Where_Decimal()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new decimal[] { 1.23456789m, 0m });
+ var y = np.array(new decimal[] { 0m, 9.87654321m });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(decimal), result.dtype);
+ Assert.AreEqual(1.23456789m, (decimal)result[0]);
+ Assert.AreEqual(9.87654321m, (decimal)result[1]);
+ }
+
+ [Test]
+ public void Where_Char()
+ {
+ var cond = np.array(new[] { true, false, true });
+ var x = np.array(new char[] { 'A', 'B', 'C' });
+ var y = np.array(new char[] { 'X', 'Y', 'Z' });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(char), result.dtype);
+ Assert.AreEqual('A', (char)result[0]);
+ Assert.AreEqual('Y', (char)result[1]);
+ Assert.AreEqual('C', (char)result[2]);
+ }
+
+ #endregion
+
+ #region View Behavior
+
+ [Test]
+ public void Where_ResultIsNewArray_NotView()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new[] { 1, 2 });
+ var y = np.array(new[] { 10, 20 });
+ var result = np.where(cond, x, y);
+
+ // Modify original, result should not change
+ x[0] = 999;
+ Assert.AreEqual(1, (int)result[0], "Result should be independent of x");
+
+ y[1] = 999;
+ Assert.AreEqual(20, (int)result[1], "Result should be independent of y");
+ }
+
+ [Test]
+ public void Where_ModifyResult_DoesNotAffectInputs()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new[] { 1, 2 });
+ var y = np.array(new[] { 10, 20 });
+ var result = np.where(cond, x, y);
+
+ result[0] = 999;
+ Assert.AreEqual(1, (int)x[0], "x should not be modified");
+ Assert.AreEqual(10, (int)y[0], "y should not be modified");
+ }
+
+ #endregion
+
+ #region Alternating Patterns
+
+ [Test]
+ public void Where_Checkerboard_Pattern()
+ {
+ // Create checkerboard condition
+ var cond = np.zeros(new[] { 4, 4 }, NPTypeCode.Boolean);
+ for (int i = 0; i < 4; i++)
+ for (int j = 0; j < 4; j++)
+ cond[i, j] = (i + j) % 2 == 0;
+
+ var x = np.ones(new[] { 4, 4 }, NPTypeCode.Int32);
+ var y = np.zeros(new[] { 4, 4 }, NPTypeCode.Int32);
+ var result = np.where(cond, x, y);
+
+ // Verify checkerboard pattern
+ Assert.AreEqual(1, (int)result[0, 0]);
+ Assert.AreEqual(0, (int)result[0, 1]);
+ Assert.AreEqual(0, (int)result[1, 0]);
+ Assert.AreEqual(1, (int)result[1, 1]);
+ }
+
+ [Test]
+ public void Where_StripedPattern()
+ {
+ // Every row alternates between all True and all False
+ var cond = np.zeros(new[] { 4, 4 }, NPTypeCode.Boolean);
+ for (int i = 0; i < 4; i++)
+ for (int j = 0; j < 4; j++)
+ cond[i, j] = i % 2 == 0;
+
+ var x = np.full(new[] { 4, 4 }, 1);
+ var y = np.full(new[] { 4, 4 }, 0);
+ var result = np.where(cond, x, y);
+
+ // Rows 0, 2 should be 1; rows 1, 3 should be 0
+ for (int j = 0; j < 4; j++)
+ {
+ Assert.AreEqual(1, (int)result[0, j]);
+ Assert.AreEqual(0, (int)result[1, j]);
+ Assert.AreEqual(1, (int)result[2, j]);
+ Assert.AreEqual(0, (int)result[3, j]);
+ }
+ }
+
+ #endregion
+ }
+}
diff --git a/test/NumSharp.UnitTest/Logic/np.where.Test.cs b/test/NumSharp.UnitTest/Logic/np.where.Test.cs
new file mode 100644
index 00000000..d53ed585
--- /dev/null
+++ b/test/NumSharp.UnitTest/Logic/np.where.Test.cs
@@ -0,0 +1,496 @@
+using System;
+using System.Linq;
+using TUnit.Core;
+using NumSharp.UnitTest.Utilities;
+using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert;
+
+namespace NumSharp.UnitTest.Logic
+{
+ ///
+ /// Comprehensive tests for np.where matching NumPy 2.x behavior.
+ ///
+ /// NumPy signature: where(condition, x=None, y=None, /)
+ /// - Single arg: returns np.nonzero(condition)
+ /// - Three args: element-wise selection with broadcasting
+ ///
+ public class np_where_Test
+ {
+ #region Single Argument (nonzero equivalent)
+
+ [Test]
+ public void Where_SingleArg_1D_ReturnsIndices()
+ {
+ // np.where([0, 1, 0, 2, 0, 3]) -> (array([1, 3, 5]),)
+ var arr = np.array(new[] { 0, 1, 0, 2, 0, 3 });
+ var result = np.where(arr);
+
+ Assert.AreEqual(1, result.Length);
+ result[0].Should().BeOfValues(1L, 3L, 5L);
+ }
+
+ [Test]
+ public void Where_SingleArg_2D_ReturnsTupleOfIndices()
+ {
+ // np.where([[0, 1, 0], [2, 0, 3]]) -> (array([0, 1, 1]), array([1, 0, 2]))
+ var arr = np.array(new int[,] { { 0, 1, 0 }, { 2, 0, 3 } });
+ var result = np.where(arr);
+
+ Assert.AreEqual(2, result.Length);
+ result[0].Should().BeOfValues(0L, 1L, 1L); // row indices
+ result[1].Should().BeOfValues(1L, 0L, 2L); // col indices
+ }
+
+ [Test]
+ public void Where_SingleArg_Boolean_ReturnsNonzero()
+ {
+ var arr = np.array(new[] { true, false, true, false, true });
+ var result = np.where(arr);
+
+ Assert.AreEqual(1, result.Length);
+ result[0].Should().BeOfValues(0L, 2L, 4L);
+ }
+
+ [Test]
+ public void Where_SingleArg_Empty_ReturnsEmptyIndices()
+ {
+ var arr = np.array(new int[0]);
+ var result = np.where(arr);
+
+ Assert.AreEqual(1, result.Length);
+ Assert.AreEqual(0, result[0].size);
+ }
+
+ [Test]
+ public void Where_SingleArg_AllFalse_ReturnsEmptyIndices()
+ {
+ var arr = np.array(new[] { false, false, false });
+ var result = np.where(arr);
+
+ Assert.AreEqual(1, result.Length);
+ Assert.AreEqual(0, result[0].size);
+ }
+
+ [Test]
+ public void Where_SingleArg_AllTrue_ReturnsAllIndices()
+ {
+ var arr = np.array(new[] { true, true, true });
+ var result = np.where(arr);
+
+ result[0].Should().BeOfValues(0L, 1L, 2L);
+ }
+
+ [Test]
+ public void Where_SingleArg_3D_ReturnsTupleOfThreeArrays()
+ {
+ // 2x2x2 array with some non-zero elements
+ var arr = np.zeros(new[] { 2, 2, 2 }, NPTypeCode.Int32);
+ arr[0, 0, 1] = 1;
+ arr[1, 1, 0] = 1;
+ var result = np.where(arr);
+
+ Assert.AreEqual(3, result.Length);
+ result[0].Should().BeOfValues(0L, 1L); // dim 0
+ result[1].Should().BeOfValues(0L, 1L); // dim 1
+ result[2].Should().BeOfValues(1L, 0L); // dim 2
+ }
+
+ #endregion
+
+ #region Three Arguments (element-wise selection)
+
+ [Test]
+ public void Where_ThreeArgs_Basic_SelectsCorrectly()
+ {
+ // np.where(a < 5, a, 10*a) for a = arange(10)
+ var a = np.arange(10);
+ var result = np.where(a < 5, a, 10 * a);
+
+ result.Should().BeOfValues(0L, 1L, 2L, 3L, 4L, 50L, 60L, 70L, 80L, 90L);
+ }
+
+ [Test]
+ public void Where_ThreeArgs_BooleanCondition()
+ {
+ var cond = np.array(new[] { true, false, true, false });
+ var x = np.array(new[] { 1, 2, 3, 4 });
+ var y = np.array(new[] { 10, 20, 30, 40 });
+ var result = np.where(cond, x, y);
+
+ result.Should().BeOfValues(1, 20, 3, 40);
+ }
+
+ [Test]
+ public void Where_ThreeArgs_2D()
+ {
+ // np.where([[True, False], [True, True]], [[1, 2], [3, 4]], [[9, 8], [7, 6]])
+ var cond = np.array(new bool[,] { { true, false }, { true, true } });
+ var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } });
+ var y = np.array(new int[,] { { 9, 8 }, { 7, 6 } });
+ var result = np.where(cond, x, y);
+
+ result.Should().BeShaped(2, 2);
+ Assert.AreEqual(1, (int)result[0, 0]);
+ Assert.AreEqual(8, (int)result[0, 1]);
+ Assert.AreEqual(3, (int)result[1, 0]);
+ Assert.AreEqual(4, (int)result[1, 1]);
+ }
+
+ [Test]
+ public void Where_ThreeArgs_NonBoolCondition_TreatsAsTruthy()
+ {
+ // np.where([0, 1, 2, 0], 100, -100) -> [-100, 100, 100, -100]
+ var cond = np.array(new[] { 0, 1, 2, 0 });
+ var result = np.where(cond, 100, -100);
+
+ result.Should().BeOfValues(-100, 100, 100, -100);
+ }
+
+ #endregion
+
+ #region Scalar Arguments
+
+ [Test]
+ public void Where_ScalarX()
+ {
+ var cond = np.array(new[] { true, false, true, false });
+ var y = np.array(new[] { 10, 20, 30, 40 });
+ var result = np.where(cond, 99, y);
+
+ result.Should().BeOfValues(99, 20, 99, 40);
+ }
+
+ [Test]
+ public void Where_ScalarY()
+ {
+ var cond = np.array(new[] { true, false, true, false });
+ var x = np.array(new[] { 1, 2, 3, 4 });
+ var result = np.where(cond, x, -1);
+
+ result.Should().BeOfValues(1, -1, 3, -1);
+ }
+
+ [Test]
+ public void Where_BothScalars()
+ {
+ var cond = np.array(new[] { true, false, true, false });
+ var result = np.where(cond, 1, 0);
+
+ result.Should().BeOfValues(1, 0, 1, 0);
+ }
+
+ [Test]
+ public void Where_ScalarFloat()
+ {
+ var cond = np.array(new[] { true, false });
+ var result = np.where(cond, 1.5, 2.5);
+
+ Assert.AreEqual(typeof(double), result.dtype);
+ Assert.AreEqual(1.5, (double)result[0], 1e-10);
+ Assert.AreEqual(2.5, (double)result[1], 1e-10);
+ }
+
+ #endregion
+
+ #region Broadcasting
+
+ [Test]
+ public void Where_Broadcasting_ScalarY()
+ {
+ // np.where(a < 4, a, -1) for 3x3 array
+ var arr = np.array(new int[,] { { 0, 1, 2 }, { 0, 2, 4 }, { 0, 3, 6 } });
+ var result = np.where(arr < 4, arr, -1);
+
+ result.Should().BeShaped(3, 3);
+ Assert.AreEqual(0, (int)result[0, 0]);
+ Assert.AreEqual(1, (int)result[0, 1]);
+ Assert.AreEqual(2, (int)result[0, 2]);
+ Assert.AreEqual(-1, (int)result[1, 2]);
+ Assert.AreEqual(-1, (int)result[2, 2]);
+ }
+
+ [Test]
+ public void Where_Broadcasting_DifferentShapes()
+ {
+ // cond: (2,1), x: (3,), y: (1,3) -> result: (2,3)
+ var cond = np.array(new bool[,] { { true }, { false } });
+ var x = np.array(new[] { 1, 2, 3 });
+ var y = np.array(new int[,] { { 10, 20, 30 } });
+ var result = np.where(cond, x, y);
+
+ result.Should().BeShaped(2, 3);
+ // Row 0: cond=True, so x values
+ Assert.AreEqual(1, (int)result[0, 0]);
+ Assert.AreEqual(2, (int)result[0, 1]);
+ Assert.AreEqual(3, (int)result[0, 2]);
+ // Row 1: cond=False, so y values
+ Assert.AreEqual(10, (int)result[1, 0]);
+ Assert.AreEqual(20, (int)result[1, 1]);
+ Assert.AreEqual(30, (int)result[1, 2]);
+ }
+
+ [Test]
+ public void Where_Broadcasting_ColumnVector()
+ {
+ // cond: (3,1), x: scalar, y: (1,4) -> result: (3,4)
+ var cond = np.array(new bool[,] { { true }, { false }, { true } });
+ var x = 1;
+ var y = np.array(new int[,] { { 10, 20, 30, 40 } });
+ var result = np.where(cond, x, y);
+
+ result.Should().BeShaped(3, 4);
+ // Row 0: all 1s
+ for (int j = 0; j < 4; j++)
+ Assert.AreEqual(1, (int)result[0, j]);
+ // Row 1: y values
+ Assert.AreEqual(10, (int)result[1, 0]);
+ Assert.AreEqual(40, (int)result[1, 3]);
+ // Row 2: all 1s
+ for (int j = 0; j < 4; j++)
+ Assert.AreEqual(1, (int)result[2, j]);
+ }
+
+ #endregion
+
+ #region Type Promotion
+
+ [Test]
+ public void Where_TypePromotion_IntFloat_ReturnsFloat64()
+ {
+ var cond = np.array(new[] { true, false });
+ var result = np.where(cond, 1, 2.5);
+
+ Assert.AreEqual(typeof(double), result.dtype);
+ Assert.AreEqual(1.0, (double)result[0], 1e-10);
+ Assert.AreEqual(2.5, (double)result[1], 1e-10);
+ }
+
+ [Test]
+ public void Where_TypePromotion_Int32Int64_ReturnsInt64()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new int[] { 1 });
+ var y = np.array(new long[] { 2 });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(long), result.dtype);
+ }
+
+ [Test]
+ public void Where_TypePromotion_FloatDouble_ReturnsDouble()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new float[] { 1.5f });
+ var y = np.array(new double[] { 2.5 });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(double), result.dtype);
+ }
+
+ #endregion
+
+ #region Edge Cases
+
+ [Test]
+ public void Where_EmptyArrays_ThreeArgs()
+ {
+ var cond = np.array(new bool[0]);
+ var x = np.array(new int[0]);
+ var y = np.array(new int[0]);
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(0, result.size);
+ }
+
+ [Test]
+ public void Where_SingleElement()
+ {
+ var cond = np.array(new[] { true });
+ var result = np.where(cond, 42, 0);
+
+ Assert.AreEqual(1, result.size);
+ Assert.AreEqual(42, (int)result[0]);
+ }
+
+ [Test]
+ public void Where_AllTrue_ReturnsAllX()
+ {
+ var cond = np.array(new[] { true, true, true });
+ var x = np.array(new[] { 1, 2, 3 });
+ var y = np.array(new[] { 10, 20, 30 });
+ var result = np.where(cond, x, y);
+
+ result.Should().BeOfValues(1, 2, 3);
+ }
+
+ [Test]
+ public void Where_AllFalse_ReturnsAllY()
+ {
+ var cond = np.array(new[] { false, false, false });
+ var x = np.array(new[] { 1, 2, 3 });
+ var y = np.array(new[] { 10, 20, 30 });
+ var result = np.where(cond, x, y);
+
+ result.Should().BeOfValues(10, 20, 30);
+ }
+
+ [Test]
+ public void Where_LargeArray()
+ {
+ var size = 100000;
+ var cond = np.arange(size) % 2 == 0; // alternating True/False
+ var x = np.ones(size, NPTypeCode.Int32);
+ var y = np.zeros(size, NPTypeCode.Int32);
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(size, result.size);
+ // Even indices should be 1, odd should be 0
+ Assert.AreEqual(1, (int)result[0]);
+ Assert.AreEqual(0, (int)result[1]);
+ Assert.AreEqual(1, (int)result[2]);
+ }
+
+ #endregion
+
+ #region NumPy Output Verification
+
+ [Test]
+ public void Where_NumPyExample1()
+ {
+ // From NumPy docs: np.where([[True, False], [True, True]],
+ // [[1, 2], [3, 4]], [[9, 8], [7, 6]])
+ // Expected: array([[1, 8], [3, 4]])
+ var cond = np.array(new bool[,] { { true, false }, { true, true } });
+ var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } });
+ var y = np.array(new int[,] { { 9, 8 }, { 7, 6 } });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(1, (int)result[0, 0]);
+ Assert.AreEqual(8, (int)result[0, 1]);
+ Assert.AreEqual(3, (int)result[1, 0]);
+ Assert.AreEqual(4, (int)result[1, 1]);
+ }
+
+ [Test]
+ public void Where_NumPyExample2()
+ {
+ // From NumPy docs: np.where(a < 5, a, 10*a) for a = arange(10)
+ // Expected: array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90])
+ var a = np.arange(10);
+ var result = np.where(a < 5, a, 10 * a);
+
+ result.Should().BeOfValues(0L, 1L, 2L, 3L, 4L, 50L, 60L, 70L, 80L, 90L);
+ }
+
+ [Test]
+ public void Where_NumPyExample3()
+ {
+ // From NumPy docs: np.where(a < 4, a, -1) for specific array
+ // Expected: array([[ 0, 1, 2], [ 0, 2, -1], [ 0, 3, -1]])
+ var a = np.array(new int[,] { { 0, 1, 2 }, { 0, 2, 4 }, { 0, 3, 6 } });
+ var result = np.where(a < 4, a, -1);
+
+ Assert.AreEqual(0, (int)result[0, 0]);
+ Assert.AreEqual(1, (int)result[0, 1]);
+ Assert.AreEqual(2, (int)result[0, 2]);
+ Assert.AreEqual(0, (int)result[1, 0]);
+ Assert.AreEqual(2, (int)result[1, 1]);
+ Assert.AreEqual(-1, (int)result[1, 2]);
+ Assert.AreEqual(0, (int)result[2, 0]);
+ Assert.AreEqual(3, (int)result[2, 1]);
+ Assert.AreEqual(-1, (int)result[2, 2]);
+ }
+
+ #endregion
+
+ #region Dtype Coverage
+
+ [Test]
+ public void Where_Dtype_Byte()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new byte[] { 1, 2 });
+ var y = np.array(new byte[] { 10, 20 });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(byte), result.dtype);
+ result.Should().BeOfValues((byte)1, (byte)20);
+ }
+
+ [Test]
+ public void Where_Dtype_Int16()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new short[] { 1, 2 });
+ var y = np.array(new short[] { 10, 20 });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(short), result.dtype);
+ result.Should().BeOfValues((short)1, (short)20);
+ }
+
+ [Test]
+ public void Where_Dtype_Int32()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new int[] { 1, 2 });
+ var y = np.array(new int[] { 10, 20 });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(int), result.dtype);
+ result.Should().BeOfValues(1, 20);
+ }
+
+ [Test]
+ public void Where_Dtype_Int64()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new long[] { 1, 2 });
+ var y = np.array(new long[] { 10, 20 });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(long), result.dtype);
+ result.Should().BeOfValues(1L, 20L);
+ }
+
+ [Test]
+ public void Where_Dtype_Single()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new float[] { 1.5f, 2.5f });
+ var y = np.array(new float[] { 10.5f, 20.5f });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(float), result.dtype);
+ Assert.AreEqual(1.5f, (float)result[0], 1e-6f);
+ Assert.AreEqual(20.5f, (float)result[1], 1e-6f);
+ }
+
+ [Test]
+ public void Where_Dtype_Double()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new double[] { 1.5, 2.5 });
+ var y = np.array(new double[] { 10.5, 20.5 });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(double), result.dtype);
+ Assert.AreEqual(1.5, (double)result[0], 1e-10);
+ Assert.AreEqual(20.5, (double)result[1], 1e-10);
+ }
+
+ [Test]
+ public void Where_Dtype_Boolean()
+ {
+ var cond = np.array(new[] { true, false });
+ var x = np.array(new bool[] { true, true });
+ var y = np.array(new bool[] { false, false });
+ var result = np.where(cond, x, y);
+
+ Assert.AreEqual(typeof(bool), result.dtype);
+ Assert.IsTrue((bool)result[0]);
+ Assert.IsFalse((bool)result[1]);
+ }
+
+ #endregion
+ }
+}