1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package org.mortbay.servlet;
16
17 import java.io.IOException;
18 import java.util.HashSet;
19 import java.util.Queue;
20 import java.util.StringTokenizer;
21 import java.util.Timer;
22 import java.util.concurrent.ConcurrentHashMap;
23 import java.util.concurrent.Semaphore;
24 import java.util.concurrent.TimeUnit;
25
26 import javax.servlet.Filter;
27 import javax.servlet.FilterChain;
28 import javax.servlet.FilterConfig;
29 import javax.servlet.ServletContext;
30 import javax.servlet.ServletException;
31 import javax.servlet.ServletRequest;
32 import javax.servlet.ServletResponse;
33 import javax.servlet.http.HttpServletRequest;
34 import javax.servlet.http.HttpServletResponse;
35 import javax.servlet.http.HttpSession;
36 import javax.servlet.http.HttpSessionBindingEvent;
37 import javax.servlet.http.HttpSessionBindingListener;
38
39 import org.mortbay.log.Log;
40 import org.mortbay.thread.Timeout;
41 import org.mortbay.util.ArrayQueue;
42 import org.mortbay.util.ajax.Continuation;
43 import org.mortbay.util.ajax.ContinuationSupport;
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96 public class DoSFilter implements Filter
97 {
98 final static String __TRACKER = "DoSFilter.Tracker";
99 final static String __THROTTLED = "DoSFilter.Throttled";
100
101 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
102 final static int __DEFAULT_DELAY_MS = 100;
103 final static int __DEFAULT_THROTTLE = 5;
104 final static int __DEFAULT_WAIT_MS=50;
105 final static long __DEFAULT_THROTTLE_MS = 30000L;
106 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
107 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
108
109 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
110 final static String DELAY_MS_INIT_PARAM = "delayMs";
111 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
112 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
113 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
114 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
115 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
116 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
117 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
118 final static String REMOTE_PORT_INIT_PARAM="remotePort";
119 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
120
121 final static int USER_AUTH = 2;
122 final static int USER_SESSION = 2;
123 final static int USER_IP = 1;
124 final static int USER_UNKNOWN = 0;
125
126 ServletContext _context;
127
128 protected long _delayMs;
129 protected long _throttleMs;
130 protected long _waitMs;
131 protected long _maxRequestMs;
132 protected long _maxIdleTrackerMs;
133 protected boolean _insertHeaders;
134 protected boolean _trackSessions;
135 protected boolean _remotePort;
136 protected Semaphore _passes;
137 protected Queue<Continuation>[] _queue;
138
139 protected int _maxRequestsPerSec;
140 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
141 private HashSet<String> _whitelist;
142
143 private final Timeout _requestTimeoutQ = new Timeout();
144 private final Timeout _trackerTimeoutQ = new Timeout();
145
146 private Thread _timerThread;
147 private volatile boolean _running;
148
149 public void init(FilterConfig filterConfig)
150 {
151 _context = filterConfig.getServletContext();
152
153 _queue = new Queue[getMaxPriority() + 1];
154 for (int p = 0; p < _queue.length; p++)
155 _queue[p] = new ArrayQueue<Continuation>();
156
157 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
158 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
159 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
160 _maxRequestsPerSec = baseRateLimit;
161
162 long delay = __DEFAULT_DELAY_MS;
163 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
164 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
165 _delayMs = delay;
166
167 int passes = __DEFAULT_THROTTLE;
168 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
169 passes = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
170 _passes = new Semaphore(passes,true);
171
172 long wait = __DEFAULT_WAIT_MS;
173 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
174 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
175 _waitMs = wait;
176
177 long suspend = __DEFAULT_THROTTLE_MS;
178 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
179 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
180 _throttleMs = suspend;
181
182 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
183 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
184 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
185 _maxRequestMs = maxRequestMs;
186
187 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
188 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
189 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
190 _maxIdleTrackerMs = maxIdleTrackerMs;
191
192 String whitelistString = "";
193 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
194 whitelistString = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
195
196
197 if (whitelistString.length() == 0 )
198 _whitelist = new HashSet<String>();
199 else
200 {
201 StringTokenizer tokenizer = new StringTokenizer(whitelistString, ",");
202 _whitelist = new HashSet<String>(tokenizer.countTokens());
203 while (tokenizer.hasMoreTokens())
204 _whitelist.add(tokenizer.nextToken().trim());
205
206 Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
207 }
208
209 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
210 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
211
212 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
213 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
214
215 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
216 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
217
218 _requestTimeoutQ.setNow();
219 _requestTimeoutQ.setDuration(_maxRequestMs);
220
221 _trackerTimeoutQ.setNow();
222 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
223
224 _running=true;
225 _timerThread = (new Thread()
226 {
227 public void run()
228 {
229 try
230 {
231 while (_running)
232 {
233 synchronized (_requestTimeoutQ)
234 {
235 _requestTimeoutQ.setNow();
236 _requestTimeoutQ.tick();
237
238 _trackerTimeoutQ.setNow(_requestTimeoutQ.getNow());
239 _trackerTimeoutQ.tick();
240 }
241 try
242 {
243 Thread.sleep(100);
244 }
245 catch (InterruptedException e)
246 {
247 Log.ignore(e);
248 }
249 }
250 }
251 finally
252 {
253 Log.debug("DoSFilter timer exited ");
254 }
255 }
256 });
257 _timerThread.start();
258 }
259
260
261 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
262 {
263 final HttpServletRequest srequest = (HttpServletRequest)request;
264 final HttpServletResponse sresponse = (HttpServletResponse)response;
265
266 final long now=_requestTimeoutQ.getNow();
267
268
269 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
270
271 if (tracker==null)
272 {
273
274
275
276 tracker = getRateTracker(request);
277
278
279 final boolean overRateLimit = tracker.isRateExceeded(now);
280
281
282 if (!overRateLimit)
283 {
284 doFilterChain(filterchain,srequest,sresponse);
285 return;
286 }
287
288
289 Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
290
291
292 switch((int)_delayMs)
293 {
294 case -1:
295 {
296
297 if (_insertHeaders)
298 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
299 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
300 return;
301 }
302 case 0:
303 {
304
305 request.setAttribute(__TRACKER,tracker);
306 break;
307 }
308 default:
309 {
310
311 if (_insertHeaders)
312 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
313 Continuation continuation = ContinuationSupport.getContinuation((HttpServletRequest)request,this);
314 request.setAttribute(__TRACKER,tracker);
315 continuation.suspend(_delayMs);
316
317 }
318 }
319 }
320
321
322 boolean accepted = false;
323 try
324 {
325
326 accepted = _passes.tryAcquire(_waitMs,TimeUnit.MILLISECONDS);
327
328 if (!accepted)
329 {
330
331
332 final Continuation continuation = ContinuationSupport.getContinuation((HttpServletRequest)request,this);
333
334 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
335 if (throttled!=Boolean.TRUE && _throttleMs>0)
336 {
337 int priority = getPriority(request,tracker);
338 request.setAttribute(__THROTTLED,Boolean.TRUE);
339 if (_insertHeaders)
340 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
341 synchronized (this)
342 {
343 _queue[priority].add(continuation);
344 continuation.reset();
345 if(continuation.suspend(_throttleMs))
346 {
347
348
349 _passes.acquire();
350 accepted = true;
351 }
352
353 }
354 }
355
356 else if (continuation.isResumed())
357 {
358
359 _passes.acquire();
360 accepted = true;
361 }
362 }
363
364
365 if (accepted)
366
367 doFilterChain(filterchain,srequest,sresponse);
368 else
369 {
370
371 if (_insertHeaders)
372 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
373 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
374 }
375 }
376 catch (InterruptedException e)
377 {
378 _context.log("DoS",e);
379 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
380 }
381 finally
382 {
383 if (accepted)
384 {
385
386 synchronized (_queue)
387 {
388 for (int p = _queue.length; p-- > 0;)
389 {
390 Continuation continuation = _queue[p].poll();
391
392 if (continuation != null)
393 {
394 continuation.resume();
395 break;
396 }
397 }
398 }
399 _passes.release();
400 }
401 }
402 }
403
404
405
406
407
408
409
410
411 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
412 throws IOException, ServletException
413 {
414 final Thread thread=Thread.currentThread();
415
416 final Timeout.Task requestTimeout = new Timeout.Task()
417 {
418 public void expired()
419 {
420 closeConnection(request, response, thread);
421 }
422 };
423
424 try
425 {
426 synchronized (_requestTimeoutQ)
427 {
428 _requestTimeoutQ.schedule(requestTimeout);
429 }
430 chain.doFilter(request,response);
431 }
432 finally
433 {
434 synchronized (_requestTimeoutQ)
435 {
436 requestTimeout.cancel();
437 }
438 }
439 }
440
441
442
443
444
445
446
447
448 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
449 {
450
451 if( !response.isCommitted() )
452 {
453 response.setHeader("Connection", "close");
454 }
455 try
456 {
457 try
458 {
459 response.getWriter().close();
460 }
461 catch (IllegalStateException e)
462 {
463 response.getOutputStream().close();
464 }
465 }
466 catch (IOException e)
467 {
468 Log.warn(e);
469 }
470
471
472 thread.interrupt();
473 }
474
475
476
477
478
479
480
481
482 protected int getPriority(ServletRequest request, RateTracker tracker)
483 {
484 if (extractUserId(request)!=null)
485 return USER_AUTH;
486 if (tracker!=null)
487 return tracker.getType();
488 return USER_UNKNOWN;
489 }
490
491
492
493
494 protected int getMaxPriority()
495 {
496 return USER_AUTH;
497 }
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515 public RateTracker getRateTracker(ServletRequest request)
516 {
517 HttpServletRequest srequest = (HttpServletRequest)request;
518
519 String loadId;
520 final int type;
521
522 loadId = extractUserId(request);
523 HttpSession session=srequest.getSession(false);
524 if (_trackSessions && session!=null && !session.isNew())
525 {
526 loadId=session.getId();
527 type = USER_SESSION;
528 }
529 else
530 {
531 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
532 type = USER_IP;
533 }
534
535 RateTracker tracker=_rateTrackers.get(loadId);
536
537 if (tracker==null)
538 {
539 RateTracker t;
540 if (_whitelist.contains(request.getRemoteAddr()))
541 {
542 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
543 }
544 else
545 {
546 t = new RateTracker(loadId,type,_maxRequestsPerSec);
547 }
548
549 tracker=_rateTrackers.putIfAbsent(loadId,t);
550 if (tracker==null)
551 tracker=t;
552
553 if (type == USER_IP)
554 {
555
556 synchronized (_trackerTimeoutQ)
557 {
558 _trackerTimeoutQ.schedule(tracker);
559 }
560 }
561 else if (session!=null)
562
563 session.setAttribute(__TRACKER,tracker);
564 }
565
566 return tracker;
567 }
568
569 public void destroy()
570 {
571 _running=false;
572 _timerThread.interrupt();
573 synchronized (_requestTimeoutQ)
574 {
575 _requestTimeoutQ.cancelAll();
576 _trackerTimeoutQ.cancelAll();
577 }
578 }
579
580
581
582
583
584
585
586
587 protected String extractUserId(ServletRequest request)
588 {
589 return null;
590 }
591
592
593
594
595
596 class RateTracker extends Timeout.Task implements HttpSessionBindingListener
597 {
598 protected final String _id;
599 protected final int _type;
600 protected final long[] _timestamps;
601 protected int _next;
602
603 public RateTracker(String id, int type,int maxRequestsPerSecond)
604 {
605 _id = id;
606 _type = type;
607 _timestamps=new long[maxRequestsPerSecond];
608 _next=0;
609 }
610
611
612
613
614 public boolean isRateExceeded(long now)
615 {
616 final long last;
617 synchronized (this)
618 {
619 last=_timestamps[_next];
620 _timestamps[_next]=now;
621 _next= (_next+1)%_timestamps.length;
622 }
623
624 boolean exceeded=last!=0 && (now-last)<1000L;
625
626 return exceeded;
627 }
628
629
630 public String getId()
631 {
632 return _id;
633 }
634
635 public int getType()
636 {
637 return _type;
638 }
639
640
641 public void valueBound(HttpSessionBindingEvent event)
642 {
643 }
644
645 public void valueUnbound(HttpSessionBindingEvent event)
646 {
647 _rateTrackers.remove(_id);
648 }
649
650 public void expired()
651 {
652 long now = _trackerTimeoutQ.getNow();
653 int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
654 long last=_timestamps[latestIndex];
655 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
656
657 if (hasRecentRequest)
658 reschedule();
659 else
660 _rateTrackers.remove(_id);
661 }
662
663 public String toString()
664 {
665 return "RateTracker/"+_id+"/"+_type;
666 }
667 }
668
669 class FixedRateTracker extends RateTracker
670 {
671 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
672 {
673 super(id,type,numRecentRequestsTracked);
674 }
675
676 public boolean isRateExceeded(long now)
677 {
678
679
680
681 synchronized (this)
682 {
683 _timestamps[_next]=now;
684 _next= (_next+1)%_timestamps.length;
685 }
686
687 return false;
688 }
689
690 public String toString()
691 {
692 return "Fixed"+super.toString();
693 }
694 }
695 }