Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-128384: Use contextvar for catch_warnings() #128463

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Doc/library/threading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ since it is impossible to detect the termination of alien threads.


.. class:: Thread(group=None, target=None, name=None, args=(), kwargs={}, *, \
daemon=None)
daemon=None, context="inherit")

This constructor should always be called with keyword arguments. Arguments
are:
Expand All @@ -359,6 +359,10 @@ since it is impossible to detect the termination of alien threads.
If ``None`` (the default), the daemonic property is inherited from the
current thread.

*context* is the `contextvars.Context` value to use while running the thread.
The default is to inherit the context of the caller of :meth:`~Thread.start`.
If set to ``None``, the context will be empty.

If the subclass overrides the constructor, it must make sure to invoke the
base class constructor (``Thread.__init__()``) before doing anything else to
the thread.
Expand All @@ -369,6 +373,10 @@ since it is impossible to detect the termination of alien threads.
.. versionchanged:: 3.10
Use the *target* name if *name* argument is omitted.

.. versionchanged:: 3.14
Added the *context* parameter. Previously threads always ran with an empty
context.

.. method:: start()

Start the thread's activity.
Expand Down
1 change: 1 addition & 0 deletions Include/internal/pycore_global_objects_fini_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Include/internal/pycore_global_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ struct _Py_global_strings {
STRUCT_FOR_ID(_type_)
STRUCT_FOR_ID(_uninitialized_submodules)
STRUCT_FOR_ID(_warn_unawaited_coroutine)
STRUCT_FOR_ID(_warnings_context)
STRUCT_FOR_ID(_xoptions)
STRUCT_FOR_ID(abs_tol)
STRUCT_FOR_ID(access)
Expand Down
1 change: 1 addition & 0 deletions Include/internal/pycore_runtime_init_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Include/internal/pycore_unicodeobject_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Include/internal/pycore_warnings.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct _warnings_runtime_state {
PyObject *filters; /* List */
PyObject *once_registry; /* Dict */
PyObject *default_action; /* String */
PyMutex mutex;
_PyRecursiveMutex lock;
long filters_version;
};

Expand Down
43 changes: 43 additions & 0 deletions Lib/test/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,49 @@ def sub(num):
tp.shutdown()
self.assertEqual(results, list(range(10)))

@isolated_context
@threading_helper.requires_working_threading()
def test_context_thread_inherit(self):
import threading

cvar = contextvars.ContextVar('cvar')

# By default, the context of the caller is inheritied
def run_inherit():
self.assertEqual(cvar.get(), 1)

cvar.set(1)
thread = threading.Thread(target=run_inherit)
thread.start()
thread.join()

# If context=None is passed, the thread has an empty context
def run_empty():
with self.assertRaises(LookupError):
cvar.get()

thread = threading.Thread(target=run_empty, context=None)
thread.start()
thread.join()

# An explicit Context value can also be passed
custom_ctx = contextvars.Context()
custom_var = None

def setup_context():
nonlocal custom_var
custom_var = contextvars.ContextVar('custom')
custom_var.set(2)

custom_ctx.run(setup_context)

def run_custom():
self.assertEqual(custom_var.get(), 2)

thread = threading.Thread(target=run_custom, context=custom_ctx)
thread.start()
thread.join()


# HAMT Tests

Expand Down
4 changes: 2 additions & 2 deletions Lib/test/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,8 +1725,8 @@ def test_threading(self):
self.finish1 = threading.Event()
self.finish2 = threading.Event()

th1 = threading.Thread(target=thfunc1, args=(self,))
th2 = threading.Thread(target=thfunc2, args=(self,))
th1 = threading.Thread(target=thfunc1, args=(self,), context=None)
th2 = threading.Thread(target=thfunc2, args=(self,), context=None)

th1.start()
th2.start()
Expand Down
48 changes: 24 additions & 24 deletions Lib/test/test_warnings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,14 @@ def warnings_state(module):
except NameError:
pass
original_warnings = warning_tests.warnings
original_filters = module.filters
saved_context, context = module._new_context()
try:
module.filters = original_filters[:]
module.simplefilter("once")
warning_tests.warnings = module
yield
finally:
warning_tests.warnings = original_warnings
module.filters = original_filters
module._set_context(saved_context)


class TestWarning(Warning):
Expand Down Expand Up @@ -336,15 +335,15 @@ def test_filterwarnings_duplicate_filters(self):
with original_warnings.catch_warnings(module=self.module):
self.module.resetwarnings()
self.module.filterwarnings("error", category=UserWarning)
self.assertEqual(len(self.module.filters), 1)
self.assertEqual(len(self.module._get_filters()), 1)
self.module.filterwarnings("ignore", category=UserWarning)
self.module.filterwarnings("error", category=UserWarning)
self.assertEqual(
len(self.module.filters), 2,
len(self.module._get_filters()), 2,
"filterwarnings inserted duplicate filter"
)
self.assertEqual(
self.module.filters[0][0], "error",
self.module._get_filters()[0][0], "error",
"filterwarnings did not promote filter to "
"the beginning of list"
)
Expand All @@ -353,15 +352,15 @@ def test_simplefilter_duplicate_filters(self):
with original_warnings.catch_warnings(module=self.module):
self.module.resetwarnings()
self.module.simplefilter("error", category=UserWarning)
self.assertEqual(len(self.module.filters), 1)
self.assertEqual(len(self.module._get_filters()), 1)
self.module.simplefilter("ignore", category=UserWarning)
self.module.simplefilter("error", category=UserWarning)
self.assertEqual(
len(self.module.filters), 2,
len(self.module._get_filters()), 2,
"simplefilter inserted duplicate filter"
)
self.assertEqual(
self.module.filters[0][0], "error",
self.module._get_filters()[0][0], "error",
"simplefilter did not promote filter to the beginning of list"
)

Expand All @@ -373,7 +372,7 @@ def test_append_duplicate(self):
self.module.simplefilter("error", append=True)
self.module.simplefilter("ignore", append=True)
self.module.warn("test_append_duplicate", category=UserWarning)
self.assertEqual(len(self.module.filters), 2,
self.assertEqual(len(self.module._get_filters()), 2,
"simplefilter inserted duplicate filter"
)
self.assertEqual(len(w), 0,
Expand Down Expand Up @@ -1049,11 +1048,11 @@ def test_issue31416(self):
# bad warnings.filters or warnings.defaultaction.
wmod = self.module
with original_warnings.catch_warnings(module=wmod):
wmod.filters = [(None, None, Warning, None, 0)]
wmod._get_filters()[:] = [(None, None, Warning, None, 0)]
with self.assertRaises(TypeError):
wmod.warn_explicit('foo', Warning, 'bar', 1)

wmod.filters = []
wmod._get_filters()[:] = []
with support.swap_attr(wmod, 'defaultaction', None), \
self.assertRaises(TypeError):
wmod.warn_explicit('foo', Warning, 'bar', 1)
Expand Down Expand Up @@ -1191,17 +1190,17 @@ class CatchWarningTests(BaseTest):

def test_catch_warnings_restore(self):
wmod = self.module
orig_filters = wmod.filters
orig_filters = wmod._get_filters()
orig_showwarning = wmod.showwarning
# Ensure both showwarning and filters are restored when recording
with wmod.catch_warnings(module=wmod, record=True):
wmod.filters = wmod.showwarning = object()
self.assertIs(wmod.filters, orig_filters)
wmod.get_context()._filters = wmod.showwarning = object()
self.assertIs(wmod._get_filters(), orig_filters)
self.assertIs(wmod.showwarning, orig_showwarning)
# Same test, but with recording disabled
with wmod.catch_warnings(module=wmod, record=False):
wmod.filters = wmod.showwarning = object()
self.assertIs(wmod.filters, orig_filters)
wmod.get_context()._filters = wmod.showwarning = object()
self.assertIs(wmod._get_filters(), orig_filters)
self.assertIs(wmod.showwarning, orig_showwarning)

def test_catch_warnings_recording(self):
Expand Down Expand Up @@ -1240,21 +1239,21 @@ def test_catch_warnings_reentry_guard(self):

def test_catch_warnings_defaults(self):
wmod = self.module
orig_filters = wmod.filters
orig_filters = wmod._get_filters()
orig_showwarning = wmod.showwarning
# Ensure default behaviour is not to record warnings
with wmod.catch_warnings(module=wmod) as w:
self.assertIsNone(w)
self.assertIs(wmod.showwarning, orig_showwarning)
self.assertIsNot(wmod.filters, orig_filters)
self.assertIs(wmod.filters, orig_filters)
self.assertIsNot(wmod._get_filters(), orig_filters)
self.assertIs(wmod._get_filters(), orig_filters)
if wmod is sys.modules['warnings']:
# Ensure the default module is this one
with wmod.catch_warnings() as w:
self.assertIsNone(w)
self.assertIs(wmod.showwarning, orig_showwarning)
self.assertIsNot(wmod.filters, orig_filters)
self.assertIs(wmod.filters, orig_filters)
self.assertIsNot(wmod._get_filters(), orig_filters)
self.assertIs(wmod._get_filters(), orig_filters)

def test_record_override_showwarning_before(self):
# Issue #28835: If warnings.showwarning() was overridden, make sure
Expand Down Expand Up @@ -1406,7 +1405,7 @@ def test_default_filter_configuration(self):
code = "import sys; sys.modules.pop('warnings', None); sys.modules['_warnings'] = None; "
else:
code = ""
code += "import warnings; [print(f) for f in warnings.filters]"
code += "import warnings; [print(f) for f in warnings._get_filters()]"

rc, stdout, stderr = assert_python_ok("-c", code, __isolated=True)
stdout_lines = [line.strip() for line in stdout.splitlines()]
Expand Down Expand Up @@ -1521,7 +1520,7 @@ def test_late_resource_warning(self):
self.assertTrue(err.startswith(expected), ascii(err))


class DeprecatedTests(unittest.TestCase):
class DeprecatedTests(PyPublicAPITests):
def test_dunder_deprecated(self):
@deprecated("A will go away soon")
class A:
Expand Down Expand Up @@ -1821,6 +1820,7 @@ async def coro(self):
self.assertFalse(inspect.iscoroutinefunction(Cls.sync))
self.assertTrue(inspect.iscoroutinefunction(Cls.coro))


def setUpModule():
py_warnings.onceregistry.clear()
c_warnings.onceregistry.clear()
Expand Down
26 changes: 23 additions & 3 deletions Lib/threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import sys as _sys
import _thread
import warnings
import contextvars as _contextvars


from time import monotonic as _time
from _weakrefset import WeakSet
Expand Down Expand Up @@ -871,7 +873,7 @@ class Thread:
_initialized = False

def __init__(self, group=None, target=None, name=None,
args=(), kwargs=None, *, daemon=None):
args=(), kwargs=None, *, daemon=None, context='inherit'):
"""This constructor should always be called with keyword arguments. Arguments are:

*group* should be None; reserved for future extension when a ThreadGroup
Expand All @@ -888,6 +890,10 @@ class is implemented.
*kwargs* is a dictionary of keyword arguments for the target
invocation. Defaults to {}.

*context* is the contextvars.Context value to use for the thread. The default
is to inherit the context of the caller. Set to None to start with an empty
context.

If a subclass overrides the constructor, it must make sure to invoke
the base class constructor (Thread.__init__()) before doing anything
else to the thread.
Expand Down Expand Up @@ -917,6 +923,7 @@ class is implemented.
self._daemonic = daemon
else:
self._daemonic = current_thread().daemon
self._context = context
self._ident = None
if _HAVE_THREAD_NATIVE_ID:
self._native_id = None
Expand Down Expand Up @@ -972,9 +979,15 @@ def start(self):

with _active_limbo_lock:
_limbo[self] = self

if self._context == 'inherit':
# No context provided, inherit the context of the caller.
self._context = _contextvars.copy_context()

try:
# Start joinable thread
_start_joinable_thread(self._bootstrap, handle=self._handle,
_start_joinable_thread(self._bootstrap,
handle=self._handle,
daemon=self.daemon)
except Exception:
with _active_limbo_lock:
Expand Down Expand Up @@ -1050,8 +1063,15 @@ def _bootstrap_inner(self):
if _profile_hook:
_sys.setprofile(_profile_hook)


try:
self.run()
if self._context is None:
# Run with empty context, matching behaviour of
# threading.local and older versions of Python.
self.run()
else:
# Run with the provided or the inherited context.
self._context.run(self.run)
except:
self._invoke_excepthook(self)
finally:
Expand Down
Loading
Loading