Merge pull request #7788 from ViacheslavKlimov/fix/sysadmin-ws-session

[3.4.3] Fix NPE on WS subscription for sysadmin
This commit is contained in:
Andrew Shvayka 2022-12-15 12:07:51 +02:00 committed by GitHub
commit 10587d797e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 95 additions and 21 deletions

View File

@ -28,10 +28,12 @@ import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.NativeWebSocketSession; import org.springframework.web.socket.adapter.NativeWebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler; import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.thingsboard.server.common.data.StringUtils; import org.thingsboard.server.common.data.StringUtils;
import org.thingsboard.server.common.data.TenantProfile;
import org.thingsboard.server.common.data.exception.ThingsboardErrorCode; import org.thingsboard.server.common.data.exception.ThingsboardErrorCode;
import org.thingsboard.server.common.data.id.CustomerId; import org.thingsboard.server.common.data.id.CustomerId;
import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.data.id.UserId; import org.thingsboard.server.common.data.id.UserId;
import org.thingsboard.server.common.data.tenant.profile.DefaultTenantProfileConfiguration;
import org.thingsboard.server.common.msg.tools.TbRateLimits; import org.thingsboard.server.common.msg.tools.TbRateLimits;
import org.thingsboard.server.config.WebSocketConfiguration; import org.thingsboard.server.config.WebSocketConfiguration;
import org.thingsboard.server.dao.tenant.TbTenantProfileCache; import org.thingsboard.server.dao.tenant.TbTenantProfileCache;
@ -50,6 +52,7 @@ import javax.websocket.Session;
import java.io.IOException; import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.security.InvalidParameterException; import java.security.InvalidParameterException;
import java.util.Optional;
import java.util.Queue; import java.util.Queue;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
@ -136,8 +139,9 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
if (!checkLimits(session, sessionRef)) { if (!checkLimits(session, sessionRef)) {
return; return;
} }
var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration(); var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef, tenantProfileConfiguration.getWsMsgQueueLimitPerSession() > 0 ? internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef,
tenantProfileConfiguration != null && tenantProfileConfiguration.getWsMsgQueueLimitPerSession() > 0 ?
tenantProfileConfiguration.getWsMsgQueueLimitPerSession() : 500)); tenantProfileConfiguration.getWsMsgQueueLimitPerSession() : 500));
externalSessionMap.put(externalSessionId, internalSessionId); externalSessionMap.put(externalSessionId, internalSessionId);
@ -316,7 +320,8 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
if (internalId != null) { if (internalId != null) {
SessionMetaData sessionMd = internalSessionMap.get(internalId); SessionMetaData sessionMd = internalSessionMap.get(internalId);
if (sessionMd != null) { if (sessionMd != null) {
var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration(); var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
if (tenantProfileConfiguration != null) {
if (StringUtils.isNotEmpty(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit())) { if (StringUtils.isNotEmpty(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit())) {
TbRateLimits rateLimits = perSessionUpdateLimits.computeIfAbsent(sessionRef.getSessionId(), sid -> new TbRateLimits(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit())); TbRateLimits rateLimits = perSessionUpdateLimits.computeIfAbsent(sessionRef.getSessionId(), sid -> new TbRateLimits(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit()));
if (!rateLimits.tryConsume()) { if (!rateLimits.tryConsume()) {
@ -333,6 +338,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
} else { } else {
perSessionUpdateLimits.remove(sessionRef.getSessionId()); perSessionUpdateLimits.remove(sessionRef.getSessionId());
} }
}
sessionMd.sendMsg(msg); sessionMd.sendMsg(msg);
} else { } else {
log.warn("[{}][{}] Failed to find session by internal id", externalId, internalId); log.warn("[{}][{}] Failed to find session by internal id", externalId, internalId);
@ -376,8 +382,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
} }
private boolean checkLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) throws Exception { private boolean checkLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) throws Exception {
var tenantProfileConfiguration = var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration();
if (tenantProfileConfiguration == null) { if (tenantProfileConfiguration == null) {
return true; return true;
} }
@ -444,7 +449,8 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
} }
private void cleanupLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) { private void cleanupLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) {
var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration(); var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
if (tenantProfileConfiguration == null) return;
String sessionId = session.getId(); String sessionId = session.getId();
perSessionUpdateLimits.remove(sessionRef.getSessionId()); perSessionUpdateLimits.remove(sessionRef.getSessionId());
@ -477,4 +483,9 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
} }
} }
private DefaultTenantProfileConfiguration getTenantProfileConfiguration(TelemetryWebSocketSessionRef sessionRef) {
return Optional.ofNullable(tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()))
.map(TenantProfile::getDefaultProfileConfiguration).orElse(null);
}
} }

View File

@ -31,6 +31,7 @@ import org.thingsboard.common.util.ThingsBoardExecutors;
import org.thingsboard.common.util.ThingsBoardThreadFactory; import org.thingsboard.common.util.ThingsBoardThreadFactory;
import org.thingsboard.server.common.data.DataConstants; import org.thingsboard.server.common.data.DataConstants;
import org.thingsboard.server.common.data.StringUtils; import org.thingsboard.server.common.data.StringUtils;
import org.thingsboard.server.common.data.TenantProfile;
import org.thingsboard.server.common.data.id.CustomerId; import org.thingsboard.server.common.data.id.CustomerId;
import org.thingsboard.server.common.data.id.EntityId; import org.thingsboard.server.common.data.id.EntityId;
import org.thingsboard.server.common.data.id.EntityIdFactory; import org.thingsboard.server.common.data.id.EntityIdFactory;
@ -316,7 +317,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
} }
private void processSessionClose(TelemetryWebSocketSessionRef sessionRef) { private void processSessionClose(TelemetryWebSocketSessionRef sessionRef) {
var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration(); var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
if (tenantProfileConfiguration != null) { if (tenantProfileConfiguration != null) {
String sessionId = "[" + sessionRef.getSessionId() + "]"; String sessionId = "[" + sessionRef.getSessionId() + "]";
@ -350,7 +351,8 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
} }
private boolean processSubscription(TelemetryWebSocketSessionRef sessionRef, SubscriptionCmd cmd) { private boolean processSubscription(TelemetryWebSocketSessionRef sessionRef, SubscriptionCmd cmd) {
var tenantProfileConfiguration = (DefaultTenantProfileConfiguration) tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration(); var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
if (tenantProfileConfiguration == null) return true;
String subId = "[" + sessionRef.getSessionId() + "]:[" + cmd.getCmdId() + "]"; String subId = "[" + sessionRef.getSessionId() + "]:[" + cmd.getCmdId() + "]";
try { try {
@ -932,4 +934,10 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
private int getLimit(int limit) { private int getLimit(int limit) {
return limit == 0 ? DEFAULT_LIMIT : limit; return limit == 0 ? DEFAULT_LIMIT : limit;
} }
private DefaultTenantProfileConfiguration getTenantProfileConfiguration(TelemetryWebSocketSessionRef sessionRef) {
return Optional.ofNullable(tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()))
.map(TenantProfile::getDefaultProfileConfiguration).orElse(null);
}
} }

View File

@ -15,6 +15,7 @@
*/ */
package org.thingsboard.server.controller; package org.thingsboard.server.controller;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.FutureCallback;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.checkerframework.checker.nullness.qual.Nullable; import org.checkerframework.checker.nullness.qual.Nullable;
@ -23,11 +24,15 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.thingsboard.common.util.JacksonUtil;
import org.thingsboard.server.common.data.Device; import org.thingsboard.server.common.data.Device;
import org.thingsboard.server.common.data.id.EntityId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.data.kv.AttributeKvEntry; import org.thingsboard.server.common.data.kv.AttributeKvEntry;
import org.thingsboard.server.common.data.kv.BaseAttributeKvEntry; import org.thingsboard.server.common.data.kv.BaseAttributeKvEntry;
import org.thingsboard.server.common.data.kv.BasicTsKvEntry; import org.thingsboard.server.common.data.kv.BasicTsKvEntry;
import org.thingsboard.server.common.data.kv.LongDataEntry; import org.thingsboard.server.common.data.kv.LongDataEntry;
import org.thingsboard.server.common.data.kv.StringDataEntry;
import org.thingsboard.server.common.data.kv.TsKvEntry; import org.thingsboard.server.common.data.kv.TsKvEntry;
import org.thingsboard.server.common.data.page.PageData; import org.thingsboard.server.common.data.page.PageData;
import org.thingsboard.server.common.data.query.DeviceTypeFilter; import org.thingsboard.server.common.data.query.DeviceTypeFilter;
@ -41,12 +46,14 @@ import org.thingsboard.server.common.data.query.EntityKeyValueType;
import org.thingsboard.server.common.data.query.FilterPredicateValue; import org.thingsboard.server.common.data.query.FilterPredicateValue;
import org.thingsboard.server.common.data.query.KeyFilter; import org.thingsboard.server.common.data.query.KeyFilter;
import org.thingsboard.server.common.data.query.NumericFilterPredicate; import org.thingsboard.server.common.data.query.NumericFilterPredicate;
import org.thingsboard.server.common.data.query.SingleEntityFilter;
import org.thingsboard.server.common.data.query.TsValue; import org.thingsboard.server.common.data.query.TsValue;
import org.thingsboard.server.service.subscription.TbAttributeSubscriptionScope; import org.thingsboard.server.service.subscription.TbAttributeSubscriptionScope;
import org.thingsboard.server.service.telemetry.TelemetrySubscriptionService; import org.thingsboard.server.service.telemetry.TelemetrySubscriptionService;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountCmd; import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountCmd;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountUpdate; import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountUpdate;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityDataUpdate; import org.thingsboard.server.service.telemetry.cmd.v2.EntityDataUpdate;
import org.thingsboard.server.service.telemetry.sub.SubscriptionErrorCode;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -54,6 +61,8 @@ import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@Slf4j @Slf4j
@ -78,6 +87,7 @@ public abstract class BaseWebsocketApiTest extends AbstractControllerTest {
@After @After
public void tearDown() throws Exception { public void tearDown() throws Exception {
loginTenantAdmin();
doDelete("/api/device/" + device.getId().getId()) doDelete("/api/device/" + device.getId().getId())
.andExpect(status().isOk()); .andExpect(status().isOk());
} }
@ -532,6 +542,28 @@ public abstract class BaseWebsocketApiTest extends AbstractControllerTest {
Assert.assertEquals(new TsValue(dataPoint5.getLastUpdateTs(), dataPoint5.getValueAsString()), attrValue); Assert.assertEquals(new TsValue(dataPoint5.getLastUpdateTs(), dataPoint5.getValueAsString()), attrValue);
} }
@Test
public void testAttributesSubscription_sysAdmin() throws Exception {
loginSysAdmin();
SingleEntityFilter entityFilter = new SingleEntityFilter();
entityFilter.setSingleEntity(tenantId);
assertThatNoException().isThrownBy(() -> {
JsonNode update = getWsClient().subscribeForAttributes(tenantId, TbAttributeSubscriptionScope.SERVER_SCOPE.name(), List.of("attr"));
assertThat(update.get("errorMsg").isNull()).isTrue();
assertThat(update.get("errorCode").asInt()).isEqualTo(SubscriptionErrorCode.NO_ERROR.getCode());
});
getWsClient().registerWaitForUpdate();
String expectedAttrValue = "42";
sendAttributes(TenantId.SYS_TENANT_ID, tenantId, TbAttributeSubscriptionScope.SERVER_SCOPE, List.of(
new BaseAttributeKvEntry(System.currentTimeMillis(), new StringDataEntry("attr", expectedAttrValue))
));
JsonNode update = JacksonUtil.toJsonNode(getWsClient().waitForUpdate());
assertThat(update).isNotNull();
assertThat(update.get("data").get("attr").get(0).get(1).asText()).isEqualTo(expectedAttrValue);
}
private void sendTelemetry(Device device, List<TsKvEntry> tsData) throws InterruptedException { private void sendTelemetry(Device device, List<TsKvEntry> tsData) throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1); CountDownLatch latch = new CountDownLatch(1);
tsService.saveAndNotify(device.getTenantId(), null, device.getId(), tsData, 0, new FutureCallback<Void>() { tsService.saveAndNotify(device.getTenantId(), null, device.getId(), tsData, 0, new FutureCallback<Void>() {
@ -549,8 +581,12 @@ public abstract class BaseWebsocketApiTest extends AbstractControllerTest {
} }
private void sendAttributes(Device device, TbAttributeSubscriptionScope scope, List<AttributeKvEntry> attrData) throws InterruptedException { private void sendAttributes(Device device, TbAttributeSubscriptionScope scope, List<AttributeKvEntry> attrData) throws InterruptedException {
sendAttributes(device.getTenantId(), device.getId(), scope, attrData);
}
private void sendAttributes(TenantId tenantId, EntityId entityId, TbAttributeSubscriptionScope scope, List<AttributeKvEntry> attrData) throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1); CountDownLatch latch = new CountDownLatch(1);
tsService.saveAndNotify(device.getTenantId(), device.getId(), scope.name(), attrData, new FutureCallback<Void>() { tsService.saveAndNotify(tenantId, entityId, scope.name(), attrData, new FutureCallback<Void>() {
@Override @Override
public void onSuccess(@Nullable Void result) { public void onSuccess(@Nullable Void result) {
latch.countDown(); latch.countDown();

View File

@ -15,16 +15,20 @@
*/ */
package org.thingsboard.server.controller; package org.thingsboard.server.controller;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient; import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake; import org.java_websocket.handshake.ServerHandshake;
import org.thingsboard.common.util.JacksonUtil; import org.thingsboard.common.util.JacksonUtil;
import org.thingsboard.server.common.data.id.EntityId;
import org.thingsboard.server.common.data.kv.Aggregation; import org.thingsboard.server.common.data.kv.Aggregation;
import org.thingsboard.server.common.data.query.EntityDataPageLink; import org.thingsboard.server.common.data.query.EntityDataPageLink;
import org.thingsboard.server.common.data.query.EntityDataQuery; import org.thingsboard.server.common.data.query.EntityDataQuery;
import org.thingsboard.server.common.data.query.EntityFilter; import org.thingsboard.server.common.data.query.EntityFilter;
import org.thingsboard.server.common.data.query.EntityKey; import org.thingsboard.server.common.data.query.EntityKey;
import org.thingsboard.server.service.telemetry.cmd.TelemetryPluginCmdsWrapper; import org.thingsboard.server.service.telemetry.cmd.TelemetryPluginCmdsWrapper;
import org.thingsboard.server.service.telemetry.cmd.v1.AttributesSubscriptionCmd;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountCmd; import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountCmd;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountUpdate; import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountUpdate;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityDataCmd; import org.thingsboard.server.service.telemetry.cmd.v2.EntityDataCmd;
@ -177,6 +181,21 @@ public class TbTestWebSocketClient extends WebSocketClient {
return parseDataReply(waitForReply()); return parseDataReply(waitForReply());
} }
public JsonNode subscribeForAttributes(EntityId entityId, String scope, List<String> keys) {
AttributesSubscriptionCmd cmd = new AttributesSubscriptionCmd();
cmd.setCmdId(1);
cmd.setEntityType(entityId.getEntityType().toString());
cmd.setEntityId(entityId.getId().toString());
cmd.setScope(scope);
cmd.setKeys(String.join(",", keys));
TelemetryPluginCmdsWrapper cmdsWrapper = new TelemetryPluginCmdsWrapper();
cmdsWrapper.setAttrSubCmds(List.of(cmd));
JsonNode msg = JacksonUtil.valueToTree(cmdsWrapper);
((ObjectNode) msg.get("attrSubCmds").get(0)).remove("type");
send(msg.toString());
return JacksonUtil.toJsonNode(waitForReply());
}
public EntityDataUpdate sendHistoryCmd(List<String> keys, long startTs, long timeWindow) { public EntityDataUpdate sendHistoryCmd(List<String> keys, long startTs, long timeWindow) {
return sendHistoryCmd(keys, startTs, timeWindow, (EntityDataQuery) null); return sendHistoryCmd(keys, startTs, timeWindow, (EntityDataQuery) null);
} }