@@ -14,6 +14,9 @@ use numpy::ToPyArray;
1414use numpy:: { PyArray2 , PyReadonlyArray2 } ;
1515use numpy:: IntoPyArray ;
1616
17+ use rayon:: prelude:: * ;
18+ use rayon:: ThreadPoolBuilder ;
19+
1720#[ pyfunction]
1821fn first_true_1d_a ( array : PyReadonlyArray1 < bool > ) -> isize {
1922 match array. as_slice ( ) {
@@ -480,7 +483,7 @@ pub fn prepare_array_for_axis<'py>(
480483
481484#[ pyfunction]
482485#[ pyo3( signature = ( array, * , forward=true , axis) ) ]
483- pub fn first_true_2d < ' py > (
486+ pub fn first_true_2d_a < ' py > (
484487 py : Python < ' py > ,
485488 array : PyReadonlyArray2 < ' py , bool > ,
486489 forward : bool ,
@@ -558,6 +561,114 @@ pub fn first_true_2d<'py>(
558561 Ok ( PyArray1 :: from_vec ( py, result) . to_owned ( ) )
559562}
560563
564+ #[ pyfunction]
565+ #[ pyo3( signature = ( array, * , forward=true , axis) ) ]
566+ pub fn first_true_2d < ' py > (
567+ py : Python < ' py > ,
568+ array : PyReadonlyArray2 < ' py , bool > ,
569+ forward : bool ,
570+ axis : isize ,
571+ ) -> PyResult < Bound < ' py , PyArray1 < isize > > > {
572+ let prepared = prepare_array_for_axis ( py, array, axis) ?;
573+ let data = prepared. data ;
574+ let rows = prepared. nrows ;
575+ let row_len = prepared. ncols ;
576+
577+ let mut result = vec ! [ -1isize ; rows] ;
578+
579+ // Dynamically select thread count
580+ let max_threads = if rows < 100 {
581+ 1
582+ } else if rows < 1000 {
583+ 2
584+ } else if rows < 10000 {
585+ 4
586+ } else {
587+ 16
588+ } ;
589+
590+ py. allow_threads ( || {
591+ let base_ptr = data. as_ptr ( ) as usize ;
592+ const LANES : usize = 32 ;
593+ let ones = u8x32:: splat ( 1 ) ;
594+
595+ let process_row = |row : usize | -> isize {
596+ let ptr = ( base_ptr + row * row_len) as * const u8 ;
597+ let mut found = -1isize ;
598+
599+ unsafe {
600+ if forward {
601+ let mut i = 0 ;
602+ while i + LANES <= row_len {
603+ let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
604+ let vec = u8x32:: from ( * chunk) ;
605+ if vec. cmp_eq ( ones) . any ( ) {
606+ break ;
607+ }
608+ i += LANES ;
609+ }
610+ while i < row_len {
611+ if * ptr. add ( i) != 0 {
612+ found = i as isize ;
613+ break ;
614+ }
615+ i += 1 ;
616+ }
617+ } else {
618+ let mut i = row_len;
619+ while i >= LANES {
620+ i -= LANES ;
621+ let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
622+ let vec = u8x32:: from ( * chunk) ;
623+ if vec. cmp_eq ( ones) . any ( ) {
624+ for j in ( i..i + LANES ) . rev ( ) {
625+ if * ptr. add ( j) != 0 {
626+ found = j as isize ;
627+ break ;
628+ }
629+ }
630+ break ;
631+ }
632+ }
633+ if i > 0 && i < LANES {
634+ for j in ( 0 ..i) . rev ( ) {
635+ if * ptr. add ( j) != 0 {
636+ found = j as isize ;
637+ break ;
638+ }
639+ }
640+ }
641+ }
642+ }
643+
644+ found
645+ } ;
646+
647+ if max_threads == 1 {
648+ // Single-threaded path
649+ for row in 0 ..rows {
650+ result[ row] = process_row ( row) ;
651+ }
652+ } else {
653+ // Multi-threaded path with Rayon
654+ let pool = rayon:: ThreadPoolBuilder :: new ( )
655+ . num_threads ( max_threads)
656+ . build ( )
657+ . unwrap ( ) ;
658+
659+ pool. install ( || {
660+ result. par_iter_mut ( ) . enumerate ( ) . for_each ( |( row, out) | {
661+ * out = process_row ( row) ;
662+ } ) ;
663+ } ) ;
664+ }
665+ } ) ;
666+
667+ Ok ( PyArray1 :: from_vec ( py, result) )
668+ }
669+
670+
671+
561672//------------------------------------------------------------------------------
562673
563674#[ pymodule]
0 commit comments