Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synchronized diagnostics worker threads #190

Merged
merged 9 commits into from
Jan 27, 2025
35 changes: 24 additions & 11 deletions src/lang/diagnostics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ use tracing::{error, trace};

use self::project_diagnostics::ProjectDiagnostics;
use self::refresh::{clear_old_diagnostics, refresh_diagnostics};
use self::trigger::trigger;
use crate::lang::diagnostics::file_batches::{batches, find_primary_files, find_secondary_files};
use crate::lang::lsp::LsProtoGroup;
use crate::server::client::Notifier;
use crate::server::panic::cancelled_anyhow;
use crate::server::schedule::thread::task_progress_monitor::TaskHandle;
use crate::server::schedule::thread::{self, JoinHandle, ThreadPriority};
use crate::server::trigger;
use crate::state::{State, StateSnapshot};

mod file_batches;
mod file_diagnostics;
mod lsp;
mod project_diagnostics;
mod refresh;
mod trigger;

/// Schedules refreshing of diagnostics in a background thread.
///
Expand All @@ -42,7 +42,7 @@ pub struct DiagnosticsController {
impl DiagnosticsController {
/// Creates a new diagnostics controller.
pub fn new(notifier: Notifier) -> Self {
let (trigger, receiver) = trigger();
let (trigger, receiver) = trigger::trigger();
let (thread, parallelism) = DiagnosticsControllerThread::spawn(receiver, notifier);
Self {
trigger,
Expand All @@ -64,6 +64,7 @@ struct DiagnosticsControllerThread {
notifier: Notifier,
pool: thread::Pool,
project_diagnostics: ProjectDiagnostics,
worker_handles: Vec<TaskHandle>,
}

impl DiagnosticsControllerThread {
Expand All @@ -73,11 +74,12 @@ impl DiagnosticsControllerThread {
receiver: trigger::Receiver<StateSnapshots>,
notifier: Notifier,
) -> (JoinHandle, NonZero<usize>) {
let this = Self {
let mut this = Self {
receiver,
notifier,
pool: thread::Pool::new(),
project_diagnostics: ProjectDiagnostics::new(),
worker_handles: Vec::new(),
};

let parallelism = this.pool.parallelism();
Expand All @@ -91,8 +93,9 @@ impl DiagnosticsControllerThread {
}

/// Runs diagnostics controller's event loop.
fn event_loop(&self) {
fn event_loop(&mut self) {
while let Some(state_snapshots) = self.receiver.wait() {
assert!(self.worker_handles.is_empty());
if let Err(err) = catch_unwind(AssertUnwindSafe(|| {
self.diagnostics_controller_tick(state_snapshots);
})) {
Expand All @@ -103,20 +106,22 @@ impl DiagnosticsControllerThread {
error!("caught panic while refreshing diagnostics");
}
}

self.join_and_clear_workers();
}
}

/// Runs a single tick of the diagnostics controller's event loop.
#[tracing::instrument(skip_all)]
fn diagnostics_controller_tick(&self, state_snapshots: StateSnapshots) {
fn diagnostics_controller_tick(&mut self, state_snapshots: StateSnapshots) {
let (state, primary_snapshots, secondary_snapshots) = state_snapshots.split();

let primary_set = find_primary_files(&state.db, &state.open_files);
let primary: Vec<_> = primary_set.iter().copied().collect();
self.spawn_refresh_worker(&primary, primary_snapshots);
self.spawn_refresh_workers(&primary, primary_snapshots);

let secondary = find_secondary_files(&state.db, &primary_set);
piotmag769 marked this conversation as resolved.
Show resolved Hide resolved
self.spawn_refresh_worker(&secondary, secondary_snapshots);
self.spawn_refresh_workers(&secondary, secondary_snapshots);

let files_to_preserve: HashSet<Url> = primary
.into_iter()
Expand All @@ -131,11 +136,11 @@ impl DiagnosticsControllerThread {

/// Shortcut for spawning a worker task which does the boilerplate around cloning state parts
/// and catching panics.
fn spawn_worker(&self, f: impl FnOnce(ProjectDiagnostics, Notifier) + Send + 'static) {
fn spawn_worker(&mut self, f: impl FnOnce(ProjectDiagnostics, Notifier) + Send + 'static) {
let project_diagnostics = self.project_diagnostics.clone();
let notifier = self.notifier.clone();
let worker_fn = move || f(project_diagnostics, notifier);
self.pool.spawn(ThreadPriority::Worker, move || {
let worker_handle = self.pool.spawn_with_tracking(ThreadPriority::Worker, move || {
if let Err(err) = catch_unwind(AssertUnwindSafe(worker_fn)) {
if let Ok(err) = cancelled_anyhow(err, "diagnostics worker has been cancelled") {
trace!("{err:?}");
Expand All @@ -144,10 +149,11 @@ impl DiagnosticsControllerThread {
}
}
});
self.worker_handles.push(worker_handle);
}

/// Makes batches out of `files` and spawns workers to run [`refresh_diagnostics`] on them.
fn spawn_refresh_worker(&self, files: &[FileId], state_snapshots: Vec<StateSnapshot>) {
fn spawn_refresh_workers(&mut self, files: &[FileId], state_snapshots: Vec<StateSnapshot>) {
let files_batches = batches(files, self.pool.parallelism());
assert_eq!(files_batches.len(), state_snapshots.len());
for (batch, state) in zip(files_batches, state_snapshots) {
Expand All @@ -162,6 +168,13 @@ impl DiagnosticsControllerThread {
});
}
}

fn join_and_clear_workers(&mut self) {
for handle in self.worker_handles.iter() {
handle.join();
}
self.worker_handles.clear();
}
}

/// Holds multiple snapshots of the state.
Expand Down
1 change: 1 addition & 0 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod commands;
pub mod connection;
pub mod panic;
pub mod schedule;
pub mod trigger;

mod routing;
pub use routing::{notification, request};
2 changes: 1 addition & 1 deletion src/server/schedule/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl<'s> Scheduler<'s> {
self.background_pool.spawn(ThreadPriority::Worker, task);
}
BackgroundSchedule::LatencySensitive => {
self.background_pool.spawn(ThreadPriority::LatencySensitive, task)
self.background_pool.spawn(ThreadPriority::LatencySensitive, task);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/server/schedule/thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use std::fmt;

mod pool;
mod priority;
pub mod task_progress_monitor;

pub use self::pool::Pool;
pub use self::priority::ThreadPriority;
Expand Down
43 changes: 39 additions & 4 deletions src/server/schedule/thread/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use std::thread::available_parallelism;

use crossbeam::channel;

use super::{Builder, JoinHandle, ThreadPriority};
use super::{Builder, JoinHandle, ThreadPriority, task_progress_monitor};
use crate::server::schedule::thread::task_progress_monitor::TaskHandle;

pub struct Pool {
// `_handles` is never read: the field is present
Expand Down Expand Up @@ -98,14 +99,48 @@ impl Pool {
where
F: FnOnce() + Send + 'static,
{
let f = Box::new(move || {
self.spawn_internal(priority, f, false);
}

pub fn spawn_with_tracking<F>(&self, priority: ThreadPriority, f: F) -> TaskHandle
where
F: FnOnce() + Send + 'static,
{
self.spawn_internal(priority, f, true).unwrap()
}

fn spawn_internal<F>(
&self,
priority: ThreadPriority,
f: F,
enable_tracking: bool,
) -> Option<TaskHandle>
where
F: FnOnce() + Send + 'static,
{
let untracked_f = move || {
if cfg!(debug_assertions) {
priority.assert_is_used_on_current_thread();
}
f();
});
};

if enable_tracking {
let (tracker, handle) = task_progress_monitor::task_progress_monitor();
let tracked_f = Box::new(move || {
untracked_f();
tracker.signal_finish();
});
self.send_job(priority, tracked_f);
Some(handle)
} else {
self.send_job(priority, Box::new(untracked_f));
None
}
}

let job = Job { requested_priority: priority, f };
fn send_job(&self, priority: ThreadPriority, f: Box<dyn FnOnce() + Send + 'static>) {
let job = Job { requested_priority: priority, f: Box::new(f) };
self.job_sender.send(job).unwrap();
}

Expand Down
25 changes: 25 additions & 0 deletions src/server/schedule/thread/task_progress_monitor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::server::trigger;

pub struct TaskHandle(trigger::Receiver<()>);

pub struct TaskTracker(trigger::Sender<()>);

impl TaskTracker {
/// Signals that a task finished executing.
pub fn signal_finish(&self) {
self.0.activate(());
}
}

impl TaskHandle {
/// Waits until tasks finishes executing.
pub fn join(&self) {
self.0.wait();
}
}

/// Creates single message channel for making it possible to wait for finishing tasks execution.
pub fn task_progress_monitor() -> (TaskTracker, TaskHandle) {
let (sender, receiver) = trigger::trigger::<()>();
(TaskTracker(sender), TaskHandle(receiver))
}
File renamed without changes.
File renamed without changes.
Loading