mirror of https://github.com/jacekkow/keycloak-protocol-cas

Jakub Malinowski
8 days ago 32997b7c31fc3b27a8df6911e0f8e8e1bcc58437
commit | author | age
74023a 1 package org.keycloak.protocol.cas.endpoints;
EH 2
fdb9f6 3 import jakarta.ws.rs.core.Response;
74023a 4 import org.jboss.logging.Logger;
b88dc3 5 import org.keycloak.events.Details;
74023a 6 import org.keycloak.events.Errors;
EH 7 import org.keycloak.events.EventBuilder;
755fd7 8 import org.keycloak.common.util.Time;
74023a 9 import org.keycloak.models.*;
EH 10 import org.keycloak.protocol.ProtocolMapper;
11 import org.keycloak.protocol.cas.CASLoginProtocol;
12 import org.keycloak.protocol.cas.mappers.CASAttributeMapper;
13 import org.keycloak.protocol.cas.representations.CASErrorCode;
14 import org.keycloak.protocol.cas.utils.CASValidationException;
755fd7 15 import org.keycloak.protocol.oidc.utils.OAuth2Code;
74023a 16 import org.keycloak.protocol.oidc.utils.RedirectUtils;
EH 17 import org.keycloak.services.managers.AuthenticationManager;
755fd7 18 import org.keycloak.services.managers.UserSessionCrossDCManager;
74023a 19 import org.keycloak.services.util.DefaultClientSessionContext;
EH 20
21 import java.util.HashMap;
22 import java.util.Map;
23 import java.util.Set;
755fd7 24 import java.util.UUID;
ARW 25 import java.util.regex.Pattern;
ea9555 26 import java.util.stream.Collectors;
755fd7 27 import org.apache.http.client.utils.URIBuilder;
ARW 28 import org.apache.http.client.methods.HttpGet;
29 import org.apache.http.HttpResponse;
30 import org.apache.http.impl.client.HttpClientBuilder;
74023a 31
EH 32 public abstract class AbstractValidateEndpoint {
33     protected final Logger logger = Logger.getLogger(getClass());
755fd7 34     private static final Pattern DOT = Pattern.compile("\\.");
74023a 35     protected KeycloakSession session;
EH 36     protected RealmModel realm;
37     protected EventBuilder event;
38     protected ClientModel client;
39     protected AuthenticatedClientSessionModel clientSession;
755fd7 40     protected String pgtIou;
74023a 41
ceed8f 42     public AbstractValidateEndpoint(KeycloakSession session, RealmModel realm, EventBuilder event) {
JK 43         this.session = session;
74023a 44         this.realm = realm;
EH 45         this.event = event;
46     }
47
48     protected void checkSsl() {
ceed8f 49         if (!session.getContext().getUri().getBaseUri().getScheme().equals("https") && realm.getSslRequired().isRequired(session.getContext().getConnection())) {
74023a 50             throw new CASValidationException(CASErrorCode.INVALID_REQUEST, "HTTPS required", Response.Status.FORBIDDEN);
EH 51         }
52     }
53
54     protected void checkRealm() {
55         if (!realm.isEnabled()) {
56             throw new CASValidationException(CASErrorCode.INTERNAL_ERROR, "Realm not enabled", Response.Status.FORBIDDEN);
57         }
58     }
59
60     protected void checkClient(String service) {
61         if (service == null) {
62             event.error(Errors.INVALID_REQUEST);
63             throw new CASValidationException(CASErrorCode.INVALID_REQUEST, "Missing parameter: " + CASLoginProtocol.SERVICE_PARAM, Response.Status.BAD_REQUEST);
64         }
65
b88dc3 66         event.detail(Details.REDIRECT_URI, service);
AP 67
ea9555 68         client = realm.getClientsStream()
74023a 69                 .filter(c -> CASLoginProtocol.LOGIN_PROTOCOL.equals(c.getProtocol()))
019db5 70                 .filter(c -> RedirectUtils.verifyRedirectUri(session, service, c) != null)
74023a 71                 .findFirst().orElse(null);
EH 72         if (client == null) {
73             event.error(Errors.CLIENT_NOT_FOUND);
74             throw new CASValidationException(CASErrorCode.INVALID_SERVICE, "Client not found", Response.Status.BAD_REQUEST);
75         }
76
77         if (!client.isEnabled()) {
78             event.error(Errors.CLIENT_DISABLED);
79             throw new CASValidationException(CASErrorCode.INVALID_SERVICE, "Client disabled", Response.Status.BAD_REQUEST);
80         }
81
82         event.client(client.getClientId());
83
84         session.getContext().setClient(client);
85     }
86
755fd7 87     protected void checkTicket(String ticket, String prefix, boolean requireReauth) {
74023a 88         if (ticket == null) {
EH 89             event.error(Errors.INVALID_CODE);
90             throw new CASValidationException(CASErrorCode.INVALID_REQUEST, "Missing parameter: " + CASLoginProtocol.TICKET_PARAM, Response.Status.BAD_REQUEST);
91         }
755fd7 92
ARW 93         if (!ticket.startsWith(prefix)) {
74023a 94             event.error(Errors.INVALID_CODE);
EH 95             throw new CASValidationException(CASErrorCode.INVALID_TICKET_SPEC, "Malformed service ticket", Response.Status.BAD_REQUEST);
96         }
97
755fd7 98         boolean isReusable = ticket.startsWith(CASLoginProtocol.PROXY_GRANTING_TICKET_PREFIX);
74023a 99
755fd7 100         String[] parsed = DOT.split(ticket.substring(prefix.length()), 3);
ARW 101         if (parsed.length != 3) {
74023a 102             event.error(Errors.INVALID_CODE);
755fd7 103             throw new CASValidationException(CASErrorCode.INVALID_TICKET_SPEC, "Invalid format of the code", Response.Status.BAD_REQUEST);
ARW 104         }
74023a 105
755fd7 106         String codeUUID = parsed[0];
ARW 107         String userSessionId = parsed[1];
108         String clientUUID = parsed[2];
109
110         event.detail(Details.CODE_ID, userSessionId);
111         event.session(userSessionId);
112
113         // Retrieve UserSession
114         UserSessionModel userSession = new UserSessionCrossDCManager(session).getUserSessionWithClient(realm, userSessionId, clientUUID);
115         if (userSession == null) {
116             // Needed to track if code is invalid
117             userSession = session.sessions().getUserSession(realm, userSessionId);
118             if (userSession == null) {
119                 event.error(Errors.USER_SESSION_NOT_FOUND);
120                 throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
74023a 121             }
755fd7 122         }
74023a 123
755fd7 124         clientSession = userSession.getAuthenticatedClientSessionByClient(clientUUID);
ARW 125         if (clientSession == null) {
126             event.error(Errors.INVALID_CODE);
74023a 127             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
EH 128         }
129
755fd7 130         SingleUseObjectProvider codeStore = session.singleUseObjects();
ARW 131         Map<String, String> codeDataSerialized = isReusable ? codeStore.get(prefix + codeUUID) : codeStore.remove(prefix + codeUUID);
74023a 132
755fd7 133         // Either code not available
ARW 134         if (codeDataSerialized == null) {
135             event.error(Errors.INVALID_CODE);
136             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
137         }
138
139         OAuth2Code codeData = OAuth2Code.deserializeCode(codeDataSerialized);
140
141         String persistedUserSessionId = codeData.getUserSessionId();
142         if (!userSessionId.equals(persistedUserSessionId)) {
143             event.error(Errors.INVALID_CODE);
144             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
145         }
146
147         // Finally doublecheck if code is not expired
148         int currentTime = Time.currentTime();
149         if (currentTime > codeData.getExpiration()) {
74023a 150             event.error(Errors.EXPIRED_CODE);
EH 151             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code is expired", Response.Status.BAD_REQUEST);
152         }
153
32997b 154         clientSession.setNote(CASLoginProtocol.SESSION_TICKET, ticket);
74023a 155
EH 156         if (requireReauth && AuthenticationManager.isSSOAuthentication(clientSession)) {
157             event.error(Errors.SESSION_EXPIRED);
158             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Interactive authentication was requested but not performed", Response.Status.BAD_REQUEST);
159         }
160
161         UserModel user = userSession.getUser();
162         if (user == null) {
163             event.error(Errors.USER_NOT_FOUND);
164             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "User not found", Response.Status.BAD_REQUEST);
165         }
166         if (!user.isEnabled()) {
167             event.error(Errors.USER_DISABLED);
168             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "User disabled", Response.Status.BAD_REQUEST);
169         }
170
171         event.user(userSession.getUser());
172         event.session(userSession.getId());
173
755fd7 174         if (client == null) {
ARW 175             client = clientSession.getClient();
176         } else {
177             if (!client.getClientId().equals(clientSession.getClient().getClientId())) {
178                 event.error(Errors.INVALID_CODE);
179                 throw new CASValidationException(CASErrorCode.INVALID_SERVICE, "Invalid service", Response.Status.BAD_REQUEST);
180             }
74023a 181         }
EH 182
183         if (!AuthenticationManager.isSessionValid(realm, userSession)) {
184             event.error(Errors.USER_SESSION_NOT_FOUND);
185             throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Session not active", Response.Status.BAD_REQUEST);
755fd7 186         }
ARW 187
188     }
189
190     protected void createProxyGrant(String pgtUrl) {
191         if ( RedirectUtils.verifyRedirectUri(session, pgtUrl, client) == null ) {
192             event.error(Errors.INVALID_REQUEST);
193             throw new CASValidationException(CASErrorCode.INVALID_PROXY_CALLBACK, "Proxy callback is invalid", Response.Status.BAD_REQUEST);
194         }
195
196         String pgtIou = getPGTIOU();
197         String pgtId  = getPGT(session, clientSession, pgtUrl);
198
199         try {
200             HttpResponse response = HttpClientBuilder.create().build().execute(
201                 new HttpGet(new URIBuilder(pgtUrl).setParameter("pgtIou",pgtIou).setParameter("pgtId",pgtId).build())
202             );
203
204             if (response.getStatusLine().getStatusCode() != 200) {
205                 throw new Exception();
206             }
207
208             this.pgtIou = pgtIou;
209         } catch (Exception e) {
210             event.error(Errors.INVALID_REQUEST);
211             throw new CASValidationException(CASErrorCode.PROXY_CALLBACK_ERROR, "Proxy callback returned an error", Response.Status.BAD_REQUEST);
74023a 212         }
EH 213     }
214
215     protected Map<String, Object> getUserAttributes() {
216         UserSessionModel userSession = clientSession.getUserSession();
217         // CAS protocol does not support scopes, so pass null scopeParam
8379a3 218         ClientSessionContext clientSessionCtx = DefaultClientSessionContext.fromClientSessionAndScopeParameter(clientSession, null, session);
74023a 219
ea9555 220         Set<ProtocolMapperModel> mappings = clientSessionCtx.getProtocolMappersStream().collect(Collectors.toSet());
74023a 221         KeycloakSessionFactory sessionFactory = session.getKeycloakSessionFactory();
EH 222         Map<String, Object> attributes = new HashMap<>();
223         for (ProtocolMapperModel mapping : mappings) {
224             ProtocolMapper mapper = (ProtocolMapper) sessionFactory.getProviderFactory(ProtocolMapper.class, mapping.getProtocolMapper());
225             if (mapper instanceof CASAttributeMapper) {
226                 ((CASAttributeMapper) mapper).setAttribute(attributes, mapping, userSession, session, clientSessionCtx);
227             }
228         }
229         return attributes;
230     }
755fd7 231
ARW 232     protected String getPGTIOU()
233     {
234         return CASLoginProtocol.PROXY_GRANTING_TICKET_IOU_PREFIX + UUID.randomUUID().toString();
235     }
236
237     protected String getPGT(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String pgtUrl)
238     {
239         return persistedTicket(pgtUrl, CASLoginProtocol.PROXY_GRANTING_TICKET_PREFIX);
240     }
241
242     protected String getPT(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String targetService)
243     {
244         return persistedTicket(targetService, CASLoginProtocol.PROXY_TICKET_PREFIX);
245     }
246
247     protected String getST(String redirectUri)
248     {
249         return persistedTicket(redirectUri, CASLoginProtocol.SERVICE_TICKET_PREFIX);
250     }
251
252     public static String getST(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String redirectUri)
253     {
254         ValidateEndpoint vp = new ValidateEndpoint(session,null,null);
255         vp.clientSession = clientSession;
256         return vp.getST(redirectUri);
257     }
258
259     protected String persistedTicket(String redirectUriParam, String prefix)
260     {
261         String key = UUID.randomUUID().toString();
262         UserSessionModel userSession = clientSession.getUserSession();
263         OAuth2Code codeData = new OAuth2Code(key, Time.currentTime() + userSession.getRealm().getAccessCodeLifespan(), null, null, redirectUriParam, null, null, userSession.getId());
264         session.singleUseObjects().put(prefix + key, clientSession.getUserSession().getRealm().getAccessCodeLifespan(), codeData.serializeCode());
265         return prefix + key + "." + clientSession.getUserSession().getId() + "." + clientSession.getClient().getId();
266     }
74023a 267 }