66```
77"""
88
9+ import json
910import os
1011import shutil
1112import tempfile
@@ -170,12 +171,7 @@ def test_whole_workflow(self):
170171
171172 # Define a simple kernel directly in the test function
172173 @triton .jit
173- def test_kernel (
174- x_ptr ,
175- y_ptr ,
176- n_elements ,
177- BLOCK_SIZE : tl .constexpr ,
178- ):
174+ def test_kernel (x_ptr , y_ptr , n_elements , BLOCK_SIZE : tl .constexpr ):
179175 pid = tl .program_id (axis = 0 )
180176 block_start = pid * BLOCK_SIZE
181177 offsets = block_start + tl .arange (0 , BLOCK_SIZE )
@@ -189,48 +185,147 @@ def test_kernel(
189185 def run_test_kernel (x ):
190186 n_elements = x .numel ()
191187 y = torch .empty_like (x )
192- BLOCK_SIZE = 256 # Smaller block size for simplicity
188+ BLOCK_SIZE = 256
193189 grid = (triton .cdiv (n_elements , BLOCK_SIZE ),)
194190 test_kernel [grid ](x , y , n_elements , BLOCK_SIZE )
195191 return y
196192
193+ # Set up test environment
197194 temp_dir = tempfile .mkdtemp ()
198- print (f"Temporary directory: { temp_dir } " )
199195 temp_dir_logs = os .path .join (temp_dir , "logs" )
200- os .makedirs (temp_dir_logs , exist_ok = True )
201196 temp_dir_parsed = os .path .join (temp_dir , "parsed_output" )
197+ os .makedirs (temp_dir_logs , exist_ok = True )
202198 os .makedirs (temp_dir_parsed , exist_ok = True )
199+ print (f"Temporary directory: { temp_dir } " )
203200
204- tritonparse .structured_logging .init (temp_dir_logs )
201+ # Initialize logging
202+ tritonparse .structured_logging .init (temp_dir_logs , enable_trace_launch = True )
205203
206- # Generate some triton compilation activity to create log files
204+ # Generate test data and run kernels
207205 torch .manual_seed (0 )
208206 size = (512 , 512 ) # Smaller size for faster testing
209207 x = torch .randn (size , device = self .cuda_device , dtype = torch .float32 )
210- run_test_kernel (x ) # Run the simple kernel
208+
209+ # Run kernel twice to generate compilation and launch events
210+ run_test_kernel (x )
211+ run_test_kernel (x )
211212 torch .cuda .synchronize ()
212213
213- # Check that temp_dir_logs folder has content
214+ # Verify log directory
214215 assert os .path .exists (
215216 temp_dir_logs
216217 ), f"Log directory { temp_dir_logs } does not exist."
217218 log_files = os .listdir (temp_dir_logs )
218- assert (
219- len (log_files ) > 0
220- ), f"No log files found in { temp_dir_logs } . Expected log files to be generated during Triton compilation."
219+ assert len (log_files ) > 0 , (
220+ f"No log files found in { temp_dir_logs } . "
221+ "Expected log files to be generated during Triton compilation."
222+ )
221223 print (f"Found { len (log_files )} log files in { temp_dir_logs } : { log_files } " )
222224
225+ def parse_log_line (line : str , line_num : int ) -> dict | None :
226+ """Parse a single log line and extract event data"""
227+ try :
228+ return json .loads (line .strip ())
229+ except json .JSONDecodeError as e :
230+ print (f" Line { line_num } : JSON decode error - { e } " )
231+ return None
232+
233+ def process_event_data (
234+ event_data : dict , line_num : int , event_counts : dict
235+ ) -> None :
236+ """Process event data and update counts"""
237+ try :
238+ event_type = event_data .get ("event_type" )
239+ if event_type is None :
240+ return
241+
242+ if event_type in event_counts :
243+ event_counts [event_type ] += 1
244+ print (
245+ f" Line { line_num } : event_type = '{ event_type } ' (count: { event_counts [event_type ]} )"
246+ )
247+ else :
248+ print (
249+ f" Line { line_num } : event_type = '{ event_type } ' (not tracked)"
250+ )
251+ except (KeyError , TypeError ) as e :
252+ print (f" Line { line_num } : Data structure error - { e } " )
253+
254+ def count_events_in_file (file_path : str , event_counts : dict ) -> None :
255+ """Count events in a single log file"""
256+ print (f"Checking event types in: { os .path .basename (file_path )} " )
257+
258+ with open (file_path , "r" ) as f :
259+ for line_num , line in enumerate (f , 1 ):
260+ event_data = parse_log_line (line , line_num )
261+ if event_data :
262+ process_event_data (event_data , line_num , event_counts )
263+
264+ def check_event_type_counts_in_logs (log_dir : str ) -> dict :
265+ """Count 'launch' and unique 'compilation' events in all log files"""
266+ event_counts = {"launch" : 0 }
267+ # Track unique compilation hashes
268+ compilation_hashes = set ()
269+
270+ for log_file in os .listdir (log_dir ):
271+ if log_file .endswith (".ndjson" ):
272+ log_file_path = os .path .join (log_dir , log_file )
273+ with open (log_file_path , "r" ) as f :
274+ for line_num , line in enumerate (f , 1 ):
275+ try :
276+ event_data = json .loads (line .strip ())
277+ event_type = event_data .get ("event_type" )
278+ if event_type == "launch" :
279+ event_counts ["launch" ] += 1
280+ print (
281+ f" Line { line_num } : event_type = 'launch' (count: { event_counts ['launch' ]} )"
282+ )
283+ elif event_type == "compilation" :
284+ # Extract hash from compilation metadata
285+ compilation_hash = (
286+ event_data .get ("payload" , {})
287+ .get ("metadata" , {})
288+ .get ("hash" )
289+ )
290+ if compilation_hash :
291+ compilation_hashes .add (compilation_hash )
292+ print (
293+ f" Line { line_num } : event_type = 'compilation' (unique hash: { compilation_hash [:8 ]} ...)"
294+ )
295+ except (json .JSONDecodeError , KeyError , TypeError ) as e :
296+ print (f" Line { line_num } : Error processing line - { e } " )
297+
298+ # Add the count of unique compilation hashes to the event_counts
299+ event_counts ["compilation" ] = len (compilation_hashes )
300+ print (
301+ f"Event type counts: { event_counts } (unique compilation hashes: { len (compilation_hashes )} )"
302+ )
303+ return event_counts
304+
305+ # Verify event counts
306+ event_counts = check_event_type_counts_in_logs (temp_dir_logs )
307+ assert (
308+ event_counts ["compilation" ] == 1
309+ ), f"Expected 1 unique 'compilation' hash, found { event_counts ['compilation' ]} "
310+ assert (
311+ event_counts ["launch" ] == 2
312+ ), f"Expected 2 'launch' events, found { event_counts ['launch' ]} "
313+ print (
314+ "✓ Verified correct event type counts: 1 unique compilation hash, 2 launch events"
315+ )
316+
317+ # Test parsing functionality
223318 tritonparse .utils .unified_parse (
224319 source = temp_dir_logs , out = temp_dir_parsed , overwrite = True
225320 )
226-
227- # Clean up temporary directory
228321 try :
229- # Check that parsed output directory has files
322+ # Verify parsing output
230323 parsed_files = os .listdir (temp_dir_parsed )
231324 assert len (parsed_files ) > 0 , "No files found in parsed output directory"
232325 finally :
326+ # Clean up
233327 shutil .rmtree (temp_dir )
328+ print ("✓ Cleaned up temporary directory" )
234329
235330
236331if __name__ == "__main__" :
0 commit comments