diff --git a/src/main/java/de/thm/arsnova/config/WebSocketConfig.java b/src/main/java/de/thm/arsnova/config/WebSocketConfig.java index 10a44514551d48afab041ad67bbdd397b7bf4ff3..edf6cf3a3a5f1b3fbaaa969a7b8eee96fe0e2ef5 100644 --- a/src/main/java/de/thm/arsnova/config/WebSocketConfig.java +++ b/src/main/java/de/thm/arsnova/config/WebSocketConfig.java @@ -1,6 +1,10 @@ package de.thm.arsnova.config; +import de.thm.arsnova.websocket.handler.AuthChannelInterceptorAdapter; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; @@ -10,6 +14,16 @@ import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerCo @EnableWebSocketMessageBroker public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { + private final AuthChannelInterceptorAdapter authChannelInterceptorAdapter; + @Value(value = "${security.cors.origins:}") private String[] corsOrigins; + + @Autowired + public WebSocketConfig(AuthChannelInterceptorAdapter authChannelInterceptorAdapter) { + this.authChannelInterceptorAdapter = authChannelInterceptorAdapter; + } + + + @Override public void configureMessageBroker(MessageBrokerRegistry config) { config @@ -19,7 +33,13 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { @Override public void registerStompEndpoints(StompEndpointRegistry registry) { - registry.addEndpoint("/ws").setAllowedOrigins("*").withSockJS(); + registry.addEndpoint("/ws").setAllowedOrigins(corsOrigins).withSockJS(); + } + + + @Override + public void configureClientInboundChannel(ChannelRegistration registration) { + registration.setInterceptors(authChannelInterceptorAdapter); } } diff --git a/src/main/java/de/thm/arsnova/service/UserService.java b/src/main/java/de/thm/arsnova/service/UserService.java index 331afcf2c6675a53808a1f1f11d08504af0c4997..b483ad088a80593a15aa5f7fa1080488ceeb2bd2 100644 --- a/src/main/java/de/thm/arsnova/service/UserService.java +++ b/src/main/java/de/thm/arsnova/service/UserService.java @@ -89,4 +89,8 @@ public interface UserService extends EntityService<UserProfile> { void initiatePasswordReset(String username); boolean resetPassword(UserProfile userProfile, String key, String password); + + void addWsSessionToJwtMapping(String wsSessionId, String jwt); + + User getAuthenticatedUserByWsSession(String wsSessionId); } diff --git a/src/main/java/de/thm/arsnova/service/UserServiceImpl.java b/src/main/java/de/thm/arsnova/service/UserServiceImpl.java index 8cb27b005b83ade89f57fcb062c9136c5e595ff3..8b51132d7398dd77dd84ed3a118e8b1ebd6cec8a 100644 --- a/src/main/java/de/thm/arsnova/service/UserServiceImpl.java +++ b/src/main/java/de/thm/arsnova/service/UserServiceImpl.java @@ -93,6 +93,9 @@ public class UserServiceImpl extends DefaultEntityServiceImpl<UserProfile> imple private static final ConcurrentHashMap<UUID, String> socketIdToUserId = new ConcurrentHashMap<>(); + /* for the new STOMP over ws functionality */ + private static final ConcurrentHashMap<String, String> wsSessionIdToJwt = new ConcurrentHashMap<>(); + /* used for Socket.IO online check solution (new) */ private static final ConcurrentHashMap<String, String> userIdToRoomId = new ConcurrentHashMap<>(); @@ -635,4 +638,16 @@ public class UserServiceImpl extends DefaultEntityServiceImpl<UserProfile> imple public void setJwtService(final JwtService jwtService) { this.jwtService = jwtService; } + + public void addWsSessionToJwtMapping(final String wsSessionId, final String jwt) { + wsSessionIdToJwt.put(wsSessionId, jwt); + } + + public User getAuthenticatedUserByWsSession(final String wsSessionId) { + String jwt = wsSessionIdToJwt.getOrDefault(wsSessionId, null); + if (jwt == null) return null; + User u = jwtService.verifyToken(jwt); + if (u == null) return null; + return u; + } } diff --git a/src/main/java/de/thm/arsnova/websocket/handler/AuthChannelInterceptorAdapter.java b/src/main/java/de/thm/arsnova/websocket/handler/AuthChannelInterceptorAdapter.java new file mode 100644 index 0000000000000000000000000000000000000000..3e06953afa5e7b6c95dd42f62526441d36c6aff1 --- /dev/null +++ b/src/main/java/de/thm/arsnova/websocket/handler/AuthChannelInterceptorAdapter.java @@ -0,0 +1,65 @@ +package de.thm.arsnova.websocket.handler; + +import de.thm.arsnova.security.User; +import de.thm.arsnova.service.UserService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.stereotype.Component; + +import org.checkerframework.checker.nullness.qual.Nullable; +import java.util.List; + +@Component +public class AuthChannelInterceptorAdapter implements ChannelInterceptor { + + private static final Logger logger = LoggerFactory.getLogger(AuthChannelInterceptorAdapter.class); + + private final UserService service; + + @Autowired + public AuthChannelInterceptorAdapter(final UserService service) { + this.service = service; + } + + @Nullable + @Override + public Message<?> preSend(final Message<?> message, final MessageChannel channel) { + StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message); + + String sessionId = accessor.getSessionId(); + if (accessor.getCommand() != null && accessor.getCommand().equals(StompCommand.CONNECT)) { + // user needs to authorize + List<String> tokenList = accessor.getNativeHeader("token"); + if (tokenList != null && tokenList.size() > 0) { + String token = tokenList.get(0); + service.addWsSessionToJwtMapping(sessionId, token); + } else { + // no token given -> auth failed + logger.debug("no auth token given, dropping connection attempt"); + return null; + } + } else { + List<String> userIdList = accessor.getNativeHeader("ars-user-id"); + if (userIdList != null && userIdList.size() > 0) { + // user-id is given, check for auth + String userId = userIdList.get(0); + User u = service.getAuthenticatedUserByWsSession(sessionId); + if (u == null || !userId.equals(u.getId())) { + // user isn't authorized, drop message + logger.debug("user-id not validated, dropping frame"); + return null; + } + } + } + + // default is to pass the frame along + return MessageBuilder.createMessage(message.getPayload(), accessor.getMessageHeaders()); + } +}