spring-security icon indicating copy to clipboard operation
spring-security copied to clipboard

Add JdbcOidcSessionRegistry implementation

Open jzheaux opened this issue 1 year ago • 3 comments

An InMemoryOidcSessionRegistry is limited to storing things only on a single instance. A JDBC-based implementation will make so that OIDC Backchannel Logout will work in a clustered environment.

jzheaux avatar Jan 30 '24 23:01 jzheaux

I was trying to do it on Redis too, but I need the Mixing for OidcSessionInformation

jsantana3c avatar May 14 '24 20:05 jsantana3c

Sample from my code to implement this:

/**
 * OIDC Session registry for a clustered server setup with multiple nodes,
 * which saves user session information in a central database.
 * This follows the suggestion in the Spring Security docs:
 * <a href="https://docs.spring.io/spring-security/reference/servlet/oauth2/login/logout.html#_customizing_the_oidc_provider_session_strategy">Customizing the OIDC Provider Session Strategy</a>
 * Implementation logic follows the implementation for the default OIDC session registry, {@code InMemoryOidcSessionRegistry}.
 * @see org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry
 */
@Slf4j
@Component
public class ClusteredOidcSessionRegistry implements OidcSessionRegistry {
    private final OidcUserSessionRepository oidcUserSessionRepository;

    public ClusteredOidcSessionRegistry(OidcUserSessionRepository oidcUserSessionRepository) {
        this.oidcUserSessionRepository = oidcUserSessionRepository;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void saveSessionInformation(OidcSessionInformation info) {
        var oidcUserSession = new OidcUserSession();
        oidcUserSession.setSessionId(info.getSessionId());
        oidcUserSession.setSessionInformation(info);
        oidcUserSessionRepository.save(oidcUserSession);
    }

    /**
     * {@inheritDoc}
     */
    @Transactional
    @Override
    public OidcSessionInformation removeSessionInformation(String clientSessionId) {
        Optional<OidcUserSession> oidcUserSession = oidcUserSessionRepository.findBySessionId(clientSessionId);
        oidcUserSession.ifPresent(oidcUserSessionRepository::delete);
        return oidcUserSession.map(OidcUserSession::getSessionInformation).orElse(null);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public Iterable<OidcSessionInformation> removeSessionInformation(OidcLogoutToken token) {
        List<String> audience = token.getAudience();
        String issuer = token.getIssuer().toString();
        String subject = token.getSubject();
        String providerSessionId = token.getSessionId();
        Predicate<OidcSessionInformation> matcher = (providerSessionId != null)
                ? sessionIdMatcher(audience, issuer, providerSessionId)
                : subjectMatcher(audience, issuer, subject);
        var allSavedSessions = oidcUserSessionRepository.findAll();
        var deletedOidcSessions = deleteAndGetMatchedSessions(allSavedSessions, matcher);
        if (deletedOidcSessions.isEmpty()) {
            log.debug("Failed to remove any sessions since none matched");
        } else {
            log.trace("Found and removed {} session(s) from mapping of {} session(s)", deletedOidcSessions.size(), allSavedSessions.size());
        }
        return deletedOidcSessions;
    }

    private Set<OidcSessionInformation> deleteAndGetMatchedSessions(List<OidcUserSession> oidcUserSessions,
                                                                    Predicate<OidcSessionInformation> matcher) {
        Set<OidcSessionInformation> infos = new HashSet<>();
        oidcUserSessions.forEach(oidcUserSession -> {
            var sessionInfo = oidcUserSession.getSessionInformation();
            if (matcher.test(sessionInfo)) {
                oidcUserSessionRepository.delete(oidcUserSession);
                infos.add(sessionInfo);
            }
        });
        return infos;
    }

    private static Predicate<OidcSessionInformation> sessionIdMatcher(List<String> audience, String issuer,
                                                                      String sessionId) {
        log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SID, sessionId);
        return session -> {
            List<String> thatAudience = session.getPrincipal().getAudience();
            String thatIssuer = session.getPrincipal().getIssuer().toString();
            String thatSessionId = session.getPrincipal().getClaimAsString(LogoutTokenClaimNames.SID);
            if (thatAudience == null) {
                return false;
            }
            return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer)
                    && sessionId.equals(thatSessionId);
        };
    }

    private static Predicate<OidcSessionInformation> subjectMatcher(List<String> audience, String issuer,
                                                                    String subject) {
        log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SUB, subject);
        return session -> {
            List<String> thatAudience = session.getPrincipal().getAudience();
            String thatIssuer = session.getPrincipal().getIssuer().toString();
            String thatSubject = session.getPrincipal().getSubject();
            if (thatAudience == null) {
                return false;
            }
            return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer)
                    && subject.equals(thatSubject);
        };
    }
}

aelillie avatar Jun 04 '24 14:06 aelillie

Sample from my code to implement this:

/**
 * OIDC Session registry for a clustered server setup with multiple nodes,
 * which saves user session information in a central database.
 * This follows the suggestion in the Spring Security docs:
 * <a href="https://docs.spring.io/spring-security/reference/servlet/oauth2/login/logout.html#_customizing_the_oidc_provider_session_strategy">Customizing the OIDC Provider Session Strategy</a>
 * Implementation logic follows the implementation for the default OIDC session registry, {@code InMemoryOidcSessionRegistry}.
 * @see org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry
 */
@Slf4j
@Component
public class ClusteredOidcSessionRegistry implements OidcSessionRegistry {
    private final OidcUserSessionRepository oidcUserSessionRepository;

    public ClusteredOidcSessionRegistry(OidcUserSessionRepository oidcUserSessionRepository) {
        this.oidcUserSessionRepository = oidcUserSessionRepository;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void saveSessionInformation(OidcSessionInformation info) {
        var oidcUserSession = new OidcUserSession();
        oidcUserSession.setSessionId(info.getSessionId());
        oidcUserSession.setSessionInformation(info);
        oidcUserSessionRepository.save(oidcUserSession);
    }

    /**
     * {@inheritDoc}
     */
    @Transactional
    @Override
    public OidcSessionInformation removeSessionInformation(String clientSessionId) {
        Optional<OidcUserSession> oidcUserSession = oidcUserSessionRepository.findBySessionId(clientSessionId);
        oidcUserSession.ifPresent(oidcUserSessionRepository::delete);
        return oidcUserSession.map(OidcUserSession::getSessionInformation).orElse(null);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public Iterable<OidcSessionInformation> removeSessionInformation(OidcLogoutToken token) {
        List<String> audience = token.getAudience();
        String issuer = token.getIssuer().toString();
        String subject = token.getSubject();
        String providerSessionId = token.getSessionId();
        Predicate<OidcSessionInformation> matcher = (providerSessionId != null)
                ? sessionIdMatcher(audience, issuer, providerSessionId)
                : subjectMatcher(audience, issuer, subject);
        var allSavedSessions = oidcUserSessionRepository.findAll();
        var deletedOidcSessions = deleteAndGetMatchedSessions(allSavedSessions, matcher);
        if (deletedOidcSessions.isEmpty()) {
            log.debug("Failed to remove any sessions since none matched");
        } else {
            log.trace("Found and removed {} session(s) from mapping of {} session(s)", deletedOidcSessions.size(), allSavedSessions.size());
        }
        return deletedOidcSessions;
    }

    private Set<OidcSessionInformation> deleteAndGetMatchedSessions(List<OidcUserSession> oidcUserSessions,
                                                                    Predicate<OidcSessionInformation> matcher) {
        Set<OidcSessionInformation> infos = new HashSet<>();
        oidcUserSessions.forEach(oidcUserSession -> {
            var sessionInfo = oidcUserSession.getSessionInformation();
            if (matcher.test(sessionInfo)) {
                oidcUserSessionRepository.delete(oidcUserSession);
                infos.add(sessionInfo);
            }
        });
        return infos;
    }

    private static Predicate<OidcSessionInformation> sessionIdMatcher(List<String> audience, String issuer,
                                                                      String sessionId) {
        log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SID, sessionId);
        return session -> {
            List<String> thatAudience = session.getPrincipal().getAudience();
            String thatIssuer = session.getPrincipal().getIssuer().toString();
            String thatSessionId = session.getPrincipal().getClaimAsString(LogoutTokenClaimNames.SID);
            if (thatAudience == null) {
                return false;
            }
            return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer)
                    && sessionId.equals(thatSessionId);
        };
    }

    private static Predicate<OidcSessionInformation> subjectMatcher(List<String> audience, String issuer,
                                                                    String subject) {
        log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SUB, subject);
        return session -> {
            List<String> thatAudience = session.getPrincipal().getAudience();
            String thatIssuer = session.getPrincipal().getIssuer().toString();
            String thatSubject = session.getPrincipal().getSubject();
            if (thatAudience == null) {
                return false;
            }
            return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer)
                    && subject.equals(thatSubject);
        };
    }
}

@aelillie Thanks for Sharing this. Mind if I ask you to also share your implementation on the OidcUserSessionRepository? I'm having a real hard time implementing the logic to properly interact with JDBC. -- Thank you so much

xiechangning20 avatar Jun 11 '24 23:06 xiechangning20

@xiechangning20 , sure, here you go:

@Repository
public interface OidcUserSessionRepository extends ListCrudRepository<OidcUserSession, UUID> {
    Optional<OidcUserSession> findBySessionId(String sessionId);
}

@Getter
@Setter
@Entity
@Table
public class OidcUserSession {
    @Id
    @GeneratedValue(strategy = GenerationType.UUID)
    @Column(nullable = false, updatable = false)
    private UUID id;
    @Column(nullable = false)
    private String sessionId;
    @Column(nullable = false)
    private OidcSessionInformation sessionInformation;
}

aelillie avatar Feb 21 '25 11:02 aelillie