Add JdbcOidcSessionRegistry implementation
An InMemoryOidcSessionRegistry is limited to storing things only on a single instance. A JDBC-based implementation will make so that OIDC Backchannel Logout will work in a clustered environment.
I was trying to do it on Redis too, but I need the Mixing for OidcSessionInformation
Sample from my code to implement this:
/**
* OIDC Session registry for a clustered server setup with multiple nodes,
* which saves user session information in a central database.
* This follows the suggestion in the Spring Security docs:
* <a href="https://docs.spring.io/spring-security/reference/servlet/oauth2/login/logout.html#_customizing_the_oidc_provider_session_strategy">Customizing the OIDC Provider Session Strategy</a>
* Implementation logic follows the implementation for the default OIDC session registry, {@code InMemoryOidcSessionRegistry}.
* @see org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry
*/
@Slf4j
@Component
public class ClusteredOidcSessionRegistry implements OidcSessionRegistry {
private final OidcUserSessionRepository oidcUserSessionRepository;
public ClusteredOidcSessionRegistry(OidcUserSessionRepository oidcUserSessionRepository) {
this.oidcUserSessionRepository = oidcUserSessionRepository;
}
/**
* {@inheritDoc}
*/
@Override
public void saveSessionInformation(OidcSessionInformation info) {
var oidcUserSession = new OidcUserSession();
oidcUserSession.setSessionId(info.getSessionId());
oidcUserSession.setSessionInformation(info);
oidcUserSessionRepository.save(oidcUserSession);
}
/**
* {@inheritDoc}
*/
@Transactional
@Override
public OidcSessionInformation removeSessionInformation(String clientSessionId) {
Optional<OidcUserSession> oidcUserSession = oidcUserSessionRepository.findBySessionId(clientSessionId);
oidcUserSession.ifPresent(oidcUserSessionRepository::delete);
return oidcUserSession.map(OidcUserSession::getSessionInformation).orElse(null);
}
/**
* {@inheritDoc}
*/
@Override
public Iterable<OidcSessionInformation> removeSessionInformation(OidcLogoutToken token) {
List<String> audience = token.getAudience();
String issuer = token.getIssuer().toString();
String subject = token.getSubject();
String providerSessionId = token.getSessionId();
Predicate<OidcSessionInformation> matcher = (providerSessionId != null)
? sessionIdMatcher(audience, issuer, providerSessionId)
: subjectMatcher(audience, issuer, subject);
var allSavedSessions = oidcUserSessionRepository.findAll();
var deletedOidcSessions = deleteAndGetMatchedSessions(allSavedSessions, matcher);
if (deletedOidcSessions.isEmpty()) {
log.debug("Failed to remove any sessions since none matched");
} else {
log.trace("Found and removed {} session(s) from mapping of {} session(s)", deletedOidcSessions.size(), allSavedSessions.size());
}
return deletedOidcSessions;
}
private Set<OidcSessionInformation> deleteAndGetMatchedSessions(List<OidcUserSession> oidcUserSessions,
Predicate<OidcSessionInformation> matcher) {
Set<OidcSessionInformation> infos = new HashSet<>();
oidcUserSessions.forEach(oidcUserSession -> {
var sessionInfo = oidcUserSession.getSessionInformation();
if (matcher.test(sessionInfo)) {
oidcUserSessionRepository.delete(oidcUserSession);
infos.add(sessionInfo);
}
});
return infos;
}
private static Predicate<OidcSessionInformation> sessionIdMatcher(List<String> audience, String issuer,
String sessionId) {
log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SID, sessionId);
return session -> {
List<String> thatAudience = session.getPrincipal().getAudience();
String thatIssuer = session.getPrincipal().getIssuer().toString();
String thatSessionId = session.getPrincipal().getClaimAsString(LogoutTokenClaimNames.SID);
if (thatAudience == null) {
return false;
}
return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer)
&& sessionId.equals(thatSessionId);
};
}
private static Predicate<OidcSessionInformation> subjectMatcher(List<String> audience, String issuer,
String subject) {
log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SUB, subject);
return session -> {
List<String> thatAudience = session.getPrincipal().getAudience();
String thatIssuer = session.getPrincipal().getIssuer().toString();
String thatSubject = session.getPrincipal().getSubject();
if (thatAudience == null) {
return false;
}
return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer)
&& subject.equals(thatSubject);
};
}
}
Sample from my code to implement this:
/** * OIDC Session registry for a clustered server setup with multiple nodes, * which saves user session information in a central database. * This follows the suggestion in the Spring Security docs: * <a href="https://docs.spring.io/spring-security/reference/servlet/oauth2/login/logout.html#_customizing_the_oidc_provider_session_strategy">Customizing the OIDC Provider Session Strategy</a> * Implementation logic follows the implementation for the default OIDC session registry, {@code InMemoryOidcSessionRegistry}. * @see org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry */ @Slf4j @Component public class ClusteredOidcSessionRegistry implements OidcSessionRegistry { private final OidcUserSessionRepository oidcUserSessionRepository; public ClusteredOidcSessionRegistry(OidcUserSessionRepository oidcUserSessionRepository) { this.oidcUserSessionRepository = oidcUserSessionRepository; } /** * {@inheritDoc} */ @Override public void saveSessionInformation(OidcSessionInformation info) { var oidcUserSession = new OidcUserSession(); oidcUserSession.setSessionId(info.getSessionId()); oidcUserSession.setSessionInformation(info); oidcUserSessionRepository.save(oidcUserSession); } /** * {@inheritDoc} */ @Transactional @Override public OidcSessionInformation removeSessionInformation(String clientSessionId) { Optional<OidcUserSession> oidcUserSession = oidcUserSessionRepository.findBySessionId(clientSessionId); oidcUserSession.ifPresent(oidcUserSessionRepository::delete); return oidcUserSession.map(OidcUserSession::getSessionInformation).orElse(null); } /** * {@inheritDoc} */ @Override public Iterable<OidcSessionInformation> removeSessionInformation(OidcLogoutToken token) { List<String> audience = token.getAudience(); String issuer = token.getIssuer().toString(); String subject = token.getSubject(); String providerSessionId = token.getSessionId(); Predicate<OidcSessionInformation> matcher = (providerSessionId != null) ? sessionIdMatcher(audience, issuer, providerSessionId) : subjectMatcher(audience, issuer, subject); var allSavedSessions = oidcUserSessionRepository.findAll(); var deletedOidcSessions = deleteAndGetMatchedSessions(allSavedSessions, matcher); if (deletedOidcSessions.isEmpty()) { log.debug("Failed to remove any sessions since none matched"); } else { log.trace("Found and removed {} session(s) from mapping of {} session(s)", deletedOidcSessions.size(), allSavedSessions.size()); } return deletedOidcSessions; } private Set<OidcSessionInformation> deleteAndGetMatchedSessions(List<OidcUserSession> oidcUserSessions, Predicate<OidcSessionInformation> matcher) { Set<OidcSessionInformation> infos = new HashSet<>(); oidcUserSessions.forEach(oidcUserSession -> { var sessionInfo = oidcUserSession.getSessionInformation(); if (matcher.test(sessionInfo)) { oidcUserSessionRepository.delete(oidcUserSession); infos.add(sessionInfo); } }); return infos; } private static Predicate<OidcSessionInformation> sessionIdMatcher(List<String> audience, String issuer, String sessionId) { log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SID, sessionId); return session -> { List<String> thatAudience = session.getPrincipal().getAudience(); String thatIssuer = session.getPrincipal().getIssuer().toString(); String thatSessionId = session.getPrincipal().getClaimAsString(LogoutTokenClaimNames.SID); if (thatAudience == null) { return false; } return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer) && sessionId.equals(thatSessionId); }; } private static Predicate<OidcSessionInformation> subjectMatcher(List<String> audience, String issuer, String subject) { log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SUB, subject); return session -> { List<String> thatAudience = session.getPrincipal().getAudience(); String thatIssuer = session.getPrincipal().getIssuer().toString(); String thatSubject = session.getPrincipal().getSubject(); if (thatAudience == null) { return false; } return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer) && subject.equals(thatSubject); }; } }
@aelillie Thanks for Sharing this. Mind if I ask you to also share your implementation on the OidcUserSessionRepository? I'm having a real hard time implementing the logic to properly interact with JDBC. -- Thank you so much
@xiechangning20 , sure, here you go:
@Repository
public interface OidcUserSessionRepository extends ListCrudRepository<OidcUserSession, UUID> {
Optional<OidcUserSession> findBySessionId(String sessionId);
}
@Getter
@Setter
@Entity
@Table
public class OidcUserSession {
@Id
@GeneratedValue(strategy = GenerationType.UUID)
@Column(nullable = false, updatable = false)
private UUID id;
@Column(nullable = false)
private String sessionId;
@Column(nullable = false)
private OidcSessionInformation sessionInformation;
}