SSLSessionUtils.java

/*
 * Copyright 2019-2021 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package nl.altindag.ssl.util;

import nl.altindag.ssl.SSLFactory;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSessionContext;
import java.time.Instant;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/**
 * @author Hakan Altindag
 */
public final class SSLSessionUtils {

    private SSLSessionUtils() {}

    public static void invalidateCaches(SSLFactory sslFactory) {
        invalidateServerCaches(sslFactory);
        invalidateClientCaches(sslFactory);
    }

    public static void invalidateServerCaches(SSLFactory sslFactory) {
        invalidateServerCaches(sslFactory.getSslContext());
    }

    public static void invalidateClientCaches(SSLFactory sslFactory) {
        invalidateClientCaches(sslFactory.getSslContext());
    }

    public static void invalidateCaches(SSLContext sslContext) {
        invalidateServerCaches(sslContext);
        invalidateClientCaches(sslContext);
    }

    public static void invalidateServerCaches(SSLContext sslContext) {
        invalidateCaches(sslContext.getServerSessionContext());
    }

    public static void invalidateClientCaches(SSLContext sslContext) {
        invalidateCaches(sslContext.getClientSessionContext());
    }

    public static void invalidateCaches(SSLSessionContext sslSessionContext) {
        SSLSessionUtils.getSslSessions(sslSessionContext).forEach(SSLSession::invalidate);
    }

    public static void invalidateCachesBefore(SSLFactory sslFactory, ZonedDateTime upperBoundary) {
        invalidateCachesBefore(sslFactory.getSslContext(), upperBoundary);
    }

    public static void invalidateCachesBefore(SSLContext sslContext, ZonedDateTime upperBoundary) {
        invalidateCachesBefore(sslContext.getServerSessionContext(), upperBoundary);
        invalidateCachesBefore(sslContext.getClientSessionContext(), upperBoundary);
    }

    public static void invalidateCachesBefore(SSLSessionContext sslSessionContext, ZonedDateTime upperBoundary) {
        invalidateCachesWithTimeStamp(sslSessionContext, sslSessionCreationTime -> sslSessionCreationTime.isBefore(upperBoundary));
    }

    public static void invalidateCachesAfter(SSLFactory sslFactory, ZonedDateTime lowerBoundary) {
        invalidateCachesAfter(sslFactory.getSslContext(), lowerBoundary);
    }

    public static void invalidateCachesAfter(SSLContext sslContext, ZonedDateTime lowerBoundary) {
        invalidateCachesAfter(sslContext.getServerSessionContext(), lowerBoundary);
        invalidateCachesAfter(sslContext.getClientSessionContext(), lowerBoundary);
    }

    public static void invalidateCachesAfter(SSLSessionContext sslSessionContext, ZonedDateTime lowerBoundary) {
        invalidateCachesWithTimeStamp(sslSessionContext, sslSessionCreationTime -> sslSessionCreationTime.isAfter(lowerBoundary));
    }

    public static void invalidateCachesBetween(SSLFactory sslFactory, ZonedDateTime lowerBoundary, ZonedDateTime upperBoundary) {
        invalidateCachesBetween(sslFactory.getSslContext(), lowerBoundary, upperBoundary);
    }

    public static void invalidateCachesBetween(SSLContext sslContext, ZonedDateTime lowerBoundary, ZonedDateTime upperBoundary) {
        invalidateCachesBetween(sslContext.getServerSessionContext(), lowerBoundary, upperBoundary);
        invalidateCachesBetween(sslContext.getClientSessionContext(), lowerBoundary, upperBoundary);
    }

    public static void invalidateCachesBetween(SSLSessionContext sslSessionContext, ZonedDateTime lowerBoundary, ZonedDateTime upperBoundary) {
        Predicate<ZonedDateTime> isAfterLowerBoundary = sslSessionCreationTime -> sslSessionCreationTime.isAfter(lowerBoundary);
        Predicate<ZonedDateTime> isBeforeUpperBoundary = sslSessionCreationTime -> sslSessionCreationTime.isBefore(upperBoundary);

        invalidateCachesWithTimeStamp(sslSessionContext, isAfterLowerBoundary.and(isBeforeUpperBoundary));
    }

    private static void invalidateCachesWithTimeStamp(SSLSessionContext sslSessionContext, Predicate<ZonedDateTime> timeStampPredicate) {
        SSLSessionUtils.getSslSessions(sslSessionContext).stream()
                .filter(sslSession -> {
                    ZonedDateTime sslSessionCreationTime = ZonedDateTime.ofInstant(Instant.ofEpochMilli(sslSession.getCreationTime()), ZoneOffset.UTC);
                    return timeStampPredicate.test(sslSessionCreationTime);
                })
                .forEach(SSLSession::invalidate);
    }

    public static void updateSessionTimeout(SSLFactory sslFactory, int timeoutInSeconds) {
        updateSessionTimeout(sslFactory.getSslContext(), timeoutInSeconds);
    }

    public static void updateSessionTimeout(SSLContext sslContext, int timeoutInSeconds) {
        validateSessionTimeout(timeoutInSeconds);

        sslContext.getClientSessionContext().setSessionTimeout(timeoutInSeconds);
        sslContext.getServerSessionContext().setSessionTimeout(timeoutInSeconds);
    }

    public static void updateSessionCacheSize(SSLFactory sslFactory, int cacheSizeInBytes) {
        updateSessionCacheSize(sslFactory.getSslContext(), cacheSizeInBytes);
    }

    public static void updateSessionCacheSize(SSLContext sslContext, int cacheSizeInBytes) {
        validateSessionCacheSize(cacheSizeInBytes);

        sslContext.getClientSessionContext().setSessionCacheSize(cacheSizeInBytes);
        sslContext.getServerSessionContext().setSessionCacheSize(cacheSizeInBytes);
    }

    public static void validateSessionTimeout(int timeoutInSeconds) {
        if (timeoutInSeconds < 0) {
            throw new IllegalArgumentException(String.format(
                    "Unsupported timeout has been provided. Timeout should be equal or greater than [%d], but received [%d]",
                    0, timeoutInSeconds));
        }
    }

    public static void validateSessionCacheSize(int cacheSizeInBytes) {
        if (cacheSizeInBytes < 0) {
            throw new IllegalArgumentException(String.format(
                    "Unsupported cache size has been provided. Cache size should be equal or greater than [%d], but received [%d]",
                    0, cacheSizeInBytes));
        }
    }

    public static List<SSLSession> getServerSslSessions(SSLFactory sslFactory) {
        return getServerSslSessions(sslFactory.getSslContext());
    }

    public static List<SSLSession> getServerSslSessions(SSLContext sslContext) {
        return getSslSessions(sslContext.getServerSessionContext());
    }

    public static List<SSLSession> getClientSslSessions(SSLFactory sslFactory) {
        return getClientSslSessions(sslFactory.getSslContext());
    }

    public static List<SSLSession> getClientSslSessions(SSLContext sslContext) {
        return getSslSessions(sslContext.getClientSessionContext());
    }

    public static List<SSLSession> getSslSessions(SSLSessionContext sslSessionContext) {
        return Collections.list(sslSessionContext.getIds()).stream()
                .map(sslSessionContext::getSession)
                .filter(Objects::nonNull)
                .collect(Collectors.collectingAndThen(Collectors.toList(), Collections::unmodifiableList));
    }

}