Skip to content
Snippets Groups Projects
Commit 29043042 authored by Daniel Gerhardt's avatar Daniel Gerhardt
Browse files

Merge branch 'STOMP-security' into 'master'

Add auth logic for stomp clients over ws

See merge request arsnova/arsnova-backend!123
parents 8e10a530 dbaef881
Branches
1 merge request!123Add auth logic for stomp clients over ws
Pipeline #25518 passed with warnings with stages
in 1 minute and 53 seconds
package de.thm.arsnova.config; package de.thm.arsnova.config;
import de.thm.arsnova.websocket.handler.AuthChannelInterceptorAdapter;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
...@@ -10,6 +14,16 @@ import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerCo ...@@ -10,6 +14,16 @@ import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerCo
@EnableWebSocketMessageBroker @EnableWebSocketMessageBroker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
private final AuthChannelInterceptorAdapter authChannelInterceptorAdapter;
@Value(value = "${security.cors.origins:}") private String[] corsOrigins;
@Autowired
public WebSocketConfig(AuthChannelInterceptorAdapter authChannelInterceptorAdapter) {
this.authChannelInterceptorAdapter = authChannelInterceptorAdapter;
}
@Override @Override
public void configureMessageBroker(MessageBrokerRegistry config) { public void configureMessageBroker(MessageBrokerRegistry config) {
config config
...@@ -19,7 +33,13 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { ...@@ -19,7 +33,13 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
@Override @Override
public void registerStompEndpoints(StompEndpointRegistry registry) { public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/ws").setAllowedOrigins("*").withSockJS(); registry.addEndpoint("/ws").setAllowedOrigins(corsOrigins).withSockJS();
}
@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
registration.setInterceptors(authChannelInterceptorAdapter);
} }
} }
...@@ -89,4 +89,8 @@ public interface UserService extends EntityService<UserProfile> { ...@@ -89,4 +89,8 @@ public interface UserService extends EntityService<UserProfile> {
void initiatePasswordReset(String username); void initiatePasswordReset(String username);
boolean resetPassword(UserProfile userProfile, String key, String password); boolean resetPassword(UserProfile userProfile, String key, String password);
void addWsSessionToJwtMapping(String wsSessionId, String jwt);
User getAuthenticatedUserByWsSession(String wsSessionId);
} }
...@@ -93,6 +93,9 @@ public class UserServiceImpl extends DefaultEntityServiceImpl<UserProfile> imple ...@@ -93,6 +93,9 @@ public class UserServiceImpl extends DefaultEntityServiceImpl<UserProfile> imple
private static final ConcurrentHashMap<UUID, String> socketIdToUserId = new ConcurrentHashMap<>(); private static final ConcurrentHashMap<UUID, String> socketIdToUserId = new ConcurrentHashMap<>();
/* for the new STOMP over ws functionality */
private static final ConcurrentHashMap<String, String> wsSessionIdToJwt = new ConcurrentHashMap<>();
/* used for Socket.IO online check solution (new) */ /* used for Socket.IO online check solution (new) */
private static final ConcurrentHashMap<String, String> userIdToRoomId = new ConcurrentHashMap<>(); private static final ConcurrentHashMap<String, String> userIdToRoomId = new ConcurrentHashMap<>();
...@@ -635,4 +638,16 @@ public class UserServiceImpl extends DefaultEntityServiceImpl<UserProfile> imple ...@@ -635,4 +638,16 @@ public class UserServiceImpl extends DefaultEntityServiceImpl<UserProfile> imple
public void setJwtService(final JwtService jwtService) { public void setJwtService(final JwtService jwtService) {
this.jwtService = jwtService; this.jwtService = jwtService;
} }
public void addWsSessionToJwtMapping(final String wsSessionId, final String jwt) {
wsSessionIdToJwt.put(wsSessionId, jwt);
}
public User getAuthenticatedUserByWsSession(final String wsSessionId) {
String jwt = wsSessionIdToJwt.getOrDefault(wsSessionId, null);
if (jwt == null) return null;
User u = jwtService.verifyToken(jwt);
if (u == null) return null;
return u;
}
} }
package de.thm.arsnova.websocket.handler;
import de.thm.arsnova.security.User;
import de.thm.arsnova.service.UserService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.stereotype.Component;
import org.checkerframework.checker.nullness.qual.Nullable;
import java.util.List;
@Component
public class AuthChannelInterceptorAdapter implements ChannelInterceptor {
private static final Logger logger = LoggerFactory.getLogger(AuthChannelInterceptorAdapter.class);
private final UserService service;
@Autowired
public AuthChannelInterceptorAdapter(final UserService service) {
this.service = service;
}
@Nullable
@Override
public Message<?> preSend(final Message<?> message, final MessageChannel channel) {
StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message);
String sessionId = accessor.getSessionId();
if (accessor.getCommand() != null && accessor.getCommand().equals(StompCommand.CONNECT)) {
// user needs to authorize
List<String> tokenList = accessor.getNativeHeader("token");
if (tokenList != null && tokenList.size() > 0) {
String token = tokenList.get(0);
service.addWsSessionToJwtMapping(sessionId, token);
} else {
// no token given -> auth failed
logger.debug("no auth token given, dropping connection attempt");
return null;
}
} else {
List<String> userIdList = accessor.getNativeHeader("ars-user-id");
if (userIdList != null && userIdList.size() > 0) {
// user-id is given, check for auth
String userId = userIdList.get(0);
User u = service.getAuthenticatedUserByWsSession(sessionId);
if (u == null || !userId.equals(u.getId())) {
// user isn't authorized, drop message
logger.debug("user-id not validated, dropping frame");
return null;
}
}
}
// default is to pass the frame along
return MessageBuilder.createMessage(message.getPayload(), accessor.getMessageHeaders());
}
}
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment