001/* 002 * The contents of this file are subject to the terms of the Common Development and 003 * Distribution License (the License). You may not use this file except in compliance with the 004 * License. 005 * 006 * You can obtain a copy of the License at legal/CDDLv1.0.txt. See the License for the 007 * specific language governing permission and limitations under the License. 008 * 009 * When distributing Covered Software, include this CDDL Header Notice in each file and include 010 * the License file at legal/CDDLv1.0.txt. If applicable, add the following below the CDDL 011 * Header, with the fields enclosed by brackets [] replaced by your own identifying 012 * information: "Portions copyright [year] [name of copyright owner]". 013 * 014 * Copyright 2015 ForgeRock AS. 015 */ 016package org.forgerock.openig.filter; 017 018import static java.util.concurrent.TimeUnit.MILLISECONDS; 019import static java.util.concurrent.TimeUnit.NANOSECONDS; 020import static java.util.concurrent.TimeUnit.SECONDS; 021import static org.forgerock.openig.el.Bindings.bindings; 022import static org.forgerock.openig.util.JsonValues.asExpression; 023import static org.forgerock.util.time.Duration.duration; 024 025import org.forgerock.guava.common.base.Ticker; 026import org.forgerock.guava.common.cache.CacheBuilder; 027import org.forgerock.guava.common.cache.CacheLoader; 028import org.forgerock.guava.common.cache.CacheStats; 029import org.forgerock.guava.common.cache.LoadingCache; 030import org.forgerock.http.Filter; 031import org.forgerock.http.Handler; 032import org.forgerock.http.protocol.Request; 033import org.forgerock.http.protocol.Response; 034import org.forgerock.http.protocol.Status; 035import org.forgerock.json.JsonValue; 036import org.forgerock.openig.el.Expression; 037import org.forgerock.openig.heap.GenericHeapObject; 038import org.forgerock.openig.heap.GenericHeaplet; 039import org.forgerock.openig.heap.HeapException; 040import org.forgerock.openig.heap.Keys; 041import org.forgerock.openig.http.Responses; 042import org.forgerock.services.context.Context; 043import org.forgerock.util.Reject; 044import org.forgerock.util.promise.NeverThrowsException; 045import org.forgerock.util.promise.Promise; 046import org.forgerock.util.promise.Promises; 047import org.forgerock.util.time.Duration; 048import org.forgerock.util.time.TimeService; 049 050/** 051 * This filter allows to limit the output rate to the specified handler. If the output rate is over, there a response 052 * with status 429 (Too Many Requests) is sent. 053 */ 054public class ThrottlingFilter extends GenericHeapObject implements Filter { 055 056 static final String DEFAULT_PARTITION_KEY = ""; 057 058 private final LoadingCache<String, TokenBucket> buckets; 059 private final Expression<String> partitionKey; 060 061 /** 062 * Constructs a ThrottlingFilter. 063 * 064 * @param time 065 * the time service. 066 * @param numberOfRequests 067 * the maximum of requests that can be filtered out during the duration. 068 * @param duration 069 * the duration of the sliding window. 070 * @param partitionKey 071 * the optional expression that tells in which bucket we have to take some token to count the output 072 * rate. 073 */ 074 public ThrottlingFilter(final TimeService time, 075 final int numberOfRequests, 076 final Duration duration, 077 final Expression<String> partitionKey) { 078 Reject.ifNull(partitionKey); 079 this.buckets = setupBuckets(time, numberOfRequests, duration); 080 this.partitionKey = partitionKey; 081 082 // Force to load the TokenBucket of the DEFAULT_PARTITION_KEY in order to validate the input parameters. 083 // If the parameters are not valid that will throw some unchecked exceptions. 084 buckets.getUnchecked(DEFAULT_PARTITION_KEY); 085 } 086 087 /** 088 * Returns the statistics of the underlying cache. This method must only be used for unit testing. 089 * 090 * @return the cache statistics 091 */ 092 CacheStats getBucketsStats() { 093 return buckets.stats(); 094 } 095 096 private LoadingCache<String, TokenBucket> setupBuckets(final TimeService time, 097 final int numberOfRequests, 098 final Duration duration) { 099 100 CacheLoader<String, TokenBucket> loader = new CacheLoader<String, TokenBucket>() { 101 @Override 102 public TokenBucket load(String key) { 103 return new TokenBucket(time, numberOfRequests, duration); 104 } 105 }; 106 107 // Wrap our TimeService so we can play with the time in our tests 108 Ticker ticker = new Ticker() { 109 110 @Override 111 public long read() { 112 return NANOSECONDS.convert(time.now(), MILLISECONDS); 113 } 114 }; 115 // Let's give some arbitrary delay for the eviction 116 long expire = duration.to(MILLISECONDS) + 3; 117 return CacheBuilder.newBuilder() 118 .ticker(ticker) 119 .expireAfterAccess(expire, MILLISECONDS) 120 .recordStats() 121 .build(loader); 122 } 123 124 @Override 125 public Promise<Response, NeverThrowsException> filter(Context context, Request request, Handler next) { 126 final String key = partitionKey.eval(bindings(context, request)); 127 if (key == null) { 128 logger.error("Did not expect a null value for the partitionKey after evaluated the expression : " 129 + partitionKey); 130 return Promises.newResultPromise(Responses.newInternalServerError()); 131 } 132 133 return filter(buckets.getUnchecked(key), context, request, next); 134 } 135 136 private Promise<Response, NeverThrowsException> filter(TokenBucket bucket, 137 Context context, 138 Request request, 139 Handler next) { 140 final long delay = bucket.tryConsume(); 141 if (delay <= 0) { 142 return next.handle(context, request); 143 } else { 144 // http://tools.ietf.org/html/rfc6585#section-4 145 Response result = new Response(Status.TOO_MANY_REQUESTS); 146 // http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.37 147 result.getHeaders().add("Retry-After", computeRetryAfter(delay)); 148 return Promises.newResultPromise(result); 149 } 150 } 151 152 private String computeRetryAfter(final long delay) { 153 // According to the Javadoc of TimeUnit.convert : 999 ms => 0 sec, but we want to answer 1 sec. 154 // 999 + 999 = 1998 => 1 second 155 // 1000 + 999 = 1999 => 1 second 156 // 1001 + 999 = 2000 => 2 seconds 157 return Long.toString(SECONDS.convert(delay + 999L, MILLISECONDS)); 158 } 159 160 /** 161 * Creates and initializes a throttling filter in a heap environment. 162 */ 163 public static class Heaplet extends GenericHeaplet { 164 @Override 165 public Object create() throws HeapException { 166 TimeService time = heap.get(Keys.TIME_SERVICE_HEAP_KEY, TimeService.class); 167 168 JsonValue rate = config.get("rate").required(); 169 170 Integer numberOfRequests = rate.get("numberOfRequests").required().asInteger(); 171 Duration duration = duration(rate.get("duration").required().asString()); 172 Expression<String> partitionKey = asExpression(config.get("partitionKey").defaultTo(DEFAULT_PARTITION_KEY), 173 String.class); 174 175 return new ThrottlingFilter(time, numberOfRequests, duration, partitionKey); 176 } 177 } 178}