1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.jboss.netty.handler.execution;
17
18 import java.lang.reflect.Method;
19 import java.util.concurrent.ConcurrentMap;
20 import java.util.concurrent.Executor;
21 import java.util.concurrent.Executors;
22 import java.util.concurrent.RejectedExecutionException;
23 import java.util.concurrent.RejectedExecutionHandler;
24 import java.util.concurrent.Semaphore;
25 import java.util.concurrent.ThreadFactory;
26 import java.util.concurrent.ThreadPoolExecutor;
27 import java.util.concurrent.TimeUnit;
28 import java.util.concurrent.atomic.AtomicLong;
29
30 import org.jboss.netty.buffer.ChannelBuffer;
31 import org.jboss.netty.channel.Channel;
32 import org.jboss.netty.channel.ChannelEvent;
33 import org.jboss.netty.channel.ChannelHandlerContext;
34 import org.jboss.netty.channel.ChannelState;
35 import org.jboss.netty.channel.ChannelStateEvent;
36 import org.jboss.netty.channel.MessageEvent;
37 import org.jboss.netty.channel.WriteCompletionEvent;
38 import org.jboss.netty.logging.InternalLogger;
39 import org.jboss.netty.logging.InternalLoggerFactory;
40 import org.jboss.netty.util.DefaultObjectSizeEstimator;
41 import org.jboss.netty.util.ObjectSizeEstimator;
42 import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
43 import org.jboss.netty.util.internal.LinkedTransferQueue;
44 import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
140
141 private static final InternalLogger logger =
142 InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
143
144 private static final SharedResourceMisuseDetector misuseDetector =
145 new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
146
147 private volatile Settings settings;
148
149 private final ConcurrentMap<Channel, AtomicLong> channelCounters =
150 new ConcurrentIdentityHashMap<Channel, AtomicLong>();
151 private final AtomicLong totalCounter = new AtomicLong();
152
153 private final Semaphore semaphore = new Semaphore(0);
154
155
156
157
158
159
160
161
162
163
164 public MemoryAwareThreadPoolExecutor(
165 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
166
167 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
168 }
169
170
171
172
173
174
175
176
177
178
179
180
181 public MemoryAwareThreadPoolExecutor(
182 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
183 long keepAliveTime, TimeUnit unit) {
184
185 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit, Executors.defaultThreadFactory());
186 }
187
188
189
190
191
192
193
194
195
196
197
198
199
200 public MemoryAwareThreadPoolExecutor(
201 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
202 long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
203
204 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit, new DefaultObjectSizeEstimator(), threadFactory);
205 }
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220 public MemoryAwareThreadPoolExecutor(
221 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
222 long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
223 ThreadFactory threadFactory) {
224
225 super(corePoolSize, corePoolSize, keepAliveTime, unit,
226 new LinkedTransferQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
227
228 if (objectSizeEstimator == null) {
229 throw new NullPointerException("objectSizeEstimator");
230 }
231 if (maxChannelMemorySize < 0) {
232 throw new IllegalArgumentException(
233 "maxChannelMemorySize: " + maxChannelMemorySize);
234 }
235 if (maxTotalMemorySize < 0) {
236 throw new IllegalArgumentException(
237 "maxTotalMemorySize: " + maxTotalMemorySize);
238 }
239
240
241
242 try {
243 Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
244 m.invoke(this, Boolean.TRUE);
245 } catch (Throwable t) {
246
247 logger.debug(
248 "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
249 "supported in this platform.");
250 }
251
252 settings = new Settings(
253 objectSizeEstimator, maxChannelMemorySize, maxTotalMemorySize);
254
255
256 misuseDetector.increase();
257 }
258
259 @Override
260 protected void terminated() {
261 super.terminated();
262 misuseDetector.decrease();
263 }
264
265
266
267
268 public ObjectSizeEstimator getObjectSizeEstimator() {
269 return settings.objectSizeEstimator;
270 }
271
272
273
274
275 public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
276 if (objectSizeEstimator == null) {
277 throw new NullPointerException("objectSizeEstimator");
278 }
279
280 settings = new Settings(
281 objectSizeEstimator,
282 settings.maxChannelMemorySize, settings.maxTotalMemorySize);
283 }
284
285
286
287
288 public long getMaxChannelMemorySize() {
289 return settings.maxChannelMemorySize;
290 }
291
292
293
294
295
296 public void setMaxChannelMemorySize(long maxChannelMemorySize) {
297 if (maxChannelMemorySize < 0) {
298 throw new IllegalArgumentException(
299 "maxChannelMemorySize: " + maxChannelMemorySize);
300 }
301
302 if (getTaskCount() > 0) {
303 throw new IllegalStateException(
304 "can't be changed after a task is executed");
305 }
306
307 settings = new Settings(
308 settings.objectSizeEstimator,
309 maxChannelMemorySize, settings.maxTotalMemorySize);
310 }
311
312
313
314
315 public long getMaxTotalMemorySize() {
316 return settings.maxTotalMemorySize;
317 }
318
319
320
321
322
323 public void setMaxTotalMemorySize(long maxTotalMemorySize) {
324 if (maxTotalMemorySize < 0) {
325 throw new IllegalArgumentException(
326 "maxTotalMemorySize: " + maxTotalMemorySize);
327 }
328
329 if (getTaskCount() > 0) {
330 throw new IllegalStateException(
331 "can't be changed after a task is executed");
332 }
333
334 settings = new Settings(
335 settings.objectSizeEstimator,
336 settings.maxChannelMemorySize, maxTotalMemorySize);
337 }
338
339 @Override
340 public void execute(Runnable command) {
341 if (!(command instanceof ChannelEventRunnable)) {
342 command = new MemoryAwareRunnable(command);
343 }
344
345 boolean pause = increaseCounter(command);
346 doExecute(command);
347 if (pause) {
348
349 semaphore.acquireUninterruptibly();
350 }
351 }
352
353
354
355
356
357 protected void doExecute(Runnable task) {
358 doUnorderedExecute(task);
359 }
360
361
362
363
364 protected final void doUnorderedExecute(Runnable task) {
365 super.execute(task);
366 }
367
368 @Override
369 public boolean remove(Runnable task) {
370 boolean removed = super.remove(task);
371 if (removed) {
372 decreaseCounter(task);
373 }
374 return removed;
375 }
376
377 @Override
378 protected void beforeExecute(Thread t, Runnable r) {
379 super.beforeExecute(t, r);
380 decreaseCounter(r);
381 }
382
383 protected boolean increaseCounter(Runnable task) {
384 if (!shouldCount(task)) {
385 return false;
386 }
387
388 Settings settings = this.settings;
389 long maxTotalMemorySize = settings.maxTotalMemorySize;
390 long maxChannelMemorySize = settings.maxChannelMemorySize;
391
392 int increment = settings.objectSizeEstimator.estimateSize(task);
393 long totalCounter = this.totalCounter.addAndGet(increment);
394
395 if (task instanceof ChannelEventRunnable) {
396 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
397 eventTask.estimatedSize = increment;
398 Channel channel = eventTask.getEvent().getChannel();
399 long channelCounter = getChannelCounter(channel).addAndGet(increment);
400
401 if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
402 if (channel.isReadable()) {
403
404 ChannelHandlerContext ctx = eventTask.getContext();
405 if (ctx.getHandler() instanceof ExecutionHandler) {
406
407 ctx.setAttachment(Boolean.TRUE);
408 }
409 channel.setReadable(false);
410 }
411 }
412 } else {
413 ((MemoryAwareRunnable) task).estimatedSize = increment;
414 }
415
416
417 return maxTotalMemorySize != 0 && totalCounter >= maxTotalMemorySize;
418 }
419
420 protected void decreaseCounter(Runnable task) {
421 if (!shouldCount(task)) {
422 return;
423 }
424
425 Settings settings = this.settings;
426 long maxTotalMemorySize = settings.maxTotalMemorySize;
427 long maxChannelMemorySize = settings.maxChannelMemorySize;
428
429 int increment;
430 if (task instanceof ChannelEventRunnable) {
431 increment = ((ChannelEventRunnable) task).estimatedSize;
432 } else {
433 increment = ((MemoryAwareRunnable) task).estimatedSize;
434 }
435
436 long totalCounter = this.totalCounter.addAndGet(-increment);
437
438
439 if (maxTotalMemorySize != 0 && totalCounter + increment >= maxTotalMemorySize) {
440
441 while (semaphore.hasQueuedThreads()) {
442 semaphore.release();
443 }
444 }
445
446 if (task instanceof ChannelEventRunnable) {
447 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
448 Channel channel = eventTask.getEvent().getChannel();
449 long channelCounter = getChannelCounter(channel).addAndGet(-increment);
450
451 if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
452 if (!channel.isReadable()) {
453
454 ChannelHandlerContext ctx = eventTask.getContext();
455 if (ctx.getHandler() instanceof ExecutionHandler) {
456
457 ctx.setAttachment(null);
458 }
459 channel.setReadable(true);
460 }
461 }
462 }
463 }
464
465 private AtomicLong getChannelCounter(Channel channel) {
466 AtomicLong counter = channelCounters.get(channel);
467 if (counter == null) {
468 counter = new AtomicLong();
469 AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
470 if (oldCounter != null) {
471 counter = oldCounter;
472 }
473 }
474
475
476 if (!channel.isOpen()) {
477 channelCounters.remove(channel);
478 }
479 return counter;
480 }
481
482
483
484
485
486
487
488 protected boolean shouldCount(Runnable task) {
489 if (task instanceof ChannelEventRunnable) {
490 ChannelEventRunnable r = (ChannelEventRunnable) task;
491 ChannelEvent e = r.getEvent();
492 if (e instanceof WriteCompletionEvent) {
493 return false;
494 } else if (e instanceof ChannelStateEvent) {
495 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
496 return false;
497 }
498 }
499 }
500 return true;
501 }
502
503 private static final class Settings {
504 final ObjectSizeEstimator objectSizeEstimator;
505 final long maxChannelMemorySize;
506 final long maxTotalMemorySize;
507
508 Settings(ObjectSizeEstimator objectSizeEstimator,
509 long maxChannelMemorySize, long maxTotalMemorySize) {
510 this.objectSizeEstimator = objectSizeEstimator;
511 this.maxChannelMemorySize = maxChannelMemorySize;
512 this.maxTotalMemorySize = maxTotalMemorySize;
513 }
514 }
515
516 private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
517 NewThreadRunsPolicy() {
518 super();
519 }
520
521 public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
522 try {
523 final Thread t = new Thread(r, "Temporary task executor");
524 t.start();
525 } catch (Throwable e) {
526 throw new RejectedExecutionException(
527 "Failed to start a new thread", e);
528 }
529 }
530 }
531
532 private static final class MemoryAwareRunnable implements Runnable {
533 final Runnable task;
534 int estimatedSize;
535
536 MemoryAwareRunnable(Runnable task) {
537 this.task = task;
538 }
539
540 public void run() {
541 task.run();
542 }
543 }
544 }