From 5eebbf89859ee08ee6c481646c060495f51086ce Mon Sep 17 00:00:00 2001 From: Sergey Matvienko Date: Thu, 13 Jul 2023 22:28:59 +0200 Subject: [PATCH] web socket handler tests added. ws msg queue fixed the last msg pickup (and msg order as result) --- .../controller/plugin/TbWebSocketHandler.java | 44 +++-- .../plugin/TbWebSocketHandlerTest.java | 160 ++++++++++++++++++ 2 files changed, 186 insertions(+), 18 deletions(-) create mode 100644 application/src/test/java/org/thingsboard/server/controller/plugin/TbWebSocketHandlerTest.java diff --git a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java index 56d88143a5..481d412d59 100644 --- a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java +++ b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java @@ -219,12 +219,12 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke .build(); } - private class SessionMetaData implements SendHandler { + class SessionMetaData implements SendHandler { private final WebSocketSession session; private final RemoteEndpoint.Async asyncRemote; private final WebSocketSessionRef sessionRef; - private final AtomicBoolean isSending = new AtomicBoolean(false); + final AtomicBoolean isSending = new AtomicBoolean(false); private final Queue> msgQueue; private volatile long lastActivityTime; @@ -254,11 +254,13 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } } - private void closeSession(CloseStatus reason) { + void closeSession(CloseStatus reason) { try { close(this.sessionRef, reason); } catch (IOException ioe) { log.trace("[{}] Session transport error", session.getId(), ioe); + } finally { + msgQueue.clear(); } } @@ -271,20 +273,19 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } void sendMsg(TbWebSocketMsg msg) { - if (isSending.compareAndSet(false, true)) { - sendMsgInternal(msg); - } else { - try { - msgQueue.add(msg); - } catch (RuntimeException e) { - if (log.isTraceEnabled()) { - log.trace("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId(), e); - } else { - log.info("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId()); - } - closeSession(CloseStatus.POLICY_VIOLATION.withReason("Max pending updates limit reached!")); + try { + msgQueue.add(msg); + } catch (RuntimeException e) { + if (log.isTraceEnabled()) { + log.trace("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId(), e); + } else { + log.info("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId()); } + closeSession(CloseStatus.POLICY_VIOLATION.withReason("Max pending updates limit reached!")); + return; } + + processNextMsg(); } private void sendMsgInternal(TbWebSocketMsg msg) { @@ -292,9 +293,11 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke if (TbWebSocketMsgType.TEXT.equals(msg.getType())) { TbWebSocketTextMsg textMsg = (TbWebSocketTextMsg) msg; this.asyncRemote.sendText(textMsg.getMsg(), this); + // isSending status will be reset in the onResult method by call back } else { TbWebSocketPingMsg pingMsg = (TbWebSocketPingMsg) msg; - this.asyncRemote.sendPing(pingMsg.getMsg()); + this.asyncRemote.sendPing(pingMsg.getMsg()); // blocking call + isSending.set(false); processNextMsg(); } } catch (Exception e) { @@ -308,12 +311,17 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke if (!result.isOK()) { log.trace("[{}] Failed to send msg", session.getId(), result.getException()); closeSession(CloseStatus.SESSION_NOT_RELIABLE); - } else { - processNextMsg(); + return; } + + isSending.set(false); + processNextMsg(); } private void processNextMsg() { + if (msgQueue.isEmpty() || !isSending.compareAndSet(false, true)) { + return; + } TbWebSocketMsg msg = msgQueue.poll(); if (msg != null) { sendMsgInternal(msg); diff --git a/application/src/test/java/org/thingsboard/server/controller/plugin/TbWebSocketHandlerTest.java b/application/src/test/java/org/thingsboard/server/controller/plugin/TbWebSocketHandlerTest.java new file mode 100644 index 0000000000..0394e8a505 --- /dev/null +++ b/application/src/test/java/org/thingsboard/server/controller/plugin/TbWebSocketHandlerTest.java @@ -0,0 +1,160 @@ +/** + * Copyright © 2016-2023 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.controller.plugin; + +import lombok.extern.slf4j.Slf4j; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.adapter.NativeWebSocketSession; +import org.thingsboard.common.util.ThingsBoardThreadFactory; +import org.thingsboard.server.service.ws.WebSocketSessionRef; + +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendHandler; +import javax.websocket.SendResult; +import javax.websocket.Session; +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.BDDMockito.willDoNothing; +import static org.mockito.BDDMockito.willReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +@Slf4j +class TbWebSocketHandlerTest { + + TbWebSocketHandler wsHandler; + NativeWebSocketSession session; + Session nativeSession; + RemoteEndpoint.Async asyncRemote; + WebSocketSessionRef sessionRef; + int maxMsgQueuePerSession; + TbWebSocketHandler.SessionMetaData sendHandler; + ExecutorService executor; + + @BeforeEach + void setUp() throws IOException { + maxMsgQueuePerSession = 100; + executor = Executors.newCachedThreadPool(ThingsBoardThreadFactory.forName(getClass().getSimpleName())); + wsHandler = spy(new TbWebSocketHandler()); + willDoNothing().given(wsHandler).close(any(), any()); + session = mock(NativeWebSocketSession.class); + nativeSession = mock(Session.class); + willReturn(nativeSession).given(session).getNativeSession(Session.class); + asyncRemote = mock(RemoteEndpoint.Async.class); + willReturn(asyncRemote).given(nativeSession).getAsyncRemote(); + sessionRef = mock(WebSocketSessionRef.class, Mockito.RETURNS_DEEP_STUBS); //prevent NPE on logs + sendHandler = spy(wsHandler.new SessionMetaData(session, sessionRef, maxMsgQueuePerSession)); + } + + @AfterEach + void tearDown() { + if (executor != null) { + executor.shutdownNow(); + } + } + + @Test + void sendHandler_sendMsg_parallel_no_race() throws InterruptedException { + CountDownLatch finishLatch = new CountDownLatch(maxMsgQueuePerSession * 2); + AtomicInteger sendersCount = new AtomicInteger(); + willAnswer(invocation -> { + assertThat(sendersCount.incrementAndGet()).as("no race").isEqualTo(1); + String text = invocation.getArgument(0); + SendHandler onResultHandler = invocation.getArgument(1); + SendResult sendResult = new SendResult(); + executor.submit(() -> { + sendersCount.decrementAndGet(); + onResultHandler.onResult(sendResult); + finishLatch.countDown(); + }); + return null; + }).given(asyncRemote).sendText(anyString(), any()); + + assertThat(sendHandler.isSending.get()).as("sendHandler not is in sending state").isFalse(); + //first batch + IntStream.range(0, maxMsgQueuePerSession).parallel().forEach(i -> sendHandler.sendMsg("hello " + i)); + Awaitility.await("first batch processed").atMost(30, TimeUnit.SECONDS).until(() -> finishLatch.getCount() == maxMsgQueuePerSession); + assertThat(sendHandler.isSending.get()).as("sendHandler not is in sending state").isFalse(); + //second batch - to test pause between big msg batches + IntStream.range(100, 100 + maxMsgQueuePerSession).parallel().forEach(i -> sendHandler.sendMsg("hello " + i)); + assertThat(finishLatch.await(30, TimeUnit.SECONDS)).as("all callbacks fired").isTrue(); + + verify(sendHandler, never()).closeSession(any()); + verify(sendHandler, times(maxMsgQueuePerSession * 2)).onResult(any()); + assertThat(sendHandler.isSending.get()).as("sendHandler not is in sending state").isFalse(); + } + + @Test + void sendHandler_sendMsg_message_order() throws InterruptedException { + CountDownLatch finishLatch = new CountDownLatch(maxMsgQueuePerSession); + Collection outputs = new ConcurrentLinkedQueue<>(); + willAnswer(invocation -> { + String text = invocation.getArgument(0); + outputs.add(text); + SendHandler onResultHandler = invocation.getArgument(1); + SendResult sendResult = new SendResult(); + executor.submit(() -> { + onResultHandler.onResult(sendResult); + finishLatch.countDown(); + }); + return null; + }).given(asyncRemote).sendText(anyString(), any()); + + List inputs = IntStream.range(0, maxMsgQueuePerSession).mapToObj(i -> "msg " + i).collect(Collectors.toList()); + inputs.forEach(s -> sendHandler.sendMsg(s)); + + assertThat(finishLatch.await(30, TimeUnit.SECONDS)).as("all callbacks fired").isTrue(); + assertThat(outputs).as("inputs exactly the same as outputs").containsExactlyElementsOf(inputs); + + verify(sendHandler, never()).closeSession(any()); + verify(sendHandler, times(maxMsgQueuePerSession)).onResult(any()); + } + + @Test + void sendHandler_sendMsg_queue_size_exceed() { + willDoNothing().given(asyncRemote).sendText(anyString(), any()); // send text will never call back, so queue will grow each sendMsg + sendHandler.sendMsg("first message to stay in-flight all the time during this test"); + IntStream.range(0, maxMsgQueuePerSession).parallel().forEach(i -> sendHandler.sendMsg("hello " + i)); + verify(sendHandler, never()).closeSession(any()); + sendHandler.sendMsg("excessive message"); + verify(sendHandler, times(1)).closeSession(eq(new CloseStatus(1008, "Max pending updates limit reached!"))); + verify(asyncRemote, times(1)).sendText(anyString(), any()); + } + +}