@@ -12,8 +12,11 @@ use std::{
1212
1313use crate :: { DType , Result } ;
1414
15- use petgraph:: dot:: { Config , Dot } ;
1615use petgraph:: Graph as PetGraph ;
16+ use petgraph:: {
17+ dot:: { Config , Dot } ,
18+ graph:: NodeIndex ,
19+ } ;
1720
1821#[ derive( Clone ) ]
1922pub struct Graph < T : DType > {
@@ -48,48 +51,69 @@ impl<T: DType> Graph<T> {
4851 next
4952 }
5053
51- /// Export this computational graph as a petgraph::Graph where nodes are operation labels.
5254 pub fn to_petgraph ( & self ) -> PetGraph < String , ( ) > {
5355 let ops = self . data . read ( ) . unwrap ( ) ;
5456 let mut g = PetGraph :: < String , ( ) > :: new ( ) ;
55- let mut nodes = Vec :: with_capacity ( ops. len ( ) ) ;
56- // Add nodes with labels
57+ // map from op‐index → Some(node) if we created a node, or None if it was a NoOp
58+ let mut idx_map: Vec < Option < NodeIndex > > = Vec :: with_capacity ( ops. len ( ) ) ;
59+
60+ // 1) Add only non‐NoOp nodes
5761 for op in ops. iter ( ) {
58- let label = match op {
59- Op :: Fill { v } => format ! ( "Fill({:?})" , v) ,
60- Op :: Arange { start, step } => format ! ( "Arange(start={:?}, step={:?})" , start, step) ,
61- Op :: BinaryOp { operator, .. } => format ! ( "BinOp({})" , operator. as_c_op( ) ) ,
62- Op :: UnaryOp { operator, .. } => format ! ( "UnOp({:?})" , operator) ,
63- Op :: FusedMulAdd { .. } => "FMA" . to_string ( ) ,
64- Op :: NoOp => "NoOp" . to_string ( ) ,
65- } ;
66- nodes. push ( g. add_node ( label) ) ;
62+ match op {
63+ Op :: NoOp => {
64+ idx_map. push ( None ) ;
65+ }
66+ _ => {
67+ let label = match op {
68+ Op :: Fill { v } => format ! ( "Fill({:?})" , v) ,
69+ Op :: Arange { start, step } => {
70+ format ! ( "Arange(start={:?}, step={:?})" , start, step)
71+ }
72+ Op :: BinaryOp { operator, .. } => format ! ( "BinOp({})" , operator. as_c_op( ) ) ,
73+ Op :: UnaryOp { operator, .. } => format ! ( "UnOp({:?})" , operator) ,
74+ Op :: FusedMulAdd { .. } => "FMA" . to_string ( ) ,
75+ // we already matched NoOp above
76+ _ => unreachable ! ( ) ,
77+ } ;
78+ let node = g. add_node ( label) ;
79+ idx_map. push ( Some ( node) ) ;
80+ }
81+ }
6782 }
68- // Add edges to represent data dependencies
83+
84+ // 2) Walk ops again and only connect edges for those dst nodes that exist
6985 for ( i, op) in ops. iter ( ) . enumerate ( ) {
70- let dst = nodes[ i] ;
86+ // if this op was NoOp, skip entirely
87+ let dst = match idx_map[ i] {
88+ Some ( dst) => dst,
89+ None => continue ,
90+ } ;
7191 match op {
7292 Op :: BinaryOp { l_id, r_id, .. } => {
73- let src_l = nodes[ usize:: from ( l_id) ] ;
74- let src_r = nodes[ usize:: from ( r_id) ] ;
75- g. add_edge ( src_l, dst, ( ) ) ;
76- g. add_edge ( src_r, dst, ( ) ) ;
93+ if let Some ( src) = idx_map[ usize:: from ( l_id) ] {
94+ g. add_edge ( src, dst, ( ) ) ;
95+ }
96+ if let Some ( src) = idx_map[ usize:: from ( r_id) ] {
97+ g. add_edge ( src, dst, ( ) ) ;
98+ }
7799 }
78100 Op :: UnaryOp { v_id, .. } => {
79- let src = nodes[ usize:: from ( v_id) ] ;
80- g. add_edge ( src, dst, ( ) ) ;
101+ if let Some ( src) = idx_map[ usize:: from ( v_id) ] {
102+ g. add_edge ( src, dst, ( ) ) ;
103+ }
81104 }
82105 Op :: FusedMulAdd { a_id, b_id, c_id } => {
83- let src_a = nodes[ usize:: from ( a_id) ] ;
84- let src_b = nodes[ usize:: from ( b_id) ] ;
85- let src_c = nodes[ usize:: from ( c_id) ] ;
86- g. add_edge ( src_a, dst, ( ) ) ;
87- g. add_edge ( src_b, dst, ( ) ) ;
88- g. add_edge ( src_c, dst, ( ) ) ;
106+ for src_id in [ a_id, b_id, c_id] {
107+ if let Some ( src) = idx_map[ usize:: from ( src_id) ] {
108+ g. add_edge ( src, dst, ( ) ) ;
109+ }
110+ }
89111 }
112+ // NoOp and Fill/Arange don’t create incoming edges
90113 _ => { }
91114 }
92115 }
116+
93117 g
94118 }
95119
0 commit comments