001/*
002 * The contents of this file are subject to the terms of the Common Development and
003 * Distribution License (the License). You may not use this file except in compliance with the
004 * License.
005 *
006 * You can obtain a copy of the License at legal/CDDLv1.0.txt. See the License for the
007 * specific language governing permission and limitations under the License.
008 *
009 * When distributing Covered Software, include this CDDL Header Notice in each file and include
010 * the License file at legal/CDDLv1.0.txt. If applicable, add the following below the CDDL
011 * Header, with the fields enclosed by brackets [] replaced by your own identifying
012 * information: "Portions copyright [year] [name of copyright owner]".
013 *
014 * Copyright 2015 ForgeRock AS.
015 */
016package org.forgerock.openig.filter;
017
018import static java.util.concurrent.TimeUnit.MILLISECONDS;
019import static java.util.concurrent.TimeUnit.NANOSECONDS;
020import static java.util.concurrent.TimeUnit.SECONDS;
021import static org.forgerock.openig.el.Bindings.bindings;
022import static org.forgerock.openig.util.JsonValues.asExpression;
023import static org.forgerock.util.time.Duration.duration;
024
025import org.forgerock.guava.common.base.Ticker;
026import org.forgerock.guava.common.cache.CacheBuilder;
027import org.forgerock.guava.common.cache.CacheLoader;
028import org.forgerock.guava.common.cache.CacheStats;
029import org.forgerock.guava.common.cache.LoadingCache;
030import org.forgerock.http.Filter;
031import org.forgerock.http.Handler;
032import org.forgerock.http.protocol.Request;
033import org.forgerock.http.protocol.Response;
034import org.forgerock.http.protocol.Status;
035import org.forgerock.json.JsonValue;
036import org.forgerock.openig.el.Expression;
037import org.forgerock.openig.heap.GenericHeapObject;
038import org.forgerock.openig.heap.GenericHeaplet;
039import org.forgerock.openig.heap.HeapException;
040import org.forgerock.openig.heap.Keys;
041import org.forgerock.openig.http.Responses;
042import org.forgerock.services.context.Context;
043import org.forgerock.util.Reject;
044import org.forgerock.util.promise.NeverThrowsException;
045import org.forgerock.util.promise.Promise;
046import org.forgerock.util.promise.Promises;
047import org.forgerock.util.time.Duration;
048import org.forgerock.util.time.TimeService;
049
050/**
051 * This filter allows to limit the output rate to the specified handler. If the output rate is over, there a response
052 * with status 429 (Too Many Requests) is sent.
053 */
054public class ThrottlingFilter extends GenericHeapObject implements Filter {
055
056    static final String DEFAULT_PARTITION_KEY = "";
057
058    private final LoadingCache<String, TokenBucket> buckets;
059    private final Expression<String> partitionKey;
060
061    /**
062     * Constructs a ThrottlingFilter.
063     *
064     * @param time
065     *            the time service.
066     * @param numberOfRequests
067     *            the maximum of requests that can be filtered out during the duration.
068     * @param duration
069     *            the duration of the sliding window.
070     * @param partitionKey
071     *            the optional expression that tells in which bucket we have to take some token to count the output
072     *            rate.
073     */
074    public ThrottlingFilter(final TimeService time,
075                            final int numberOfRequests,
076                            final Duration duration,
077                            final Expression<String> partitionKey) {
078        Reject.ifNull(partitionKey);
079        this.buckets = setupBuckets(time, numberOfRequests, duration);
080        this.partitionKey = partitionKey;
081
082        // Force to load the TokenBucket of the DEFAULT_PARTITION_KEY in order to validate the input parameters.
083        // If the parameters are not valid that will throw some unchecked exceptions.
084        buckets.getUnchecked(DEFAULT_PARTITION_KEY);
085    }
086
087    /**
088     * Returns the statistics of the underlying cache. This method must only be used for unit testing.
089     *
090     * @return the cache statistics
091     */
092    CacheStats getBucketsStats() {
093        return buckets.stats();
094    }
095
096    private LoadingCache<String, TokenBucket> setupBuckets(final TimeService time,
097                                                           final int numberOfRequests,
098                                                           final Duration duration) {
099
100        CacheLoader<String, TokenBucket> loader = new CacheLoader<String, TokenBucket>() {
101            @Override
102            public TokenBucket load(String key) {
103                return new TokenBucket(time, numberOfRequests, duration);
104            }
105        };
106
107        // Wrap our TimeService so we can play with the time in our tests
108        Ticker ticker = new Ticker() {
109
110            @Override
111            public long read() {
112                return NANOSECONDS.convert(time.now(), MILLISECONDS);
113            }
114        };
115        // Let's give some arbitrary delay for the eviction
116        long expire = duration.to(MILLISECONDS) + 3;
117        return CacheBuilder.newBuilder()
118                           .ticker(ticker)
119                           .expireAfterAccess(expire, MILLISECONDS)
120                           .recordStats()
121                           .build(loader);
122    }
123
124    @Override
125    public Promise<Response, NeverThrowsException> filter(Context context, Request request, Handler next) {
126        final String key = partitionKey.eval(bindings(context, request));
127        if (key == null) {
128            logger.error("Did not expect a null value for the partitionKey after evaluated the expression : "
129                    + partitionKey);
130            return Promises.newResultPromise(Responses.newInternalServerError());
131        }
132
133        return filter(buckets.getUnchecked(key), context, request, next);
134    }
135
136    private Promise<Response, NeverThrowsException> filter(TokenBucket bucket,
137                                                           Context context,
138                                                           Request request,
139                                                           Handler next) {
140        final long delay = bucket.tryConsume();
141        if (delay <= 0) {
142            return next.handle(context, request);
143        } else {
144            // http://tools.ietf.org/html/rfc6585#section-4
145            Response result = new Response(Status.TOO_MANY_REQUESTS);
146            // http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.37
147            result.getHeaders().add("Retry-After", computeRetryAfter(delay));
148            return Promises.newResultPromise(result);
149        }
150    }
151
152    private String computeRetryAfter(final long delay) {
153        // According to the Javadoc of TimeUnit.convert : 999 ms => 0 sec, but we want to answer 1 sec.
154        //  999 + 999 = 1998 => 1 second
155        // 1000 + 999 = 1999 => 1 second
156        // 1001 + 999 = 2000 => 2 seconds
157        return Long.toString(SECONDS.convert(delay + 999L, MILLISECONDS));
158    }
159
160    /**
161     * Creates and initializes a throttling filter in a heap environment.
162     */
163    public static class Heaplet extends GenericHeaplet {
164        @Override
165        public Object create() throws HeapException {
166            TimeService time = heap.get(Keys.TIME_SERVICE_HEAP_KEY, TimeService.class);
167
168            JsonValue rate = config.get("rate").required();
169
170            Integer numberOfRequests = rate.get("numberOfRequests").required().asInteger();
171            Duration duration = duration(rate.get("duration").required().asString());
172            Expression<String> partitionKey = asExpression(config.get("partitionKey").defaultTo(DEFAULT_PARTITION_KEY),
173                                                           String.class);
174
175            return new ThrottlingFilter(time, numberOfRequests, duration, partitionKey);
176        }
177    }
178}