mirror of https://github.com/jacekkow/keycloak-protocol-cas

Alexandre Rocha Wendling
2024-05-14 5d7080a6157ca47763bf1e682a67ea4875475fad
commit | author | age
74023a 1 package org.keycloak.protocol.cas.endpoints;
EH 2
fdb9f6 3 import jakarta.ws.rs.core.Response;
74023a 4 import org.jboss.logging.Logger;
b88dc3 5 import org.keycloak.events.Details;
74023a 6 import org.keycloak.events.Errors;
EH 7 import org.keycloak.events.EventBuilder;
5d7080 8 import org.keycloak.common.util.Time;
74023a 9 import org.keycloak.models.*;
EH 10 import org.keycloak.protocol.ProtocolMapper;
11 import org.keycloak.protocol.cas.CASLoginProtocol;
12 import org.keycloak.protocol.cas.mappers.CASAttributeMapper;
13 import org.keycloak.protocol.cas.representations.CASErrorCode;
14 import org.keycloak.protocol.cas.utils.CASValidationException;
5d7080 15 import org.keycloak.protocol.oidc.utils.OAuth2Code;
74023a 16 import org.keycloak.protocol.oidc.utils.RedirectUtils;
EH 17 import org.keycloak.services.managers.AuthenticationManager;
5d7080 18 import org.keycloak.services.managers.UserSessionCrossDCManager;
74023a 19 import org.keycloak.services.util.DefaultClientSessionContext;
EH 20
21 import java.util.HashMap;
22 import java.util.Map;
23 import java.util.Set;
5d7080 24 import java.util.UUID;
ARW 25 import java.util.regex.Pattern;
ea9555 26 import java.util.stream.Collectors;
5d7080 27 import org.apache.http.client.utils.URIBuilder;
ARW 28 import org.apache.http.client.methods.HttpGet;
29 import org.apache.http.HttpResponse;
30 import org.apache.http.impl.client.HttpClientBuilder;
74023a 31
EH 32 public abstract class AbstractValidateEndpoint {
33     protected final Logger logger = Logger.getLogger(getClass());
5d7080 34     private static final Pattern DOT = Pattern.compile("\\.");
74023a 35     protected KeycloakSession session;
EH 36     protected RealmModel realm;
37     protected EventBuilder event;
38     protected ClientModel client;
39     protected AuthenticatedClientSessionModel clientSession;
5d7080 40     protected String pgtIou;
74023a 41
ceed8f 42     public AbstractValidateEndpoint(KeycloakSession session, RealmModel realm, EventBuilder event) {
JK 43         this.session = session;
74023a 44         this.realm = realm;
EH 45         this.event = event;
46     }
47
48     protected void checkSsl() {
ceed8f 49         if (!session.getContext().getUri().getBaseUri().getScheme().equals("https") && realm.getSslRequired().isRequired(session.getContext().getConnection())) {
74023a 50             throw new CASValidationException(CASErrorCode.INVALID_REQUEST, "HTTPS required", Response.Status.FORBIDDEN);
EH 51         }
52     }
53
54     protected void checkRealm() {
55         if (!realm.isEnabled()) {
56             throw new CASValidationException(CASErrorCode.INTERNAL_ERROR, "Realm not enabled", Response.Status.FORBIDDEN);
57         }
58     }
59
60     protected void checkClient(String service) {
61         if (service == null) {
62             event.error(Errors.INVALID_REQUEST);
63             throw new CASValidationException(CASErrorCode.INVALID_REQUEST, "Missing parameter: " + CASLoginProtocol.SERVICE_PARAM, Response.Status.BAD_REQUEST);
64         }
65
b88dc3 66         event.detail(Details.REDIRECT_URI, service);
AP 67
ea9555 68         client = realm.getClientsStream()
74023a 69                 .filter(c -> CASLoginProtocol.LOGIN_PROTOCOL.equals(c.getProtocol()))
019db5 70                 .filter(c -> RedirectUtils.verifyRedirectUri(session, service, c) != null)
74023a 71                 .findFirst().orElse(null);
EH 72         if (client == null) {
73             event.error(Errors.CLIENT_NOT_FOUND);
74             throw new CASValidationException(CASErrorCode.INVALID_SERVICE, "Client not found", Response.Status.BAD_REQUEST);
75         }
76
77         if (!client.isEnabled()) {
78             event.error(Errors.CLIENT_DISABLED);
79             throw new CASValidationException(CASErrorCode.INVALID_SERVICE, "Client disabled", Response.Status.BAD_REQUEST);
80         }
81
82         event.client(client.getClientId());
83
84         session.getContext().setClient(client);
85     }
86
5d7080 87     protected void checkTicket(String ticket, String prefix, boolean requireReauth) {
74023a 88         if (ticket == null) {
EH 89             event.error(Errors.INVALID_CODE);
90             throw new CASValidationException(CASErrorCode.INVALID_REQUEST, "Missing parameter: " + CASLoginProtocol.TICKET_PARAM, Response.Status.BAD_REQUEST);
91         }
5d7080 92
ARW 93         if (!ticket.startsWith(prefix)) {
74023a 94             event.error(Errors.INVALID_CODE);
EH 95             throw new CASValidationException(CASErrorCode.INVALID_TICKET_SPEC, "Malformed service ticket", Response.Status.BAD_REQUEST);
96         }
97
5d7080 98         Boolean isreuse = ticket.startsWith(CASLoginProtocol.PROXY_GRANTING_TICKET_PREFIX);
74023a 99
5d7080 100         String[] parsed = DOT.split(ticket.substring(prefix.length()), 3);
ARW 101         if (parsed.length != 3) {
74023a 102             event.error(Errors.INVALID_CODE);
5d7080 103             throw new CASValidationException(CASErrorCode.INVALID_TICKET_SPEC, "Invalid format of the code", Response.Status.BAD_REQUEST);
ARW 104         }
74023a 105
5d7080 106         String userSessionId = parsed[1];
ARW 107         String clientUUID = parsed[2];
108
109         event.detail(Details.CODE_ID, userSessionId);
110         event.session(userSessionId);
111
112         // Parse UUID
113         String codeUUID;
114         try {
115             codeUUID = parsed[0];
116         } catch (IllegalArgumentException re) {
117             event.error(Errors.INVALID_CODE);
118             throw new CASValidationException(CASErrorCode.INVALID_TICKET_SPEC, "Invalid format of the UUID in the code", Response.Status.BAD_REQUEST);
119         }
120
121         // Retrieve UserSession
122         UserSessionModel userSession = new UserSessionCrossDCManager(session).getUserSessionWithClient(realm, userSessionId, clientUUID);
123         if (userSession == null) {
124             // Needed to track if code is invalid
125             userSession = session.sessions().getUserSession(realm, userSessionId);
126             if (userSession == null) {
127                 event.error(Errors.USER_SESSION_NOT_FOUND);
128                 throw new CASValidationException(CASErrorCode.INVALID_TICKET, "User session not found", Response.Status.BAD_REQUEST);
74023a 129             }
5d7080 130         }
74023a 131
5d7080 132         clientSession = userSession.getAuthenticatedClientSessionByClient(clientUUID);
ARW 133         if (clientSession == null) {
74023a 134             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
EH 135         }
136
5d7080 137         SingleUseObjectProvider codeStore = session.singleUseObjects();
ARW 138         Map<String, String> codeDataSerialized = isreuse? codeStore.get(prefix + codeUUID) : codeStore.remove(prefix + codeUUID);
74023a 139
5d7080 140         // Either code not available
ARW 141         if (codeDataSerialized == null) {
142             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code already used", Response.Status.BAD_REQUEST);
143         }
144
145         OAuth2Code codeData = OAuth2Code.deserializeCode(codeDataSerialized);
146
147         String persistedUserSessionId = codeData.getUserSessionId();
148         if (!userSessionId.equals(persistedUserSessionId)) {
149             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code "+codeUUID+"' is bound to a different session", Response.Status.BAD_REQUEST);
150         }
151
152         // Finally doublecheck if code is not expired
153         int currentTime = Time.currentTime();
154         if (currentTime > codeData.getExpiration()) {
74023a 155             event.error(Errors.EXPIRED_CODE);
EH 156             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code is expired", Response.Status.BAD_REQUEST);
157         }
158
5d7080 159         clientSession.setNote(prefix, ticket);
74023a 160
EH 161         if (requireReauth && AuthenticationManager.isSSOAuthentication(clientSession)) {
162             event.error(Errors.SESSION_EXPIRED);
163             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Interactive authentication was requested but not performed", Response.Status.BAD_REQUEST);
164         }
165
166         UserModel user = userSession.getUser();
167         if (user == null) {
168             event.error(Errors.USER_NOT_FOUND);
169             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "User not found", Response.Status.BAD_REQUEST);
170         }
171         if (!user.isEnabled()) {
172             event.error(Errors.USER_DISABLED);
173             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "User disabled", Response.Status.BAD_REQUEST);
174         }
175
176         event.user(userSession.getUser());
177         event.session(userSession.getId());
178
5d7080 179         if (client == null) {
ARW 180             client = clientSession.getClient();
181         } else {
182             if (!client.getClientId().equals(clientSession.getClient().getClientId())) {
183                 event.error(Errors.INVALID_CODE);
184                 throw new CASValidationException(CASErrorCode.INVALID_SERVICE, "Auth error", Response.Status.BAD_REQUEST);
185             }
74023a 186         }
EH 187
188         if (!AuthenticationManager.isSessionValid(realm, userSession)) {
189             event.error(Errors.USER_SESSION_NOT_FOUND);
190             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Session not active", Response.Status.BAD_REQUEST);
5d7080 191         }
ARW 192
193     }
194
195     protected void createProxyGrant(String pgtUrl) {
196         if ( RedirectUtils.verifyRedirectUri(session, pgtUrl, client) == null ) {
197             event.error(Errors.INVALID_REQUEST);
198             throw new CASValidationException(CASErrorCode.INVALID_PROXY_CALLBACK, "Proxy callback is invalid", Response.Status.BAD_REQUEST);
199         }
200
201         String pgtIou = getPGTIOU();
202         String pgtId  = getPGT(session, clientSession, pgtUrl);
203
204         try {
205             HttpResponse response = HttpClientBuilder.create().build().execute(
206                 new HttpGet(new URIBuilder(pgtUrl).setParameter("pgtIou",pgtIou).setParameter("pgtId",pgtId).build())
207             );
208
209             if (response.getStatusLine().getStatusCode() != 200) {
210                 throw new Exception();
211             }
212
213             this.pgtIou = pgtIou;
214         } catch (Exception e) {
215             event.error(Errors.INVALID_REQUEST);
216             throw new CASValidationException(CASErrorCode.PROXY_CALLBACK_ERROR, "Proxy callback return with error", Response.Status.BAD_REQUEST);
74023a 217         }
EH 218     }
219
220     protected Map<String, Object> getUserAttributes() {
221         UserSessionModel userSession = clientSession.getUserSession();
222         // CAS protocol does not support scopes, so pass null scopeParam
8379a3 223         ClientSessionContext clientSessionCtx = DefaultClientSessionContext.fromClientSessionAndScopeParameter(clientSession, null, session);
74023a 224
ea9555 225         Set<ProtocolMapperModel> mappings = clientSessionCtx.getProtocolMappersStream().collect(Collectors.toSet());
74023a 226         KeycloakSessionFactory sessionFactory = session.getKeycloakSessionFactory();
EH 227         Map<String, Object> attributes = new HashMap<>();
228         for (ProtocolMapperModel mapping : mappings) {
229             ProtocolMapper mapper = (ProtocolMapper) sessionFactory.getProviderFactory(ProtocolMapper.class, mapping.getProtocolMapper());
230             if (mapper instanceof CASAttributeMapper) {
231                 ((CASAttributeMapper) mapper).setAttribute(attributes, mapping, userSession, session, clientSessionCtx);
232             }
233         }
234         return attributes;
235     }
5d7080 236
ARW 237     protected String getPGTIOU()
238     {
239         return CASLoginProtocol.PROXY_GRANTING_TICKET_IOU_PREFIX + UUID.randomUUID().toString();
240     }
241
242     protected String getPGT(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String pgtUrl)
243     {
244         return persistedTicket(pgtUrl, CASLoginProtocol.PROXY_GRANTING_TICKET_PREFIX);
245     }
246
247     protected String getPT(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String targetService)
248     {
249         return persistedTicket(targetService, CASLoginProtocol.PROXY_TICKET_PREFIX);
250     }
251
252     protected String getST(String redirectUri)
253     {
254         return persistedTicket(redirectUri, CASLoginProtocol.SERVICE_TICKET_PREFIX);
255     }
256
257     public static String getST(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String redirectUri)
258     {
259         ValidateEndpoint vp = new ValidateEndpoint(session,null,null);
260         vp.clientSession = clientSession;
261         return vp.getST(redirectUri);
262     }
263
264     protected String persistedTicket(String redirectUriParam, String prefix)
265     {
266         String key = UUID.randomUUID().toString();
267         UserSessionModel userSession = clientSession.getUserSession();
268         OAuth2Code codeData = new OAuth2Code(key, Time.currentTime() + userSession.getRealm().getAccessCodeLifespan(), null, null, redirectUriParam, null, null, userSession.getId());
269         session.singleUseObjects().put(prefix + key, clientSession.getUserSession().getRealm().getAccessCodeLifespan(), codeData.serializeCode());
270         return prefix + key + "." + clientSession.getUserSession().getId() + "." + clientSession.getClient().getId();
271     }
74023a 272 }