1use crate::enter;
2use crate::unpark_mutex::UnparkMutex;
3use futures_core::future::Future;
4use futures_core::task::{Context, Poll};
5use futures_task::{waker_ref, ArcWake};
6use futures_task::{FutureObj, Spawn, SpawnError};
7use futures_util::future::FutureExt;
8use std::boxed::Box;
9use std::cmp;
10use std::fmt;
11use std::format;
12use std::io;
13use std::string::String;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::sync::mpsc;
16use std::sync::{Arc, Mutex};
17use std::thread;
18
19#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
31pub struct ThreadPool {
32 state: Arc<PoolState>,
33}
34
35#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
40pub struct ThreadPoolBuilder {
41 pool_size: usize,
42 stack_size: usize,
43 name_prefix: Option<String>,
44 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
45 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
46}
47
48#[allow(dead_code)]
49trait AssertSendSync: Send + Sync {}
50impl AssertSendSync for ThreadPool {}
51
52struct PoolState {
53 tx: Mutex<mpsc::Sender<Message>>,
54 rx: Mutex<mpsc::Receiver<Message>>,
55 cnt: AtomicUsize,
56 size: usize,
57}
58
59impl fmt::Debug for ThreadPool {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 f.debug_struct("ThreadPool").field("size", &self.state.size).finish()
62 }
63}
64
65impl fmt::Debug for ThreadPoolBuilder {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 f.debug_struct("ThreadPoolBuilder")
68 .field("pool_size", &self.pool_size)
69 .field("name_prefix", &self.name_prefix)
70 .finish()
71 }
72}
73
74enum Message {
75 Run(Task),
76 Close,
77}
78
79impl ThreadPool {
80 pub fn new() -> Result<Self, io::Error> {
86 ThreadPoolBuilder::new().create()
87 }
88
89 pub fn builder() -> ThreadPoolBuilder {
95 ThreadPoolBuilder::new()
96 }
97
98 pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) {
103 let task = Task {
104 future,
105 wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }),
106 exec: self.clone(),
107 };
108 self.state.send(Message::Run(task));
109 }
110
111 pub fn spawn_ok<Fut>(&self, future: Fut)
129 where
130 Fut: Future<Output = ()> + Send + 'static,
131 {
132 self.spawn_obj_ok(FutureObj::new(Box::new(future)))
133 }
134}
135
136impl Spawn for ThreadPool {
137 fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
138 self.spawn_obj_ok(future);
139 Ok(())
140 }
141}
142
143impl PoolState {
144 fn send(&self, msg: Message) {
145 self.tx.lock().unwrap().send(msg).unwrap();
146 }
147
148 fn work(
149 &self,
150 idx: usize,
151 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
152 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
153 ) {
154 let _scope = enter().unwrap();
155 if let Some(after_start) = after_start {
156 after_start(idx);
157 }
158 loop {
159 let msg = self.rx.lock().unwrap().recv().unwrap();
160 match msg {
161 Message::Run(task) => task.run(),
162 Message::Close => break,
163 }
164 }
165 if let Some(before_stop) = before_stop {
166 before_stop(idx);
167 }
168 }
169}
170
171impl Clone for ThreadPool {
172 fn clone(&self) -> Self {
173 self.state.cnt.fetch_add(1, Ordering::Relaxed);
174 Self { state: self.state.clone() }
175 }
176}
177
178impl Drop for ThreadPool {
179 fn drop(&mut self) {
180 if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
181 for _ in 0..self.state.size {
182 self.state.send(Message::Close);
183 }
184 }
185 }
186}
187
188impl ThreadPoolBuilder {
189 pub fn new() -> Self {
193 Self {
194 pool_size: cmp::max(1, num_cpus::get()),
195 stack_size: 0,
196 name_prefix: None,
197 after_start: None,
198 before_stop: None,
199 }
200 }
201
202 pub fn pool_size(&mut self, size: usize) -> &mut Self {
211 assert!(size > 0);
212 self.pool_size = size;
213 self
214 }
215
216 pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
220 self.stack_size = stack_size;
221 self
222 }
223
224 pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
231 self.name_prefix = Some(name_prefix.into());
232 self
233 }
234
235 pub fn after_start<F>(&mut self, f: F) -> &mut Self
245 where
246 F: Fn(usize) + Send + Sync + 'static,
247 {
248 self.after_start = Some(Arc::new(f));
249 self
250 }
251
252 pub fn before_stop<F>(&mut self, f: F) -> &mut Self
261 where
262 F: Fn(usize) + Send + Sync + 'static,
263 {
264 self.before_stop = Some(Arc::new(f));
265 self
266 }
267
268 pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
270 let (tx, rx) = mpsc::channel();
271 let pool = ThreadPool {
272 state: Arc::new(PoolState {
273 tx: Mutex::new(tx),
274 rx: Mutex::new(rx),
275 cnt: AtomicUsize::new(1),
276 size: self.pool_size,
277 }),
278 };
279
280 for counter in 0..self.pool_size {
281 let state = pool.state.clone();
282 let after_start = self.after_start.clone();
283 let before_stop = self.before_stop.clone();
284 let mut thread_builder = thread::Builder::new();
285 if let Some(ref name_prefix) = self.name_prefix {
286 thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter));
287 }
288 if self.stack_size > 0 {
289 thread_builder = thread_builder.stack_size(self.stack_size);
290 }
291 thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
292 }
293 Ok(pool)
294 }
295}
296
297impl Default for ThreadPoolBuilder {
298 fn default() -> Self {
299 Self::new()
300 }
301}
302
303struct Task {
305 future: FutureObj<'static, ()>,
306 exec: ThreadPool,
307 wake_handle: Arc<WakeHandle>,
308}
309
310struct WakeHandle {
311 mutex: UnparkMutex<Task>,
312 exec: ThreadPool,
313}
314
315impl Task {
316 fn run(self) {
319 let Self { mut future, wake_handle, mut exec } = self;
320 let waker = waker_ref(&wake_handle);
321 let mut cx = Context::from_waker(&waker);
322
323 unsafe {
326 wake_handle.mutex.start_poll();
327
328 loop {
329 let res = future.poll_unpin(&mut cx);
330 match res {
331 Poll::Pending => {}
332 Poll::Ready(()) => return wake_handle.mutex.complete(),
333 }
334 let task = Self { future, wake_handle: wake_handle.clone(), exec };
335 match wake_handle.mutex.wait(task) {
336 Ok(()) => return, Err(task) => {
338 future = task.future;
340 exec = task.exec;
341 }
342 }
343 }
344 }
345 }
346}
347
348impl fmt::Debug for Task {
349 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350 f.debug_struct("Task").field("contents", &"...").finish()
351 }
352}
353
354impl ArcWake for WakeHandle {
355 fn wake_by_ref(arc_self: &Arc<Self>) {
356 if let Ok(task) = arc_self.mutex.notify() {
357 arc_self.exec.state.send(Message::Run(task))
358 }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_drop_after_start() {
368 {
369 let (tx, rx) = mpsc::sync_channel(2);
370 let _cpu_pool = ThreadPoolBuilder::new()
371 .pool_size(2)
372 .after_start(move |_| tx.send(1).unwrap())
373 .create()
374 .unwrap();
375
376 let count = rx.into_iter().count();
379 assert_eq!(count, 2);
380 }
381 std::thread::sleep(std::time::Duration::from_millis(500)); }
383}