From fb92aef8cb24cb3ba067ecc4d1431fd6d147d26e Mon Sep 17 00:00:00 2001 From: ViacheslavKlimov Date: Mon, 11 Dec 2023 11:24:55 +0200 Subject: [PATCH] WS inbound msg queue --- .../controller/plugin/TbWebSocketHandler.java | 184 ++++++++++-------- .../ws/telemetry/cmd/v2/AuthCmdUpdate.java | 34 ---- .../ws/telemetry/cmd/v2/CmdUpdateType.java | 3 +- .../controller/TbTestWebSocketClient.java | 1 - .../plugin/TbWebSocketHandlerTest.java | 29 ++- 5 files changed, 135 insertions(+), 116 deletions(-) delete mode 100644 application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/AuthCmdUpdate.java diff --git a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java index 855cbbc1e5..0cdb82b612 100644 --- a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java +++ b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java @@ -19,6 +19,7 @@ import com.github.benmanes.caffeine.cache.Cache; import com.github.benmanes.caffeine.cache.Caffeine; import com.github.benmanes.caffeine.cache.RemovalCause; import lombok.RequiredArgsConstructor; +import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.BeanCreationNotAllowedException; @@ -57,7 +58,6 @@ import org.thingsboard.server.service.ws.WebSocketSessionType; import org.thingsboard.server.service.ws.WsCommandsWrapper; import org.thingsboard.server.service.ws.notification.cmd.NotificationCmdsWrapper; import org.thingsboard.server.service.ws.telemetry.cmd.TelemetryCmdsWrapper; -import org.thingsboard.server.service.ws.telemetry.cmd.v2.AuthCmdUpdate; import javax.websocket.RemoteEndpoint; import javax.websocket.SendHandler; @@ -70,10 +70,13 @@ import java.util.Queue; import java.util.Set; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import static org.thingsboard.server.service.ws.DefaultWebSocketService.NUMBER_OF_PING_ATTEMPTS; @@ -131,61 +134,63 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke session.close(CloseStatus.SERVER_ERROR.withReason("Session not found!")); return; } - WebSocketSessionRef sessionRef = sessionMd.sessionRef; String msg = message.getPayload(); - - WsCommandsWrapper cmdsWrapper; - try { - switch (sessionRef.getSessionType()) { - case GENERAL: - cmdsWrapper = JacksonUtil.fromString(msg, WsCommandsWrapper.class); - break; - case TELEMETRY: - cmdsWrapper = JacksonUtil.fromString(msg, TelemetryCmdsWrapper.class).toCommonCmdsWrapper(); - break; - case NOTIFICATIONS: - cmdsWrapper = JacksonUtil.fromString(msg, NotificationCmdsWrapper.class).toCommonCmdsWrapper(); - break; - default: - return; - } - } catch (Exception e) { - log.warn("Failed to decode subscription cmd: {}", e.getMessage(), e); - if (sessionRef.getSecurityCtx() != null) { - webSocketService.sendError(sessionRef, 1, SubscriptionErrorCode.BAD_REQUEST, "Failed to parse the payload"); - } else { - close(sessionRef, CloseStatus.BAD_DATA.withReason(e.getMessage())); - } - return; - } - - if (sessionRef.getSecurityCtx() != null) { - log.trace("[{}][{}] Processing {}", sessionRef.getSecurityCtx().getTenantId(), session.getId(), msg); - webSocketService.handleCommands(sessionRef, cmdsWrapper); - } else { - AuthCmd authCmd = cmdsWrapper.getAuthCmd(); - if (authCmd == null) { - close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Auth cmd is missing")); - return; - } - log.trace("[{}] Authenticating session", session.getId()); - SecurityUser securityCtx; - try { - securityCtx = authenticationProvider.authenticate(authCmd.getToken()); - } catch (Exception e) { - close(sessionRef, CloseStatus.BAD_DATA.withReason(e.getMessage())); - return; - } - sessionRef.setSecurityCtx(securityCtx); - pendingSessions.invalidate(session.getId()); - establishSession(session, sessionRef); - webSocketService.sendUpdate(sessionRef.getSessionId(), new AuthCmdUpdate(1)); - } + sessionMd.onMsg(msg); } catch (IOException e) { log.warn("IO error", e); } } + void processMsg(SessionMetaData sessionMd, String msg) throws IOException { + WebSocketSessionRef sessionRef = sessionMd.sessionRef; + WsCommandsWrapper cmdsWrapper; + try { + switch (sessionRef.getSessionType()) { + case GENERAL: + cmdsWrapper = JacksonUtil.fromString(msg, WsCommandsWrapper.class); + break; + case TELEMETRY: + cmdsWrapper = JacksonUtil.fromString(msg, TelemetryCmdsWrapper.class).toCommonCmdsWrapper(); + break; + case NOTIFICATIONS: + cmdsWrapper = JacksonUtil.fromString(msg, NotificationCmdsWrapper.class).toCommonCmdsWrapper(); + break; + default: + return; + } + } catch (Exception e) { + log.warn("Failed to decode subscription cmd: {}", e.getMessage(), e); + if (sessionRef.getSecurityCtx() != null) { + webSocketService.sendError(sessionRef, 1, SubscriptionErrorCode.BAD_REQUEST, "Failed to parse the payload"); + } else { + close(sessionRef, CloseStatus.BAD_DATA.withReason(e.getMessage())); + } + return; + } + + if (sessionRef.getSecurityCtx() != null) { + log.trace("[{}][{}] Processing {}", sessionRef.getSecurityCtx().getTenantId(), sessionMd.session.getId(), msg); + webSocketService.handleCommands(sessionRef, cmdsWrapper); + } else { + AuthCmd authCmd = cmdsWrapper.getAuthCmd(); + if (authCmd == null) { + close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Auth cmd is missing")); + return; + } + log.trace("[{}] Authenticating session", sessionMd.session.getId()); + SecurityUser securityCtx; + try { + securityCtx = authenticationProvider.authenticate(authCmd.getToken()); + } catch (Exception e) { + close(sessionRef, CloseStatus.BAD_DATA.withReason(e.getMessage())); + return; + } + sessionRef.setSecurityCtx(securityCtx); + pendingSessions.invalidate(sessionMd.session.getId()); + establishSession(sessionMd.session, sessionRef, sessionMd); + } + } + @Override protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception { try { @@ -214,7 +219,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } WebSocketSessionRef sessionRef = toRef(session); log.debug("[{}][{}] Session opened from address: {}", sessionRef.getSessionId(), session.getId(), session.getRemoteAddress()); - establishSession(session, sessionRef); + establishSession(session, sessionRef, null); } catch (InvalidParameterException e) { log.warn("[{}] Failed to start session", session.getId(), e); session.close(CloseStatus.BAD_DATA.withReason(e.getMessage())); @@ -224,24 +229,26 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } } - private void establishSession(WebSocketSession session, WebSocketSessionRef sessionRef) throws IOException { + private void establishSession(WebSocketSession session, WebSocketSessionRef sessionRef, SessionMetaData sessionMd) throws IOException { if (sessionRef.getSecurityCtx() != null) { if (!checkLimits(session, sessionRef)) { return; } - var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef); - int wsTenantProfileQueueLimit = tenantProfileConfiguration != null ? - tenantProfileConfiguration.getWsMsgQueueLimitPerSession() : wsMaxQueueMessagesPerSession; - SessionMetaData sessionMd = new SessionMetaData(session, sessionRef, - (wsTenantProfileQueueLimit > 0 && wsTenantProfileQueueLimit < wsMaxQueueMessagesPerSession) ? - wsTenantProfileQueueLimit : wsMaxQueueMessagesPerSession); + int maxMsgQueueSize = Optional.ofNullable(getTenantProfileConfiguration(sessionRef)) + .map(DefaultTenantProfileConfiguration::getWsMsgQueueLimitPerSession) + .filter(profileLimit -> profileLimit > 0 && profileLimit < wsMaxQueueMessagesPerSession) + .orElse(wsMaxQueueMessagesPerSession); + if (sessionMd == null) { + sessionMd = new SessionMetaData(session, sessionRef); + } + sessionMd.setMaxMsgQueueSize(maxMsgQueueSize); internalSessionMap.put(session.getId(), sessionMd); externalSessionMap.put(sessionRef.getSessionId(), session.getId()); processInWebSocketService(sessionRef, SessionEvent.onEstablished()); log.info("[{}][{}][{}] Session established from address: {}", sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSessionId(), session.getId(), session.getRemoteAddress()); } else { - SessionMetaData sessionMd = new SessionMetaData(session, sessionRef, wsMaxQueueMessagesPerSession); + sessionMd = new SessionMetaData(session, sessionRef); pendingSessions.put(session.getId(), sessionMd); externalSessionMap.put(sessionRef.getSessionId(), session.getId()); } @@ -328,19 +335,22 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke private final WebSocketSessionRef sessionRef; final AtomicBoolean isSending = new AtomicBoolean(false); - private final Queue> msgQueue; + private final Queue> outboundMsgQueue = new ConcurrentLinkedQueue<>(); + private final AtomicInteger outboundMsgQueueSize = new AtomicInteger(); + @Setter + private int maxMsgQueueSize = wsMaxQueueMessagesPerSession; - // TODO: msg queue as in org.thingsboard.server.transport.mqtt.session.DeviceSessionCtx + private final Queue inboundMsgQueue = new ConcurrentLinkedQueue<>(); + private final Lock inboundMsgQueueProcessorLock = new ReentrantLock(); private volatile long lastActivityTime; - SessionMetaData(WebSocketSession session, WebSocketSessionRef sessionRef, int maxMsgQueuePerSession) { + SessionMetaData(WebSocketSession session, WebSocketSessionRef sessionRef) { super(); this.session = session; Session nativeSession = ((NativeWebSocketSession) session).getNativeSession(Session.class); this.asyncRemote = nativeSession.getAsyncRemote(); this.sessionRef = sessionRef; - this.msgQueue = new LinkedBlockingQueue<>(maxMsgQueuePerSession); this.lastActivityTime = System.currentTimeMillis(); } @@ -365,7 +375,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } catch (IOException ioe) { log.trace("[{}] Session transport error", session.getId(), ioe); } finally { - msgQueue.clear(); + outboundMsgQueue.clear(); } } @@ -378,19 +388,14 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } void sendMsg(TbWebSocketMsg msg) { - try { - msgQueue.add(msg); - } catch (RuntimeException e) { - if (log.isTraceEnabled()) { - log.trace("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId(), e); - } else { - log.info("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId()); - } + if (outboundMsgQueueSize.get() < maxMsgQueueSize) { + outboundMsgQueue.add(msg); + outboundMsgQueueSize.incrementAndGet(); + processNextMsg(); + } else { + log.info("[{}][{}] Session closed due to updates queue size exceeded", sessionRef.getSecurityCtx().getTenantId(), session.getId()); closeSession(CloseStatus.POLICY_VIOLATION.withReason("Max pending updates limit reached!")); - return; } - - processNextMsg(); } private void sendMsgInternal(TbWebSocketMsg msg) { @@ -424,16 +429,39 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } private void processNextMsg() { - if (msgQueue.isEmpty() || !isSending.compareAndSet(false, true)) { + if (outboundMsgQueue.isEmpty() || !isSending.compareAndSet(false, true)) { return; } - TbWebSocketMsg msg = msgQueue.poll(); + TbWebSocketMsg msg = outboundMsgQueue.poll(); if (msg != null) { + outboundMsgQueueSize.decrementAndGet(); sendMsgInternal(msg); } else { isSending.set(false); } } + + public void onMsg(String msg) throws IOException { + inboundMsgQueue.add(msg); + tryProcessInboundMsgs(); + } + + void tryProcessInboundMsgs() throws IOException { + while (!inboundMsgQueue.isEmpty()) { + if (inboundMsgQueueProcessorLock.tryLock()) { + try { + String msg; + while ((msg = inboundMsgQueue.poll()) != null) { + processMsg(this, msg); + } + } finally { + inboundMsgQueueProcessorLock.unlock(); + } + } else { + return; + } + } + } } @Override diff --git a/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/AuthCmdUpdate.java b/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/AuthCmdUpdate.java deleted file mode 100644 index 61ed4fc9ce..0000000000 --- a/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/AuthCmdUpdate.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright © 2016-2023 The Thingsboard Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.thingsboard.server.service.ws.telemetry.cmd.v2; - -import org.thingsboard.server.service.subscription.SubscriptionErrorCode; - -public class AuthCmdUpdate extends CmdUpdate { - - public AuthCmdUpdate(int cmdId) { - this(cmdId, SubscriptionErrorCode.NO_ERROR.getCode(), null); - } - - public AuthCmdUpdate(int cmdId, int errorCode, String errorMsg) { - super(cmdId, errorCode, errorMsg); - } - - @Override - public CmdUpdateType getCmdUpdateType() { - return CmdUpdateType.AUTH; - } -} diff --git a/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/CmdUpdateType.java b/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/CmdUpdateType.java index 04b3cbd06e..f5b3809ce2 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/CmdUpdateType.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/CmdUpdateType.java @@ -21,6 +21,5 @@ public enum CmdUpdateType { ALARM_COUNT_DATA, COUNT_DATA, NOTIFICATIONS, - NOTIFICATIONS_COUNT, - AUTH + NOTIFICATIONS_COUNT } diff --git a/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java b/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java index 69eefa78bd..f7082782a3 100644 --- a/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java +++ b/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java @@ -69,7 +69,6 @@ public class TbTestWebSocketClient extends WebSocketClient { WsCommandsWrapper cmdsWrapper = new WsCommandsWrapper(); cmdsWrapper.setAuthCmd(new AuthCmd(1, token)); send(JacksonUtil.toString(cmdsWrapper)); - waitForReply(); } @Override diff --git a/application/src/test/java/org/thingsboard/server/controller/plugin/TbWebSocketHandlerTest.java b/application/src/test/java/org/thingsboard/server/controller/plugin/TbWebSocketHandlerTest.java index 0394e8a505..7d3d68f2d3 100644 --- a/application/src/test/java/org/thingsboard/server/controller/plugin/TbWebSocketHandlerTest.java +++ b/application/src/test/java/org/thingsboard/server/controller/plugin/TbWebSocketHandlerTest.java @@ -32,7 +32,10 @@ import javax.websocket.SendResult; import javax.websocket.Session; import java.io.IOException; import java.util.Collection; +import java.util.Deque; import java.util.List; +import java.util.Random; +import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; @@ -49,6 +52,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.BDDMockito.willDoNothing; import static org.mockito.BDDMockito.willReturn; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; @@ -79,7 +83,9 @@ class TbWebSocketHandlerTest { asyncRemote = mock(RemoteEndpoint.Async.class); willReturn(asyncRemote).given(nativeSession).getAsyncRemote(); sessionRef = mock(WebSocketSessionRef.class, Mockito.RETURNS_DEEP_STUBS); //prevent NPE on logs - sendHandler = spy(wsHandler.new SessionMetaData(session, sessionRef, maxMsgQueuePerSession)); + TbWebSocketHandler.SessionMetaData sessionMd = wsHandler.new SessionMetaData(session, sessionRef); + sessionMd.setMaxMsgQueueSize(maxMsgQueuePerSession); + sendHandler = spy(sessionMd); } @AfterEach @@ -157,4 +163,25 @@ class TbWebSocketHandlerTest { verify(asyncRemote, times(1)).sendText(anyString(), any()); } + @Test + void sendHandler_onMsg_allProcessed() throws Exception { + Deque msgs = new ConcurrentLinkedDeque<>(); + doAnswer(inv -> msgs.add(inv.getArgument(1))).when(wsHandler).processMsg(any(), any()); + for (int i = 0; i < 100; i++) { + String msg = String.valueOf(i); + executor.submit(() -> { + try { + Thread.sleep(new Random().nextInt(50)); + sendHandler.onMsg(msg); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + executor.shutdown(); + executor.awaitTermination(5, TimeUnit.SECONDS); + + assertThat(msgs).map(Integer::parseInt).doesNotHaveDuplicates().hasSize(100); + } + }