001/*
002 * The contents of this file are subject to the terms of the Common Development and
003 * Distribution License (the License). You may not use this file except in compliance with the
004 * License.
005 *
006 * You can obtain a copy of the License at legal/CDDLv1.0.txt. See the License for the
007 * specific language governing permission and limitations under the License.
008 *
009 * When distributing Covered Software, include this CDDL Header Notice in each file and include
010 * the License file at legal/CDDLv1.0.txt. If applicable, add the following below the CDDL
011 * Header, with the fields enclosed by brackets [] replaced by your own identifying
012 * information: "Portions Copyright [year] [name of copyright owner]".
013 *
014 * Copyright 2010-2011 ApexIdentity Inc.
015 * Portions Copyright 2011-2015 ForgeRock AS.
016 */
017
018package org.forgerock.openig.filter;
019
020import static org.forgerock.openig.el.Bindings.bindings;
021import static org.forgerock.openig.util.JsonValues.asExpression;
022
023import java.util.Collections;
024import java.util.List;
025import java.util.Map;
026
027import org.forgerock.http.Filter;
028import org.forgerock.http.Handler;
029import org.forgerock.http.protocol.Headers;
030import org.forgerock.http.protocol.Message;
031import org.forgerock.http.protocol.Request;
032import org.forgerock.http.protocol.Response;
033import org.forgerock.http.util.CaseInsensitiveSet;
034import org.forgerock.json.JsonValue;
035import org.forgerock.openig.el.Bindings;
036import org.forgerock.openig.heap.GenericHeapObject;
037import org.forgerock.openig.heap.GenericHeaplet;
038import org.forgerock.openig.heap.HeapException;
039import org.forgerock.openig.util.MessageType;
040import org.forgerock.services.context.Context;
041import org.forgerock.util.promise.NeverThrowsException;
042import org.forgerock.util.promise.Promise;
043import org.forgerock.util.promise.ResultHandler;
044
045/**
046 * Removes headers from and adds headers to a message.
047 */
048public class HeaderFilter extends GenericHeapObject implements Filter {
049
050    /** Indicates the type of message to filter headers for. */
051    private final MessageType messageType;
052
053    /** The names of header fields to remove from the message. */
054    private final CaseInsensitiveSet removedHeaders = new CaseInsensitiveSet();
055
056    /** Header fields to add to the message. */
057    private final Headers addedHeaders = new Headers();
058
059    /**
060     * Builds a HeaderFilter processing either the incoming or outgoing message.
061     * @param messageType {@link MessageType#REQUEST} or {@link MessageType#RESPONSE}
062     */
063    public HeaderFilter(final MessageType messageType) {
064        this.messageType = messageType;
065    }
066
067    /**
068     * Returns the names of header fields to remove from the message.
069     * @return the names of header fields to remove from the message.
070     */
071    public CaseInsensitiveSet getRemovedHeaders() {
072        return removedHeaders;
073    }
074
075    /**
076     * Returns the header fields to add to the message.
077     * This is a essentially a Map of String to a List of String, each listed value representing
078     * an expression that will be evaluated.
079     * @return the header fields to add to the message.
080     */
081    public Headers getAddedHeaders() {
082        return addedHeaders;
083    }
084
085    /**
086     * Removes all specified headers, then adds all specified headers.
087     */
088    private void process(Message message, Bindings bindings) {
089        for (String s : this.removedHeaders) {
090            message.getHeaders().remove(s);
091        }
092        for (String key : this.addedHeaders.keySet()) {
093            for (String value : this.addedHeaders.get(key).getValues()) {
094                JsonValue jsonValue = new JsonValue(value);
095                message.getHeaders().add(key, asExpression(jsonValue, String.class).eval(bindings));
096            }
097        }
098    }
099
100    @Override
101    public Promise<Response, NeverThrowsException> filter(final Context context,
102                                                          final Request request,
103                                                          final Handler next) {
104        if (messageType == MessageType.REQUEST) {
105            process(request, bindings(context, request));
106        }
107        Promise<Response, NeverThrowsException> promise = next.handle(context, request);
108        if (messageType == MessageType.RESPONSE) {
109            return promise.thenOnResult(new ResultHandler<Response>() {
110                @Override
111                public void handleResult(final Response response) {
112                    process(response, bindings(context, request, response));
113                }
114            });
115        }
116        return promise;
117    }
118
119    /** Creates and initializes a header filter in a heap environment. */
120    public static class Heaplet extends GenericHeaplet {
121        @Override
122        public Object create() throws HeapException {
123            HeaderFilter filter = new HeaderFilter(config.get("messageType")
124                                                         .required()
125                                                         .asEnum(MessageType.class));
126            filter.removedHeaders.addAll(config.get("remove")
127                                         .defaultTo(Collections.emptyList())
128                                         .asList(String.class));
129            JsonValue add = config.get("add")
130                    .defaultTo(Collections.emptyMap())
131                    .expect(Map.class);
132            for (String key : add.keys()) {
133                List<String> values = add.get(key).required().asList(String.class);
134                filter.addedHeaders.put(key, values);
135            }
136            return filter;
137        }
138    }
139}