码迷,mamicode.com
首页 > 编程语言 > 详细

一个简单缩略版的python 线程池实现

时间:2016-07-24 13:28:09      阅读:1002      评论:0      收藏:0      [点我收藏+]

标签:

  1 #-*-coding:utf-8-*-2
  2 
  3 import threading
  4 import queue
  5 import itertools
  6 import os
  7 import time
  8 
  9 
 10 RUN = 0
 11 CLOSE = 1
 12 TERMINATE = 2
 13 job_counter = itertools.count()
 14 
 15 
 16 class Pool(object):
 17 
 18     def __init__(self, max_thread_num=None):
 19         self.__chk_thread_num(max_thread_num)
 20         self._setup_queues()
 21         self._cache = {}                    # 存储任务运行结果ApplyResult对象的cache
 22         self._state = RUN
 23         self._max_num = max_thread_num      # 线程上限
 24         self._pool = []                     # 真.线程池!
 25         self._add_thread_to_pool()
 26 
 27         # 监控结果cache的handle线程
 28         self._worker_handler = threading.Thread(
 29             target=Pool._handle_workers,
 30             args=(self, )
 31             )
 32         self.__init_handler(self._worker_handler)
 33 
 34         # 监控输入任务的taskqueue,写入 inqueue的handle线程
 35         self._task_handler = threading.Thread(
 36             target=Pool._handle_tasks,
 37             args=(self._taskqueue, self._quick_put, self._outqueue,
 38                   self._pool, self._cache)
 39             )
 40         self.__init_handler(self._task_handler)
 41 
 42         # 监控ouqueue,写入运行结果的handle线程
 43         self._result_handler = threading.Thread(
 44             target=Pool._handle_results,
 45             args=(self._quick_get, self._cache)
 46             )
 47         self.__init_handler(self._result_handler)
 48 
 49     def __chk_thread_num(self, max_thread_num):
 50         # 检查最大线程数是否合法
 51         if max_thread_num is None:
 52             max_thread_num = os.cpu_count() or 1
 53         if max_thread_num < 1:
 54             raise ValueError("Number of thread should bigger than 1")
 55 
 56     @staticmethod
 57     def __init_handler(handler):
 58         handler.daemon = True
 59         handler._state = RUN
 60         handler.start()
 61 
 62     def _add_thread_to_pool(self):
 63         # 给线程池补充线程
 64         for i in range(self._max_num - len(self._pool)):
 65             thread_worker = self.Process(target=self.worker)
 66             self._pool.append(thread_worker)
 67             thread_worker.daemon = True
 68             thread_worker.start()
 69 
 70     def _setup_queues(self):
 71         self._inqueue = queue.Queue()           # worker的输入队列,保护task信息
 72         self._outqueue = queue.Queue()          # worker的输出队列,包含task运行结果
 73         self._quick_put = self._inqueue.put
 74         self._quick_get = self._outqueue.get
 75         self._taskqueue = queue.Queue()         # 线程池获得的原始taskqueue
 76 
 77     def Process(self, *args, **kwds):
 78         return threading.Thread(*args, **kwds)
 79 
 80     def apply(self, func, args=(), kwds={}):
 81         # apply 直接调用get,
 82         # 阻塞到任务运行完,再直接返回
 83         assert self._state == RUN
 84         return self.apply_async(func, args, kwds).get()
 85 
 86     def apply_async(self, func, args=(), kwds={}, callback=None):
 87         # 异步调用,将任务信息放到_taskqueue里之后执行,
 88         # 返回的是异步调用结果对象,可以通过该对象再取真正的结果
 89         if self._state != RUN:
 90             raise ValueError("Pool not running")
 91         result = ApplyResult(self._cache, callback)
 92 
 93         self._taskqueue.put((result._job, None, func, args, kwds))
 94         return result
 95 
 96     @staticmethod
 97     def _handle_workers(pool):
 98         thread = threading.current_thread()
 99 
100         # 当线程池被close或teminate后
101         # 间隔检查存储结果cache的情况
102         # 若cache为空则发送None通知 tasks handle停止
103         while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
104             time.sleep(0.1)
105         pool._taskqueue.put(None)
106 
107     @staticmethod
108     def _handle_tasks(taskqueue, put, outqueue, pool, cache):
109         thread = threading.current_thread()
110         try:
111             # taskqueue.get到None时结束循环
112             for task_info in iter(taskqueue.get, None):
113                 if thread._state:
114                     break
115                 try:
116                     put(task_info)
117                 except Exception as e:
118                     job, ind = task_info[:2]
119                     try:
120                         cache[job]._set((False, e))
121                     except KeyError:
122                         pass
123         except Exception as ex:
124             job, ind = (0, 0)
125             if job in cache:
126                 cache[job]._set((False, ex))
127 
128         try:
129             outqueue.put(None) # 通知result handle结束
130             # 通知pool里的所有worker结束
131             for p in pool:
132                 put(None)
133         except OSError:
134             print(task handler got OSError when sending sentinels)
135 
136     @staticmethod
137     def _handle_results(get, cache):
138         thread = threading.current_thread()
139         while 1:
140             try:
141                 task_rtn = get()
142             except (OSError, EOFError):
143                 return
144 
145             if thread._state:
146                 assert thread._state == TERMINATE
147                 break
148 
149             if task_rtn is None:
150                 break
151 
152             job, i, obj = task_rtn
153             try:
154                 cache[job]._set(obj)
155             except KeyError:
156                 pass
157 
158         # 将剩余在cache里的结果全部处理完
159         while cache and thread._state != TERMINATE:
160             try:
161                 task_rtn = get()
162             except (OSError, EOFError):
163                 return
164 
165             if task_rtn is None:
166                 continue
167             job, i, obj = task_rtn
168             try:
169                 cache[job]._set(obj)
170             except KeyError:
171                 pass
172 
173     def close(self):
174         if self._state == RUN:
175             self._state = CLOSE
176             self._worker_handler._state = CLOSE
177 
178     def terminate(self):
179         self._state = TERMINATE
180         self._worker_handler._state = TERMINATE
181         self._terminate_pool(self._inqueue, self._outqueue, self._pool,
182                   self._worker_handler, self._task_handler,
183                   self._result_handler, self._cache)
184 
185     def join(self):
186         assert self._state in (CLOSE, TERMINATE)
187         self._worker_handler.join()
188         self._task_handler.join()
189         self._result_handler.join()
190         for p in self._pool:
191             p.join()
192 
193     @staticmethod
194     def _help_stuff_finish(inqueue, size):
195         # 清空inqueue,放入标志None,通知pool里的所有worker结束
196         with inqueue.not_empty:
197             inqueue.queue.clear()
198             inqueue.queue.extend([None] * size)
199             inqueue.not_empty.notify_all()
200 
201     @classmethod
202     def _terminate_pool(cls, inqueue, outqueue, pool,
203                         worker_handler, task_handler, result_handler, cache):
204         worker_handler._state = TERMINATE
205         task_handler._state = TERMINATE
206 
207         cls._help_stuff_finish(inqueue, len(pool))
208 
209         assert result_handler.is_alive() or len(cache) == 0
210 
211         result_handler._state = TERMINATE
212         outqueue.put(None)                  # 终止标志
213 
214         # 等到三个监控handle都运行终止
215         # 防止有worker还没运行结束
216         for handler in (worker_handler, task_handler, result_handler):
217             if threading.current_thread() is not handler:
218                 handler.join()
219 
220     def worker(self):
221         # worker从inqueue中获取任务信息并执行
222         # 将结果写入outqueue
223         while 1:
224             try:
225                 task_info = self._inqueue.get()
226             except (EOFError, OSError):
227                 break
228 
229             if task_info is None:
230                 break
231 
232             job, i, func, args, kwds = task_info
233             try:
234                 result = (True, func(*args, **kwds))
235             except Exception as e:
236                 print(Exception occurred: %s\n%s % (e, e.__traceback__))
237                 result = (False, e)
238             try:
239                 self._outqueue.put((job, i, result))
240             except Exception as e:
241                 err_msg = "Exception occurred while sending %s: %s" % (result[1], e)
242                 print(err_msg)
243                 self._outqueue.put((job, i, (False, err_msg)))
244 
245 
246 class ApplyResult(object):
247     def __init__(self, cache, callback):
248         self._event = threading.Event()
249         self._job = next(job_counter)
250         self._cache = cache
251         self._callback = callback
252         self._success = False
253         self._value = None
254         cache[self._job] = self
255 
256     def ready(self):
257         return self._event.is_set()
258 
259     def wait(self, timeout=None):
260         self._event.wait(timeout)
261 
262     def get(self, timeout=None):
263         # 当_set结束,task有结果后结束阻塞
264         self.wait(timeout)
265         if not self.ready():
266             print (Timeout!)
267         if self._success:
268             return self._value
269         else:
270             raise self._value
271 
272     def _set(self, obj):
273         # 处理任务运行后的结果
274         self._success, self._value = obj
275         if self._callback and self._success:
276             self._callback(self._value)
277         self._event.set()
278         del self._cache[self._job]

 

参考multiprocessing的Pool,简单劣化版的替换成了线程的池子的版本...

一个简单缩略版的python 线程池实现

标签:

原文地址:http://www.cnblogs.com/nigel-woo/p/5700548.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!