Fix NPE on WS subscription for sysadmin

This commit is contained in:
ViacheslavKlimov 2022-12-13 15:06:29 +02:00
parent db0288dda3
commit 4a834a0442
4 changed files with 95 additions and 21 deletions

View File

@ -28,10 +28,12 @@ import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.NativeWebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.thingsboard.server.common.data.StringUtils;
import org.thingsboard.server.common.data.TenantProfile;
import org.thingsboard.server.common.data.exception.ThingsboardErrorCode;
import org.thingsboard.server.common.data.id.CustomerId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.data.id.UserId;
import org.thingsboard.server.common.data.tenant.profile.DefaultTenantProfileConfiguration;
import org.thingsboard.server.common.msg.tools.TbRateLimits;
import org.thingsboard.server.config.WebSocketConfiguration;
import org.thingsboard.server.dao.tenant.TbTenantProfileCache;
@ -50,6 +52,7 @@ import javax.websocket.Session;
import java.io.IOException;
import java.net.URI;
import java.security.InvalidParameterException;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.UUID;
@ -136,8 +139,9 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
if (!checkLimits(session, sessionRef)) {
return;
}
var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration();
internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef, tenantProfileConfiguration.getWsMsgQueueLimitPerSession() > 0 ?
var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef,
tenantProfileConfiguration != null && tenantProfileConfiguration.getWsMsgQueueLimitPerSession() > 0 ?
tenantProfileConfiguration.getWsMsgQueueLimitPerSession() : 500));
externalSessionMap.put(externalSessionId, internalSessionId);
@ -316,22 +320,24 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
if (internalId != null) {
SessionMetaData sessionMd = internalSessionMap.get(internalId);
if (sessionMd != null) {
var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration();
if (StringUtils.isNotEmpty(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit())) {
TbRateLimits rateLimits = perSessionUpdateLimits.computeIfAbsent(sessionRef.getSessionId(), sid -> new TbRateLimits(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit()));
if (!rateLimits.tryConsume()) {
if (blacklistedSessions.putIfAbsent(externalId, sessionRef) == null) {
log.info("[{}][{}][{}] Failed to process session update. Max session updates limit reached"
, sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), externalId);
sessionMd.sendMsg("{\"subscriptionId\":" + subscriptionId + ", \"errorCode\":" + ThingsboardErrorCode.TOO_MANY_UPDATES.getErrorCode() + ", \"errorMsg\":\"Too many updates!\"}");
var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
if (tenantProfileConfiguration != null) {
if (StringUtils.isNotEmpty(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit())) {
TbRateLimits rateLimits = perSessionUpdateLimits.computeIfAbsent(sessionRef.getSessionId(), sid -> new TbRateLimits(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit()));
if (!rateLimits.tryConsume()) {
if (blacklistedSessions.putIfAbsent(externalId, sessionRef) == null) {
log.info("[{}][{}][{}] Failed to process session update. Max session updates limit reached"
, sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), externalId);
sessionMd.sendMsg("{\"subscriptionId\":" + subscriptionId + ", \"errorCode\":" + ThingsboardErrorCode.TOO_MANY_UPDATES.getErrorCode() + ", \"errorMsg\":\"Too many updates!\"}");
}
return;
} else {
log.debug("[{}][{}][{}] Session is no longer blacklisted.", sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), externalId);
blacklistedSessions.remove(externalId);
}
return;
} else {
log.debug("[{}][{}][{}] Session is no longer blacklisted.", sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), externalId);
blacklistedSessions.remove(externalId);
perSessionUpdateLimits.remove(sessionRef.getSessionId());
}
} else {
perSessionUpdateLimits.remove(sessionRef.getSessionId());
}
sessionMd.sendMsg(msg);
} else {
@ -376,8 +382,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
}
private boolean checkLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) throws Exception {
var tenantProfileConfiguration =
tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration();
var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
if (tenantProfileConfiguration == null) {
return true;
}
@ -444,7 +449,8 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
}
private void cleanupLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) {
var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration();
var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
if (tenantProfileConfiguration == null) return;
String sessionId = session.getId();
perSessionUpdateLimits.remove(sessionRef.getSessionId());
@ -477,4 +483,9 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
}
}
private DefaultTenantProfileConfiguration getTenantProfileConfiguration(TelemetryWebSocketSessionRef sessionRef) {
return Optional.ofNullable(tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()))
.map(TenantProfile::getDefaultProfileConfiguration).orElse(null);
}
}

View File

@ -31,6 +31,7 @@ import org.thingsboard.common.util.ThingsBoardExecutors;
import org.thingsboard.common.util.ThingsBoardThreadFactory;
import org.thingsboard.server.common.data.DataConstants;
import org.thingsboard.server.common.data.StringUtils;
import org.thingsboard.server.common.data.TenantProfile;
import org.thingsboard.server.common.data.id.CustomerId;
import org.thingsboard.server.common.data.id.EntityId;
import org.thingsboard.server.common.data.id.EntityIdFactory;
@ -316,7 +317,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
}
private void processSessionClose(TelemetryWebSocketSessionRef sessionRef) {
var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration();
var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
if (tenantProfileConfiguration != null) {
String sessionId = "[" + sessionRef.getSessionId() + "]";
@ -350,7 +351,8 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
}
private boolean processSubscription(TelemetryWebSocketSessionRef sessionRef, SubscriptionCmd cmd) {
var tenantProfileConfiguration = (DefaultTenantProfileConfiguration) tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration();
var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef);
if (tenantProfileConfiguration == null) return true;
String subId = "[" + sessionRef.getSessionId() + "]:[" + cmd.getCmdId() + "]";
try {
@ -932,4 +934,10 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
private int getLimit(int limit) {
return limit == 0 ? DEFAULT_LIMIT : limit;
}
private DefaultTenantProfileConfiguration getTenantProfileConfiguration(TelemetryWebSocketSessionRef sessionRef) {
return Optional.ofNullable(tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()))
.map(TenantProfile::getDefaultProfileConfiguration).orElse(null);
}
}

View File

@ -15,6 +15,7 @@
*/
package org.thingsboard.server.controller;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.util.concurrent.FutureCallback;
import lombok.extern.slf4j.Slf4j;
import org.checkerframework.checker.nullness.qual.Nullable;
@ -23,11 +24,15 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.thingsboard.common.util.JacksonUtil;
import org.thingsboard.server.common.data.Device;
import org.thingsboard.server.common.data.id.EntityId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.data.kv.AttributeKvEntry;
import org.thingsboard.server.common.data.kv.BaseAttributeKvEntry;
import org.thingsboard.server.common.data.kv.BasicTsKvEntry;
import org.thingsboard.server.common.data.kv.LongDataEntry;
import org.thingsboard.server.common.data.kv.StringDataEntry;
import org.thingsboard.server.common.data.kv.TsKvEntry;
import org.thingsboard.server.common.data.page.PageData;
import org.thingsboard.server.common.data.query.DeviceTypeFilter;
@ -41,12 +46,14 @@ import org.thingsboard.server.common.data.query.EntityKeyValueType;
import org.thingsboard.server.common.data.query.FilterPredicateValue;
import org.thingsboard.server.common.data.query.KeyFilter;
import org.thingsboard.server.common.data.query.NumericFilterPredicate;
import org.thingsboard.server.common.data.query.SingleEntityFilter;
import org.thingsboard.server.common.data.query.TsValue;
import org.thingsboard.server.service.subscription.TbAttributeSubscriptionScope;
import org.thingsboard.server.service.telemetry.TelemetrySubscriptionService;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountCmd;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountUpdate;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityDataUpdate;
import org.thingsboard.server.service.telemetry.sub.SubscriptionErrorCode;
import java.util.Arrays;
import java.util.Collections;
@ -54,6 +61,8 @@ import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@Slf4j
@ -78,6 +87,7 @@ public abstract class BaseWebsocketApiTest extends AbstractControllerTest {
@After
public void tearDown() throws Exception {
loginTenantAdmin();
doDelete("/api/device/" + device.getId().getId())
.andExpect(status().isOk());
}
@ -532,6 +542,28 @@ public abstract class BaseWebsocketApiTest extends AbstractControllerTest {
Assert.assertEquals(new TsValue(dataPoint5.getLastUpdateTs(), dataPoint5.getValueAsString()), attrValue);
}
@Test
public void testAttributesSubscription_sysAdmin() throws Exception {
loginSysAdmin();
SingleEntityFilter entityFilter = new SingleEntityFilter();
entityFilter.setSingleEntity(tenantId);
assertThatNoException().isThrownBy(() -> {
JsonNode update = getWsClient().subscribeForAttributes(tenantId, TbAttributeSubscriptionScope.SERVER_SCOPE.name(), List.of("attr"));
assertThat(update.get("errorMsg").isNull()).isTrue();
assertThat(update.get("errorCode").asInt()).isEqualTo(SubscriptionErrorCode.NO_ERROR.getCode());
});
getWsClient().registerWaitForUpdate();
String expectedAttrValue = "42";
sendAttributes(TenantId.SYS_TENANT_ID, tenantId, TbAttributeSubscriptionScope.SERVER_SCOPE, List.of(
new BaseAttributeKvEntry(System.currentTimeMillis(), new StringDataEntry("attr", expectedAttrValue))
));
JsonNode update = JacksonUtil.toJsonNode(getWsClient().waitForUpdate());
assertThat(update).isNotNull();
assertThat(update.get("data").get("attr").get(0).get(1).asText()).isEqualTo(expectedAttrValue);
}
private void sendTelemetry(Device device, List<TsKvEntry> tsData) throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
tsService.saveAndNotify(device.getTenantId(), null, device.getId(), tsData, 0, new FutureCallback<Void>() {
@ -549,8 +581,12 @@ public abstract class BaseWebsocketApiTest extends AbstractControllerTest {
}
private void sendAttributes(Device device, TbAttributeSubscriptionScope scope, List<AttributeKvEntry> attrData) throws InterruptedException {
sendAttributes(device.getTenantId(), device.getId(), scope, attrData);
}
private void sendAttributes(TenantId tenantId, EntityId entityId, TbAttributeSubscriptionScope scope, List<AttributeKvEntry> attrData) throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
tsService.saveAndNotify(device.getTenantId(), device.getId(), scope.name(), attrData, new FutureCallback<Void>() {
tsService.saveAndNotify(tenantId, entityId, scope.name(), attrData, new FutureCallback<Void>() {
@Override
public void onSuccess(@Nullable Void result) {
latch.countDown();

View File

@ -15,16 +15,20 @@
*/
package org.thingsboard.server.controller;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake;
import org.thingsboard.common.util.JacksonUtil;
import org.thingsboard.server.common.data.id.EntityId;
import org.thingsboard.server.common.data.kv.Aggregation;
import org.thingsboard.server.common.data.query.EntityDataPageLink;
import org.thingsboard.server.common.data.query.EntityDataQuery;
import org.thingsboard.server.common.data.query.EntityFilter;
import org.thingsboard.server.common.data.query.EntityKey;
import org.thingsboard.server.service.telemetry.cmd.TelemetryPluginCmdsWrapper;
import org.thingsboard.server.service.telemetry.cmd.v1.AttributesSubscriptionCmd;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountCmd;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountUpdate;
import org.thingsboard.server.service.telemetry.cmd.v2.EntityDataCmd;
@ -177,6 +181,21 @@ public class TbTestWebSocketClient extends WebSocketClient {
return parseDataReply(waitForReply());
}
public JsonNode subscribeForAttributes(EntityId entityId, String scope, List<String> keys) {
AttributesSubscriptionCmd cmd = new AttributesSubscriptionCmd();
cmd.setCmdId(1);
cmd.setEntityType(entityId.getEntityType().toString());
cmd.setEntityId(entityId.getId().toString());
cmd.setScope(scope);
cmd.setKeys(String.join(",", keys));
TelemetryPluginCmdsWrapper cmdsWrapper = new TelemetryPluginCmdsWrapper();
cmdsWrapper.setAttrSubCmds(List.of(cmd));
JsonNode msg = JacksonUtil.valueToTree(cmdsWrapper);
((ObjectNode) msg.get("attrSubCmds").get(0)).remove("type");
send(msg.toString());
return JacksonUtil.toJsonNode(waitForReply());
}
public EntityDataUpdate sendHistoryCmd(List<String> keys, long startTs, long timeWindow) {
return sendHistoryCmd(keys, startTs, timeWindow, (EntityDataQuery) null);
}