001/*
002 * The contents of this file are subject to the terms of the Common Development and
003 * Distribution License (the License). You may not use this file except in compliance with the
004 * License.
005 *
006 * You can obtain a copy of the License at legal/CDDLv1.0.txt. See the License for the
007 * specific language governing permission and limitations under the License.
008 *
009 * When distributing Covered Software, include this CDDL Header Notice in each file and include
010 * the License file at legal/CDDLv1.0.txt. If applicable, add the following below the CDDL
011 * Header, with the fields enclosed by brackets [] replaced by your own identifying
012 * information: "Portions Copyright [year] [name of copyright owner]".
013 *
014 * Copyright 2008 Sun Microsystems, Inc.
015 * Portions Copyright 2015-2017 ForgeRock AS.
016 */
017package org.opends.admin.ads.util;
018
019import com.forgerock.opendj.cli.ConnectionFactoryProvider;
020
021import java.io.IOException;
022import java.net.InetAddress;
023import java.net.Socket;
024import java.security.GeneralSecurityException;
025import java.security.KeyManagementException;
026import java.security.NoSuchAlgorithmException;
027import java.util.Arrays;
028import java.util.HashMap;
029import java.util.HashSet;
030import java.util.List;
031import java.util.Map;
032import java.util.Set;
033
034import javax.net.SocketFactory;
035import javax.net.ssl.KeyManager;
036import javax.net.ssl.SSLContext;
037import javax.net.ssl.SSLKeyException;
038import javax.net.ssl.SSLSocket;
039import javax.net.ssl.SSLSocketFactory;
040import javax.net.ssl.TrustManager;
041
042/**
043 * An implementation of SSLSocketFactory.
044 * <p>
045 * Note: The class must be public so it can be instantiated by the
046 * {@link javax.naming.ldap.InitialLdapContext}.
047 */
048public class TrustedSocketFactory extends SSLSocketFactory
049{
050  private static final Map<Thread, TrustManager> hmTrustManager = new HashMap<>();
051  private static final Map<Thread, KeyManager> hmKeyManager = new HashMap<>();
052
053  private static final Map<TrustManager, SocketFactory> hmDefaultFactoryTm = new HashMap<>();
054  private static final Map<KeyManager, SocketFactory> hmDefaultFactoryKm = new HashMap<>();
055
056  private SSLSocketFactory innerFactory;
057  private final TrustManager trustManager;
058  private final KeyManager keyManager;
059
060  /**
061   * Constructor of the TrustedSocketFactory.
062   * <p>
063   * Note: The class must be public so it can be instantiated by the
064   * {@link javax.naming.ldap.InitialLdapContext}.
065   *
066   * @param trustManager
067   *          the trust manager to use.
068   * @param keyManager
069   *          the key manager to use.
070   */
071  public TrustedSocketFactory(TrustManager trustManager, KeyManager keyManager)
072  {
073    this.trustManager = trustManager;
074    this.keyManager   = keyManager;
075  }
076
077  /**
078   * Sets the provided trust and key manager for the operations in the
079   * current thread.
080   *
081   * @param trustManager
082   *          the trust manager to use.
083   * @param keyManager
084   *          the key manager to use.
085   */
086  static synchronized void setCurrentThreadTrustManager(TrustManager trustManager, KeyManager keyManager)
087  {
088    setThreadTrustManager(trustManager, Thread.currentThread());
089    setThreadKeyManager  (keyManager, Thread.currentThread());
090  }
091
092  /**
093   * Sets the provided trust manager for the operations in the provided thread.
094   * @param trustManager the trust manager to use.
095   * @param thread the thread where we want to use the provided trust manager.
096   */
097  static synchronized void setThreadTrustManager(TrustManager trustManager, Thread thread)
098  {
099    TrustManager currentTrustManager = hmTrustManager.get(thread);
100    if (currentTrustManager != null) {
101      hmDefaultFactoryTm.remove(currentTrustManager);
102      hmTrustManager.remove(thread);
103    }
104    if (trustManager != null) {
105      hmTrustManager.put(thread, trustManager);
106    }
107  }
108
109  /**
110   * Sets the provided key manager for the operations in the provided thread.
111   * @param keyManager the key manager to use.
112   * @param thread the thread where we want to use the provided key manager.
113   */
114  static synchronized void setThreadKeyManager(KeyManager keyManager, Thread thread)
115  {
116    KeyManager currentKeyManager = hmKeyManager.get(thread);
117    if (currentKeyManager != null) {
118      hmDefaultFactoryKm.remove(currentKeyManager);
119      hmKeyManager.remove(thread);
120    }
121    if (keyManager != null) {
122      hmKeyManager.put(thread, keyManager);
123    }
124  }
125
126  // SocketFactory implementation
127  /**
128   * Returns the default SSL socket factory. The default
129   * implementation can be changed by setting the value of the
130   * "ssl.SocketFactory.provider" security property (in the Java
131   * security properties file) to the desired class. If SSL has not
132   * been configured properly for this virtual machine, the factory
133   * will be inoperative (reporting instantiation exceptions).
134   *
135   * @return the default SocketFactory
136   */
137  public static synchronized SocketFactory getDefault()
138  {
139    Thread currentThread = Thread.currentThread();
140    TrustManager trustManager = hmTrustManager.get(currentThread);
141    KeyManager   keyManager   = hmKeyManager.get(currentThread);
142    SocketFactory result;
143
144    if (trustManager == null)
145    {
146      if (keyManager == null)
147      {
148        result = new TrustedSocketFactory(null,null);
149      }
150      else
151      {
152        result = hmDefaultFactoryKm.get(keyManager);
153        if (result == null)
154        {
155          result = new TrustedSocketFactory(null,keyManager);
156          hmDefaultFactoryKm.put(keyManager, result);
157        }
158      }
159    }
160    else
161    {
162      if (keyManager == null)
163      {
164        result = hmDefaultFactoryTm.get(trustManager);
165        if (result == null)
166        {
167          result = new TrustedSocketFactory(trustManager, null);
168          hmDefaultFactoryTm.put(trustManager, result);
169        }
170      }
171      else
172      {
173        SocketFactory tmsf = hmDefaultFactoryTm.get(trustManager);
174        SocketFactory kmsf = hmDefaultFactoryKm.get(keyManager);
175        if (tmsf == null || kmsf == null)
176        {
177          result = new TrustedSocketFactory(trustManager, keyManager);
178          hmDefaultFactoryTm.put(trustManager, result);
179          hmDefaultFactoryKm.put(keyManager, result);
180        }
181        else if (!tmsf.equals(kmsf))
182        {
183          result = new TrustedSocketFactory(trustManager, keyManager);
184          hmDefaultFactoryTm.put(trustManager, result);
185          hmDefaultFactoryKm.put(keyManager, result);
186        }
187        else
188        {
189          result = tmsf;
190        }
191      }
192    }
193
194    return result;
195  }
196
197  @Override
198  public Socket createSocket(InetAddress address, int port) throws IOException {
199    return reconfigureSocket(getInnerFactory().createSocket(address, port));
200  }
201
202  @Override
203  public Socket createSocket(InetAddress address, int port,
204      InetAddress clientAddress, int clientPort) throws IOException
205  {
206    return reconfigureSocket(getInnerFactory().createSocket(address, port, clientAddress, clientPort));
207  }
208
209  @Override
210  public Socket createSocket(String host, int port) throws IOException
211  {
212    return reconfigureSocket(getInnerFactory().createSocket(host, port));
213  }
214
215  @Override
216  public Socket createSocket(String host, int port, InetAddress clientHost,
217      int clientPort) throws IOException
218  {
219    return reconfigureSocket(getInnerFactory().createSocket(host, port, clientHost, clientPort));
220  }
221
222  @Override
223  public Socket createSocket(Socket s, String host, int port, boolean autoClose)
224  throws IOException
225  {
226    return reconfigureSocket(getInnerFactory().createSocket(s, host, port, autoClose));
227  }
228
229  /*
230   * Reconfigure the created socket so that it has a list of multiple enabled protocols. There seems to be no other way
231   * for a factory to do this.
232   *
233   * @param s  The socket to reconfigure.
234   * @return The reconfigured socket (if an SSLSocket) or the original socket if not.
235   */
236  private Socket reconfigureSocket(Socket s) throws IOException
237  {
238    if (s instanceof SSLSocket)
239    {
240      try
241      {
242        SSLSocket sslSocket = (SSLSocket) s;
243        final List<String> protocols = ConnectionFactoryProvider.getDefaultProtocols();
244        sslSocket.setEnabledProtocols(protocols.toArray(new String[0]));
245        final Set<String> enabledCiphers = new HashSet<>();
246        for (String protocol : protocols)
247        {
248          try
249          {
250            final SSLContext context = SSLContext.getInstance(protocol);
251            context.init(null, null, null);
252            enabledCiphers.addAll(Arrays.asList(context.createSSLEngine().getEnabledCipherSuites()));
253          }
254          catch (KeyManagementException ignored)
255          {
256            // ignore
257          }
258        }
259        sslSocket.setEnabledCipherSuites(enabledCiphers.toArray(new String[0]));
260      }
261      catch (NoSuchAlgorithmException e)
262      {
263        throw new IOException(e.getMessage());
264      }
265    }
266    return s;
267  }
268
269  @Override
270  public String[] getDefaultCipherSuites()
271  {
272    try
273    {
274      return getInnerFactory().getDefaultCipherSuites();
275    }
276    catch(IOException x)
277    {
278      return new String[0];
279    }
280  }
281
282  @Override
283  public String[] getSupportedCipherSuites()
284  {
285    try
286    {
287      return getInnerFactory().getSupportedCipherSuites();
288    }
289    catch(IOException x)
290    {
291      return new String[0];
292    }
293  }
294
295  private SSLSocketFactory getInnerFactory() throws IOException {
296    if (innerFactory == null)
297    {
298      String algorithm = "TLSv1";
299
300      try {
301        KeyManager[] km = null;
302        TrustManager[] tm = null;
303        SSLContext sslCtx = SSLContext.getInstance(algorithm);
304        if (trustManager != null)
305        {
306          tm = new TrustManager[] { trustManager };
307        }
308        if (keyManager != null)
309        {
310          km = new KeyManager[] { keyManager };
311        }
312        sslCtx.init(km, tm, new java.security.SecureRandom() );
313        innerFactory = sslCtx.getSocketFactory();
314      }
315      catch(GeneralSecurityException x) {
316        SSLKeyException xx = new SSLKeyException("Failed to create SSLContext for " + algorithm);
317        xx.initCause(x);
318        throw xx;
319      }
320    }
321    return innerFactory;
322  }
323}