View Javadoc

1   // ========================================================================
2   // Copyright 2009 Mort Bay Consulting Pty. Ltd.
3   // ------------------------------------------------------------------------
4   // Licensed under the Apache License, Version 2.0 (the "License");
5   // you may not use this file except in compliance with the License.
6   // You may obtain a copy of the License at 
7   // http://www.apache.org/licenses/LICENSE-2.0
8   // Unless required by applicable law or agreed to in writing, software
9   // distributed under the License is distributed on an "AS IS" BASIS,
10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  // See the License for the specific language governing permissions and
12  // limitations under the License.
13  //========================================================================
14  
15  package org.mortbay.servlet;
16  
17  import java.io.IOException;
18  import java.util.HashSet;
19  import java.util.Queue;
20  import java.util.StringTokenizer;
21  import java.util.Timer;
22  import java.util.concurrent.ConcurrentHashMap;
23  import java.util.concurrent.Semaphore;
24  import java.util.concurrent.TimeUnit;
25  
26  import javax.servlet.Filter;
27  import javax.servlet.FilterChain;
28  import javax.servlet.FilterConfig;
29  import javax.servlet.ServletContext;
30  import javax.servlet.ServletException;
31  import javax.servlet.ServletRequest;
32  import javax.servlet.ServletResponse;
33  import javax.servlet.http.HttpServletRequest;
34  import javax.servlet.http.HttpServletResponse;
35  import javax.servlet.http.HttpSession;
36  import javax.servlet.http.HttpSessionBindingEvent;
37  import javax.servlet.http.HttpSessionBindingListener;
38  
39  import org.mortbay.log.Log;
40  import org.mortbay.thread.Timeout;
41  import org.mortbay.util.ArrayQueue;
42  import org.mortbay.util.ajax.Continuation;
43  import org.mortbay.util.ajax.ContinuationSupport;
44  
45  /**
46   * Denial of Service filter
47   * 
48   * <p>
49   * This filter is based on the {@link QoSFilter}. it is useful for limiting
50   * exposure to abuse from request flooding, whether malicious, or as a result of
51   * a misconfigured client.
52   * <p>
53   * The filter keeps track of the number of requests from a connection per
54   * second. If a limit is exceeded, the request is either rejected, delayed, or
55   * throttled.
56   * <p>
57   * When a request is throttled, it is placed in a priority queue. Priority is
58   * given first to authenticated users and users with an HttpSession, then
59   * connections which can be identified by their IP addresses. Connections with
60   * no way to identify them are given lowest priority.
61   * <p>
62   * The {@link #extractUserId(ServletRequest request)} function should be
63   * implemented, in order to uniquely identify authenticated users.
64   * <p>
65   * The following init parameters control the behavior of the filter:
66   * 
67   * maxRequestsPerSec    the maximum number of requests from a connection per
68   *                      second. Requests in excess of this are first delayed, 
69   *                      then throttled.
70   * 
71   * delayMs              is the delay given to all requests over the rate limit, 
72   *                      before they are considered at all. -1 means just reject request, 
73   *                      0 means no delay, otherwise it is the delay.
74   * 
75   * maxWaitMs            how long to blocking wait for the throttle semaphore.
76   * 
77   * throttledRequests    is the number of requests over the rate limit able to be
78   *                      considered at once.
79   * 
80   * throttleMs           how long to async wait for semaphore.
81   * 
82   * maxRequestMs         how long to allow this request to run.
83   * 
84   * maxIdleTrackerMs     how long to keep track of request rates for a connection, 
85   *                      before deciding that the user has gone away, and discarding it
86   * 
87   * insertHeaders        if true , insert the DoSFilter headers into the response. Defaults to true.
88   * 
89   * trackSessions        if true, usage rate is tracked by session if a session exists. Defaults to true.
90   * 
91   * remotePort           if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.
92   * 
93   * ipWhitelist          a comma-separated list of IP addresses that will not be rate limited
94   */
95  
96  public class DoSFilter implements Filter
97  {
98      final static String __TRACKER = "DoSFilter.Tracker";
99      final static String __THROTTLED = "DoSFilter.Throttled";
100 
101     final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
102     final static int __DEFAULT_DELAY_MS = 100;
103     final static int __DEFAULT_THROTTLE = 5;
104     final static int __DEFAULT_WAIT_MS=50;
105     final static long __DEFAULT_THROTTLE_MS = 30000L;
106     final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
107     final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
108 
109     final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
110     final static String DELAY_MS_INIT_PARAM = "delayMs";
111     final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
112     final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
113     final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
114     final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
115     final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
116     final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
117     final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
118     final static String REMOTE_PORT_INIT_PARAM="remotePort";
119     final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
120 
121     final static int USER_AUTH = 2;
122     final static int USER_SESSION = 2;
123     final static int USER_IP = 1;
124     final static int USER_UNKNOWN = 0;
125 
126     ServletContext _context;
127 
128     protected long _delayMs;
129     protected long _throttleMs;
130     protected long _waitMs;
131     protected long _maxRequestMs;
132     protected long _maxIdleTrackerMs;
133     protected boolean _insertHeaders;
134     protected boolean _trackSessions;
135     protected boolean _remotePort;
136     protected Semaphore _passes;
137     protected Queue<Continuation>[] _queue;
138 
139     protected int _maxRequestsPerSec;
140     protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
141     private HashSet<String> _whitelist; 
142     
143     private final Timeout _requestTimeoutQ = new Timeout();
144     private final Timeout _trackerTimeoutQ = new Timeout();
145 
146     private Thread _timerThread;
147     private volatile boolean _running;
148 
149     public void init(FilterConfig filterConfig)
150     {
151         _context = filterConfig.getServletContext();
152 
153         _queue = new Queue[getMaxPriority() + 1];
154         for (int p = 0; p < _queue.length; p++)
155             _queue[p] = new ArrayQueue<Continuation>();
156 
157         int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
158         if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
159             baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
160         _maxRequestsPerSec = baseRateLimit;
161 
162         long delay = __DEFAULT_DELAY_MS;
163         if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
164             delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
165         _delayMs = delay;
166 
167         int passes = __DEFAULT_THROTTLE;
168         if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
169             passes = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
170         _passes = new Semaphore(passes,true);
171 
172         long wait = __DEFAULT_WAIT_MS;
173         if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
174             wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
175         _waitMs = wait;
176 
177         long suspend = __DEFAULT_THROTTLE_MS;
178         if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
179             suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
180         _throttleMs = suspend;
181 
182         long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
183         if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
184             maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
185         _maxRequestMs = maxRequestMs;
186 
187         long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
188         if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
189             maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
190         _maxIdleTrackerMs = maxIdleTrackerMs;
191         
192         String whitelistString = "";
193         if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
194             whitelistString = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
195         
196         // empty 
197         if (whitelistString.length() == 0 )
198             _whitelist = new HashSet<String>();
199         else
200         {
201             StringTokenizer tokenizer = new StringTokenizer(whitelistString, ",");
202             _whitelist = new HashSet<String>(tokenizer.countTokens());
203             while (tokenizer.hasMoreTokens())
204                 _whitelist.add(tokenizer.nextToken().trim());
205             
206             Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
207         }
208 
209         String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
210         _insertHeaders = tmp==null || Boolean.parseBoolean(tmp); 
211         
212         tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
213         _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
214         
215         tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
216         _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
217 
218         _requestTimeoutQ.setNow();
219         _requestTimeoutQ.setDuration(_maxRequestMs);
220         
221         _trackerTimeoutQ.setNow();
222         _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
223         
224         _running=true;
225         _timerThread = (new Thread()
226         {
227             public void run()
228             {
229                 try
230                 {
231                     while (_running)
232                     {
233                         synchronized (_requestTimeoutQ)
234                         {
235                             _requestTimeoutQ.setNow();
236                             _requestTimeoutQ.tick();
237 
238                             _trackerTimeoutQ.setNow(_requestTimeoutQ.getNow());
239                             _trackerTimeoutQ.tick();
240                         }
241                         try
242                         {
243                             Thread.sleep(100);
244                         }
245                         catch (InterruptedException e)
246                         {
247                             Log.ignore(e);
248                         }
249                     }
250                 }
251                 finally
252                 {
253                     Log.debug("DoSFilter timer exited ");
254                 }
255             }
256         });
257         _timerThread.start();
258     }
259     
260 
261     public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
262     {
263         final HttpServletRequest srequest = (HttpServletRequest)request;
264         final HttpServletResponse sresponse = (HttpServletResponse)response;
265         
266         final long now=_requestTimeoutQ.getNow();
267         
268         // Look for the rate tracker for this request
269         RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
270             
271         if (tracker==null)
272         {
273             // This is the first time we have seen this request.
274             
275             // get a rate tracker associated with this request, and record one hit
276             tracker = getRateTracker(request);
277             
278             // Calculate the rate and check it is over the allowed limit
279             final boolean overRateLimit = tracker.isRateExceeded(now);
280 
281             // pass it through if  we are not currently over the rate limit
282             if (!overRateLimit)
283             {
284                 doFilterChain(filterchain,srequest,sresponse);
285                 return;
286             }   
287             
288             // We are over the limit.
289             Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
290             
291             // So either reject it, delay it or throttle it
292             switch((int)_delayMs)
293             {
294                 case -1: 
295                 {
296                     // Reject this request
297                     if (_insertHeaders)
298                         ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
299                     ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
300                     return;
301                 }
302                 case 0:
303                 {
304                     // fall through to throttle code
305                     request.setAttribute(__TRACKER,tracker);
306                     break;
307                 }
308                 default:
309                 {
310                     // insert a delay before throttling the request
311                     if (_insertHeaders)
312                         ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
313                     Continuation continuation = ContinuationSupport.getContinuation((HttpServletRequest)request,this);
314                     request.setAttribute(__TRACKER,tracker);
315                     continuation.suspend(_delayMs);
316                     // can fall through if this was a waiting continuation
317                 }
318             }
319         }
320 
321         // Throttle the request
322         boolean accepted = false;
323         try
324         {
325             // check if we can afford to accept another request at this time
326             accepted = _passes.tryAcquire(_waitMs,TimeUnit.MILLISECONDS);
327 
328             if (!accepted)
329             {
330                 // we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail
331 
332                 final Continuation continuation = ContinuationSupport.getContinuation((HttpServletRequest)request,this);
333                 
334                 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
335                 if (throttled!=Boolean.TRUE && _throttleMs>0)
336                 {
337                     int priority = getPriority(request,tracker);
338                     request.setAttribute(__THROTTLED,Boolean.TRUE);
339                     if (_insertHeaders)
340                         ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
341                     synchronized (this)
342                     {
343                         _queue[priority].add(continuation);
344                         continuation.reset();
345                         if(continuation.suspend(_throttleMs))
346                         {
347                             // handle waiting continuation strangeness
348                             // continuation was waiting and was resumed.
349                             _passes.acquire();
350                             accepted = true;
351                         }
352                         // can fall through if this was a waiting continuation
353                     }
354                 }
355                 // else were we resumed?
356                 else if (continuation.isResumed())
357                 {
358                     // we were resumed and somebody stole our pass, so we wait for the next one.
359                     _passes.acquire();
360                     accepted = true;
361                 }
362             }
363             
364             // if we were accepted (either immediately or after throttle)
365             if (accepted)       
366                 // call the chain
367                 doFilterChain(filterchain,srequest,sresponse);
368             else                
369             {
370                 // fail the request
371                 if (_insertHeaders)
372                     ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
373                 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
374             }
375         }
376         catch (InterruptedException e)
377         {
378             _context.log("DoS",e);
379             ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
380         }
381         finally
382         {
383             if (accepted)
384             {
385                 // wake up the next highest priority request.
386                 synchronized (_queue)
387                 {
388                     for (int p = _queue.length; p-- > 0;)
389                     {
390                         Continuation continuation = _queue[p].poll();
391 
392                         if (continuation != null)
393                         {
394                             continuation.resume();
395                             break;
396                         }
397                     }
398                 }
399                 _passes.release();
400             }
401         }
402     }
403 
404     /**
405      * @param chain
406      * @param request
407      * @param response
408      * @throws IOException
409      * @throws ServletException
410      */
411     protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) 
412         throws IOException, ServletException
413     {
414         final Thread thread=Thread.currentThread();
415         
416         final Timeout.Task requestTimeout = new Timeout.Task()
417         {
418             public void expired()
419             {
420                 closeConnection(request, response, thread);
421             }
422         };
423 
424         try
425         {
426             synchronized (_requestTimeoutQ)
427             {
428                 _requestTimeoutQ.schedule(requestTimeout);
429             }
430             chain.doFilter(request,response);
431         }
432         finally
433         {
434             synchronized (_requestTimeoutQ)
435             {
436                 requestTimeout.cancel();
437             }
438         }
439     }
440 
441     /**
442      * Takes drastic measures to return this response and stop this thread.
443      * Due to the way the connection is interrupted, may return mixed up headers.
444      * @param request current request
445      * @param response current response, which must be stopped
446      * @param thread the handling thread
447      */
448     protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
449     {
450         // take drastic measures to return this response and stop this thread.
451         if( !response.isCommitted() )
452         {
453             response.setHeader("Connection", "close");
454         }
455         try 
456         {
457             try
458             {
459                 response.getWriter().close();
460             }
461             catch (IllegalStateException e)
462             {
463                 response.getOutputStream().close();
464             }
465         }
466         catch (IOException e)
467         {
468             Log.warn(e);
469         }
470         
471         // interrupt the handling thread
472         thread.interrupt();
473     }
474         
475     /**
476      * Get priority for this request, based on user type
477      * 
478      * @param request
479      * @param tracker
480      * @return priority
481      */
482     protected int getPriority(ServletRequest request, RateTracker tracker)
483     {
484         if (extractUserId(request)!=null)
485             return USER_AUTH;
486         if (tracker!=null)
487             return tracker.getType();
488         return USER_UNKNOWN;
489     }
490 
491     /**
492      * @return the maximum priority that we can assign to a request
493      */
494     protected int getMaxPriority()
495     {
496         return USER_AUTH;
497     }
498 
499     /**
500      * Return a request rate tracker associated with this connection; keeps
501      * track of this connection's request rate. If this is not the first request
502      * from this connection, return the existing object with the stored stats.
503      * If it is the first request, then create a new request tracker.
504      * 
505      * Assumes that each connection has an identifying characteristic, and goes
506      * through them in order, taking the first that matches: user id (logged
507      * in), session id, client IP address. Unidentifiable connections are lumped
508      * into one.
509      * 
510      * When a session expires, its rate tracker is automatically deleted.
511      * 
512      * @param request
513      * @return the request rate tracker for the current connection
514      */
515     public RateTracker getRateTracker(ServletRequest request)
516     {
517         HttpServletRequest srequest = (HttpServletRequest)request;
518 
519         String loadId;
520         final int type;
521         
522         loadId = extractUserId(request);
523         HttpSession session=srequest.getSession(false);
524         if (_trackSessions && session!=null && !session.isNew())
525         {
526             loadId=session.getId();
527             type = USER_SESSION;
528         }
529         else
530         {
531             loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
532             type = USER_IP;
533         }
534 
535         RateTracker tracker=_rateTrackers.get(loadId);
536         
537         if (tracker==null)
538         {
539             RateTracker t;
540             if (_whitelist.contains(request.getRemoteAddr()))
541             {
542                 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
543             }
544             else
545             {
546                 t = new RateTracker(loadId,type,_maxRequestsPerSec);
547             }
548             
549             tracker=_rateTrackers.putIfAbsent(loadId,t);
550             if (tracker==null)
551                 tracker=t;
552             
553             if (type == USER_IP)
554             {
555                 // USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ
556                 synchronized (_trackerTimeoutQ)
557                 {
558                     _trackerTimeoutQ.schedule(tracker);
559                 }
560             }
561             else if (session!=null)
562                 // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
563                 session.setAttribute(__TRACKER,tracker);
564         }
565 
566         return tracker;
567     }
568 
569     public void destroy()
570     {
571         _running=false;
572         _timerThread.interrupt();
573         synchronized (_requestTimeoutQ)
574         {
575             _requestTimeoutQ.cancelAll();
576             _trackerTimeoutQ.cancelAll();
577         }
578     }
579 
580     /**
581      * Returns the user id, used to track this connection.
582      * This SHOULD be overridden by subclasses.
583      * 
584      * @param request
585      * @return a unique user id, if logged in; otherwise null.
586      */
587     protected String extractUserId(ServletRequest request)
588     {
589         return null;
590     }
591 
592     /**
593      * A RateTracker is associated with a connection, and stores request rate
594      * data.
595      */
596     class RateTracker extends Timeout.Task implements HttpSessionBindingListener
597     {
598         protected final String _id;
599         protected final int _type;
600         protected final long[] _timestamps;
601         protected int _next;
602         
603         public RateTracker(String id, int type,int maxRequestsPerSecond)
604         {
605             _id = id;
606             _type = type;
607             _timestamps=new long[maxRequestsPerSecond];
608             _next=0;
609         }
610 
611         /**
612          * @return the current calculated request rate over the last second
613          */
614         public boolean isRateExceeded(long now)
615         {
616             final long last;
617             synchronized (this)
618             {
619                 last=_timestamps[_next];
620                 _timestamps[_next]=now;
621                 _next= (_next+1)%_timestamps.length;
622             }
623 
624             boolean exceeded=last!=0 && (now-last)<1000L;
625             // System.err.println("rateExceeded? "+last+" "+(now-last)+" "+exceeded);
626             return exceeded;
627         }
628 
629 
630         public String getId()
631         {
632             return _id;
633         }
634 
635         public int getType()
636         {
637             return _type;
638         }
639 
640         
641         public void valueBound(HttpSessionBindingEvent event)
642         {
643         }
644 
645         public void valueUnbound(HttpSessionBindingEvent event)
646         {
647             _rateTrackers.remove(_id);
648         }
649         
650         public void expired()
651         {
652             long now = _trackerTimeoutQ.getNow();
653             int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length; 
654             long last=_timestamps[latestIndex];
655             boolean hasRecentRequest = last != 0 && (now-last)<1000L;
656             
657             if (hasRecentRequest)
658                 reschedule();
659             else
660                 _rateTrackers.remove(_id);
661         }
662         
663         public String toString()
664         {
665             return "RateTracker/"+_id+"/"+_type;
666         }
667     }
668     
669     class FixedRateTracker extends RateTracker
670     {
671         public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
672         {
673             super(id,type,numRecentRequestsTracked);
674         }
675 
676         public boolean isRateExceeded(long now)
677         {
678             // rate limit is never exceeded, but we keep track of the request timestamps
679             // so that we know whether there was recent activity on this tracker
680             // and whether it should be expired
681             synchronized (this)
682             {
683                 _timestamps[_next]=now;
684                 _next= (_next+1)%_timestamps.length;
685             }
686 
687             return false;
688         }     
689         
690         public String toString()
691         {
692             return "Fixed"+super.toString();
693         }   
694     }
695 }