1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.jboss.netty.handler.execution;
17
18 import java.io.IOException;
19 import java.lang.reflect.Method;
20 import java.util.HashSet;
21 import java.util.List;
22 import java.util.Set;
23 import java.util.concurrent.ConcurrentMap;
24 import java.util.concurrent.Executor;
25 import java.util.concurrent.Executors;
26 import java.util.concurrent.LinkedBlockingQueue;
27 import java.util.concurrent.RejectedExecutionException;
28 import java.util.concurrent.RejectedExecutionHandler;
29 import java.util.concurrent.ThreadFactory;
30 import java.util.concurrent.ThreadPoolExecutor;
31 import java.util.concurrent.TimeUnit;
32 import java.util.concurrent.atomic.AtomicLong;
33
34 import org.jboss.netty.buffer.ChannelBuffer;
35 import org.jboss.netty.channel.Channel;
36 import org.jboss.netty.channel.ChannelEvent;
37 import org.jboss.netty.channel.ChannelFuture;
38 import org.jboss.netty.channel.ChannelHandlerContext;
39 import org.jboss.netty.channel.ChannelState;
40 import org.jboss.netty.channel.ChannelStateEvent;
41 import org.jboss.netty.channel.Channels;
42 import org.jboss.netty.channel.MessageEvent;
43 import org.jboss.netty.channel.WriteCompletionEvent;
44 import org.jboss.netty.logging.InternalLogger;
45 import org.jboss.netty.logging.InternalLoggerFactory;
46 import org.jboss.netty.util.DefaultObjectSizeEstimator;
47 import org.jboss.netty.util.ObjectSizeEstimator;
48 import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
49 import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
140
141 private static final InternalLogger logger =
142 InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
143
144 private static final SharedResourceMisuseDetector misuseDetector =
145 new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
146
147 private volatile Settings settings;
148
149 private final ConcurrentMap<Channel, AtomicLong> channelCounters =
150 new ConcurrentIdentityHashMap<Channel, AtomicLong>();
151 private final Limiter totalLimiter;
152
153 private volatile boolean notifyOnShutdown;
154
155
156
157
158
159
160
161
162
163
164 public MemoryAwareThreadPoolExecutor(
165 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
166
167 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
168 }
169
170
171
172
173
174
175
176
177
178
179
180
181 public MemoryAwareThreadPoolExecutor(
182 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
183 long keepAliveTime, TimeUnit unit) {
184
185 this(
186 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
187 Executors.defaultThreadFactory());
188 }
189
190
191
192
193
194
195
196
197
198
199
200
201
202 public MemoryAwareThreadPoolExecutor(
203 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
204 long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
205
206 this(
207 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
208 new DefaultObjectSizeEstimator(), threadFactory);
209 }
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224 public MemoryAwareThreadPoolExecutor(
225 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
226 long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
227 ThreadFactory threadFactory) {
228
229 super(corePoolSize, corePoolSize, keepAliveTime, unit,
230 new LinkedBlockingQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
231
232 if (objectSizeEstimator == null) {
233 throw new NullPointerException("objectSizeEstimator");
234 }
235 if (maxChannelMemorySize < 0) {
236 throw new IllegalArgumentException(
237 "maxChannelMemorySize: " + maxChannelMemorySize);
238 }
239 if (maxTotalMemorySize < 0) {
240 throw new IllegalArgumentException(
241 "maxTotalMemorySize: " + maxTotalMemorySize);
242 }
243
244
245
246 try {
247 Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
248 m.invoke(this, Boolean.TRUE);
249 } catch (Throwable t) {
250
251 logger.debug(
252 "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
253 "supported in this platform.");
254 }
255
256 settings = new Settings(
257 objectSizeEstimator, maxChannelMemorySize);
258
259 if (maxTotalMemorySize == 0) {
260 totalLimiter = null;
261 } else {
262 totalLimiter = new Limiter(maxTotalMemorySize);
263 }
264
265
266 misuseDetector.increase();
267 }
268
269 @Override
270 protected void terminated() {
271 super.terminated();
272 misuseDetector.decrease();
273 }
274
275
276
277
278 @Override
279 public List<Runnable> shutdownNow() {
280 return shutdownNow(notifyOnShutdown);
281 }
282
283
284
285
286
287
288
289
290
291
292
293
294
295 public List<Runnable> shutdownNow(boolean notify) {
296 if (!notify) {
297 return super.shutdownNow();
298 }
299 Throwable cause = null;
300 Set<Channel> channels = null;
301
302 List<Runnable> tasks = super.shutdownNow();
303
304
305 for (Runnable task: tasks) {
306 if (task instanceof ChannelEventRunnable) {
307 if (cause == null) {
308 cause = new IOException("Unable to process queued event");
309 }
310 ChannelEvent event = ((ChannelEventRunnable) task).getEvent();
311 event.getFuture().setFailure(cause);
312
313 if (channels == null) {
314 channels = new HashSet<Channel>();
315 }
316
317
318
319 channels.add(event.getChannel());
320 }
321 }
322
323
324 if (channels != null) {
325 for (Channel channel: channels) {
326 Channels.fireExceptionCaughtLater(channel, cause);
327 }
328 }
329 return tasks;
330 }
331
332
333
334
335 public ObjectSizeEstimator getObjectSizeEstimator() {
336 return settings.objectSizeEstimator;
337 }
338
339
340
341
342 public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
343 if (objectSizeEstimator == null) {
344 throw new NullPointerException("objectSizeEstimator");
345 }
346
347 settings = new Settings(
348 objectSizeEstimator,
349 settings.maxChannelMemorySize);
350 }
351
352
353
354
355 public long getMaxChannelMemorySize() {
356 return settings.maxChannelMemorySize;
357 }
358
359
360
361
362
363 public void setMaxChannelMemorySize(long maxChannelMemorySize) {
364 if (maxChannelMemorySize < 0) {
365 throw new IllegalArgumentException(
366 "maxChannelMemorySize: " + maxChannelMemorySize);
367 }
368
369 if (getTaskCount() > 0) {
370 throw new IllegalStateException(
371 "can't be changed after a task is executed");
372 }
373
374 settings = new Settings(
375 settings.objectSizeEstimator,
376 maxChannelMemorySize);
377 }
378
379
380
381
382 public long getMaxTotalMemorySize() {
383 return totalLimiter.limit;
384 }
385
386
387
388
389
390 @Deprecated
391 public void setMaxTotalMemorySize(long maxTotalMemorySize) {
392 if (maxTotalMemorySize < 0) {
393 throw new IllegalArgumentException(
394 "maxTotalMemorySize: " + maxTotalMemorySize);
395 }
396
397 if (getTaskCount() > 0) {
398 throw new IllegalStateException(
399 "can't be changed after a task is executed");
400 }
401 }
402
403
404
405
406
407
408
409
410
411
412
413 public void setNotifyChannelFuturesOnShutdown(boolean notifyOnShutdown) {
414 this.notifyOnShutdown = notifyOnShutdown;
415 }
416
417
418
419
420
421 public boolean getNotifyChannelFuturesOnShutdown() {
422 return notifyOnShutdown;
423 }
424
425
426
427 @Override
428 public void execute(Runnable command) {
429 if (command instanceof ChannelDownstreamEventRunnable) {
430 throw new RejectedExecutionException("command must be enclosed with an upstream event.");
431 }
432 if (!(command instanceof ChannelEventRunnable)) {
433 command = new MemoryAwareRunnable(command);
434 }
435
436 increaseCounter(command);
437 doExecute(command);
438 }
439
440
441
442
443
444 protected void doExecute(Runnable task) {
445 doUnorderedExecute(task);
446 }
447
448
449
450
451 protected final void doUnorderedExecute(Runnable task) {
452 super.execute(task);
453 }
454
455 @Override
456 public boolean remove(Runnable task) {
457 boolean removed = super.remove(task);
458 if (removed) {
459 decreaseCounter(task);
460 }
461 return removed;
462 }
463
464 @Override
465 protected void beforeExecute(Thread t, Runnable r) {
466 super.beforeExecute(t, r);
467 decreaseCounter(r);
468 }
469
470 protected void increaseCounter(Runnable task) {
471 if (!shouldCount(task)) {
472 return;
473 }
474
475 Settings settings = this.settings;
476 long maxChannelMemorySize = settings.maxChannelMemorySize;
477
478 int increment = settings.objectSizeEstimator.estimateSize(task);
479
480 if (task instanceof ChannelEventRunnable) {
481 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
482 eventTask.estimatedSize = increment;
483 Channel channel = eventTask.getEvent().getChannel();
484 long channelCounter = getChannelCounter(channel).addAndGet(increment);
485
486 if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
487 if (channel.isReadable()) {
488
489 ChannelHandlerContext ctx = eventTask.getContext();
490 if (ctx.getHandler() instanceof ExecutionHandler) {
491
492 ctx.setAttachment(Boolean.TRUE);
493 }
494 channel.setReadable(false);
495 }
496 }
497 } else {
498 ((MemoryAwareRunnable) task).estimatedSize = increment;
499 }
500
501 if (totalLimiter != null) {
502 totalLimiter.increase(increment);
503 }
504 }
505
506 protected void decreaseCounter(Runnable task) {
507 if (!shouldCount(task)) {
508 return;
509 }
510
511 Settings settings = this.settings;
512 long maxChannelMemorySize = settings.maxChannelMemorySize;
513
514 int increment;
515 if (task instanceof ChannelEventRunnable) {
516 increment = ((ChannelEventRunnable) task).estimatedSize;
517 } else {
518 increment = ((MemoryAwareRunnable) task).estimatedSize;
519 }
520
521 if (totalLimiter != null) {
522 totalLimiter.decrease(increment);
523 }
524
525 if (task instanceof ChannelEventRunnable) {
526 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
527 Channel channel = eventTask.getEvent().getChannel();
528 long channelCounter = getChannelCounter(channel).addAndGet(-increment);
529
530 if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
531 if (!channel.isReadable()) {
532
533 ChannelHandlerContext ctx = eventTask.getContext();
534 if (ctx.getHandler() instanceof ExecutionHandler) {
535
536
537
538
539
540 if (ctx.getAttachment() != null) {
541
542 ctx.setAttachment(null);
543 channel.setReadable(true);
544 }
545 } else {
546 channel.setReadable(true);
547 }
548 }
549 }
550 }
551 }
552
553 private AtomicLong getChannelCounter(Channel channel) {
554 AtomicLong counter = channelCounters.get(channel);
555 if (counter == null) {
556 counter = new AtomicLong();
557 AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
558 if (oldCounter != null) {
559 counter = oldCounter;
560 }
561 }
562
563
564 if (!channel.isOpen()) {
565 channelCounters.remove(channel);
566 }
567 return counter;
568 }
569
570
571
572
573
574
575
576 protected boolean shouldCount(Runnable task) {
577 if (task instanceof ChannelUpstreamEventRunnable) {
578 ChannelUpstreamEventRunnable r = (ChannelUpstreamEventRunnable) task;
579 ChannelEvent e = r.getEvent();
580 if (e instanceof WriteCompletionEvent) {
581 return false;
582 } else if (e instanceof ChannelStateEvent) {
583 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
584 return false;
585 }
586 }
587 }
588 return true;
589 }
590
591 private static final class Settings {
592 final ObjectSizeEstimator objectSizeEstimator;
593 final long maxChannelMemorySize;
594
595 Settings(ObjectSizeEstimator objectSizeEstimator,
596 long maxChannelMemorySize) {
597 this.objectSizeEstimator = objectSizeEstimator;
598 this.maxChannelMemorySize = maxChannelMemorySize;
599 }
600 }
601
602 private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
603 public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
604 try {
605 final Thread t = new Thread(r, "Temporary task executor");
606 t.start();
607 } catch (Throwable e) {
608 throw new RejectedExecutionException(
609 "Failed to start a new thread", e);
610 }
611 }
612 }
613
614 private static final class MemoryAwareRunnable implements Runnable {
615 final Runnable task;
616 int estimatedSize;
617
618 MemoryAwareRunnable(Runnable task) {
619 this.task = task;
620 }
621
622 public void run() {
623 task.run();
624 }
625 }
626
627
628 private static class Limiter {
629
630 final long limit;
631 private long counter;
632 private int waiters;
633
634 Limiter(long limit) {
635 this.limit = limit;
636 }
637
638 synchronized void increase(long amount) {
639 while (counter >= limit) {
640 waiters ++;
641 try {
642 wait();
643 } catch (InterruptedException e) {
644 Thread.currentThread().interrupt();
645 } finally {
646 waiters --;
647 }
648 }
649 counter += amount;
650 }
651
652 synchronized void decrease(long amount) {
653 counter -= amount;
654 if (counter < limit && waiters > 0) {
655 notifyAll();
656 }
657 }
658 }
659 }