1
2
3
4
5
6
7
8
9
10
11
12
13
14 package org.eclipse.jetty.servlets;
15
16 import java.io.IOException;
17 import java.util.HashSet;
18 import java.util.Queue;
19 import java.util.StringTokenizer;
20 import java.util.concurrent.ConcurrentHashMap;
21 import java.util.concurrent.ConcurrentLinkedQueue;
22 import java.util.concurrent.Semaphore;
23 import java.util.concurrent.TimeUnit;
24
25 import javax.servlet.Filter;
26 import javax.servlet.FilterChain;
27 import javax.servlet.FilterConfig;
28 import javax.servlet.ServletContext;
29 import javax.servlet.ServletException;
30 import javax.servlet.ServletRequest;
31 import javax.servlet.ServletResponse;
32 import javax.servlet.http.HttpServletRequest;
33 import javax.servlet.http.HttpServletResponse;
34 import javax.servlet.http.HttpSession;
35 import javax.servlet.http.HttpSessionBindingEvent;
36 import javax.servlet.http.HttpSessionBindingListener;
37
38 import org.eclipse.jetty.continuation.Continuation;
39 import org.eclipse.jetty.continuation.ContinuationListener;
40 import org.eclipse.jetty.continuation.ContinuationSupport;
41 import org.eclipse.jetty.util.log.Log;
42 import org.eclipse.jetty.util.thread.Timeout;
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95 public class DoSFilter implements Filter
96 {
97 final static String __TRACKER = "DoSFilter.Tracker";
98 final static String __THROTTLED = "DoSFilter.Throttled";
99
100 final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
101 final static int __DEFAULT_DELAY_MS = 100;
102 final static int __DEFAULT_THROTTLE = 5;
103 final static int __DEFAULT_WAIT_MS=50;
104 final static long __DEFAULT_THROTTLE_MS = 30000L;
105 final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
106 final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
107
108 final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
109 final static String DELAY_MS_INIT_PARAM = "delayMs";
110 final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
111 final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
112 final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
113 final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
114 final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
115 final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
116 final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
117 final static String REMOTE_PORT_INIT_PARAM="remotePort";
118 final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
119
120 final static int USER_AUTH = 2;
121 final static int USER_SESSION = 2;
122 final static int USER_IP = 1;
123 final static int USER_UNKNOWN = 0;
124
125 ServletContext _context;
126
127 protected long _delayMs;
128 protected long _throttleMs;
129 protected long _waitMs;
130 protected long _maxRequestMs;
131 protected long _maxIdleTrackerMs;
132 protected boolean _insertHeaders;
133 protected boolean _trackSessions;
134 protected boolean _remotePort;
135 protected Semaphore _passes;
136 protected Queue<Continuation>[] _queue;
137 protected ContinuationListener[] _listener;
138
139 protected int _maxRequestsPerSec;
140 protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
141 private HashSet<String> _whitelist;
142
143 private final Timeout _requestTimeoutQ = new Timeout();
144 private final Timeout _trackerTimeoutQ = new Timeout();
145
146 private Thread _timerThread;
147 private volatile boolean _running;
148
149 public void init(FilterConfig filterConfig)
150 {
151 _context = filterConfig.getServletContext();
152
153 _queue = new Queue[getMaxPriority() + 1];
154 _listener = new ContinuationListener[getMaxPriority() + 1];
155 for (int p = 0; p < _queue.length; p++)
156 {
157 _queue[p] = new ConcurrentLinkedQueue<Continuation>();
158
159 final int priority=p;
160 _listener[p] = new ContinuationListener()
161 {
162 public void onComplete(Continuation continuation)
163 {}
164
165 public void onTimeout(Continuation continuation)
166 {
167 _queue[priority].remove(continuation);
168 }
169 };
170 }
171
172 int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
173 if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
174 baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
175 _maxRequestsPerSec = baseRateLimit;
176
177 long delay = __DEFAULT_DELAY_MS;
178 if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
179 delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
180 _delayMs = delay;
181
182 int passes = __DEFAULT_THROTTLE;
183 if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
184 passes = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
185 _passes = new Semaphore(passes,true);
186
187 long wait = __DEFAULT_WAIT_MS;
188 if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
189 wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
190 _waitMs = wait;
191
192 long suspend = __DEFAULT_THROTTLE_MS;
193 if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
194 suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
195 _throttleMs = suspend;
196
197 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
198 if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
199 maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
200 _maxRequestMs = maxRequestMs;
201
202 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
203 if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
204 maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
205 _maxIdleTrackerMs = maxIdleTrackerMs;
206
207 String whitelistString = "";
208 if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
209 whitelistString = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
210
211
212 if (whitelistString.length() == 0 )
213 _whitelist = new HashSet<String>();
214 else
215 {
216 StringTokenizer tokenizer = new StringTokenizer(whitelistString, ",");
217 _whitelist = new HashSet<String>(tokenizer.countTokens());
218 while (tokenizer.hasMoreTokens())
219 _whitelist.add(tokenizer.nextToken().trim());
220
221 Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
222 }
223
224 String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
225 _insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
226
227 tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
228 _trackSessions = tmp==null || Boolean.parseBoolean(tmp);
229
230 tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
231 _remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
232
233 _requestTimeoutQ.setNow();
234 _requestTimeoutQ.setDuration(_maxRequestMs);
235
236 _trackerTimeoutQ.setNow();
237 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
238
239 _running=true;
240 _timerThread = (new Thread()
241 {
242 public void run()
243 {
244 try
245 {
246 while (_running)
247 {
248 synchronized (_requestTimeoutQ)
249 {
250 _requestTimeoutQ.setNow();
251 _requestTimeoutQ.tick();
252
253 _trackerTimeoutQ.setNow(_requestTimeoutQ.getNow());
254 _trackerTimeoutQ.tick();
255 }
256 try
257 {
258 Thread.sleep(100);
259 }
260 catch (InterruptedException e)
261 {
262 Log.ignore(e);
263 }
264 }
265 }
266 finally
267 {
268 Log.info("DoSFilter timer exited");
269 }
270 }
271 });
272 _timerThread.start();
273 }
274
275
276 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
277 {
278 final HttpServletRequest srequest = (HttpServletRequest)request;
279 final HttpServletResponse sresponse = (HttpServletResponse)response;
280
281 final long now=_requestTimeoutQ.getNow();
282
283
284 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
285
286 if (tracker==null)
287 {
288
289
290
291 tracker = getRateTracker(request);
292
293
294 final boolean overRateLimit = tracker.isRateExceeded(now);
295
296
297 if (!overRateLimit)
298 {
299 doFilterChain(filterchain,srequest,sresponse);
300 return;
301 }
302
303
304 Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
305
306
307 switch((int)_delayMs)
308 {
309 case -1:
310 {
311
312 if (_insertHeaders)
313 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
314 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
315 return;
316 }
317 case 0:
318 {
319
320 request.setAttribute(__TRACKER,tracker);
321 break;
322 }
323 default:
324 {
325
326 if (_insertHeaders)
327 ((HttpServletResponse)response).addHeader("DoSFilter","delayed");
328 Continuation continuation = ContinuationSupport.getContinuation(request);
329 request.setAttribute(__TRACKER,tracker);
330 if (_delayMs > 0)
331 continuation.setTimeout(_delayMs);
332 continuation.suspend();
333 return;
334 }
335 }
336 }
337
338
339 boolean accepted = false;
340 try
341 {
342
343 accepted = _passes.tryAcquire(_waitMs,TimeUnit.MILLISECONDS);
344
345 if (!accepted)
346 {
347
348 final Continuation continuation = ContinuationSupport.getContinuation(request);
349
350 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
351 if (throttled!=Boolean.TRUE && _throttleMs>0)
352 {
353 int priority = getPriority(request,tracker);
354 request.setAttribute(__THROTTLED,Boolean.TRUE);
355 if (_insertHeaders)
356 ((HttpServletResponse)response).addHeader("DoSFilter","throttled");
357 if (_throttleMs > 0)
358 continuation.setTimeout(_throttleMs);
359 continuation.suspend();
360
361 continuation.addContinuationListener(_listener[priority]);
362 _queue[priority].add(continuation);
363 return;
364 }
365
366 else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
367 {
368
369 _passes.acquire();
370 accepted = true;
371 }
372 }
373
374
375 if (accepted)
376
377 doFilterChain(filterchain,srequest,sresponse);
378 else
379 {
380
381 if (_insertHeaders)
382 ((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
383 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
384 }
385 }
386 catch (InterruptedException e)
387 {
388 _context.log("DoS",e);
389 ((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
390 }
391 finally
392 {
393 if (accepted)
394 {
395
396 for (int p = _queue.length; p-- > 0;)
397 {
398 Continuation continuation = _queue[p].poll();
399 if (continuation != null && continuation.isSuspended())
400 {
401 continuation.resume();
402 break;
403 }
404 }
405 _passes.release();
406 }
407 }
408 }
409
410
411
412
413
414
415
416
417 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
418 throws IOException, ServletException
419 {
420 final Thread thread=Thread.currentThread();
421
422 final Timeout.Task requestTimeout = new Timeout.Task()
423 {
424 public void expired()
425 {
426 closeConnection(request, response, thread);
427 }
428 };
429
430 try
431 {
432 synchronized (_requestTimeoutQ)
433 {
434 _requestTimeoutQ.schedule(requestTimeout);
435 }
436 chain.doFilter(request,response);
437 }
438 finally
439 {
440 synchronized (_requestTimeoutQ)
441 {
442 requestTimeout.cancel();
443 }
444 }
445 }
446
447
448
449
450
451
452
453
454 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
455 {
456
457 if( !response.isCommitted() )
458 {
459 response.setHeader("Connection", "close");
460 }
461 try
462 {
463 try
464 {
465 response.getWriter().close();
466 }
467 catch (IllegalStateException e)
468 {
469 response.getOutputStream().close();
470 }
471 }
472 catch (IOException e)
473 {
474 Log.warn(e);
475 }
476
477
478 thread.interrupt();
479 }
480
481
482
483
484
485
486
487
488 protected int getPriority(ServletRequest request, RateTracker tracker)
489 {
490 if (extractUserId(request)!=null)
491 return USER_AUTH;
492 if (tracker!=null)
493 return tracker.getType();
494 return USER_UNKNOWN;
495 }
496
497
498
499
500 protected int getMaxPriority()
501 {
502 return USER_AUTH;
503 }
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521 public RateTracker getRateTracker(ServletRequest request)
522 {
523 HttpServletRequest srequest = (HttpServletRequest)request;
524
525 String loadId;
526 final int type;
527
528 loadId = extractUserId(request);
529 HttpSession session=srequest.getSession(false);
530 if (_trackSessions && session!=null && !session.isNew())
531 {
532 loadId=session.getId();
533 type = USER_SESSION;
534 }
535 else
536 {
537 loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
538 type = USER_IP;
539 }
540
541 RateTracker tracker=_rateTrackers.get(loadId);
542
543 if (tracker==null)
544 {
545 RateTracker t;
546 if (_whitelist.contains(request.getRemoteAddr()))
547 {
548 t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
549 }
550 else
551 {
552 t = new RateTracker(loadId,type,_maxRequestsPerSec);
553 }
554
555 tracker=_rateTrackers.putIfAbsent(loadId,t);
556 if (tracker==null)
557 tracker=t;
558
559 if (type == USER_IP)
560 {
561
562 synchronized (_trackerTimeoutQ)
563 {
564 _trackerTimeoutQ.schedule(tracker);
565 }
566 }
567 else if (session!=null)
568
569 session.setAttribute(__TRACKER,tracker);
570 }
571
572 return tracker;
573 }
574
575 public void destroy()
576 {
577 _running=false;
578 _timerThread.interrupt();
579 synchronized (_requestTimeoutQ)
580 {
581 _requestTimeoutQ.cancelAll();
582 _trackerTimeoutQ.cancelAll();
583 }
584 }
585
586
587
588
589
590
591
592
593 protected String extractUserId(ServletRequest request)
594 {
595 return null;
596 }
597
598
599
600
601
602 class RateTracker extends Timeout.Task implements HttpSessionBindingListener
603 {
604 protected final String _id;
605 protected final int _type;
606 protected final long[] _timestamps;
607 protected int _next;
608
609 public RateTracker(String id, int type,int maxRequestsPerSecond)
610 {
611 _id = id;
612 _type = type;
613 _timestamps=new long[maxRequestsPerSecond];
614 _next=0;
615 }
616
617
618
619
620 public boolean isRateExceeded(long now)
621 {
622 final long last;
623 synchronized (this)
624 {
625 last=_timestamps[_next];
626 _timestamps[_next]=now;
627 _next= (_next+1)%_timestamps.length;
628 }
629
630 boolean exceeded=last!=0 && (now-last)<1000L;
631 return exceeded;
632 }
633
634
635 public String getId()
636 {
637 return _id;
638 }
639
640 public int getType()
641 {
642 return _type;
643 }
644
645
646 public void valueBound(HttpSessionBindingEvent event)
647 {
648 }
649
650 public void valueUnbound(HttpSessionBindingEvent event)
651 {
652 _rateTrackers.remove(_id);
653 }
654
655 public void expired()
656 {
657 long now = _trackerTimeoutQ.getNow();
658 int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
659 long last=_timestamps[latestIndex];
660 boolean hasRecentRequest = last != 0 && (now-last)<1000L;
661
662 if (hasRecentRequest)
663 reschedule();
664 else
665 _rateTrackers.remove(_id);
666 }
667 }
668
669 class FixedRateTracker extends RateTracker
670 {
671 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
672 {
673 super(id,type,numRecentRequestsTracked);
674 }
675
676 public boolean isRateExceeded(long now)
677 {
678
679
680
681 synchronized (this)
682 {
683 _timestamps[_next]=now;
684 _next= (_next+1)%_timestamps.length;
685 }
686
687 return false;
688 }
689 }
690 }