@@ -7,15 +7,17 @@ use wide::*;
77// use std::simd::Simd;
88// use std::simd::cmp::SimdPartialEq;
99
10+ use numpy:: ndarray:: { Array2 , ArrayView2 } ;
11+ use numpy:: IntoPyArray ;
1012use numpy:: PyArray1 ;
1113use numpy:: PyArrayMethods ;
1214use numpy:: PyUntypedArrayMethods ;
1315use numpy:: ToPyArray ;
1416use numpy:: { PyArray2 , PyReadonlyArray2 } ;
15- use numpy:: IntoPyArray ;
1617
1718use rayon:: prelude:: * ;
1819use rayon:: ThreadPoolBuilder ;
20+ use std:: sync:: Arc ;
1921
2022#[ pyfunction]
2123fn first_true_1d_a ( array : PyReadonlyArray1 < bool > ) -> isize {
@@ -393,32 +395,17 @@ fn first_true_1d(py: Python, array: PyReadonlyArray1<bool>, forward: bool) -> is
393395// }
394396// }
395397
396-
397- // use numpy::{PyReadonlyArray2, IntoPyArray, PyArray2};
398- // use pyo3::prelude::*;
399-
400398pub struct PreparedBool2D < ' py > {
401- pub data : & ' py [ u8 ] , // flat contiguous buffer
402- pub nrows : usize , // number of logical rows
399+ pub data : & ' py [ u8 ] , // contiguous byte slice (bool as u8)
400+ pub nrows : usize ,
403401 pub ncols : usize ,
404- _keepalive : Option < Bound < ' py , PyAny > > , // holds any copied/transposed buffer
402+ _keepalive : Option < Arc < Array2 < bool > > > , // holds owned data if needed
405403}
406404
407405pub fn prepare_array_for_axis < ' py > (
408- py : Python < ' py > ,
409406 array : PyReadonlyArray2 < ' py , bool > ,
410407 axis : isize ,
411408) -> PyResult < PreparedBool2D < ' py > > {
412-
413- // let shape = array.shape();
414- // let slice = array.as_slice().unwrap();
415- // return Ok(PreparedBool2D {
416- // data: unsafe { std::mem::transmute(slice) }, // &[bool] → &[u8]
417- // nrows: shape[0],
418- // ncols: shape[1],
419- // _keepalive: None,
420- // });
421-
422409 if axis != 0 && axis != 1 {
423410 return Err ( PyValueError :: new_err ( "axis must be 0 or 1" ) ) ;
424411 }
@@ -459,117 +446,122 @@ pub fn prepare_array_for_axis<'py>(
459446 }
460447 }
461448
462- // Case 3: fallback — create a new C-contiguous owned array
463- let prepared_array : Bound < ' py , PyArray2 < bool > > = if axis == 0 {
464- array_view. reversed_axes ( ) . as_standard_layout ( ) . to_owned ( ) . to_pyarray ( py )
449+ // Case 3: fallback — make ndarray owned copy, but no PyArray!
450+ let array_owned : Array2 < bool > = if axis == 0 {
451+ array_view. reversed_axes ( ) . as_standard_layout ( ) . to_owned ( )
465452 } else {
466- array_view. as_standard_layout ( ) . to_owned ( ) . to_pyarray ( py )
453+ array_view. as_standard_layout ( ) . to_owned ( )
467454 } ;
468455
469- let array_view = unsafe { prepared_array. as_array ( ) } ;
470- let prepared_slice = array_view
471- . as_slice_memory_order ( )
472- . expect ( "Newly allocated array must be contiguous" ) ;
456+ let slice = array_owned
457+ . as_slice_memory_order ( )
458+ . expect ( "newly allocated Array2 must be contiguous" ) ;
473459
474460 Ok ( PreparedBool2D {
475- data : unsafe { std:: mem:: transmute ( prepared_slice ) } ,
461+ data : unsafe { std:: mem:: transmute ( slice ) } ,
476462 nrows,
477463 ncols,
478- _keepalive : Some ( prepared_array . into_any ( ) ) ,
464+ _keepalive : Some ( Arc :: new ( array_owned ) ) ,
479465 } )
480466}
481467
482-
483-
484468#[ pyfunction]
485469#[ pyo3( signature = ( array, * , forward=true , axis) ) ]
486- pub fn first_true_2d_a < ' py > (
470+ pub fn first_true_2d < ' py > (
487471 py : Python < ' py > ,
488472 array : PyReadonlyArray2 < ' py , bool > ,
489473 forward : bool ,
490474 axis : isize ,
491475) -> PyResult < Bound < ' py , PyArray1 < isize > > > {
492-
493- let prepared = prepare_array_for_axis ( py, array, axis) ?;
476+ let prepared = prepare_array_for_axis ( array, axis) ?;
494477 let data = prepared. data ;
495478 let rows = prepared. nrows ;
496479 let row_len = prepared. ncols ;
497480
498- let mut result = vec ! [ -1isize ; rows] ;
481+ let pyarray = unsafe { PyArray1 :: < isize > :: new ( py, [ rows] , false ) } ;
482+ let result = unsafe { pyarray. as_slice_mut ( ) . unwrap ( ) } ;
483+ result. fill ( -1 ) ;
499484
500- py. allow_threads ( || {
501- const LANES : usize = 32 ;
502- let ones = u8x32:: splat ( 1 ) ;
503- let base_ptr = data. as_ptr ( ) ;
485+ // let mut result = vec![-1isize; rows];
486+
487+ // py.allow_threads(|| {
488+ const LANES : usize = 32 ;
489+ let ones = u8x32:: splat ( 1 ) ;
490+ let base_ptr = data. as_ptr ( ) ;
491+ let mut i;
504492
493+ if forward {
505494 for row in 0 ..rows {
506495 let ptr = unsafe { base_ptr. add ( row * row_len) } ;
507- if forward {
508- // Forward search
509- let mut i = 0 ;
510- unsafe {
511- while i + LANES <= row_len {
512- let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
513- let vec = u8x32:: from ( * chunk) ;
514- if vec. cmp_eq ( ones) . any ( ) {
515- break ;
516- }
517- i += LANES ;
496+ i = 0 ;
497+ unsafe {
498+ while i + LANES <= row_len {
499+ let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
500+ let vec = u8x32:: from ( * chunk) ;
501+ if vec. cmp_eq ( ones) . any ( ) {
502+ break ;
518503 }
519- while i < row_len {
520- if * ptr . add ( i ) != 0 {
521- result [ row ] = i as isize ;
522- break ;
523- }
524- i += 1 ;
504+ i += LANES ;
505+ }
506+ while i < row_len {
507+ if * ptr . add ( i ) != 0 {
508+ result [ row ] = i as isize ;
509+ break ;
525510 }
511+ i += 1 ;
526512 }
527- } else {
528- // Backward search
529- let mut i = row_len;
530- unsafe {
531- // Process LANES bytes at a time with SIMD (backwards)
532- while i >= LANES {
533- i -= LANES ;
513+ }
514+ }
515+ } else {
516+ // Backward search
517+ for row in 0 ..rows {
518+ let ptr = unsafe { base_ptr. add ( row * row_len) } ;
534519
535- let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
536- let vec = u8x32:: from ( * chunk) ;
537- if vec. cmp_eq ( ones) . any ( ) {
538- // Found a true in this chunk, search backwards within it
539- for j in ( i..i + LANES ) . rev ( ) {
540- if * ptr. add ( j) != 0 {
541- result[ row] = j as isize ;
542- break ;
543- }
544- }
545- break ;
546- }
547- }
548- // Handle remaining bytes at the beginning
549- if i > 0 && i < LANES {
550- for j in ( 0 ..i) . rev ( ) {
520+ i = row_len;
521+ unsafe {
522+ // Process LANES bytes at a time with SIMD (backwards)
523+ while i >= LANES {
524+ i -= LANES ;
525+
526+ let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
527+ let vec = u8x32:: from ( * chunk) ;
528+ if vec. cmp_eq ( ones) . any ( ) {
529+ // Found a true in this chunk, search backwards within it
530+ for j in ( i..i + LANES ) . rev ( ) {
551531 if * ptr. add ( j) != 0 {
552532 result[ row] = j as isize ;
553533 break ;
554534 }
555535 }
536+ break ;
537+ }
538+ }
539+ // Handle remaining bytes at the beginning
540+ if i > 0 && i < LANES {
541+ for j in ( 0 ..i) . rev ( ) {
542+ if * ptr. add ( j) != 0 {
543+ result[ row] = j as isize ;
544+ break ;
545+ }
556546 }
557547 }
558548 }
559549 }
560- } ) ;
561- Ok ( PyArray1 :: from_vec ( py, result) . to_owned ( ) )
550+ }
551+ // });
552+ // Ok(PyArray1::from_vec(py, result).to_owned())
553+ Ok ( pyarray)
562554}
563555
564556#[ pyfunction]
565557#[ pyo3( signature = ( array, * , forward=true , axis) ) ]
566- pub fn first_true_2d < ' py > (
558+ pub fn first_true_2d_b < ' py > (
567559 py : Python < ' py > ,
568560 array : PyReadonlyArray2 < ' py , bool > ,
569561 forward : bool ,
570562 axis : isize ,
571563) -> PyResult < Bound < ' py , PyArray1 < isize > > > {
572- let prepared = prepare_array_for_axis ( py , array, axis) ?;
564+ let prepared = prepare_array_for_axis ( array, axis) ?;
573565 let data = prepared. data ;
574566 let rows = prepared. nrows ;
575567 let row_len = prepared. ncols ;
@@ -580,9 +572,9 @@ pub fn first_true_2d<'py>(
580572 let max_threads = if rows < 100 {
581573 1
582574 } else if rows < 1000 {
583- 2
575+ 1
584576 } else if rows < 10000 {
585- 4
577+ 1
586578 } else {
587579 16
588580 } ;
@@ -667,8 +659,6 @@ pub fn first_true_2d<'py>(
667659 Ok ( PyArray1 :: from_vec ( py, result) )
668660}
669661
670-
671-
672662//------------------------------------------------------------------------------
673663
674664#[ pymodule]
0 commit comments