diff --git a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java index 432b35fd80..3a429add9b 100644 --- a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java +++ b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java @@ -28,10 +28,12 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.NativeWebSocketSession; import org.springframework.web.socket.handler.TextWebSocketHandler; import org.thingsboard.server.common.data.StringUtils; +import org.thingsboard.server.common.data.TenantProfile; import org.thingsboard.server.common.data.exception.ThingsboardErrorCode; import org.thingsboard.server.common.data.id.CustomerId; import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.common.data.id.UserId; +import org.thingsboard.server.common.data.tenant.profile.DefaultTenantProfileConfiguration; import org.thingsboard.server.common.msg.tools.TbRateLimits; import org.thingsboard.server.config.WebSocketConfiguration; import org.thingsboard.server.dao.tenant.TbTenantProfileCache; @@ -50,6 +52,7 @@ import javax.websocket.Session; import java.io.IOException; import java.net.URI; import java.security.InvalidParameterException; +import java.util.Optional; import java.util.Queue; import java.util.Set; import java.util.UUID; @@ -136,8 +139,9 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr if (!checkLimits(session, sessionRef)) { return; } - var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration(); - internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef, tenantProfileConfiguration.getWsMsgQueueLimitPerSession() > 0 ? + var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef); + internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef, + tenantProfileConfiguration != null && tenantProfileConfiguration.getWsMsgQueueLimitPerSession() > 0 ? tenantProfileConfiguration.getWsMsgQueueLimitPerSession() : 500)); externalSessionMap.put(externalSessionId, internalSessionId); @@ -316,22 +320,24 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr if (internalId != null) { SessionMetaData sessionMd = internalSessionMap.get(internalId); if (sessionMd != null) { - var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration(); - if (StringUtils.isNotEmpty(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit())) { - TbRateLimits rateLimits = perSessionUpdateLimits.computeIfAbsent(sessionRef.getSessionId(), sid -> new TbRateLimits(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit())); - if (!rateLimits.tryConsume()) { - if (blacklistedSessions.putIfAbsent(externalId, sessionRef) == null) { - log.info("[{}][{}][{}] Failed to process session update. Max session updates limit reached" - , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), externalId); - sessionMd.sendMsg("{\"subscriptionId\":" + subscriptionId + ", \"errorCode\":" + ThingsboardErrorCode.TOO_MANY_UPDATES.getErrorCode() + ", \"errorMsg\":\"Too many updates!\"}"); + var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef); + if (tenantProfileConfiguration != null) { + if (StringUtils.isNotEmpty(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit())) { + TbRateLimits rateLimits = perSessionUpdateLimits.computeIfAbsent(sessionRef.getSessionId(), sid -> new TbRateLimits(tenantProfileConfiguration.getWsUpdatesPerSessionRateLimit())); + if (!rateLimits.tryConsume()) { + if (blacklistedSessions.putIfAbsent(externalId, sessionRef) == null) { + log.info("[{}][{}][{}] Failed to process session update. Max session updates limit reached" + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), externalId); + sessionMd.sendMsg("{\"subscriptionId\":" + subscriptionId + ", \"errorCode\":" + ThingsboardErrorCode.TOO_MANY_UPDATES.getErrorCode() + ", \"errorMsg\":\"Too many updates!\"}"); + } + return; + } else { + log.debug("[{}][{}][{}] Session is no longer blacklisted.", sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), externalId); + blacklistedSessions.remove(externalId); } - return; } else { - log.debug("[{}][{}][{}] Session is no longer blacklisted.", sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), externalId); - blacklistedSessions.remove(externalId); + perSessionUpdateLimits.remove(sessionRef.getSessionId()); } - } else { - perSessionUpdateLimits.remove(sessionRef.getSessionId()); } sessionMd.sendMsg(msg); } else { @@ -376,8 +382,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr } private boolean checkLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) throws Exception { - var tenantProfileConfiguration = - tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration(); + var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef); if (tenantProfileConfiguration == null) { return true; } @@ -444,7 +449,8 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr } private void cleanupLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) { - var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration(); + var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef); + if (tenantProfileConfiguration == null) return; String sessionId = session.getId(); perSessionUpdateLimits.remove(sessionRef.getSessionId()); @@ -477,4 +483,9 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr } } + private DefaultTenantProfileConfiguration getTenantProfileConfiguration(TelemetryWebSocketSessionRef sessionRef) { + return Optional.ofNullable(tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId())) + .map(TenantProfile::getDefaultProfileConfiguration).orElse(null); + } + } \ No newline at end of file diff --git a/application/src/main/java/org/thingsboard/server/service/telemetry/DefaultTelemetryWebSocketService.java b/application/src/main/java/org/thingsboard/server/service/telemetry/DefaultTelemetryWebSocketService.java index d9beb1ed33..06de7a5fe7 100644 --- a/application/src/main/java/org/thingsboard/server/service/telemetry/DefaultTelemetryWebSocketService.java +++ b/application/src/main/java/org/thingsboard/server/service/telemetry/DefaultTelemetryWebSocketService.java @@ -31,6 +31,7 @@ import org.thingsboard.common.util.ThingsBoardExecutors; import org.thingsboard.common.util.ThingsBoardThreadFactory; import org.thingsboard.server.common.data.DataConstants; import org.thingsboard.server.common.data.StringUtils; +import org.thingsboard.server.common.data.TenantProfile; import org.thingsboard.server.common.data.id.CustomerId; import org.thingsboard.server.common.data.id.EntityId; import org.thingsboard.server.common.data.id.EntityIdFactory; @@ -316,7 +317,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi } private void processSessionClose(TelemetryWebSocketSessionRef sessionRef) { - var tenantProfileConfiguration = tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration(); + var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef); if (tenantProfileConfiguration != null) { String sessionId = "[" + sessionRef.getSessionId() + "]"; @@ -350,7 +351,8 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi } private boolean processSubscription(TelemetryWebSocketSessionRef sessionRef, SubscriptionCmd cmd) { - var tenantProfileConfiguration = (DefaultTenantProfileConfiguration) tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId()).getDefaultProfileConfiguration(); + var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef); + if (tenantProfileConfiguration == null) return true; String subId = "[" + sessionRef.getSessionId() + "]:[" + cmd.getCmdId() + "]"; try { @@ -932,4 +934,10 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi private int getLimit(int limit) { return limit == 0 ? DEFAULT_LIMIT : limit; } + + private DefaultTenantProfileConfiguration getTenantProfileConfiguration(TelemetryWebSocketSessionRef sessionRef) { + return Optional.ofNullable(tenantProfileCache.get(sessionRef.getSecurityCtx().getTenantId())) + .map(TenantProfile::getDefaultProfileConfiguration).orElse(null); + } + } diff --git a/application/src/test/java/org/thingsboard/server/controller/BaseWebsocketApiTest.java b/application/src/test/java/org/thingsboard/server/controller/BaseWebsocketApiTest.java index f7838d9689..5f133658af 100644 --- a/application/src/test/java/org/thingsboard/server/controller/BaseWebsocketApiTest.java +++ b/application/src/test/java/org/thingsboard/server/controller/BaseWebsocketApiTest.java @@ -15,6 +15,7 @@ */ package org.thingsboard.server.controller; +import com.fasterxml.jackson.databind.JsonNode; import com.google.common.util.concurrent.FutureCallback; import lombok.extern.slf4j.Slf4j; import org.checkerframework.checker.nullness.qual.Nullable; @@ -23,11 +24,15 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.springframework.beans.factory.annotation.Autowired; +import org.thingsboard.common.util.JacksonUtil; import org.thingsboard.server.common.data.Device; +import org.thingsboard.server.common.data.id.EntityId; +import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.common.data.kv.AttributeKvEntry; import org.thingsboard.server.common.data.kv.BaseAttributeKvEntry; import org.thingsboard.server.common.data.kv.BasicTsKvEntry; import org.thingsboard.server.common.data.kv.LongDataEntry; +import org.thingsboard.server.common.data.kv.StringDataEntry; import org.thingsboard.server.common.data.kv.TsKvEntry; import org.thingsboard.server.common.data.page.PageData; import org.thingsboard.server.common.data.query.DeviceTypeFilter; @@ -41,12 +46,14 @@ import org.thingsboard.server.common.data.query.EntityKeyValueType; import org.thingsboard.server.common.data.query.FilterPredicateValue; import org.thingsboard.server.common.data.query.KeyFilter; import org.thingsboard.server.common.data.query.NumericFilterPredicate; +import org.thingsboard.server.common.data.query.SingleEntityFilter; import org.thingsboard.server.common.data.query.TsValue; import org.thingsboard.server.service.subscription.TbAttributeSubscriptionScope; import org.thingsboard.server.service.telemetry.TelemetrySubscriptionService; import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountCmd; import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountUpdate; import org.thingsboard.server.service.telemetry.cmd.v2.EntityDataUpdate; +import org.thingsboard.server.service.telemetry.sub.SubscriptionErrorCode; import java.util.Arrays; import java.util.Collections; @@ -54,6 +61,8 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @Slf4j @@ -78,6 +87,7 @@ public abstract class BaseWebsocketApiTest extends AbstractControllerTest { @After public void tearDown() throws Exception { + loginTenantAdmin(); doDelete("/api/device/" + device.getId().getId()) .andExpect(status().isOk()); } @@ -532,6 +542,28 @@ public abstract class BaseWebsocketApiTest extends AbstractControllerTest { Assert.assertEquals(new TsValue(dataPoint5.getLastUpdateTs(), dataPoint5.getValueAsString()), attrValue); } + @Test + public void testAttributesSubscription_sysAdmin() throws Exception { + loginSysAdmin(); + SingleEntityFilter entityFilter = new SingleEntityFilter(); + entityFilter.setSingleEntity(tenantId); + + assertThatNoException().isThrownBy(() -> { + JsonNode update = getWsClient().subscribeForAttributes(tenantId, TbAttributeSubscriptionScope.SERVER_SCOPE.name(), List.of("attr")); + assertThat(update.get("errorMsg").isNull()).isTrue(); + assertThat(update.get("errorCode").asInt()).isEqualTo(SubscriptionErrorCode.NO_ERROR.getCode()); + }); + + getWsClient().registerWaitForUpdate(); + String expectedAttrValue = "42"; + sendAttributes(TenantId.SYS_TENANT_ID, tenantId, TbAttributeSubscriptionScope.SERVER_SCOPE, List.of( + new BaseAttributeKvEntry(System.currentTimeMillis(), new StringDataEntry("attr", expectedAttrValue)) + )); + JsonNode update = JacksonUtil.toJsonNode(getWsClient().waitForUpdate()); + assertThat(update).isNotNull(); + assertThat(update.get("data").get("attr").get(0).get(1).asText()).isEqualTo(expectedAttrValue); + } + private void sendTelemetry(Device device, List tsData) throws InterruptedException { CountDownLatch latch = new CountDownLatch(1); tsService.saveAndNotify(device.getTenantId(), null, device.getId(), tsData, 0, new FutureCallback() { @@ -549,8 +581,12 @@ public abstract class BaseWebsocketApiTest extends AbstractControllerTest { } private void sendAttributes(Device device, TbAttributeSubscriptionScope scope, List attrData) throws InterruptedException { + sendAttributes(device.getTenantId(), device.getId(), scope, attrData); + } + + private void sendAttributes(TenantId tenantId, EntityId entityId, TbAttributeSubscriptionScope scope, List attrData) throws InterruptedException { CountDownLatch latch = new CountDownLatch(1); - tsService.saveAndNotify(device.getTenantId(), device.getId(), scope.name(), attrData, new FutureCallback() { + tsService.saveAndNotify(tenantId, entityId, scope.name(), attrData, new FutureCallback() { @Override public void onSuccess(@Nullable Void result) { latch.countDown(); diff --git a/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java b/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java index be7973d517..efd2336ac1 100644 --- a/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java +++ b/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java @@ -15,16 +15,20 @@ */ package org.thingsboard.server.controller; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; import lombok.extern.slf4j.Slf4j; import org.java_websocket.client.WebSocketClient; import org.java_websocket.handshake.ServerHandshake; import org.thingsboard.common.util.JacksonUtil; +import org.thingsboard.server.common.data.id.EntityId; import org.thingsboard.server.common.data.kv.Aggregation; import org.thingsboard.server.common.data.query.EntityDataPageLink; import org.thingsboard.server.common.data.query.EntityDataQuery; import org.thingsboard.server.common.data.query.EntityFilter; import org.thingsboard.server.common.data.query.EntityKey; import org.thingsboard.server.service.telemetry.cmd.TelemetryPluginCmdsWrapper; +import org.thingsboard.server.service.telemetry.cmd.v1.AttributesSubscriptionCmd; import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountCmd; import org.thingsboard.server.service.telemetry.cmd.v2.EntityCountUpdate; import org.thingsboard.server.service.telemetry.cmd.v2.EntityDataCmd; @@ -177,6 +181,21 @@ public class TbTestWebSocketClient extends WebSocketClient { return parseDataReply(waitForReply()); } + public JsonNode subscribeForAttributes(EntityId entityId, String scope, List keys) { + AttributesSubscriptionCmd cmd = new AttributesSubscriptionCmd(); + cmd.setCmdId(1); + cmd.setEntityType(entityId.getEntityType().toString()); + cmd.setEntityId(entityId.getId().toString()); + cmd.setScope(scope); + cmd.setKeys(String.join(",", keys)); + TelemetryPluginCmdsWrapper cmdsWrapper = new TelemetryPluginCmdsWrapper(); + cmdsWrapper.setAttrSubCmds(List.of(cmd)); + JsonNode msg = JacksonUtil.valueToTree(cmdsWrapper); + ((ObjectNode) msg.get("attrSubCmds").get(0)).remove("type"); + send(msg.toString()); + return JacksonUtil.toJsonNode(waitForReply()); + } + public EntityDataUpdate sendHistoryCmd(List keys, long startTs, long timeWindow) { return sendHistoryCmd(keys, startTs, timeWindow, (EntityDataQuery) null); }