View Javadoc

1   // ========================================================================
2   // Copyright (c) 2009 Mort Bay Consulting Pty. Ltd.
3   // ------------------------------------------------------------------------
4   // All rights reserved. This program and the accompanying materials
5   // are made available under the terms of the Eclipse Public License v1.0
6   // and Apache License v2.0 which accompanies this distribution.
7   // The Eclipse Public License is available at 
8   // http://www.eclipse.org/legal/epl-v10.html
9   // The Apache License v2.0 is available at
10  // http://www.opensource.org/licenses/apache2.0.php
11  // You may elect to redistribute this code under either of these licenses. 
12  // ========================================================================
13  
14  package org.eclipse.jetty.servlets;
15  
16  import java.io.IOException;
17  import java.util.HashSet;
18  import java.util.Queue;
19  import java.util.StringTokenizer;
20  import java.util.concurrent.ConcurrentHashMap;
21  import java.util.concurrent.ConcurrentLinkedQueue;
22  import java.util.concurrent.Semaphore;
23  import java.util.concurrent.TimeUnit;
24  
25  import javax.servlet.Filter;
26  import javax.servlet.FilterChain;
27  import javax.servlet.FilterConfig;
28  import javax.servlet.ServletContext;
29  import javax.servlet.ServletException;
30  import javax.servlet.ServletRequest;
31  import javax.servlet.ServletResponse;
32  import javax.servlet.http.HttpServletRequest;
33  import javax.servlet.http.HttpServletResponse;
34  import javax.servlet.http.HttpSession;
35  import javax.servlet.http.HttpSessionBindingEvent;
36  import javax.servlet.http.HttpSessionBindingListener;
37  
38  import org.eclipse.jetty.continuation.Continuation;
39  import org.eclipse.jetty.continuation.ContinuationListener;
40  import org.eclipse.jetty.continuation.ContinuationSupport;
41  import org.eclipse.jetty.util.log.Log;
42  import org.eclipse.jetty.util.thread.Timeout;
43  
44  /**
45   * Denial of Service filter
46   * 
47   * <p>
48   * This filter is based on the {@link QoSFilter}. it is useful for limiting
49   * exposure to abuse from request flooding, whether malicious, or as a result of
50   * a misconfigured client.
51   * <p>
52   * The filter keeps track of the number of requests from a connection per
53   * second. If a limit is exceeded, the request is either rejected, delayed, or
54   * throttled.
55   * <p>
56   * When a request is throttled, it is placed in a priority queue. Priority is
57   * given first to authenticated users and users with an HttpSession, then
58   * connections which can be identified by their IP addresses. Connections with
59   * no way to identify them are given lowest priority.
60   * <p>
61   * The {@link #extractUserId(ServletRequest request)} function should be
62   * implemented, in order to uniquely identify authenticated users.
63   * <p>
64   * The following init parameters control the behavior of the filter:
65   * 
66   * maxRequestsPerSec    the maximum number of requests from a connection per
67   *                      second. Requests in excess of this are first delayed, 
68   *                      then throttled.
69   * 
70   * delayMs              is the delay given to all requests over the rate limit, 
71   *                      before they are considered at all. -1 means just reject request, 
72   *                      0 means no delay, otherwise it is the delay.
73   * 
74   * maxWaitMs            how long to blocking wait for the throttle semaphore.
75   * 
76   * throttledRequests    is the number of requests over the rate limit able to be
77   *                      considered at once.
78   * 
79   * throttleMs           how long to async wait for semaphore.
80   * 
81   * maxRequestMs         how long to allow this request to run.
82   * 
83   * maxIdleTrackerMs     how long to keep track of request rates for a connection, 
84   *                      before deciding that the user has gone away, and discarding it
85   * 
86   * insertHeaders        if true , insert the DoSFilter headers into the response. Defaults to true.
87   * 
88   * trackSessions        if true, usage rate is tracked by session if a session exists. Defaults to true.
89   * 
90   * remotePort           if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.
91   * 
92   * ipWhitelist          a comma-separated list of IP addresses that will not be rate limited
93   */
94  
95  public class DoSFilter implements Filter
96  {
97      final static String __TRACKER = "DoSFilter.Tracker";
98      final static String __THROTTLED = "DoSFilter.Throttled";
99  
100     final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
101     final static int __DEFAULT_DELAY_MS = 100;
102     final static int __DEFAULT_THROTTLE = 5;
103     final static int __DEFAULT_WAIT_MS=50;
104     final static long __DEFAULT_THROTTLE_MS = 30000L;
105     final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
106     final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
107 
108     final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
109     final static String DELAY_MS_INIT_PARAM = "delayMs";
110     final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
111     final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
112     final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
113     final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
114     final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
115     final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
116     final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
117     final static String REMOTE_PORT_INIT_PARAM="remotePort";
118     final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
119 
120     final static int USER_AUTH = 2;
121     final static int USER_SESSION = 2;
122     final static int USER_IP = 1;
123     final static int USER_UNKNOWN = 0;
124 
125     ServletContext _context;
126 
127     protected long _delayMs;
128     protected long _throttleMs;
129     protected long _waitMs;
130     protected long _maxRequestMs;
131     protected long _maxIdleTrackerMs;
132     protected boolean _insertHeaders;
133     protected boolean _trackSessions;
134     protected boolean _remotePort;
135     protected Semaphore _passes;
136     protected Queue<Continuation>[] _queue;
137     protected ContinuationListener[] _listener;
138 
139     protected int _maxRequestsPerSec;
140     protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
141     private HashSet<String> _whitelist; 
142     
143     private final Timeout _requestTimeoutQ = new Timeout();
144     private final Timeout _trackerTimeoutQ = new Timeout();
145 
146     private Thread _timerThread;
147     private volatile boolean _running;
148 
149     public void init(FilterConfig filterConfig)
150     {
151         _context = filterConfig.getServletContext();
152 
153         _queue = new Queue[getMaxPriority() + 1];
154         _listener = new ContinuationListener[getMaxPriority() + 1];
155         for (int p = 0; p < _queue.length; p++)
156         {
157             _queue[p] = new ConcurrentLinkedQueue<Continuation>();
158             
159             final int priority=p;
160             _listener[p] = new ContinuationListener()
161             {
162                 public void onComplete(Continuation continuation)
163                 {}
164 
165                 public void onTimeout(Continuation continuation)
166                 {
167                     _queue[priority].remove(continuation);
168                 }
169             };
170         }
171 
172         int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
173         if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
174             baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
175         _maxRequestsPerSec = baseRateLimit;
176 
177         long delay = __DEFAULT_DELAY_MS;
178         if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
179             delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
180         _delayMs = delay;
181 
182         int passes = __DEFAULT_THROTTLE;
183         if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
184             passes = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
185         _passes = new Semaphore(passes,true);
186 
187         long wait = __DEFAULT_WAIT_MS;
188         if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
189             wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
190         _waitMs = wait;
191 
192         long suspend = __DEFAULT_THROTTLE_MS;
193         if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
194             suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
195         _throttleMs = suspend;
196 
197         long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
198         if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
199             maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
200         _maxRequestMs = maxRequestMs;
201 
202         long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
203         if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
204             maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
205         _maxIdleTrackerMs = maxIdleTrackerMs;
206         
207         String whitelistString = "";
208         if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
209             whitelistString = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
210         
211         // empty 
212         if (whitelistString.length() == 0 )
213             _whitelist = new HashSet<String>();
214         else
215         {
216             StringTokenizer tokenizer = new StringTokenizer(whitelistString, ",");
217             _whitelist = new HashSet<String>(tokenizer.countTokens());
218             while (tokenizer.hasMoreTokens())
219                 _whitelist.add(tokenizer.nextToken().trim());
220             
221             Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
222         }
223 
224         String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
225         _insertHeaders = tmp==null || Boolean.parseBoolean(tmp); 
226         
227         tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
228         _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
229         
230         tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
231         _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
232 
233         _requestTimeoutQ.setNow();
234         _requestTimeoutQ.setDuration(_maxRequestMs);
235         
236         _trackerTimeoutQ.setNow();
237         _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
238         
239         _running=true;
240         _timerThread = (new Thread()
241         {
242             public void run()
243             {
244                 try
245                 {
246                     while (_running)
247                     {
248                         synchronized (_requestTimeoutQ)
249                         {
250                             _requestTimeoutQ.setNow();
251                             _requestTimeoutQ.tick();
252 
253                             _trackerTimeoutQ.setNow(_requestTimeoutQ.getNow());
254                             _trackerTimeoutQ.tick();
255                         }
256                         try
257                         {
258                             Thread.sleep(100);
259                         }
260                         catch (InterruptedException e)
261                         {
262                             Log.ignore(e);
263                         }
264                     }
265                 }
266                 finally
267                 {
268                     Log.info("DoSFilter timer exited");
269                 }
270             }
271         });
272         _timerThread.start();
273     }
274     
275 
276     public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
277     {
278         final HttpServletRequest srequest = (HttpServletRequest)request;
279         final HttpServletResponse sresponse = (HttpServletResponse)response;
280         
281         final long now=_requestTimeoutQ.getNow();
282         
283         // Look for the rate tracker for this request
284         RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
285             
286         if (tracker==null)
287         {
288             // This is the first time we have seen this request.
289             
290             // get a rate tracker associated with this request, and record one hit
291             tracker = getRateTracker(request);
292             
293             // Calculate the rate and check it is over the allowed limit
294             final boolean overRateLimit = tracker.isRateExceeded(now);
295 
296             // pass it through if  we are not currently over the rate limit
297             if (!overRateLimit)
298             {
299                 doFilterChain(filterchain,srequest,sresponse);
300                 return;
301             }   
302             
303             // We are over the limit.
304             Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
305             
306             // So either reject it, delay it or throttle it
307             switch((int)_delayMs)
308             {
309                 case -1: 
310                 {
311                     // Reject this request
312                     if (_insertHeaders)
313                         ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
314                     ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
315                     return;
316                 }
317                 case 0:
318                 {
319                     // fall through to throttle code
320                     request.setAttribute(__TRACKER,tracker);
321                     break;
322                 }
323                 default:
324                 {
325                     // insert a delay before throttling the request
326                     if (_insertHeaders)
327                         ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
328                     Continuation continuation = ContinuationSupport.getContinuation(request);
329                     request.setAttribute(__TRACKER,tracker);
330                     if (_delayMs > 0)
331                         continuation.setTimeout(_delayMs);
332                     continuation.suspend();
333                     return;
334                 }
335             }
336         }
337 
338         // Throttle the request
339         boolean accepted = false;
340         try
341         {
342             // check if we can afford to accept another request at this time
343             accepted = _passes.tryAcquire(_waitMs,TimeUnit.MILLISECONDS);
344 
345             if (!accepted)
346             {
347                 // we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail
348                 final Continuation continuation = ContinuationSupport.getContinuation(request);
349                 
350                 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
351                 if (throttled!=Boolean.TRUE && _throttleMs>0)
352                 {
353                     int priority = getPriority(request,tracker);
354                     request.setAttribute(__THROTTLED,Boolean.TRUE);
355                     if (_insertHeaders)
356                         ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
357                     if (_throttleMs > 0)
358                         continuation.setTimeout(_throttleMs);
359                     continuation.suspend();
360 
361                     continuation.addContinuationListener(_listener[priority]);
362                     _queue[priority].add(continuation);
363                     return;
364                 }
365                 // else were we resumed?
366                 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
367                 {
368                     // we were resumed and somebody stole our pass, so we wait for the next one.
369                     _passes.acquire();
370                     accepted = true;
371                 }
372             }
373             
374             // if we were accepted (either immediately or after throttle)
375             if (accepted)       
376                 // call the chain
377                 doFilterChain(filterchain,srequest,sresponse);
378             else                
379             {
380                 // fail the request
381                 if (_insertHeaders)
382                     ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
383                 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
384             }
385         }
386         catch (InterruptedException e)
387         {
388             _context.log("DoS",e);
389             ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
390         }
391         finally
392         {
393             if (accepted)
394             {
395                 // wake up the next highest priority request.
396                 for (int p = _queue.length; p-- > 0;)
397                 {
398                     Continuation continuation = _queue[p].poll();
399                     if (continuation != null && continuation.isSuspended())
400                     {
401                         continuation.resume();
402                         break;
403                     }
404                 }
405                 _passes.release();
406             }
407         }
408     }
409 
410     /**
411      * @param chain
412      * @param request
413      * @param response
414      * @throws IOException
415      * @throws ServletException
416      */
417     protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) 
418         throws IOException, ServletException
419     {
420         final Thread thread=Thread.currentThread();
421         
422         final Timeout.Task requestTimeout = new Timeout.Task()
423         {
424             public void expired()
425             {
426                 closeConnection(request, response, thread);
427             }
428         };
429 
430         try
431         {
432             synchronized (_requestTimeoutQ)
433             {
434                 _requestTimeoutQ.schedule(requestTimeout);
435             }
436             chain.doFilter(request,response);
437         }
438         finally
439         {
440             synchronized (_requestTimeoutQ)
441             {
442                 requestTimeout.cancel();
443             }
444         }
445     }
446 
447     /**
448      * Takes drastic measures to return this response and stop this thread.
449      * Due to the way the connection is interrupted, may return mixed up headers.
450      * @param request current request
451      * @param response current response, which must be stopped
452      * @param thread the handling thread
453      */
454     protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
455     {
456         // take drastic measures to return this response and stop this thread.
457         if( !response.isCommitted() )
458         {
459             response.setHeader("Connection", "close");
460         }
461         try 
462         {
463             try
464             {
465                 response.getWriter().close();
466             }
467             catch (IllegalStateException e)
468             {
469                 response.getOutputStream().close();
470             }
471         }
472         catch (IOException e)
473         {
474             Log.warn(e);
475         }
476         
477         // interrupt the handling thread
478         thread.interrupt();
479     }
480         
481     /**
482      * Get priority for this request, based on user type
483      * 
484      * @param request
485      * @param tracker
486      * @return priority
487      */
488     protected int getPriority(ServletRequest request, RateTracker tracker)
489     {
490         if (extractUserId(request)!=null)
491             return USER_AUTH;
492         if (tracker!=null)
493             return tracker.getType();
494         return USER_UNKNOWN;
495     }
496 
497     /**
498      * @return the maximum priority that we can assign to a request
499      */
500     protected int getMaxPriority()
501     {
502         return USER_AUTH;
503     }
504 
505     /**
506      * Return a request rate tracker associated with this connection; keeps
507      * track of this connection's request rate. If this is not the first request
508      * from this connection, return the existing object with the stored stats.
509      * If it is the first request, then create a new request tracker.
510      * 
511      * Assumes that each connection has an identifying characteristic, and goes
512      * through them in order, taking the first that matches: user id (logged
513      * in), session id, client IP address. Unidentifiable connections are lumped
514      * into one.
515      * 
516      * When a session expires, its rate tracker is automatically deleted.
517      * 
518      * @param request
519      * @return the request rate tracker for the current connection
520      */
521     public RateTracker getRateTracker(ServletRequest request)
522     {
523         HttpServletRequest srequest = (HttpServletRequest)request;
524 
525         String loadId;
526         final int type;
527         
528         loadId = extractUserId(request);
529         HttpSession session=srequest.getSession(false);
530         if (_trackSessions && session!=null && !session.isNew())
531         {
532             loadId=session.getId();
533             type = USER_SESSION;
534         }
535         else
536         {
537             loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
538             type = USER_IP;
539         }
540 
541         RateTracker tracker=_rateTrackers.get(loadId);
542         
543         if (tracker==null)
544         {
545             RateTracker t;
546             if (_whitelist.contains(request.getRemoteAddr()))
547             {
548                 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
549             }
550             else
551             {
552                 t = new RateTracker(loadId,type,_maxRequestsPerSec);
553             }
554             
555             tracker=_rateTrackers.putIfAbsent(loadId,t);
556             if (tracker==null)
557                 tracker=t;
558             
559             if (type == USER_IP)
560             {
561                 // USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ
562                 synchronized (_trackerTimeoutQ)
563                 {
564                     _trackerTimeoutQ.schedule(tracker);
565                 }
566             }
567             else if (session!=null)
568                 // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
569                 session.setAttribute(__TRACKER,tracker);
570         }
571 
572         return tracker;
573     }
574 
575     public void destroy()
576     {
577         _running=false;
578         _timerThread.interrupt();
579         synchronized (_requestTimeoutQ)
580         {
581             _requestTimeoutQ.cancelAll();
582             _trackerTimeoutQ.cancelAll();
583         }
584     }
585 
586     /**
587      * Returns the user id, used to track this connection.
588      * This SHOULD be overridden by subclasses.
589      * 
590      * @param request
591      * @return a unique user id, if logged in; otherwise null.
592      */
593     protected String extractUserId(ServletRequest request)
594     {
595         return null;
596     }
597 
598     /**
599      * A RateTracker is associated with a connection, and stores request rate
600      * data.
601      */
602     class RateTracker extends Timeout.Task implements HttpSessionBindingListener
603     {
604         protected final String _id;
605         protected final int _type;
606         protected final long[] _timestamps;
607         protected int _next;
608         
609         public RateTracker(String id, int type,int maxRequestsPerSecond)
610         {
611             _id = id;
612             _type = type;
613             _timestamps=new long[maxRequestsPerSecond];
614             _next=0;
615         }
616 
617         /**
618          * @return the current calculated request rate over the last second
619          */
620         public boolean isRateExceeded(long now)
621         {
622             final long last;
623             synchronized (this)
624             {
625                 last=_timestamps[_next];
626                 _timestamps[_next]=now;
627                 _next= (_next+1)%_timestamps.length;
628             }
629 
630             boolean exceeded=last!=0 && (now-last)<1000L;
631             return exceeded;
632         }
633 
634 
635         public String getId()
636         {
637             return _id;
638         }
639 
640         public int getType()
641         {
642             return _type;
643         }
644 
645         
646         public void valueBound(HttpSessionBindingEvent event)
647         {
648         }
649 
650         public void valueUnbound(HttpSessionBindingEvent event)
651         {
652             _rateTrackers.remove(_id);
653         }
654         
655         public void expired()
656         {
657             long now = _trackerTimeoutQ.getNow();
658             int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length; 
659             long last=_timestamps[latestIndex];
660             boolean hasRecentRequest = last != 0 && (now-last)<1000L;
661             
662             if (hasRecentRequest)
663                 reschedule();
664             else
665                 _rateTrackers.remove(_id);
666         }
667     }
668     
669     class FixedRateTracker extends RateTracker
670     {
671         public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
672         {
673             super(id,type,numRecentRequestsTracked);
674         }
675 
676         public boolean isRateExceeded(long now)
677         {
678             // rate limit is never exceeded, but we keep track of the request timestamps
679             // so that we know whether there was recent activity on this tracker
680             // and whether it should be expired
681             synchronized (this)
682             {
683                 _timestamps[_next]=now;
684                 _next= (_next+1)%_timestamps.length;
685             }
686 
687             return false;
688         }        
689     }
690 }