Refactor rate limiting to be configurable and use existing class

This commit is contained in:
Dmytro Skarzhynets 2024-02-15 17:09:32 +02:00 committed by Dmytro Skarzhynets
parent c5c18202c2
commit e55d14325a
6 changed files with 80 additions and 145 deletions

View File

@ -561,6 +561,10 @@ public class ActorSystemContext {
@Getter
private boolean externalNodeForceAck;
@Value("${state.rule.node.deviceState.rateLimit:1:1,30:60,60:3600}")
@Getter
private String deviceStateNodeRateLimitConfig;
@Getter
@Setter
private TbActorSystem actorSystem;

View File

@ -689,6 +689,11 @@ class DefaultTbContext implements TbContext {
return mainCtx.getDeviceStateManager();
}
@Override
public String getDeviceStateNodeRateLimitConfig() {
return mainCtx.getDeviceStateNodeRateLimitConfig();
}
@Override
public TbClusterService getClusterService() {
return mainCtx.getClusterService();

View File

@ -794,6 +794,15 @@ state:
# Used only when state.persistToTelemetry is set to 'true' and Cassandra is used for timeseries data.
# 0 means time-to-live mechanism is disabled.
telemetryTtl: "${STATE_TELEMETRY_TTL:0}"
# Configuration properties for rule nodes related to device activity state
rule:
node:
# Device state rule node
deviceState:
# Defines the rate at which device connectivity events can be triggered.
# Comma-separated list of capacity:duration pairs that define bandwidth capacity and refill duration for token bucket rate limit algorithm.
# Refill is set to be greedy. Please refer to Bucket4j library documentation for more details.
rateLimit: "${DEVICE_STATE_NODE_RATE_LIMIT_CONFIGURATION:1:1,30:60,60:3600}"
# Tbel parameters
tbel:

View File

@ -282,6 +282,8 @@ public interface TbContext {
RuleEngineDeviceStateManager getDeviceStateManager();
String getDeviceStateNodeRateLimitConfig();
TbClusterService getClusterService();
DashboardService getDashboardService();

View File

@ -15,8 +15,8 @@
*/
package org.thingsboard.rule.engine.action;
import com.google.common.base.Stopwatch;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.ConcurrentReferenceHashMap;
import org.thingsboard.rule.engine.api.RuleEngineDeviceStateManager;
import org.thingsboard.rule.engine.api.RuleNode;
import org.thingsboard.rule.engine.api.TbContext;
@ -28,17 +28,15 @@ import org.thingsboard.server.common.data.EntityType;
import org.thingsboard.server.common.data.id.DeviceId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.data.msg.TbMsgType;
import org.thingsboard.server.common.data.msg.TbNodeConnectionType;
import org.thingsboard.server.common.data.plugin.ComponentType;
import org.thingsboard.server.common.msg.TbMsg;
import org.thingsboard.server.common.msg.TbMsgMetaData;
import org.thingsboard.server.common.msg.queue.PartitionChangeMsg;
import org.thingsboard.server.common.msg.queue.TbCallback;
import org.thingsboard.server.common.msg.tools.TbRateLimits;
import java.time.Duration;
import java.util.EnumSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@Slf4j
@RuleNode(
@ -47,8 +45,8 @@ import java.util.concurrent.ConcurrentMap;
nodeDescription = "Triggers device connectivity events",
nodeDetails = "If incoming message originator is a device, registers configured event for that device in the Device State Service, which sends appropriate message to the Rule Engine." +
" If metadata <code>ts</code> property is present, it will be used as event timestamp. Otherwise, the message timestamp will be used." +
" Incoming message is forwarded using the <code>Success</code> chain, unless an unexpected error occurs during message processing" +
" then incoming message is forwarded using the <code>Failure</code> chain." +
" If originator entity type is not <code>DEVICE</code> or unexpected error happened during processing, then incoming message is forwarded using <code>Failure</code> chain." +
" If rate of connectivity events for a given originator is too high, then incoming message is forwarded using <code>Rate limited</code> chain. " +
"<br>" +
"Supported device connectivity events are:" +
"<ul>" +
@ -59,6 +57,7 @@ import java.util.concurrent.ConcurrentMap;
"</ul>" +
"This node is particularly useful when device isn't using transports to receive data, such as when fetching data from external API or computing new data within the rule chain.",
configClazz = TbDeviceStateNodeConfiguration.class,
relationTypes = {TbNodeConnectionType.SUCCESS, TbNodeConnectionType.FAILURE, "Rate limited"},
uiResources = {"static/rulenode/rulenode-core-config.js"},
configDirective = "tbActionNodeDeviceStateConfig"
)
@ -67,12 +66,9 @@ public class TbDeviceStateNode implements TbNode {
private static final Set<TbMsgType> SUPPORTED_EVENTS = EnumSet.of(
TbMsgType.CONNECT_EVENT, TbMsgType.ACTIVITY_EVENT, TbMsgType.DISCONNECT_EVENT, TbMsgType.INACTIVITY_EVENT
);
private static final Duration ONE_SECOND = Duration.ofSeconds(1L);
private static final Duration ENTRY_EXPIRATION_TIME = Duration.ofDays(1L);
private static final Duration ENTRY_CLEANUP_PERIOD = Duration.ofHours(1L);
private Stopwatch stopwatch;
private ConcurrentMap<DeviceId, Duration> lastActivityEventTimestamps;
private static final String DEFAULT_RATE_LIMIT_CONFIG = "1:1,30:60,60:3600";
private ConcurrentReferenceHashMap<DeviceId, TbRateLimits> rateLimits;
private String rateLimitConfig;
private TbMsgType event;
@Override
@ -85,19 +81,19 @@ public class TbDeviceStateNode implements TbNode {
throw new TbNodeException("Unsupported event: " + event, true);
}
this.event = event;
lastActivityEventTimestamps = new ConcurrentHashMap<>();
stopwatch = Stopwatch.createStarted();
scheduleCleanupMsg(ctx);
rateLimits = new ConcurrentReferenceHashMap<>();
String deviceStateNodeRateLimitConfig = ctx.getDeviceStateNodeRateLimitConfig();
try {
rateLimitConfig = new TbRateLimits(deviceStateNodeRateLimitConfig).getConfiguration();
} catch (Exception e) {
log.error("[{}][{}] Invalid rate limit configuration provided: [{}]. Will use default value [{}].",
ctx.getTenantId().getId(), ctx.getSelfId().getId(), deviceStateNodeRateLimitConfig, DEFAULT_RATE_LIMIT_CONFIG, e);
rateLimitConfig = DEFAULT_RATE_LIMIT_CONFIG;
}
}
@Override
public void onMsg(TbContext ctx, TbMsg msg) {
if (msg.isTypeOf(TbMsgType.DEVICE_STATE_STALE_ENTRIES_CLEANUP_SELF_MSG)) {
removeStaleEntries();
scheduleCleanupMsg(ctx);
return;
}
EntityType originatorEntityType = msg.getOriginator().getEntityType();
if (!EntityType.DEVICE.equals(originatorEntityType)) {
ctx.tellFailure(msg, new IllegalArgumentException(
@ -105,41 +101,18 @@ public class TbDeviceStateNode implements TbNode {
));
return;
}
DeviceId originator = new DeviceId(msg.getOriginator().getId());
lastActivityEventTimestamps.compute(originator, (__, lastEventTs) -> {
Duration now = stopwatch.elapsed();
if (lastEventTs == null) {
rateLimits.compute(originator, (__, rateLimit) -> {
if (rateLimit == null) {
rateLimit = new TbRateLimits(rateLimitConfig);
}
boolean isNotRateLimited = rateLimit.tryConsume();
if (isNotRateLimited) {
sendEventAndTell(ctx, originator, msg);
return now;
} else {
ctx.tellNext(msg, "Rate limited");
}
Duration elapsedSinceLastEventSent = now.minus(lastEventTs);
if (elapsedSinceLastEventSent.compareTo(ONE_SECOND) < 0) {
ctx.tellSuccess(msg);
return lastEventTs;
}
sendEventAndTell(ctx, originator, msg);
return now;
});
}
private void scheduleCleanupMsg(TbContext ctx) {
TbMsg cleanupMsg = ctx.newMsg(
null, TbMsgType.DEVICE_STATE_STALE_ENTRIES_CLEANUP_SELF_MSG, ctx.getSelfId(), TbMsgMetaData.EMPTY, TbMsg.EMPTY_STRING
);
ctx.tellSelf(cleanupMsg, ENTRY_CLEANUP_PERIOD.toMillis());
}
private void removeStaleEntries() {
lastActivityEventTimestamps.entrySet().removeIf(entry -> {
Duration now = stopwatch.elapsed();
Duration lastEventTs = entry.getValue();
Duration elapsedSinceLastEventSent = now.minus(lastEventTs);
return elapsedSinceLastEventSent.compareTo(ENTRY_EXPIRATION_TIME) > 0;
return rateLimit;
});
}
@ -184,16 +157,16 @@ public class TbDeviceStateNode implements TbNode {
@Override
public void onPartitionChangeMsg(TbContext ctx, PartitionChangeMsg msg) {
lastActivityEventTimestamps.entrySet().removeIf(entry -> !ctx.isLocalEntity(entry.getKey()));
rateLimits.entrySet().removeIf(entry -> !ctx.isLocalEntity(entry.getKey()));
}
@Override
public void destroy() {
if (lastActivityEventTimestamps != null) {
lastActivityEventTimestamps.clear();
lastActivityEventTimestamps = null;
if (rateLimits != null) {
rateLimits.clear();
rateLimits = null;
}
stopwatch = null;
rateLimitConfig = null;
event = null;
}

View File

@ -15,8 +15,6 @@
*/
package org.thingsboard.rule.engine.action;
import com.google.common.base.Stopwatch;
import com.google.common.base.Ticker;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
@ -29,6 +27,7 @@ import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.util.ConcurrentReferenceHashMap;
import org.thingsboard.common.util.JacksonUtil;
import org.thingsboard.rule.engine.api.RuleEngineDeviceStateManager;
import org.thingsboard.rule.engine.api.TbContext;
@ -45,11 +44,9 @@ import org.thingsboard.server.common.msg.TbMsgMetaData;
import org.thingsboard.server.common.msg.queue.PartitionChangeMsg;
import org.thingsboard.server.common.msg.queue.ServiceType;
import org.thingsboard.server.common.msg.queue.TbCallback;
import org.thingsboard.server.common.msg.tools.TbRateLimits;
import java.time.Duration;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Stream;
import static org.assertj.core.api.Assertions.assertThat;
@ -57,9 +54,10 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.then;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@ExtendWith(MockitoExtension.class)
public class TbDeviceStateNodeTest {
@ -71,21 +69,12 @@ public class TbDeviceStateNodeTest {
@Captor
private static ArgumentCaptor<TbCallback> callbackCaptor;
private TbDeviceStateNode node;
private RuleNodeId nodeId;
private TbDeviceStateNodeConfiguration config;
private static final TenantId TENANT_ID = TenantId.fromUUID(UUID.randomUUID());
private static final DeviceId DEVICE_ID = new DeviceId(UUID.randomUUID());
private static final long METADATA_TS = 123L;
private TbMsg cleanupMsg;
private TbMsg msg;
private long nowNanos;
private final Ticker controlledTicker = new Ticker() {
@Override
public long read() {
return nowNanos;
}
};
@BeforeEach
public void setup() {
@ -96,8 +85,6 @@ public class TbDeviceStateNodeTest {
var data = JacksonUtil.newObjectNode();
data.put("humidity", 58.3);
msg = TbMsg.newMsg(TbMsgType.POST_TELEMETRY_REQUEST, DEVICE_ID, metaData, JacksonUtil.toString(data));
nodeId = new RuleNodeId(UUID.randomUUID());
cleanupMsg = TbMsg.newMsg(null, TbMsgType.DEVICE_STATE_STALE_ENTRIES_CLEANUP_SELF_MSG, nodeId, TbMsgMetaData.EMPTY, TbMsg.EMPTY_STRING);
}
@BeforeEach
@ -121,80 +108,42 @@ public class TbDeviceStateNodeTest {
}
@Test
public void givenValidConfig_whenInit_thenSchedulesCleanupMsg() {
public void givenInvalidRateLimitConfig_whenInit_thenUsesDefaultConfig() {
// GIVEN
given(ctxMock.getSelfId()).willReturn(nodeId);
given(ctxMock.newMsg(isNull(), eq(TbMsgType.DEVICE_STATE_STALE_ENTRIES_CLEANUP_SELF_MSG), eq(nodeId), eq(TbMsgMetaData.EMPTY), eq(TbMsg.EMPTY_STRING))).willReturn(cleanupMsg);
given(ctxMock.getDeviceStateNodeRateLimitConfig()).willReturn("invalid rate limit config");
given(ctxMock.getTenantId()).willReturn(TENANT_ID);
given(ctxMock.getSelfId()).willReturn(new RuleNodeId(UUID.randomUUID()));
// WHEN
try {
initNode(TbMsgType.ACTIVITY_EVENT);
} catch (Exception e) {
fail("Node failed to initialize.");
fail("Node failed to initialize!", e);
}
// THEN
verifyCleanupMsgSent();
}
@Test
public void givenCleanupMsg_whenOnMsg_thenCleansStaleEntries() {
// GIVEN
given(ctxMock.getSelfId()).willReturn(nodeId);
given(ctxMock.newMsg(isNull(), eq(TbMsgType.DEVICE_STATE_STALE_ENTRIES_CLEANUP_SELF_MSG), eq(nodeId), eq(TbMsgMetaData.EMPTY), eq(TbMsg.EMPTY_STRING))).willReturn(cleanupMsg);
ConcurrentMap<DeviceId, Duration> lastActivityEventTimestamps = new ConcurrentHashMap<>();
ReflectionTestUtils.setField(node, "lastActivityEventTimestamps", lastActivityEventTimestamps);
Stopwatch stopwatch = Stopwatch.createStarted(controlledTicker);
ReflectionTestUtils.setField(node, "stopwatch", stopwatch);
// WHEN
Duration expirationTime = Duration.ofDays(1L);
DeviceId staleId = DEVICE_ID;
Duration staleTs = Duration.ofHours(4L);
lastActivityEventTimestamps.put(staleId, staleTs);
DeviceId goodId = new DeviceId(UUID.randomUUID());
Duration goodTs = staleTs.plus(expirationTime);
lastActivityEventTimestamps.put(goodId, goodTs);
nowNanos = staleTs.toNanos() + expirationTime.toNanos() + 1;
node.onMsg(ctxMock, cleanupMsg);
// THEN
assertThat(lastActivityEventTimestamps)
.containsKey(goodId)
.doesNotContainKey(staleId)
.size().isOne();
verifyCleanupMsgSent();
then(ctxMock).shouldHaveNoMoreInteractions();
String actualRateLimitConfig = (String) ReflectionTestUtils.getField(node, "rateLimitConfig");
assertThat(actualRateLimitConfig).isEqualTo("1:1,30:60,60:3600");
}
@Test
public void givenMsgArrivedTooFast_whenOnMsg_thenRateLimitsThisMsg() {
// GIVEN
ConcurrentMap<DeviceId, Duration> lastActivityEventTimestamps = new ConcurrentHashMap<>();
ReflectionTestUtils.setField(node, "lastActivityEventTimestamps", lastActivityEventTimestamps);
ConcurrentReferenceHashMap<DeviceId, TbRateLimits> rateLimits = new ConcurrentReferenceHashMap<>();
ReflectionTestUtils.setField(node, "rateLimits", rateLimits);
Stopwatch stopwatch = Stopwatch.createStarted(controlledTicker);
ReflectionTestUtils.setField(node, "stopwatch", stopwatch);
var rateLimitMock = mock(TbRateLimits.class);
rateLimits.put(DEVICE_ID, rateLimitMock);
given(rateLimitMock.tryConsume()).willReturn(false);
// WHEN
Duration firstEventTs = Duration.ofMillis(1000L);
lastActivityEventTimestamps.put(DEVICE_ID, firstEventTs);
Duration tooFastEventTs = firstEventTs.plus(Duration.ofMillis(999L));
nowNanos = tooFastEventTs.toNanos();
node.onMsg(ctxMock, msg);
// THEN
Duration actualEventTs = lastActivityEventTimestamps.get(DEVICE_ID);
assertThat(actualEventTs).isEqualTo(firstEventTs);
then(ctxMock).should().tellSuccess(msg);
then(ctxMock).should().tellNext(msg, "Rate limited");
then(ctxMock).should(never()).tellSuccess(any());
then(ctxMock).should(never()).tellFailure(any(), any());
then(ctxMock).shouldHaveNoMoreInteractions();
then(deviceStateManagerMock).shouldHaveNoInteractions();
}
@ -202,25 +151,25 @@ public class TbDeviceStateNodeTest {
@Test
public void givenHasNonLocalDevices_whenOnPartitionChange_thenRemovesEntriesForNonLocalDevices() {
// GIVEN
ConcurrentMap<DeviceId, Duration> lastActivityEventTimestamps = new ConcurrentHashMap<>();
ReflectionTestUtils.setField(node, "lastActivityEventTimestamps", lastActivityEventTimestamps);
ConcurrentReferenceHashMap<DeviceId, TbRateLimits> rateLimits = new ConcurrentReferenceHashMap<>();
ReflectionTestUtils.setField(node, "rateLimits", rateLimits);
lastActivityEventTimestamps.put(DEVICE_ID, Duration.ofHours(24L));
rateLimits.put(DEVICE_ID, new TbRateLimits("1:1"));
given(ctxMock.isLocalEntity(eq(DEVICE_ID))).willReturn(true);
DeviceId nonLocalDeviceId1 = new DeviceId(UUID.randomUUID());
lastActivityEventTimestamps.put(nonLocalDeviceId1, Duration.ofHours(30L));
rateLimits.put(nonLocalDeviceId1, new TbRateLimits("2:2"));
given(ctxMock.isLocalEntity(eq(nonLocalDeviceId1))).willReturn(false);
DeviceId nonLocalDeviceId2 = new DeviceId(UUID.randomUUID());
lastActivityEventTimestamps.put(nonLocalDeviceId2, Duration.ofHours(32L));
rateLimits.put(nonLocalDeviceId2, new TbRateLimits("3:3"));
given(ctxMock.isLocalEntity(eq(nonLocalDeviceId2))).willReturn(false);
// WHEN
node.onPartitionChangeMsg(ctxMock, new PartitionChangeMsg(ServiceType.TB_RULE_ENGINE));
// THEN
assertThat(lastActivityEventTimestamps)
assertThat(rateLimits)
.containsKey(DEVICE_ID)
.doesNotContainKey(nonLocalDeviceId1)
.doesNotContainKey(nonLocalDeviceId2)
@ -275,6 +224,7 @@ public class TbDeviceStateNodeTest {
@Test
public void givenMetadataDoesNotContainTs_whenOnMsg_thenMsgTsIsUsedAsEventTs() {
// GIVEN
given(ctxMock.getDeviceStateNodeRateLimitConfig()).willReturn("1:1");
try {
initNode(TbMsgType.ACTIVITY_EVENT);
} catch (TbNodeException e) {
@ -298,9 +248,8 @@ public class TbDeviceStateNodeTest {
@MethodSource
public void givenSupportedEventAndDeviceOriginator_whenOnMsg_thenCorrectEventIsSentWithCorrectCallback(TbMsgType supportedEventType, Runnable actionVerification) {
// GIVEN
given(ctxMock.getSelfId()).willReturn(nodeId);
given(ctxMock.newMsg(isNull(), eq(TbMsgType.DEVICE_STATE_STALE_ENTRIES_CLEANUP_SELF_MSG), eq(nodeId), eq(TbMsgMetaData.EMPTY), eq(TbMsg.EMPTY_STRING))).willReturn(cleanupMsg);
given(ctxMock.getTenantId()).willReturn(TENANT_ID);
given(ctxMock.getDeviceStateNodeRateLimitConfig()).willReturn("1:1");
given(ctxMock.getDeviceStateManager()).willReturn(deviceStateManagerMock);
try {
@ -308,7 +257,6 @@ public class TbDeviceStateNodeTest {
} catch (TbNodeException e) {
fail("Node failed to initialize!", e);
}
verifyCleanupMsgSent();
// WHEN
node.onMsg(ctxMock, msg);
@ -345,10 +293,4 @@ public class TbDeviceStateNodeTest {
node.init(ctxMock, nodeConfig);
}
private void verifyCleanupMsgSent() {
then(ctxMock).should().getSelfId();
then(ctxMock).should().newMsg(isNull(), eq(TbMsgType.DEVICE_STATE_STALE_ENTRIES_CLEANUP_SELF_MSG), eq(nodeId), eq(TbMsgMetaData.EMPTY), eq(TbMsg.EMPTY_STRING));
then(ctxMock).should().tellSelf(eq(cleanupMsg), eq(Duration.ofHours(1L).toMillis()));
}
}