Merge pull request #8924 from smatvienko-tb/hotfix/web-socket-sync-fix

[Hotfix] web socket synchronized fix
This commit is contained in:
Andrew Shvayka 2023-07-18 16:05:00 +03:00 committed by GitHub
commit 04a9f41e52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 211 additions and 35 deletions

View File

@ -219,12 +219,12 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
.build(); .build();
} }
private class SessionMetaData implements SendHandler { class SessionMetaData implements SendHandler {
private final WebSocketSession session; private final WebSocketSession session;
private final RemoteEndpoint.Async asyncRemote; private final RemoteEndpoint.Async asyncRemote;
private final WebSocketSessionRef sessionRef; private final WebSocketSessionRef sessionRef;
private final AtomicBoolean isSending = new AtomicBoolean(false); final AtomicBoolean isSending = new AtomicBoolean(false);
private final Queue<TbWebSocketMsg<?>> msgQueue; private final Queue<TbWebSocketMsg<?>> msgQueue;
private volatile long lastActivityTime; private volatile long lastActivityTime;
@ -239,7 +239,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
this.lastActivityTime = System.currentTimeMillis(); this.lastActivityTime = System.currentTimeMillis();
} }
synchronized void sendPing(long currentTime) { void sendPing(long currentTime) {
try { try {
long timeSinceLastActivity = currentTime - lastActivityTime; long timeSinceLastActivity = currentTime - lastActivityTime;
if (timeSinceLastActivity >= pingTimeout) { if (timeSinceLastActivity >= pingTimeout) {
@ -254,37 +254,38 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
} }
} }
private void closeSession(CloseStatus reason) { void closeSession(CloseStatus reason) {
try { try {
close(this.sessionRef, reason); close(this.sessionRef, reason);
} catch (IOException ioe) { } catch (IOException ioe) {
log.trace("[{}] Session transport error", session.getId(), ioe); log.trace("[{}] Session transport error", session.getId(), ioe);
} finally {
msgQueue.clear();
} }
} }
synchronized void processPongMessage(long currentTime) { void processPongMessage(long currentTime) {
lastActivityTime = currentTime; lastActivityTime = currentTime;
} }
synchronized void sendMsg(String msg) { void sendMsg(String msg) {
sendMsg(new TbWebSocketTextMsg(msg)); sendMsg(new TbWebSocketTextMsg(msg));
} }
synchronized void sendMsg(TbWebSocketMsg<?> msg) { void sendMsg(TbWebSocketMsg<?> msg) {
if (isSending.compareAndSet(false, true)) { try {
sendMsgInternal(msg); msgQueue.add(msg);
} else { } catch (RuntimeException e) {
try { if (log.isTraceEnabled()) {
msgQueue.add(msg); log.trace("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId(), e);
} catch (RuntimeException e) { } else {
if (log.isTraceEnabled()) { log.info("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId());
log.trace("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId(), e);
} else {
log.info("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId());
}
closeSession(CloseStatus.POLICY_VIOLATION.withReason("Max pending updates limit reached!"));
} }
closeSession(CloseStatus.POLICY_VIOLATION.withReason("Max pending updates limit reached!"));
return;
} }
processNextMsg();
} }
private void sendMsgInternal(TbWebSocketMsg<?> msg) { private void sendMsgInternal(TbWebSocketMsg<?> msg) {
@ -292,9 +293,11 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
if (TbWebSocketMsgType.TEXT.equals(msg.getType())) { if (TbWebSocketMsgType.TEXT.equals(msg.getType())) {
TbWebSocketTextMsg textMsg = (TbWebSocketTextMsg) msg; TbWebSocketTextMsg textMsg = (TbWebSocketTextMsg) msg;
this.asyncRemote.sendText(textMsg.getMsg(), this); this.asyncRemote.sendText(textMsg.getMsg(), this);
// isSending status will be reset in the onResult method by call back
} else { } else {
TbWebSocketPingMsg pingMsg = (TbWebSocketPingMsg) msg; TbWebSocketPingMsg pingMsg = (TbWebSocketPingMsg) msg;
this.asyncRemote.sendPing(pingMsg.getMsg()); this.asyncRemote.sendPing(pingMsg.getMsg()); // blocking call
isSending.set(false);
processNextMsg(); processNextMsg();
} }
} catch (Exception e) { } catch (Exception e) {
@ -308,12 +311,17 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
if (!result.isOK()) { if (!result.isOK()) {
log.trace("[{}] Failed to send msg", session.getId(), result.getException()); log.trace("[{}] Failed to send msg", session.getId(), result.getException());
closeSession(CloseStatus.SESSION_NOT_RELIABLE); closeSession(CloseStatus.SESSION_NOT_RELIABLE);
} else { return;
processNextMsg();
} }
isSending.set(false);
processNextMsg();
} }
private void processNextMsg() { private void processNextMsg() {
if (msgQueue.isEmpty() || !isSending.compareAndSet(false, true)) {
return;
}
TbWebSocketMsg<?> msg = msgQueue.poll(); TbWebSocketMsg<?> msg = msgQueue.poll();
if (msg != null) { if (msg != null) {
sendMsgInternal(msg); sendMsgInternal(msg);
@ -397,19 +405,21 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
if (tenantProfileConfiguration == null) { if (tenantProfileConfiguration == null) {
return true; return true;
} }
boolean limitAllowed;
String sessionId = session.getId(); String sessionId = session.getId();
if (tenantProfileConfiguration.getMaxWsSessionsPerTenant() > 0) { if (tenantProfileConfiguration.getMaxWsSessionsPerTenant() > 0) {
Set<String> tenantSessions = tenantSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet()); Set<String> tenantSessions = tenantSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet());
synchronized (tenantSessions) { synchronized (tenantSessions) {
if (tenantSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerTenant()) { limitAllowed = tenantSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerTenant();
if (limitAllowed) {
tenantSessions.add(sessionId); tenantSessions.add(sessionId);
} else { }
}
if (!limitAllowed) {
log.info("[{}][{}][{}] Failed to start session. Max tenant sessions limit reached" log.info("[{}][{}][{}] Failed to start session. Max tenant sessions limit reached"
, sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
session.close(CloseStatus.POLICY_VIOLATION.withReason("Max tenant sessions limit reached!")); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max tenant sessions limit reached!"));
return false; return false;
}
} }
} }
@ -417,42 +427,48 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
if (tenantProfileConfiguration.getMaxWsSessionsPerCustomer() > 0) { if (tenantProfileConfiguration.getMaxWsSessionsPerCustomer() > 0) {
Set<String> customerSessions = customerSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet()); Set<String> customerSessions = customerSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet());
synchronized (customerSessions) { synchronized (customerSessions) {
if (customerSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerCustomer()) { limitAllowed = customerSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerCustomer();
if (limitAllowed) {
customerSessions.add(sessionId); customerSessions.add(sessionId);
} else { }
}
if (!limitAllowed) {
log.info("[{}][{}][{}] Failed to start session. Max customer sessions limit reached" log.info("[{}][{}][{}] Failed to start session. Max customer sessions limit reached"
, sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
session.close(CloseStatus.POLICY_VIOLATION.withReason("Max customer sessions limit reached")); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max customer sessions limit reached"));
return false; return false;
}
} }
} }
if (tenantProfileConfiguration.getMaxWsSessionsPerRegularUser() > 0 if (tenantProfileConfiguration.getMaxWsSessionsPerRegularUser() > 0
&& UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
Set<String> regularUserSessions = regularUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); Set<String> regularUserSessions = regularUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
synchronized (regularUserSessions) { synchronized (regularUserSessions) {
if (regularUserSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerRegularUser()) { limitAllowed = regularUserSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerRegularUser();
if (limitAllowed) {
regularUserSessions.add(sessionId); regularUserSessions.add(sessionId);
} else { }
}
if (!limitAllowed) {
log.info("[{}][{}][{}] Failed to start session. Max regular user sessions limit reached" log.info("[{}][{}][{}] Failed to start session. Max regular user sessions limit reached"
, sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
session.close(CloseStatus.POLICY_VIOLATION.withReason("Max regular user sessions limit reached")); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max regular user sessions limit reached"));
return false; return false;
}
} }
} }
if (tenantProfileConfiguration.getMaxWsSessionsPerPublicUser() > 0 if (tenantProfileConfiguration.getMaxWsSessionsPerPublicUser() > 0
&& UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
Set<String> publicUserSessions = publicUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); Set<String> publicUserSessions = publicUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
synchronized (publicUserSessions) { synchronized (publicUserSessions) {
if (publicUserSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerPublicUser()) { limitAllowed = publicUserSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerPublicUser();
if (limitAllowed) {
publicUserSessions.add(sessionId); publicUserSessions.add(sessionId);
} else { }
}
if (!limitAllowed) {
log.info("[{}][{}][{}] Failed to start session. Max public user sessions limit reached" log.info("[{}][{}][{}] Failed to start session. Max public user sessions limit reached"
, sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
session.close(CloseStatus.POLICY_VIOLATION.withReason("Max public user sessions limit reached")); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max public user sessions limit reached"));
return false; return false;
}
} }
} }
} }

View File

@ -0,0 +1,160 @@
/**
* Copyright © 2016-2023 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.controller.plugin;
import lombok.extern.slf4j.Slf4j;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.adapter.NativeWebSocketSession;
import org.thingsboard.common.util.ThingsBoardThreadFactory;
import org.thingsboard.server.service.ws.WebSocketSessionRef;
import javax.websocket.RemoteEndpoint;
import javax.websocket.SendHandler;
import javax.websocket.SendResult;
import javax.websocket.Session;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.BDDMockito.willDoNothing;
import static org.mockito.BDDMockito.willReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@Slf4j
class TbWebSocketHandlerTest {
TbWebSocketHandler wsHandler;
NativeWebSocketSession session;
Session nativeSession;
RemoteEndpoint.Async asyncRemote;
WebSocketSessionRef sessionRef;
int maxMsgQueuePerSession;
TbWebSocketHandler.SessionMetaData sendHandler;
ExecutorService executor;
@BeforeEach
void setUp() throws IOException {
maxMsgQueuePerSession = 100;
executor = Executors.newCachedThreadPool(ThingsBoardThreadFactory.forName(getClass().getSimpleName()));
wsHandler = spy(new TbWebSocketHandler());
willDoNothing().given(wsHandler).close(any(), any());
session = mock(NativeWebSocketSession.class);
nativeSession = mock(Session.class);
willReturn(nativeSession).given(session).getNativeSession(Session.class);
asyncRemote = mock(RemoteEndpoint.Async.class);
willReturn(asyncRemote).given(nativeSession).getAsyncRemote();
sessionRef = mock(WebSocketSessionRef.class, Mockito.RETURNS_DEEP_STUBS); //prevent NPE on logs
sendHandler = spy(wsHandler.new SessionMetaData(session, sessionRef, maxMsgQueuePerSession));
}
@AfterEach
void tearDown() {
if (executor != null) {
executor.shutdownNow();
}
}
@Test
void sendHandler_sendMsg_parallel_no_race() throws InterruptedException {
CountDownLatch finishLatch = new CountDownLatch(maxMsgQueuePerSession * 2);
AtomicInteger sendersCount = new AtomicInteger();
willAnswer(invocation -> {
assertThat(sendersCount.incrementAndGet()).as("no race").isEqualTo(1);
String text = invocation.getArgument(0);
SendHandler onResultHandler = invocation.getArgument(1);
SendResult sendResult = new SendResult();
executor.submit(() -> {
sendersCount.decrementAndGet();
onResultHandler.onResult(sendResult);
finishLatch.countDown();
});
return null;
}).given(asyncRemote).sendText(anyString(), any());
assertThat(sendHandler.isSending.get()).as("sendHandler not is in sending state").isFalse();
//first batch
IntStream.range(0, maxMsgQueuePerSession).parallel().forEach(i -> sendHandler.sendMsg("hello " + i));
Awaitility.await("first batch processed").atMost(30, TimeUnit.SECONDS).until(() -> finishLatch.getCount() == maxMsgQueuePerSession);
assertThat(sendHandler.isSending.get()).as("sendHandler not is in sending state").isFalse();
//second batch - to test pause between big msg batches
IntStream.range(100, 100 + maxMsgQueuePerSession).parallel().forEach(i -> sendHandler.sendMsg("hello " + i));
assertThat(finishLatch.await(30, TimeUnit.SECONDS)).as("all callbacks fired").isTrue();
verify(sendHandler, never()).closeSession(any());
verify(sendHandler, times(maxMsgQueuePerSession * 2)).onResult(any());
assertThat(sendHandler.isSending.get()).as("sendHandler not is in sending state").isFalse();
}
@Test
void sendHandler_sendMsg_message_order() throws InterruptedException {
CountDownLatch finishLatch = new CountDownLatch(maxMsgQueuePerSession);
Collection<String> outputs = new ConcurrentLinkedQueue<>();
willAnswer(invocation -> {
String text = invocation.getArgument(0);
outputs.add(text);
SendHandler onResultHandler = invocation.getArgument(1);
SendResult sendResult = new SendResult();
executor.submit(() -> {
onResultHandler.onResult(sendResult);
finishLatch.countDown();
});
return null;
}).given(asyncRemote).sendText(anyString(), any());
List<String> inputs = IntStream.range(0, maxMsgQueuePerSession).mapToObj(i -> "msg " + i).collect(Collectors.toList());
inputs.forEach(s -> sendHandler.sendMsg(s));
assertThat(finishLatch.await(30, TimeUnit.SECONDS)).as("all callbacks fired").isTrue();
assertThat(outputs).as("inputs exactly the same as outputs").containsExactlyElementsOf(inputs);
verify(sendHandler, never()).closeSession(any());
verify(sendHandler, times(maxMsgQueuePerSession)).onResult(any());
}
@Test
void sendHandler_sendMsg_queue_size_exceed() {
willDoNothing().given(asyncRemote).sendText(anyString(), any()); // send text will never call back, so queue will grow each sendMsg
sendHandler.sendMsg("first message to stay in-flight all the time during this test");
IntStream.range(0, maxMsgQueuePerSession).parallel().forEach(i -> sendHandler.sendMsg("hello " + i));
verify(sendHandler, never()).closeSession(any());
sendHandler.sendMsg("excessive message");
verify(sendHandler, times(1)).closeSession(eq(new CloseStatus(1008, "Max pending updates limit reached!")));
verify(asyncRemote, times(1)).sendText(anyString(), any());
}
}