diff --git a/application/src/main/resources/thingsboard.yml b/application/src/main/resources/thingsboard.yml index 96e409373e..b0b8923dcc 100644 --- a/application/src/main/resources/thingsboard.yml +++ b/application/src/main/resources/thingsboard.yml @@ -1117,6 +1117,7 @@ queue: automatic_recovery_enabled: "${TB_QUEUE_RABBIT_MQ_AUTOMATIC_RECOVERY_ENABLED:false}" connection_timeout: "${TB_QUEUE_RABBIT_MQ_CONNECTION_TIMEOUT:60000}" handshake_timeout: "${TB_QUEUE_RABBIT_MQ_HANDSHAKE_TIMEOUT:10000}" + max_poll_messages: "${TB_QUEUE_RABBIT_MQ_MAX_POLL_MESSAGES:1}" queue-properties: rule-engine: "${TB_QUEUE_RABBIT_MQ_RE_QUEUE_PROPERTIES:x-max-length-bytes:1048576000;x-message-ttl:604800000}" core: "${TB_QUEUE_RABBIT_MQ_CORE_QUEUE_PROPERTIES:x-max-length-bytes:1048576000;x-message-ttl:604800000}" diff --git a/common/queue/src/main/java/org/thingsboard/server/queue/rabbitmq/TbRabbitMqConsumerTemplate.java b/common/queue/src/main/java/org/thingsboard/server/queue/rabbitmq/TbRabbitMqConsumerTemplate.java index e073a46a05..4891f65e30 100644 --- a/common/queue/src/main/java/org/thingsboard/server/queue/rabbitmq/TbRabbitMqConsumerTemplate.java +++ b/common/queue/src/main/java/org/thingsboard/server/queue/rabbitmq/TbRabbitMqConsumerTemplate.java @@ -20,6 +20,8 @@ import com.google.protobuf.InvalidProtocolBufferException; import com.rabbitmq.client.Channel; import com.rabbitmq.client.Connection; import com.rabbitmq.client.GetResponse; +import java.util.ArrayList; +import java.util.Collection; import lombok.extern.slf4j.Slf4j; import org.thingsboard.server.common.msg.queue.TopicPartitionInfo; import org.thingsboard.server.queue.TbQueueAdmin; @@ -31,7 +33,6 @@ import org.thingsboard.server.queue.common.DefaultTbQueueMsg; import java.io.IOException; import java.util.Collections; import java.util.List; -import java.util.Objects; import java.util.Set; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; @@ -44,6 +45,7 @@ public class TbRabbitMqConsumerTemplate extends AbstractTb private final TbQueueMsgDecoder decoder; private final Channel channel; private final Connection connection; + private final int maxPollMessages; private volatile Set queues; @@ -51,6 +53,7 @@ public class TbRabbitMqConsumerTemplate extends AbstractTb super(topic); this.admin = admin; this.decoder = decoder; + this.maxPollMessages = rabbitMqSettings.getMaxPollMessages(); try { connection = rabbitMqSettings.getConnectionFactory().newConnection(); } catch (IOException | TimeoutException e) { @@ -70,13 +73,19 @@ public class TbRabbitMqConsumerTemplate extends AbstractTb protected List doPoll(long durationInMillis) { List result = queues.stream() .map(queue -> { - try { - return channel.basicGet(queue, false); - } catch (IOException e) { - log.error("Failed to get messages from queue: [{}]", queue); - throw new RuntimeException("Failed to get messages from queue.", e); + List messages = new ArrayList<>(); + for (int i = 0; i < maxPollMessages; i++) { + GetResponse response = doQueuePoll(queue); + if (response == null) { + break; + } + messages.add(response); } - }).filter(Objects::nonNull).collect(Collectors.toList()); + return messages; + }) + .filter(r -> !r.isEmpty()) + .flatMap(Collection::stream) + .collect(Collectors.toList()); if (result.size() > 0) { return result; } else { @@ -84,6 +93,15 @@ public class TbRabbitMqConsumerTemplate extends AbstractTb } } + protected GetResponse doQueuePoll(String queue) { + try { + return channel.basicGet(queue, false); + } catch (IOException e) { + log.error("Failed to get messages from queue: [{}]", queue); + throw new RuntimeException("Failed to get messages from queue.", e); + } + } + @Override protected void doSubscribe(List topicNames) { queues = partitions.stream() diff --git a/common/queue/src/main/java/org/thingsboard/server/queue/rabbitmq/TbRabbitMqSettings.java b/common/queue/src/main/java/org/thingsboard/server/queue/rabbitmq/TbRabbitMqSettings.java index 0273374ddc..fd2977f7cc 100644 --- a/common/queue/src/main/java/org/thingsboard/server/queue/rabbitmq/TbRabbitMqSettings.java +++ b/common/queue/src/main/java/org/thingsboard/server/queue/rabbitmq/TbRabbitMqSettings.java @@ -47,6 +47,8 @@ public class TbRabbitMqSettings { private int connectionTimeout; @Value("${queue.rabbitmq.handshake_timeout:}") private int handshakeTimeout; + @Value("${queue.rabbitmq.max_poll_messages:1}") + private int maxPollMessages; private ConnectionFactory connectionFactory; diff --git a/common/queue/src/test/java/org/thingsboard/server/queue/rabbitmq/TbRabbitMqConsumerTemplateTest.java b/common/queue/src/test/java/org/thingsboard/server/queue/rabbitmq/TbRabbitMqConsumerTemplateTest.java new file mode 100644 index 0000000000..ab132912fc --- /dev/null +++ b/common/queue/src/test/java/org/thingsboard/server/queue/rabbitmq/TbRabbitMqConsumerTemplateTest.java @@ -0,0 +1,128 @@ +/** + * Copyright © 2016-2023 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.queue.rabbitmq; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.rabbitmq.client.Channel; +import com.rabbitmq.client.Connection; +import com.rabbitmq.client.ConnectionFactory; +import com.rabbitmq.client.GetResponse; +import java.nio.charset.StandardCharsets; +import java.util.Set; +import java.util.UUID; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.thingsboard.server.common.msg.queue.TopicPartitionInfo; +import org.thingsboard.server.queue.TbQueueAdmin; +import org.thingsboard.server.queue.TbQueueMsgDecoder; +import org.thingsboard.server.queue.common.DefaultTbQueueMsg; + +@ExtendWith(MockitoExtension.class) +class TbRabbitMqConsumerTemplateTest { + + private static final String TOPIC = "some-topic"; + + @Mock + private TbQueueAdmin admin; + + @Mock + private ConnectionFactory connectionFactory; + + @Mock + private TbQueueMsgDecoder decoder; + + @Mock + private Connection connection; + + @Mock + private Channel channel; + + @Mock + private TopicPartitionInfo partition; + + @Mock + private GetResponse getResponse; + + private TbRabbitMqConsumerTemplate consumer; + + private void setUpConsumerWithMaxPollMessages(int maxPollMessages) throws Exception { + when(connectionFactory.newConnection()).thenReturn(connection); + when(connection.createChannel()).thenReturn(channel); + TbRabbitMqSettings settings = new TbRabbitMqSettings(); + settings.setMaxPollMessages(maxPollMessages); + settings.setConnectionFactory(connectionFactory); + + consumer = new TbRabbitMqConsumerTemplate<>(admin, settings, TOPIC, decoder); + when(partition.getFullTopicName()).thenReturn(TOPIC); + consumer.subscribe(Set.of(partition)); + } + + @Test + void pollWithMax5PollMessagesReturnsEmptyListIfNoMessages() throws Exception { + setUpConsumerWithMaxPollMessages(5); + when(channel.basicGet(anyString(), anyBoolean())).thenReturn(null); + + assertThat(consumer.poll(0L)).isEmpty(); + + verify(channel).basicGet(anyString(), anyBoolean()); + } + + @Test + void pollWithMax5PollMessagesReturns5MessagesIfQueueContains5() throws Exception { + setUpConsumerWithMaxPollMessages(5); + when(getResponse.getBody()).thenReturn(newMessageBody()); + when(channel.basicGet(anyString(), anyBoolean())).thenReturn(getResponse); + + assertThat(consumer.poll(0L)).hasSize(5); + + verify(channel, times(5)).basicGet(anyString(), anyBoolean()); + } + + @Test + void pollWithMax1PollMessageReturns1MessageIfQueueContainsMore() throws Exception { + setUpConsumerWithMaxPollMessages(1); + when(getResponse.getBody()).thenReturn(newMessageBody()); + when(channel.basicGet(anyString(), anyBoolean())).thenReturn(getResponse); + + assertThat(consumer.poll(0L)).hasSize(1); + + verify(channel).basicGet(anyString(), anyBoolean()); + } + + @Test + void pollWithMax3PollMessagesReturns2MessagesIfQueueContains2() throws Exception { + setUpConsumerWithMaxPollMessages(3); + when(getResponse.getBody()).thenReturn(newMessageBody()); + when(channel.basicGet(anyString(), anyBoolean())).thenReturn(getResponse, getResponse, null); + + assertThat(consumer.poll(0L)).hasSize(2); + + verify(channel, times(3)).basicGet(anyString(), anyBoolean()); + } + + private byte[] newMessageBody() { + return ("{\"key\": \"" + UUID.randomUUID() + "\"}").getBytes(StandardCharsets.UTF_8); + } + +} \ No newline at end of file