diff --git a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java index 9274b989dd..4461e50bd3 100644 --- a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java +++ b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java @@ -59,6 +59,7 @@ import org.thingsboard.server.service.ws.WsCommandsWrapper; import org.thingsboard.server.service.ws.notification.cmd.NotificationCmdsWrapper; import org.thingsboard.server.service.ws.telemetry.cmd.TelemetryCmdsWrapper; +import javax.annotation.PostConstruct; import javax.websocket.RemoteEndpoint; import javax.websocket.SendHandler; import javax.websocket.SendResult; @@ -104,6 +105,8 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke private long pingTimeout; @Value("${server.ws.max_queue_messages_per_session:1000}") private int wsMaxQueueMessagesPerSession; + @Value("${server.ws.auth_timeout_ms:10000}") + private int authTimeoutMs; private final ConcurrentMap blacklistedSessions = new ConcurrentHashMap<>(); @@ -112,18 +115,23 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke private final ConcurrentMap> regularUserSessionsMap = new ConcurrentHashMap<>(); private final ConcurrentMap> publicUserSessionsMap = new ConcurrentHashMap<>(); - private final Cache pendingSessions = Caffeine.newBuilder() - .expireAfterWrite(10, TimeUnit.SECONDS) - .removalListener((sessionId, sessionMd, removalCause) -> { - if (removalCause == RemovalCause.EXPIRED && sessionMd != null) { - try { - close(sessionMd.sessionRef, CloseStatus.POLICY_VIOLATION); - } catch (IOException e) { - log.warn("IO error", e); + private Cache pendingSessions; + + @PostConstruct + private void init() { + pendingSessions = Caffeine.newBuilder() + .expireAfterWrite(authTimeoutMs, TimeUnit.MILLISECONDS) + .removalListener((sessionId, sessionMd, removalCause) -> { + if (removalCause == RemovalCause.EXPIRED && sessionMd != null) { + try { + close(sessionMd.sessionRef, CloseStatus.POLICY_VIOLATION); + } catch (IOException e) { + log.warn("IO error", e); + } } - } - }) - .build(); + }) + .build(); + } @Override public void handleTextMessage(WebSocketSession session, TextMessage message) { @@ -134,8 +142,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke session.close(CloseStatus.SERVER_ERROR.withReason("Session not found!")); return; } - String msg = message.getPayload(); - sessionMd.onMsg(msg); + sessionMd.onMsg(message.getPayload()); } catch (IOException e) { log.warn("IO error", e); } @@ -159,7 +166,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke return; } } catch (Exception e) { - log.warn("Failed to decode subscription cmd: {}", e.getMessage(), e); + log.debug("{} Failed to decode subscription cmd: {}", sessionRef.toString(), e.getMessage(), e); if (sessionRef.getSecurityCtx() != null) { webSocketService.sendError(sessionRef, 1, SubscriptionErrorCode.BAD_REQUEST, "Failed to parse the payload"); } else { @@ -169,7 +176,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } if (sessionRef.getSecurityCtx() != null) { - log.trace("[{}][{}] Processing {}", sessionRef.getSecurityCtx().getTenantId(), sessionMd.session.getId(), msg); + log.trace("{} Processing {}", sessionRef.toString(), msg); webSocketService.handleCommands(sessionRef, cmdsWrapper); } else { AuthCmd authCmd = cmdsWrapper.getAuthCmd(); @@ -177,7 +184,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Auth cmd is missing")); return; } - log.trace("[{}] Authenticating session", sessionMd.session.getId()); + log.trace("{} Authenticating session", sessionRef.toString()); SecurityUser securityCtx; try { securityCtx = authenticationProvider.authenticate(authCmd.getToken()); @@ -188,6 +195,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke sessionRef.setSecurityCtx(securityCtx); pendingSessions.invalidate(sessionMd.session.getId()); establishSession(sessionMd.session, sessionRef, sessionMd); + webSocketService.handleCommands(sessionRef, cmdsWrapper); } } @@ -197,7 +205,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke try { SessionMetaData sessionMd = getSessionMd(session.getId()); if (sessionMd != null) { - log.trace("[{}][{}] Processing pong response {}", sessionMd.sessionRef.getSecurityCtx().getTenantId(), session.getId(), message.getPayload()); + log.trace("{} Processing pong response {}", sessionMd.sessionRef.toString(), message.getPayload()); sessionMd.processPongMessage(System.currentTimeMillis()); } else { log.trace("[{}] Failed to find session", session.getId()); @@ -247,7 +255,8 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke internalSessionMap.put(session.getId(), sessionMd); externalSessionMap.put(sessionRef.getSessionId(), session.getId()); processInWebSocketService(sessionRef, SessionEvent.onEstablished()); - log.info("[{}][{}][{}] Session established from address: {}", sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSessionId(), session.getId(), session.getRemoteAddress()); + log.info("[{}][{}][{}][{}] Session established from address: {}", sessionRef.getSecurityCtx().getTenantId(), + sessionRef.getSecurityCtx().getId(), sessionRef.getSessionId(), session.getId(), session.getRemoteAddress()); } else { sessionMd = new SessionMetaData(session, sessionRef); pendingSessions.put(session.getId(), sessionMd); @@ -280,7 +289,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke cleanupLimits(session, sessionMd.sessionRef); processInWebSocketService(sessionMd.sessionRef, SessionEvent.onClosed()); } - log.info("[{}][{}][{}] Session is closed", sessionMd.sessionRef.getSecurityCtx().getTenantId(), sessionMd.sessionRef.getSessionId(), session.getId()); + log.info("{} Session is closed", sessionMd.sessionRef.toString()); } else { log.info("[{}] Session is closed", session.getId()); } @@ -293,7 +302,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke try { webSocketService.handleSessionEvent(sessionRef, event); } catch (BeanCreationNotAllowedException e) { - log.warn("[{}] Failed to close session due to possible shutdown state", sessionRef.getSessionId()); + log.warn("{} Failed to close session due to possible shutdown state", sessionRef.toString()); } } @@ -359,13 +368,13 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke try { long timeSinceLastActivity = currentTime - lastActivityTime; if (timeSinceLastActivity >= pingTimeout) { - log.warn("[{}] Closing session due to ping timeout", session.getId()); + log.warn("{} Closing session due to ping timeout", sessionRef.toString()); closeSession(CloseStatus.SESSION_NOT_RELIABLE); } else if (timeSinceLastActivity >= pingTimeout / NUMBER_OF_PING_ATTEMPTS) { sendMsg(TbWebSocketPingMsg.INSTANCE); } } catch (Exception e) { - log.trace("[{}] Failed to send ping msg", session.getId(), e); + log.trace("{} Failed to send ping msg", sessionRef.toString(), e); closeSession(CloseStatus.SESSION_NOT_RELIABLE); } } @@ -374,7 +383,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke try { close(this.sessionRef, reason); } catch (IOException ioe) { - log.trace("[{}] Session transport error", session.getId(), ioe); + log.trace("{} Session transport error", sessionRef.toString(), ioe); } finally { outboundMsgQueue.clear(); } @@ -394,7 +403,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke outboundMsgQueueSize.incrementAndGet(); processNextMsg(); } else { - log.info("[{}][{}] Session closed due to updates queue size exceeded", sessionRef.getSecurityCtx().getTenantId(), session.getId()); + log.info("{} Session closed due to updates queue size exceeded", sessionRef.toString()); closeSession(CloseStatus.POLICY_VIOLATION.withReason("Max pending updates limit reached!")); } } @@ -412,7 +421,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke processNextMsg(); } } catch (Exception e) { - log.trace("[{}] Failed to send msg", session.getId(), e); + log.trace("{} Failed to send msg", sessionRef.toString(), e); closeSession(CloseStatus.SESSION_NOT_RELIABLE); } } @@ -420,7 +429,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke @Override public void onResult(SendResult result) { if (!result.isOK()) { - log.trace("[{}] Failed to send msg", session.getId(), result.getException()); + log.trace("{} Failed to send msg", sessionRef.toString(), result.getException()); closeSession(CloseStatus.SESSION_NOT_RELIABLE); return; } @@ -467,8 +476,8 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke @Override public void send(WebSocketSessionRef sessionRef, int subscriptionId, String msg) throws IOException { + log.debug("{} Sending {}", sessionRef.toString(), msg); String externalId = sessionRef.getSessionId(); - log.debug("[{}] Sending {}", externalId, msg); String internalId = externalSessionMap.get(externalId); if (internalId != null) { SessionMetaData sessionMd = internalSessionMap.get(internalId); @@ -476,13 +485,12 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke TenantId tenantId = sessionRef.getSecurityCtx().getTenantId(); if (!rateLimitService.checkRateLimit(LimitedApi.WS_UPDATES_PER_SESSION, tenantId, (Object) sessionRef.getSessionId())) { if (blacklistedSessions.putIfAbsent(externalId, sessionRef) == null) { - log.info("[{}][{}][{}] Failed to process session update. Max session updates limit reached" - , tenantId, sessionRef.getSecurityCtx().getId(), externalId); + log.info("{} Failed to process session update. Max session updates limit reached", sessionRef.toString()); sessionMd.sendMsg("{\"subscriptionId\":" + subscriptionId + ", \"errorCode\":" + ThingsboardErrorCode.TOO_MANY_UPDATES.getErrorCode() + ", \"errorMsg\":\"Too many updates!\"}"); } return; } else { - log.debug("[{}][{}][{}] Session is no longer blacklisted.", tenantId, sessionRef.getSecurityCtx().getId(), externalId); + log.debug("{} Session is no longer blacklisted.", sessionRef.toString()); blacklistedSessions.remove(externalId); } sessionMd.sendMsg(msg); @@ -513,7 +521,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke @Override public void close(WebSocketSessionRef sessionRef, CloseStatus reason) throws IOException { String externalId = sessionRef.getSessionId(); - log.debug("[{}] Processing close request", externalId); + log.debug("{} Processing close request", sessionRef.toString()); String internalId = externalSessionMap.get(externalId); if (internalId != null) { SessionMetaData sessionMd = getSessionMd(internalId); @@ -543,8 +551,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } } if (!limitAllowed) { - log.info("[{}][{}][{}] Failed to start session. Max tenant sessions limit reached" - , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); + log.info("{} Failed to start session. Max tenant sessions limit reached", sessionRef.toString()); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max tenant sessions limit reached!")); return false; } @@ -560,8 +567,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } } if (!limitAllowed) { - log.info("[{}][{}][{}] Failed to start session. Max customer sessions limit reached" - , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); + log.info("{} Failed to start session. Max customer sessions limit reached", sessionRef.toString()); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max customer sessions limit reached")); return false; } @@ -576,8 +582,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } } if (!limitAllowed) { - log.info("[{}][{}][{}] Failed to start session. Max regular user sessions limit reached" - , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); + log.info("{} Failed to start session. Max regular user sessions limit reached", sessionRef.toString()); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max regular user sessions limit reached")); return false; } @@ -592,8 +597,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } } if (!limitAllowed) { - log.info("[{}][{}][{}] Failed to start session. Max public user sessions limit reached" - , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); + log.info("{} Failed to start session. Max public user sessions limit reached", sessionRef.toString()); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max public user sessions limit reached")); return false; } diff --git a/application/src/main/java/org/thingsboard/server/service/ws/DefaultWebSocketService.java b/application/src/main/java/org/thingsboard/server/service/ws/DefaultWebSocketService.java index b1118224ef..3e9ee69cfc 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/DefaultWebSocketService.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/DefaultWebSocketService.java @@ -86,6 +86,7 @@ import javax.annotation.PreDestroy; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; +import java.util.EnumMap; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -147,7 +148,7 @@ public class DefaultWebSocketService implements WebSocketService { private ScheduledExecutorService pingExecutor; private String serviceId; - private List> cmdsHandlers; + private Map> cmdsHandlers; @PostConstruct public void init() { @@ -157,24 +158,23 @@ public class DefaultWebSocketService implements WebSocketService { pingExecutor = Executors.newSingleThreadScheduledExecutor(ThingsBoardThreadFactory.forName("telemetry-web-socket-ping")); pingExecutor.scheduleWithFixedDelay(this::sendPing, pingTimeout / NUMBER_OF_PING_ATTEMPTS, pingTimeout / NUMBER_OF_PING_ATTEMPTS, TimeUnit.MILLISECONDS); - cmdsHandlers = List.of( - newCmdHandler(WsCmdType.ATTRIBUTES, this::handleWsAttributesSubscriptionCmd), - newCmdHandler(WsCmdType.TIMESERIES, this::handleWsTimeseriesSubscriptionCmd), - newCmdHandler(WsCmdType.TIMESERIES_HISTORY, this::handleWsHistoryCmd), - newCmdHandler(WsCmdType.ENTITY_DATA, this::handleWsEntityDataCmd), - newCmdHandler(WsCmdType.ALARM_DATA, this::handleWsAlarmDataCmd), - newCmdHandler(WsCmdType.ENTITY_COUNT, this::handleWsEntityCountCmd), - newCmdHandler(WsCmdType.ALARM_COUNT, this::handleWsAlarmCountCmd), - newCmdHandler(WsCmdType.ENTITY_DATA_UNSUBSCRIBE, this::handleWsDataUnsubscribeCmd), - newCmdHandler(WsCmdType.ALARM_DATA_UNSUBSCRIBE, this::handleWsDataUnsubscribeCmd), - newCmdHandler(WsCmdType.ENTITY_COUNT_UNSUBSCRIBE, this::handleWsDataUnsubscribeCmd), - newCmdHandler(WsCmdType.ALARM_COUNT_UNSUBSCRIBE, this::handleWsDataUnsubscribeCmd), - newCmdHandler(WsCmdType.NOTIFICATIONS, notificationCmdsHandler::handleUnreadNotificationsSubCmd), - newCmdHandler(WsCmdType.NOTIFICATIONS_COUNT, notificationCmdsHandler::handleUnreadNotificationsCountSubCmd), - newCmdHandler(WsCmdType.MARK_NOTIFICATIONS_AS_READ, notificationCmdsHandler::handleMarkAsReadCmd), - newCmdHandler(WsCmdType.MARK_ALL_NOTIFICATIONS_AS_READ, notificationCmdsHandler::handleMarkAllAsReadCmd), - newCmdHandler(WsCmdType.NOTIFICATIONS_UNSUBSCRIBE, notificationCmdsHandler::handleUnsubCmd) - ); + cmdsHandlers = new EnumMap<>(WsCmdType.class); + cmdsHandlers.put(WsCmdType.ATTRIBUTES, newCmdHandler(this::handleWsAttributesSubscriptionCmd)); + cmdsHandlers.put(WsCmdType.TIMESERIES, newCmdHandler(this::handleWsTimeseriesSubscriptionCmd)); + cmdsHandlers.put(WsCmdType.TIMESERIES_HISTORY, newCmdHandler(this::handleWsHistoryCmd)); + cmdsHandlers.put(WsCmdType.ENTITY_DATA, newCmdHandler(this::handleWsEntityDataCmd)); + cmdsHandlers.put(WsCmdType.ALARM_DATA, newCmdHandler(this::handleWsAlarmDataCmd)); + cmdsHandlers.put(WsCmdType.ENTITY_COUNT, newCmdHandler(this::handleWsEntityCountCmd)); + cmdsHandlers.put(WsCmdType.ALARM_COUNT, newCmdHandler(this::handleWsAlarmCountCmd)); + cmdsHandlers.put(WsCmdType.ENTITY_DATA_UNSUBSCRIBE, newCmdHandler(this::handleWsDataUnsubscribeCmd)); + cmdsHandlers.put(WsCmdType.ALARM_DATA_UNSUBSCRIBE, newCmdHandler(this::handleWsDataUnsubscribeCmd)); + cmdsHandlers.put(WsCmdType.ENTITY_COUNT_UNSUBSCRIBE, newCmdHandler(this::handleWsDataUnsubscribeCmd)); + cmdsHandlers.put(WsCmdType.ALARM_COUNT_UNSUBSCRIBE, newCmdHandler(this::handleWsDataUnsubscribeCmd)); + cmdsHandlers.put(WsCmdType.NOTIFICATIONS, newCmdHandler(notificationCmdsHandler::handleUnreadNotificationsSubCmd)); + cmdsHandlers.put(WsCmdType.NOTIFICATIONS_COUNT, newCmdHandler(notificationCmdsHandler::handleUnreadNotificationsCountSubCmd)); + cmdsHandlers.put(WsCmdType.MARK_NOTIFICATIONS_AS_READ, newCmdHandler(notificationCmdsHandler::handleMarkAsReadCmd)); + cmdsHandlers.put(WsCmdType.MARK_ALL_NOTIFICATIONS_AS_READ, newCmdHandler(notificationCmdsHandler::handleMarkAllAsReadCmd)); + cmdsHandlers.put(WsCmdType.NOTIFICATIONS_UNSUBSCRIBE, newCmdHandler(notificationCmdsHandler::handleUnsubCmd)); } @PreDestroy @@ -221,7 +221,7 @@ public class DefaultWebSocketService implements WebSocketService { for (WsCmd cmd : commandsWrapper.getCmds()) { log.debug("[{}][{}][{}] Processing cmd: {}", sessionId, cmd.getType(), cmd.getCmdId(), cmd); try { - Optional.ofNullable(getCmdHandler(cmd.getType())) + Optional.ofNullable(cmdsHandlers.get(cmd.getType())) .ifPresent(cmdHandler -> cmdHandler.handle(sessionRef, cmd)); } catch (Exception e) { log.error("[sessionId: {}, tenantId: {}, userId: {}] Failed to handle WS cmd: {}", sessionId, @@ -963,24 +963,14 @@ public class DefaultWebSocketService implements WebSocketService { .map(TenantProfile::getDefaultProfileConfiguration).orElse(null); } - public WsCmdHandler getCmdHandler(WsCmdType cmdType) { - for (WsCmdHandler cmdHandler : cmdsHandlers) { - if (cmdHandler.getCmdType() == cmdType) { - return cmdHandler; - } - } - return null; - } - - public static WsCmdHandler newCmdHandler(WsCmdType cmdType, BiConsumer handler) { - return new WsCmdHandler<>(cmdType, handler); + public static WsCmdHandler newCmdHandler(BiConsumer handler) { + return new WsCmdHandler<>(handler); } @RequiredArgsConstructor @Getter @SuppressWarnings("unchecked") public static class WsCmdHandler { - private final WsCmdType cmdType; protected final BiConsumer handler; public void handle(WebSocketSessionRef sessionRef, WsCmd cmd) { diff --git a/application/src/main/java/org/thingsboard/server/service/ws/WebSocketSessionRef.java b/application/src/main/java/org/thingsboard/server/service/ws/WebSocketSessionRef.java index 1dc75b0149..ba9d27b44c 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/WebSocketSessionRef.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/WebSocketSessionRef.java @@ -54,11 +54,13 @@ public class WebSocketSessionRef { @Override public String toString() { - return "WebSocketSessionRef{" + - "sessionId='" + sessionId + '\'' + - ", localAddress=" + localAddress + - ", remoteAddress=" + remoteAddress + - ", sessionType=" + sessionType + - '}'; + String info = ""; + if (securityCtx != null) { + info += "[" + securityCtx.getTenantId() + "]"; + info += "[" + securityCtx.getId() + "]"; + } + info += "[" + sessionId + "]"; + return info; } + } diff --git a/application/src/main/resources/thingsboard.yml b/application/src/main/resources/thingsboard.yml index 9421dfa742..b7829d2b88 100644 --- a/application/src/main/resources/thingsboard.yml +++ b/application/src/main/resources/thingsboard.yml @@ -76,6 +76,8 @@ server: max_entities_per_alarm_subscription: "${TB_SERVER_WS_MAX_ENTITIES_PER_ALARM_SUBSCRIPTION:10000}" # Maximum queue size of the websocket updates per session. This restriction prevents infinite updates of WS max_queue_messages_per_session: "${TB_SERVER_WS_DEFAULT_QUEUE_MESSAGES_PER_SESSION:1000}" + # Maximum time between WS session opening and sending auth command + auth_timeout_ms: "${TB_SERVER_WS_AUTH_TIMEOUT_MS:10000}" rest: server_side_rpc: # Minimum value of the server-side RPC timeout. May override value provided in the REST API call.