diff --git a/docs/api/rust_backend.md b/docs/api/rust_backend.md index 938e1af1..9547ca14 100644 --- a/docs/api/rust_backend.md +++ b/docs/api/rust_backend.md @@ -15,3 +15,16 @@ r = RustNotify(['first/path', 'second/path'], False, False, 0) changes = r.watch(1_600, 50, 100, None) print(changes) ``` + +Or using `RustNotify` as a context manager: + +```py +title="Rust backend context manager example" +from watchfiles._rust_notify import RustNotify + +with RustNotify(['first/path', 'second/path'], False, False, 0) as r: + changes = r.watch(1_600, 50, 100, None) + print(changes) +``` + +(See the documentation on `close` above for when the context manager or `close` method are required.) diff --git a/src/lib.rs b/src/lib.rs index 6320ba31..f81c7876 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,7 @@ const CHANGE_DELETED: u8 = 3; #[derive(Debug)] enum WatcherEnum { + None, Poll(PollWatcher), Recommended(RecommendedWatcher), } @@ -157,13 +158,16 @@ impl RustNotify { timeout_ms: u64, stop_event: PyObject, ) -> PyResult { + if matches!(self.watcher, WatcherEnum::None) { + return Err(PyRuntimeError::new_err("RustNotify watcher closed")); + } let stop_event_is_set: Option<&PyAny> = match stop_event.is_none(py) { true => None, false => { let event: &PyAny = stop_event.extract(py)?; let func: &PyAny = event.getattr("is_set")?.extract()?; if !func.is_callable() { - return Err(PyTypeError::new_err("'stop_event.is_set' must be callable".to_string())); + return Err(PyTypeError::new_err("'stop_event.is_set' must be callable")); } Some(func) } @@ -228,10 +232,25 @@ impl RustNotify { Ok(py_changes) } + /// https://github.com/PyO3/pyo3/issues/1205#issuecomment-1164096251 for advice on `__enter__` + pub fn __enter__(slf: Py) -> Py { + slf + } + + pub fn close(&mut self) { + self.watcher = WatcherEnum::None; + } + + pub fn __exit__(&mut self, _exc_type: PyObject, _exc_value: PyObject, _traceback: PyObject) { + self.close(); + } + pub fn __repr__(&self) -> PyResult { Ok(format!("RustNotify({:#?})", self.watcher)) } +} +impl RustNotify { fn clear(&self) { self.changes.lock().unwrap().clear(); } diff --git a/tests/conftest.py b/tests/conftest.py index 696ae567..a6370c21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -79,6 +79,15 @@ def watch(self, debounce_ms: int, step_ms: int, timeout_ms: int, cancel_event): self.watch_count += 1 return change + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def close(self): + pass + if TYPE_CHECKING: from typing import Literal, Protocol diff --git a/tests/test_rust_notify.py b/tests/test_rust_notify.py index c075607a..775b4572 100644 --- a/tests/test_rust_notify.py +++ b/tests/test_rust_notify.py @@ -16,10 +16,16 @@ def test_add(test_dir: Path): assert watcher.watch(200, 50, 500, None) == {(1, str(test_dir / 'new_file.txt'))} -def test_recommended_repr(test_dir: Path): +def test_close(test_dir: Path): watcher = RustNotify([str(test_dir)], True, False, 0) assert repr(watcher).startswith('RustNotify(Recommended(\n') + watcher.close() + + assert repr(watcher) == 'RustNotify(None)' + with pytest.raises(RuntimeError, match='RustNotify watcher closed'): + watcher.watch(200, 50, 500, None) + def test_modify_write(test_dir: Path): watcher = RustNotify([str(test_dir)], True, False, 0) diff --git a/tests/test_watch.py b/tests/test_watch.py index f6da3f5b..33e804f4 100644 --- a/tests/test_watch.py +++ b/tests/test_watch.py @@ -201,6 +201,15 @@ def watch(self, *args): self.i += 1 return {(Change.added, 'spam.py')} + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def close(self): + pass + async def test_awatch_interrupt_raise(mocker, caplog): mocker.patch('watchfiles.main.RustNotify', return_value=MockRustNotifyRaise()) diff --git a/watchfiles/_rust_notify.pyi b/watchfiles/_rust_notify.pyi index 911f3d05..4cd5d6eb 100644 --- a/watchfiles/_rust_notify.pyi +++ b/watchfiles/_rust_notify.pyi @@ -1,4 +1,4 @@ -from typing import List, Literal, Optional, Protocol, Set, Tuple, Union +from typing import Any, List, Literal, Optional, Protocol, Set, Tuple, Union __all__ = 'RustNotify', 'WatchfilesRustInternalError' @@ -10,7 +10,7 @@ class AbstractEvent(Protocol): class RustNotify: """ Interface to the Rust [notify](https://crates.io/crates/notify) crate which does - the heavy lifting of watching for file changes and grouping them into a single event. + the heavy lifting of watching for file changes and grouping them into events. """ def __init__(self, watch_paths: List[str], debug: bool, force_polling: bool, poll_delay_ms: int) -> None: @@ -56,6 +56,30 @@ class RustNotify: `'signal'` if a signal was received, `'stop'` if the `stop_event` was set, or `'timeout'` if `timeout_ms` was exceeded. """ + def __enter__(self) -> 'RustNotify': + """ + Does nothing, but allows `RustNotify` to be used as a context manager. + + Note: the watching thead is created when an instance is initiated, not on + `__enter__`. + """ + def __exit__(self, *args: Any) -> None: + """ + Calls close. + """ + def close(self) -> None: + """ + Stops the watching thread. After `close` is called, the RustNotify instance can no + longer be used, calls to [`watch`][watchfiles.RustNotify.watch] will raise a `RuntimeError`. + + Note: `close` is not required, just deleting the `RustNotify` instance will kill the thread + implicitly. + + As per samuelcolvin/watchfiles#163 `close()` is only required because in the + event of an error, the traceback in `sys.exc_info` keeps a reference to `watchfiles.watch`'s + frame, so you can't rely on the `RustNotify` object being deleted, and thereby stopping + the watching thread. + """ class WatchfilesRustInternalError(RuntimeError): """ diff --git a/watchfiles/main.py b/watchfiles/main.py index 5bc144d1..277ec245 100644 --- a/watchfiles/main.py +++ b/watchfiles/main.py @@ -100,27 +100,27 @@ def watch( print(changes) ``` """ - watcher = RustNotify([str(p) for p in paths], debug, force_polling, poll_delay_ms) - while True: - raw_changes = watcher.watch(debounce, step, rust_timeout, stop_event) - if raw_changes == 'timeout': - if yield_on_timeout: - yield set() - else: - logger.debug('rust notify timeout, continuing') - elif raw_changes == 'signal': - if raise_interrupt: - raise KeyboardInterrupt - else: - logger.warning('KeyboardInterrupt caught, stopping watch') + with RustNotify([str(p) for p in paths], debug, force_polling, poll_delay_ms) as watcher: + while True: + raw_changes = watcher.watch(debounce, step, rust_timeout, stop_event) + if raw_changes == 'timeout': + if yield_on_timeout: + yield set() + else: + logger.debug('rust notify timeout, continuing') + elif raw_changes == 'signal': + if raise_interrupt: + raise KeyboardInterrupt + else: + logger.warning('KeyboardInterrupt caught, stopping watch') + return + elif raw_changes == 'stop': return - elif raw_changes == 'stop': - return - else: - changes = _prep_changes(raw_changes, watch_filter) - if changes: - _log_changes(changes) - yield changes + else: + changes = _prep_changes(raw_changes, watch_filter) + if changes: + _log_changes(changes) + yield changes async def awatch( # noqa C901 @@ -214,35 +214,35 @@ async def stop_soon(): else: stop_event_ = stop_event - watcher = RustNotify([str(p) for p in paths], debug, force_polling, poll_delay_ms) - timeout = _calc_async_timeout(rust_timeout) - CancelledError = anyio.get_cancelled_exc_class() - - while True: - async with anyio.create_task_group() as tg: - try: - raw_changes = await anyio.to_thread.run_sync(watcher.watch, debounce, step, timeout, stop_event_) - except (CancelledError, KeyboardInterrupt): - stop_event_.set() - # suppressing KeyboardInterrupt wouldn't stop it getting raised by the top level asyncio.run call - raise - tg.cancel_scope.cancel() - - if raw_changes == 'timeout': - if yield_on_timeout: - yield set() + with RustNotify([str(p) for p in paths], debug, force_polling, poll_delay_ms) as watcher: + timeout = _calc_async_timeout(rust_timeout) + CancelledError = anyio.get_cancelled_exc_class() + + while True: + async with anyio.create_task_group() as tg: + try: + raw_changes = await anyio.to_thread.run_sync(watcher.watch, debounce, step, timeout, stop_event_) + except (CancelledError, KeyboardInterrupt): + stop_event_.set() + # suppressing KeyboardInterrupt wouldn't stop it getting raised by the top level asyncio.run call + raise + tg.cancel_scope.cancel() + + if raw_changes == 'timeout': + if yield_on_timeout: + yield set() + else: + logger.debug('rust notify timeout, continuing') + elif raw_changes == 'stop': + return + elif raw_changes == 'signal': + # in theory the watch thread should never get a signal + raise RuntimeError('watch thread unexpectedly received a signal') else: - logger.debug('rust notify timeout, continuing') - elif raw_changes == 'stop': - return - elif raw_changes == 'signal': - # in theory the watch thread should never get a signal - raise RuntimeError('watch thread unexpectedly received a signal') - else: - changes = _prep_changes(raw_changes, watch_filter) - if changes: - _log_changes(changes) - yield changes + changes = _prep_changes(raw_changes, watch_filter) + if changes: + _log_changes(changes) + yield changes def _prep_changes(