@@ -678,17 +678,136 @@ def test_stack(cpp, dtype):
678678 arrays = [random_array ((3 ,), seed = i , dtype = dtype ) for i in range (4 )]
679679 assert_bit_aligned (cpp .stack (arrays ), np .stack (arrays ), "stack" )
680680
681- def test_concatenate (cpp , dtype ):
681+ def test_concatenate_1d (cpp , dtype ):
682682 arrays = [random_array ((3 ,), seed = i , dtype = dtype ) for i in range (3 )]
683- assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concatenate" )
683+ assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concatenate 1d" )
684+
685+ def test_concatenate_2d_axis0 (cpp , dtype ):
686+ arrays = [random_array ((2 , 3 ), seed = i , dtype = dtype ) for i in range (3 )]
687+ assert_bit_aligned (cpp .concatenate (arrays , 0 ), np .concatenate (arrays , axis = 0 ), "concatenate 2d axis=0" )
688+ # Verify default axis=0
689+ assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concatenate 2d default axis" )
690+
691+ def test_concatenate_2d_axis1 (cpp , dtype ):
692+ arrays = [random_array ((3 , 2 ), seed = i , dtype = dtype ) for i in range (3 )]
693+ assert_bit_aligned (cpp .concatenate (arrays , 1 ), np .concatenate (arrays , axis = 1 ), "concatenate 2d axis=1" )
694+
695+ def test_concatenate_2d_axis_neg1 (cpp , dtype ):
696+ arrays = [random_array ((3 , 2 ), seed = i , dtype = dtype ) for i in range (3 )]
697+ assert_bit_aligned (cpp .concatenate (arrays , - 1 ), np .concatenate (arrays , axis = - 1 ), "concatenate 2d axis=-1" )
698+
699+ def test_concatenate_3d_axis0 (cpp , dtype ):
700+ arrays = [random_array ((2 , 3 , 4 ), seed = i , dtype = dtype ) for i in range (2 )]
701+ assert_bit_aligned (cpp .concatenate (arrays , 0 ), np .concatenate (arrays , axis = 0 ), "concatenate 3d axis=0" )
702+
703+ def test_concatenate_3d_axis1 (cpp , dtype ):
704+ arrays = [random_array ((3 , 2 , 4 ), seed = i , dtype = dtype ) for i in range (2 )]
705+ assert_bit_aligned (cpp .concatenate (arrays , 1 ), np .concatenate (arrays , axis = 1 ), "concatenate 3d axis=1" )
706+
707+ def test_concatenate_3d_axis2 (cpp , dtype ):
708+ arrays = [random_array ((3 , 4 , 2 ), seed = i , dtype = dtype ) for i in range (2 )]
709+ assert_bit_aligned (cpp .concatenate (arrays , 2 ), np .concatenate (arrays , axis = 2 ), "concatenate 3d axis=2" )
710+
711+ def test_concatenate_two_arrays (cpp , dtype ):
712+ arrays = [random_array ((5 ,), seed = 0 , dtype = dtype ), random_array ((7 ,), seed = 1 , dtype = dtype )]
713+ assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concatenate two" )
714+
715+ def test_concatenate_single (cpp , dtype ):
716+ arrays = [random_array ((5 ,), dtype = dtype )]
717+ assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concatenate single" )
684718
685719def test_vstack (cpp , dtype ):
686720 arrays = [random_array ((1 , 3 ), seed = i , dtype = dtype ) for i in range (4 )]
687721 assert_bit_aligned (cpp .vstack (arrays ), np .vstack (arrays ), "vstack" )
688722
723+ def test_vstack_1d (cpp , dtype ):
724+ arrays = [random_array ((3 ,), seed = i , dtype = dtype ) for i in range (4 )]
725+ assert_bit_aligned (cpp .vstack (arrays ), np .vstack (arrays ), "vstack 1d" )
726+
689727def test_hstack (cpp , dtype ):
690728 arrays = [random_array ((3 ,), seed = i , dtype = dtype ) for i in range (3 )]
691- assert_bit_aligned (cpp .hstack (arrays ), np .hstack (arrays ), "hstack" )
729+ assert_bit_aligned (cpp .hstack (arrays ), np .hstack (arrays ), "hstack 1d" )
730+
731+ def test_hstack_2d (cpp , dtype ):
732+ arrays = [random_array ((3 , 2 ), seed = i , dtype = dtype ) for i in range (3 )]
733+ assert_bit_aligned (cpp .hstack (arrays ), np .hstack (arrays ), "hstack 2d" )
734+
735+ # -- Concatenate complex / edge-case tests ----------------------------------
736+
737+ def test_concatenate_4d_axis0 (cpp , dtype ):
738+ arrays = [random_array ((2 , 3 , 4 , 5 ), seed = i , dtype = dtype ) for i in range (2 )]
739+ assert_bit_aligned (cpp .concatenate (arrays , 0 ), np .concatenate (arrays , axis = 0 ), "concatenate 4d axis=0" )
740+
741+ def test_concatenate_4d_axis2 (cpp , dtype ):
742+ arrays = [random_array ((2 , 3 , 2 , 5 ), seed = i , dtype = dtype ) for i in range (2 )]
743+ assert_bit_aligned (cpp .concatenate (arrays , 2 ), np .concatenate (arrays , axis = 2 ), "concatenate 4d axis=2" )
744+
745+ def test_concatenate_4d_axis_neg2 (cpp , dtype ):
746+ arrays = [random_array ((2 , 3 , 2 , 5 ), seed = i , dtype = dtype ) for i in range (2 )]
747+ assert_bit_aligned (cpp .concatenate (arrays , - 2 ), np .concatenate (arrays , axis = - 2 ), "concatenate 4d axis=-2" )
748+
749+ def test_concatenate_unequal_axis_sizes (cpp , dtype ):
750+ """Concatenate arrays of different sizes along the concatenation axis."""
751+ a = random_array ((3 , 2 ), seed = 1 , dtype = dtype )
752+ b = random_array ((3 , 4 ), seed = 2 , dtype = dtype )
753+ c = random_array ((3 , 1 ), seed = 3 , dtype = dtype )
754+ assert_bit_aligned (cpp .concatenate ([a , b , c ], 1 ),
755+ np .concatenate ([a , b , c ], axis = 1 ), "concat unequal axis sizes" )
756+
757+ def test_concatenate_many_arrays (cpp , dtype ):
758+ """Concatenate 10 arrays along axis=0."""
759+ arrays = [random_array ((3 ,), seed = i , dtype = dtype ) for i in range (10 )]
760+ assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concat 10 arrays" )
761+
762+ def test_concatenate_large_3d (cpp , dtype ):
763+ """Large 3D concatenation along middle axis."""
764+ arrays = [random_array ((50 , 20 , 30 ), seed = i , dtype = dtype ) for i in range (3 )]
765+ assert_bit_aligned (cpp .concatenate (arrays , 1 ), np .concatenate (arrays , axis = 1 ), "concat large 3d axis=1" )
766+
767+ def test_concatenate_large_2d_axis0 (cpp , dtype ):
768+ """Large 2D concatenation — 500 rows each, 4 arrays."""
769+ arrays = [random_array ((500 , 10 ), seed = i , dtype = dtype ) for i in range (4 )]
770+ assert_bit_aligned (cpp .concatenate (arrays , 0 ), np .concatenate (arrays , axis = 0 ), "concat large 2d axis=0" )
771+
772+ def test_concatenate_large_2d_axis1 (cpp , dtype ):
773+ """Large 2D concatenation — 500 cols each, 3 arrays."""
774+ arrays = [random_array ((10 , 500 ), seed = i , dtype = dtype ) for i in range (3 )]
775+ assert_bit_aligned (cpp .concatenate (arrays , 1 ), np .concatenate (arrays , axis = 1 ), "concat large 2d axis=1" )
776+
777+ def test_concatenate_identity (cpp , dtype ):
778+ """Concatenating a single array returns identical copy."""
779+ a = random_array ((3 , 4 ), seed = 42 , dtype = dtype )
780+ assert_bit_aligned (cpp .concatenate ([a ], 0 ), np .concatenate ([a ], axis = 0 ), "concat identity" )
781+ assert_bit_aligned (cpp .concatenate ([a ], 1 ), np .concatenate ([a ], axis = 1 ), "concat identity axis=1" )
782+
783+ def test_concatenate_zeros (cpp , dtype ):
784+ """Concatenate arrays of zeros."""
785+ a = np .zeros ((2 , 3 ), dtype = dtype )
786+ b = np .zeros ((2 , 5 ), dtype = dtype )
787+ assert_bit_aligned (cpp .concatenate ([a , b ], 1 ), np .concatenate ([a , b ], axis = 1 ), "concat zeros" )
788+
789+ def test_concatenate_ones (cpp , dtype ):
790+ """Concatenate arrays of ones."""
791+ a = np .ones ((3 , 2 ), dtype = dtype )
792+ b = np .ones ((5 , 2 ), dtype = dtype )
793+ assert_bit_aligned (cpp .concatenate ([a , b ], 0 ), np .concatenate ([a , b ], axis = 0 ), "concat ones" )
794+
795+ def test_concatenate_3d_axis_neg2 (cpp , dtype ):
796+ """3D concatenate along axis=-2 (middle axis)."""
797+ arrays = [random_array ((2 , 3 , 4 ), seed = i , dtype = dtype ) for i in range (3 )]
798+ assert_bit_aligned (cpp .concatenate (arrays , - 2 ), np .concatenate (arrays , axis = - 2 ), "concat 3d axis=-2" )
799+
800+ def test_concatenate_3d_axis_neg3 (cpp , dtype ):
801+ """3D concatenate along axis=-3 (first axis)."""
802+ arrays = [random_array ((2 , 3 , 4 ), seed = i , dtype = dtype ) for i in range (2 )]
803+ assert_bit_aligned (cpp .concatenate (arrays , - 3 ), np .concatenate (arrays , axis = - 3 ), "concat 3d axis=-3" )
804+
805+ def test_concatenate_5d (cpp , dtype ):
806+ """5D concatenate along various axes."""
807+ arrays = [random_array ((2 , 3 , 2 , 3 , 2 ), seed = i , dtype = dtype ) for i in range (2 )]
808+ assert_bit_aligned (cpp .concatenate (arrays , 0 ), np .concatenate (arrays , axis = 0 ), "concat 5d axis=0" )
809+ assert_bit_aligned (cpp .concatenate (arrays , 2 ), np .concatenate (arrays , axis = 2 ), "concat 5d axis=2" )
810+ assert_bit_aligned (cpp .concatenate (arrays , - 1 ), np .concatenate (arrays , axis = - 1 ), "concat 5d axis=-1" )
692811
693812def test_where_scalar (cpp , dtype ):
694813 cond = np .array ([True , False , True , False , True ])
0 commit comments