From 755fd78fa0ee0f2a67417a119382c63e02c1091e Mon Sep 17 00:00:00 2001
From: Alexandre Rocha Wendling <alexandrerw@celepar.pr.gov.br>
Date: Tue, 16 Jul 2024 14:15:23 +0000
Subject: [PATCH] Proxy ticket service and proxy ticket validation Proxy endpoints improvements suggested by Jacek Kowalski Add ticket type to storage key Rename isreuse to isReusable Remove "parsing" of "codeUUID" that is String, not UUID Improve error reporting in CAS ticket validation

---
 src/main/java/org/keycloak/protocol/cas/endpoints/AbstractValidateEndpoint.java |  168 +++++++++++++++++++++++++++++++++++++++++++------------
 1 files changed, 131 insertions(+), 37 deletions(-)

diff --git a/src/main/java/org/keycloak/protocol/cas/endpoints/AbstractValidateEndpoint.java b/src/main/java/org/keycloak/protocol/cas/endpoints/AbstractValidateEndpoint.java
index e60b4e7..e166bb0 100644
--- a/src/main/java/org/keycloak/protocol/cas/endpoints/AbstractValidateEndpoint.java
+++ b/src/main/java/org/keycloak/protocol/cas/endpoints/AbstractValidateEndpoint.java
@@ -1,52 +1,52 @@
 package org.keycloak.protocol.cas.endpoints;
 
+import jakarta.ws.rs.core.Response;
 import org.jboss.logging.Logger;
-import org.jboss.resteasy.spi.HttpRequest;
-import org.keycloak.common.ClientConnection;
 import org.keycloak.events.Details;
 import org.keycloak.events.Errors;
 import org.keycloak.events.EventBuilder;
+import org.keycloak.common.util.Time;
 import org.keycloak.models.*;
 import org.keycloak.protocol.ProtocolMapper;
 import org.keycloak.protocol.cas.CASLoginProtocol;
 import org.keycloak.protocol.cas.mappers.CASAttributeMapper;
 import org.keycloak.protocol.cas.representations.CASErrorCode;
 import org.keycloak.protocol.cas.utils.CASValidationException;
-import org.keycloak.protocol.oidc.utils.OAuth2CodeParser;
+import org.keycloak.protocol.oidc.utils.OAuth2Code;
 import org.keycloak.protocol.oidc.utils.RedirectUtils;
 import org.keycloak.services.managers.AuthenticationManager;
+import org.keycloak.services.managers.UserSessionCrossDCManager;
 import org.keycloak.services.util.DefaultClientSessionContext;
 
-import javax.ws.rs.core.Context;
-import javax.ws.rs.core.HttpHeaders;
-import javax.ws.rs.core.Response;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
+import java.util.UUID;
+import java.util.regex.Pattern;
 import java.util.stream.Collectors;
+import org.apache.http.client.utils.URIBuilder;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.HttpResponse;
+import org.apache.http.impl.client.HttpClientBuilder;
 
 public abstract class AbstractValidateEndpoint {
     protected final Logger logger = Logger.getLogger(getClass());
-    @Context
+    private static final Pattern DOT = Pattern.compile("\\.");
     protected KeycloakSession session;
-    @Context
-    protected ClientConnection clientConnection;
-    @Context
-    protected HttpRequest request;
-    @Context
-    protected HttpHeaders headers;
     protected RealmModel realm;
     protected EventBuilder event;
     protected ClientModel client;
     protected AuthenticatedClientSessionModel clientSession;
+    protected String pgtIou;
 
-    public AbstractValidateEndpoint(RealmModel realm, EventBuilder event) {
+    public AbstractValidateEndpoint(KeycloakSession session, RealmModel realm, EventBuilder event) {
+        this.session = session;
         this.realm = realm;
         this.event = event;
     }
 
     protected void checkSsl() {
-        if (!session.getContext().getUri().getBaseUri().getScheme().equals("https") && realm.getSslRequired().isRequired(clientConnection)) {
+        if (!session.getContext().getUri().getBaseUri().getScheme().equals("https") && realm.getSslRequired().isRequired(session.getContext().getConnection())) {
             throw new CASValidationException(CASErrorCode.INVALID_REQUEST, "HTTPS required", Response.Status.FORBIDDEN);
         }
     }
@@ -84,50 +84,78 @@
         session.getContext().setClient(client);
     }
 
-    protected void checkTicket(String ticket, boolean requireReauth) {
+    protected void checkTicket(String ticket, String prefix, boolean requireReauth) {
         if (ticket == null) {
             event.error(Errors.INVALID_CODE);
             throw new CASValidationException(CASErrorCode.INVALID_REQUEST, "Missing parameter: " + CASLoginProtocol.TICKET_PARAM, Response.Status.BAD_REQUEST);
         }
-        if (!ticket.startsWith(CASLoginProtocol.SERVICE_TICKET_PREFIX)) {
+
+        if (!ticket.startsWith(prefix)) {
             event.error(Errors.INVALID_CODE);
             throw new CASValidationException(CASErrorCode.INVALID_TICKET_SPEC, "Malformed service ticket", Response.Status.BAD_REQUEST);
         }
 
-        String code = ticket.substring(CASLoginProtocol.SERVICE_TICKET_PREFIX.length());
+        boolean isReusable = ticket.startsWith(CASLoginProtocol.PROXY_GRANTING_TICKET_PREFIX);
 
-        OAuth2CodeParser.ParseResult parseResult = OAuth2CodeParser.parseCode(session, code, realm, event);
-        if (parseResult.isIllegalCode()) {
+        String[] parsed = DOT.split(ticket.substring(prefix.length()), 3);
+        if (parsed.length != 3) {
             event.error(Errors.INVALID_CODE);
+            throw new CASValidationException(CASErrorCode.INVALID_TICKET_SPEC, "Invalid format of the code", Response.Status.BAD_REQUEST);
+        }
 
-            // Attempt to use same code twice should invalidate existing clientSession
-            AuthenticatedClientSessionModel clientSession = parseResult.getClientSession();
-            if (clientSession != null) {
-                clientSession.detachFromUserSession();
+        String codeUUID = parsed[0];
+        String userSessionId = parsed[1];
+        String clientUUID = parsed[2];
+
+        event.detail(Details.CODE_ID, userSessionId);
+        event.session(userSessionId);
+
+        // Retrieve UserSession
+        UserSessionModel userSession = new UserSessionCrossDCManager(session).getUserSessionWithClient(realm, userSessionId, clientUUID);
+        if (userSession == null) {
+            // Needed to track if code is invalid
+            userSession = session.sessions().getUserSession(realm, userSessionId);
+            if (userSession == null) {
+                event.error(Errors.USER_SESSION_NOT_FOUND);
+                throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
             }
+        }
 
+        clientSession = userSession.getAuthenticatedClientSessionByClient(clientUUID);
+        if (clientSession == null) {
+            event.error(Errors.INVALID_CODE);
             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
         }
 
-        clientSession = parseResult.getClientSession();
+        SingleUseObjectProvider codeStore = session.singleUseObjects();
+        Map<String, String> codeDataSerialized = isReusable ? codeStore.get(prefix + codeUUID) : codeStore.remove(prefix + codeUUID);
 
-        if (parseResult.isExpiredCode()) {
+        // Either code not available
+        if (codeDataSerialized == null) {
+            event.error(Errors.INVALID_CODE);
+            throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
+        }
+
+        OAuth2Code codeData = OAuth2Code.deserializeCode(codeDataSerialized);
+
+        String persistedUserSessionId = codeData.getUserSessionId();
+        if (!userSessionId.equals(persistedUserSessionId)) {
+            event.error(Errors.INVALID_CODE);
+            throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
+        }
+
+        // Finally doublecheck if code is not expired
+        int currentTime = Time.currentTime();
+        if (currentTime > codeData.getExpiration()) {
             event.error(Errors.EXPIRED_CODE);
             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code is expired", Response.Status.BAD_REQUEST);
         }
 
-        clientSession.setNote(CASLoginProtocol.SESSION_SERVICE_TICKET, ticket);
+        clientSession.setNote(prefix, ticket);
 
         if (requireReauth && AuthenticationManager.isSSOAuthentication(clientSession)) {
             event.error(Errors.SESSION_EXPIRED);
             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Interactive authentication was requested but not performed", Response.Status.BAD_REQUEST);
-        }
-
-        UserSessionModel userSession = clientSession.getUserSession();
-
-        if (userSession == null) {
-            event.error(Errors.USER_SESSION_NOT_FOUND);
-            throw new CASValidationException(CASErrorCode.INVALID_TICKET, "User session not found", Response.Status.BAD_REQUEST);
         }
 
         UserModel user = userSession.getUser();
@@ -143,14 +171,44 @@
         event.user(userSession.getUser());
         event.session(userSession.getId());
 
-        if (!client.getClientId().equals(clientSession.getClient().getClientId())) {
-            event.error(Errors.INVALID_CODE);
-            throw new CASValidationException(CASErrorCode.INVALID_SERVICE, "Auth error", Response.Status.BAD_REQUEST);
+        if (client == null) {
+            client = clientSession.getClient();
+        } else {
+            if (!client.getClientId().equals(clientSession.getClient().getClientId())) {
+                event.error(Errors.INVALID_CODE);
+                throw new CASValidationException(CASErrorCode.INVALID_SERVICE, "Invalid service", Response.Status.BAD_REQUEST);
+            }
         }
 
         if (!AuthenticationManager.isSessionValid(realm, userSession)) {
             event.error(Errors.USER_SESSION_NOT_FOUND);
             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Session not active", Response.Status.BAD_REQUEST);
+        }
+
+    }
+
+    protected void createProxyGrant(String pgtUrl) {
+        if ( RedirectUtils.verifyRedirectUri(session, pgtUrl, client) == null ) {
+            event.error(Errors.INVALID_REQUEST);
+            throw new CASValidationException(CASErrorCode.INVALID_PROXY_CALLBACK, "Proxy callback is invalid", Response.Status.BAD_REQUEST);
+        }
+
+        String pgtIou = getPGTIOU();
+        String pgtId  = getPGT(session, clientSession, pgtUrl);
+
+        try {
+            HttpResponse response = HttpClientBuilder.create().build().execute(
+                new HttpGet(new URIBuilder(pgtUrl).setParameter("pgtIou",pgtIou).setParameter("pgtId",pgtId).build())
+            );
+
+            if (response.getStatusLine().getStatusCode() != 200) {
+                throw new Exception();
+            }
+
+            this.pgtIou = pgtIou;
+        } catch (Exception e) {
+            event.error(Errors.INVALID_REQUEST);
+            throw new CASValidationException(CASErrorCode.PROXY_CALLBACK_ERROR, "Proxy callback returned an error", Response.Status.BAD_REQUEST);
         }
     }
 
@@ -170,4 +228,40 @@
         }
         return attributes;
     }
+
+    protected String getPGTIOU()
+    {
+        return CASLoginProtocol.PROXY_GRANTING_TICKET_IOU_PREFIX + UUID.randomUUID().toString();
+    }
+
+    protected String getPGT(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String pgtUrl)
+    {
+        return persistedTicket(pgtUrl, CASLoginProtocol.PROXY_GRANTING_TICKET_PREFIX);
+    }
+
+    protected String getPT(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String targetService)
+    {
+        return persistedTicket(targetService, CASLoginProtocol.PROXY_TICKET_PREFIX);
+    }
+
+    protected String getST(String redirectUri)
+    {
+        return persistedTicket(redirectUri, CASLoginProtocol.SERVICE_TICKET_PREFIX);
+    }
+
+    public static String getST(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String redirectUri)
+    {
+        ValidateEndpoint vp = new ValidateEndpoint(session,null,null);
+        vp.clientSession = clientSession;
+        return vp.getST(redirectUri);
+    }
+
+    protected String persistedTicket(String redirectUriParam, String prefix)
+    {
+        String key = UUID.randomUUID().toString();
+        UserSessionModel userSession = clientSession.getUserSession();
+        OAuth2Code codeData = new OAuth2Code(key, Time.currentTime() + userSession.getRealm().getAccessCodeLifespan(), null, null, redirectUriParam, null, null, userSession.getId());
+        session.singleUseObjects().put(prefix + key, clientSession.getUserSession().getRealm().getAccessCodeLifespan(), codeData.serializeCode());
+        return prefix + key + "." + clientSession.getUserSession().getId() + "." + clientSession.getClient().getId();
+    }
 }

--
Gitblit v1.9.1