futures_executor/
thread_pool.rs

1use crate::enter;
2use crate::unpark_mutex::UnparkMutex;
3use futures_core::future::Future;
4use futures_core::task::{Context, Poll};
5use futures_task::{waker_ref, ArcWake};
6use futures_task::{FutureObj, Spawn, SpawnError};
7use futures_util::future::FutureExt;
8use std::boxed::Box;
9use std::cmp;
10use std::fmt;
11use std::format;
12use std::io;
13use std::string::String;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::sync::mpsc;
16use std::sync::{Arc, Mutex};
17use std::thread;
18
19/// A general-purpose thread pool for scheduling tasks that poll futures to
20/// completion.
21///
22/// The thread pool multiplexes any number of tasks onto a fixed number of
23/// worker threads.
24///
25/// This type is a clonable handle to the threadpool itself.
26/// Cloning it will only create a new reference, not a new threadpool.
27///
28/// This type is only available when the `thread-pool` feature of this
29/// library is activated.
30#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
31pub struct ThreadPool {
32    state: Arc<PoolState>,
33}
34
35/// Thread pool configuration object.
36///
37/// This type is only available when the `thread-pool` feature of this
38/// library is activated.
39#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
40pub struct ThreadPoolBuilder {
41    pool_size: usize,
42    stack_size: usize,
43    name_prefix: Option<String>,
44    after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
45    before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
46}
47
48#[allow(dead_code)]
49trait AssertSendSync: Send + Sync {}
50impl AssertSendSync for ThreadPool {}
51
52struct PoolState {
53    tx: Mutex<mpsc::Sender<Message>>,
54    rx: Mutex<mpsc::Receiver<Message>>,
55    cnt: AtomicUsize,
56    size: usize,
57}
58
59impl fmt::Debug for ThreadPool {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        f.debug_struct("ThreadPool").field("size", &self.state.size).finish()
62    }
63}
64
65impl fmt::Debug for ThreadPoolBuilder {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        f.debug_struct("ThreadPoolBuilder")
68            .field("pool_size", &self.pool_size)
69            .field("name_prefix", &self.name_prefix)
70            .finish()
71    }
72}
73
74enum Message {
75    Run(Task),
76    Close,
77}
78
79impl ThreadPool {
80    /// Creates a new thread pool with the default configuration.
81    ///
82    /// See documentation for the methods in
83    /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
84    /// configuration.
85    pub fn new() -> Result<Self, io::Error> {
86        ThreadPoolBuilder::new().create()
87    }
88
89    /// Create a default thread pool configuration, which can then be customized.
90    ///
91    /// See documentation for the methods in
92    /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
93    /// configuration.
94    pub fn builder() -> ThreadPoolBuilder {
95        ThreadPoolBuilder::new()
96    }
97
98    /// Spawns a future that will be run to completion.
99    ///
100    /// > **Note**: This method is similar to `Spawn::spawn_obj`, except that
101    /// >           it is guaranteed to always succeed.
102    pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) {
103        let task = Task {
104            future,
105            wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }),
106            exec: self.clone(),
107        };
108        self.state.send(Message::Run(task));
109    }
110
111    /// Spawns a task that polls the given future with output `()` to
112    /// completion.
113    ///
114    /// ```
115    /// # {
116    /// use futures::executor::ThreadPool;
117    ///
118    /// let pool = ThreadPool::new().unwrap();
119    ///
120    /// let future = async { /* ... */ };
121    /// pool.spawn_ok(future);
122    /// # }
123    /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
124    /// ```
125    ///
126    /// > **Note**: This method is similar to `SpawnExt::spawn`, except that
127    /// >           it is guaranteed to always succeed.
128    pub fn spawn_ok<Fut>(&self, future: Fut)
129    where
130        Fut: Future<Output = ()> + Send + 'static,
131    {
132        self.spawn_obj_ok(FutureObj::new(Box::new(future)))
133    }
134}
135
136impl Spawn for ThreadPool {
137    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
138        self.spawn_obj_ok(future);
139        Ok(())
140    }
141}
142
143impl PoolState {
144    fn send(&self, msg: Message) {
145        self.tx.lock().unwrap().send(msg).unwrap();
146    }
147
148    fn work(
149        &self,
150        idx: usize,
151        after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
152        before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
153    ) {
154        let _scope = enter().unwrap();
155        if let Some(after_start) = after_start {
156            after_start(idx);
157        }
158        loop {
159            let msg = self.rx.lock().unwrap().recv().unwrap();
160            match msg {
161                Message::Run(task) => task.run(),
162                Message::Close => break,
163            }
164        }
165        if let Some(before_stop) = before_stop {
166            before_stop(idx);
167        }
168    }
169}
170
171impl Clone for ThreadPool {
172    fn clone(&self) -> Self {
173        self.state.cnt.fetch_add(1, Ordering::Relaxed);
174        Self { state: self.state.clone() }
175    }
176}
177
178impl Drop for ThreadPool {
179    fn drop(&mut self) {
180        if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
181            for _ in 0..self.state.size {
182                self.state.send(Message::Close);
183            }
184        }
185    }
186}
187
188impl ThreadPoolBuilder {
189    /// Create a default thread pool configuration.
190    ///
191    /// See the other methods on this type for details on the defaults.
192    pub fn new() -> Self {
193        Self {
194            pool_size: cmp::max(1, num_cpus::get()),
195            stack_size: 0,
196            name_prefix: None,
197            after_start: None,
198            before_stop: None,
199        }
200    }
201
202    /// Set size of a future ThreadPool
203    ///
204    /// The size of a thread pool is the number of worker threads spawned. By
205    /// default, this is equal to the number of CPU cores.
206    ///
207    /// # Panics
208    ///
209    /// Panics if `pool_size == 0`.
210    pub fn pool_size(&mut self, size: usize) -> &mut Self {
211        assert!(size > 0);
212        self.pool_size = size;
213        self
214    }
215
216    /// Set stack size of threads in the pool, in bytes.
217    ///
218    /// By default, worker threads use Rust's standard stack size.
219    pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
220        self.stack_size = stack_size;
221        self
222    }
223
224    /// Set thread name prefix of a future ThreadPool.
225    ///
226    /// Thread name prefix is used for generating thread names. For example, if prefix is
227    /// `my-pool-`, then threads in the pool will get names like `my-pool-1` etc.
228    ///
229    /// By default, worker threads are assigned Rust's standard thread name.
230    pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
231        self.name_prefix = Some(name_prefix.into());
232        self
233    }
234
235    /// Execute the closure `f` immediately after each worker thread is started,
236    /// but before running any tasks on it.
237    ///
238    /// This hook is intended for bookkeeping and monitoring.
239    /// The closure `f` will be dropped after the `builder` is dropped
240    /// and all worker threads in the pool have executed it.
241    ///
242    /// The closure provided will receive an index corresponding to the worker
243    /// thread it's running on.
244    pub fn after_start<F>(&mut self, f: F) -> &mut Self
245    where
246        F: Fn(usize) + Send + Sync + 'static,
247    {
248        self.after_start = Some(Arc::new(f));
249        self
250    }
251
252    /// Execute closure `f` just prior to shutting down each worker thread.
253    ///
254    /// This hook is intended for bookkeeping and monitoring.
255    /// The closure `f` will be dropped after the `builder` is dropped
256    /// and all threads in the pool have executed it.
257    ///
258    /// The closure provided will receive an index corresponding to the worker
259    /// thread it's running on.
260    pub fn before_stop<F>(&mut self, f: F) -> &mut Self
261    where
262        F: Fn(usize) + Send + Sync + 'static,
263    {
264        self.before_stop = Some(Arc::new(f));
265        self
266    }
267
268    /// Create a [`ThreadPool`](ThreadPool) with the given configuration.
269    pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
270        let (tx, rx) = mpsc::channel();
271        let pool = ThreadPool {
272            state: Arc::new(PoolState {
273                tx: Mutex::new(tx),
274                rx: Mutex::new(rx),
275                cnt: AtomicUsize::new(1),
276                size: self.pool_size,
277            }),
278        };
279
280        for counter in 0..self.pool_size {
281            let state = pool.state.clone();
282            let after_start = self.after_start.clone();
283            let before_stop = self.before_stop.clone();
284            let mut thread_builder = thread::Builder::new();
285            if let Some(ref name_prefix) = self.name_prefix {
286                thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter));
287            }
288            if self.stack_size > 0 {
289                thread_builder = thread_builder.stack_size(self.stack_size);
290            }
291            thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
292        }
293        Ok(pool)
294    }
295}
296
297impl Default for ThreadPoolBuilder {
298    fn default() -> Self {
299        Self::new()
300    }
301}
302
303/// A task responsible for polling a future to completion.
304struct Task {
305    future: FutureObj<'static, ()>,
306    exec: ThreadPool,
307    wake_handle: Arc<WakeHandle>,
308}
309
310struct WakeHandle {
311    mutex: UnparkMutex<Task>,
312    exec: ThreadPool,
313}
314
315impl Task {
316    /// Actually run the task (invoking `poll` on the future) on the current
317    /// thread.
318    fn run(self) {
319        let Self { mut future, wake_handle, mut exec } = self;
320        let waker = waker_ref(&wake_handle);
321        let mut cx = Context::from_waker(&waker);
322
323        // Safety: The ownership of this `Task` object is evidence that
324        // we are in the `POLLING`/`REPOLL` state for the mutex.
325        unsafe {
326            wake_handle.mutex.start_poll();
327
328            loop {
329                let res = future.poll_unpin(&mut cx);
330                match res {
331                    Poll::Pending => {}
332                    Poll::Ready(()) => return wake_handle.mutex.complete(),
333                }
334                let task = Self { future, wake_handle: wake_handle.clone(), exec };
335                match wake_handle.mutex.wait(task) {
336                    Ok(()) => return, // we've waited
337                    Err(task) => {
338                        // someone's notified us
339                        future = task.future;
340                        exec = task.exec;
341                    }
342                }
343            }
344        }
345    }
346}
347
348impl fmt::Debug for Task {
349    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350        f.debug_struct("Task").field("contents", &"...").finish()
351    }
352}
353
354impl ArcWake for WakeHandle {
355    fn wake_by_ref(arc_self: &Arc<Self>) {
356        if let Ok(task) = arc_self.mutex.notify() {
357            arc_self.exec.state.send(Message::Run(task))
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_drop_after_start() {
368        {
369            let (tx, rx) = mpsc::sync_channel(2);
370            let _cpu_pool = ThreadPoolBuilder::new()
371                .pool_size(2)
372                .after_start(move |_| tx.send(1).unwrap())
373                .create()
374                .unwrap();
375
376            // After ThreadPoolBuilder is deconstructed, the tx should be dropped
377            // so that we can use rx as an iterator.
378            let count = rx.into_iter().count();
379            assert_eq!(count, 2);
380        }
381        std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
382    }
383}