77import inspect
88import re
99import sys
10+ from collections .abc import Sequence
1011from dataclasses import dataclass
1112from importlib .abc import Loader
1213from pathlib import Path
2425 from .config import ExamplesConfig
2526 from .find_examples import CodeExample
2627
27- __all__ = 'run_code' , 'InsertPrintStatements'
28+ __all__ = 'run_code' , 'InsertPrintStatements' , 'IncludePrint'
2829
2930parent_frame_id = 4 if sys .version_info >= (3 , 8 ) else 3
31+ IncludePrint = Callable [[Path , inspect .FrameInfo , Sequence [Any ]], bool ]
3032
3133
3234def run_code (
@@ -37,6 +39,7 @@ def run_code(
3739 config : ExamplesConfig ,
3840 enable_print_mock : bool ,
3941 print_callback : Callable [[str ], str ] | None ,
42+ include_print : IncludePrint | None ,
4043 module_globals : dict [str , Any ] | None ,
4144 call : str | None ,
4245) -> tuple [InsertPrintStatements , dict [str , Any ]]:
@@ -49,6 +52,7 @@ def run_code(
4952 config: The `ExamplesConfig` to use.
5053 enable_print_mock: If True, mock the `print` function.
5154 print_callback: If not None, a callback to call on `print`.
55+ include_print: If not None, a function to call to determine if the print statement should be included.
5256 module_globals: The extra globals to add before calling the module.
5357 call: If not None, a (coroutine) function to call in the module.
5458
@@ -63,7 +67,7 @@ def run_code(
6367 module = importlib .util .module_from_spec (spec )
6468
6569 # does nothing if insert_print_statements is False
66- insert_print = InsertPrintStatements (python_file , config , enable_print_mock , print_callback )
70+ insert_print = InsertPrintStatements (python_file , config , enable_print_mock , print_callback , include_print )
6771
6872 if module_globals :
6973 module .__dict__ .update (module_globals )
@@ -141,26 +145,40 @@ def not_print(*args):
141145
142146
143147class MockPrintFunction :
144- def __init__ (self , file : Path ) -> None :
148+ __slots__ = 'file' , 'statements' , 'include_print'
149+
150+ def __init__ (self , file : Path , include_print : IncludePrint | None ) -> None :
145151 self .file = file
146152 self .statements : list [PrintStatement ] = []
153+ self .include_print = include_print
147154
148155 def __call__ (self , * args : Any , sep : str = ' ' , ** kwargs : Any ) -> None :
149156 frame = inspect .stack ()[parent_frame_id ]
150157
151- if self .file . samefile (frame . filename ):
158+ if self ._include_file (frame , args ):
152159 # -1 to account for the line number being 1-indexed
153160 s = PrintStatement (frame .lineno , sep , [Arg (arg ) for arg in args ])
154161 self .statements .append (s )
155162
163+ def _include_file (self , frame : inspect .FrameInfo , args : Sequence [Any ]) -> bool :
164+ if self .include_print :
165+ return self .include_print (self .file , frame , args )
166+ else :
167+ return self .file .samefile (frame .filename )
168+
156169
157170class InsertPrintStatements :
158171 def __init__ (
159- self , python_path : Path , config : ExamplesConfig , enable : bool , print_callback : Callable [[str ], str ] | None
172+ self ,
173+ python_path : Path ,
174+ config : ExamplesConfig ,
175+ enable : bool ,
176+ print_callback : Callable [[str ], str ] | None ,
177+ include_print : IncludePrint | None ,
160178 ):
161179 self .file = python_path
162180 self .config = config
163- self .print_func = MockPrintFunction (python_path ) if enable else None
181+ self .print_func = MockPrintFunction (python_path , include_print ) if enable else None
164182 self .print_callback = print_callback
165183 self .patch = None
166184
0 commit comments