netty client - added channel reader idle state handling, fixed ping/pong keepalive logic

This commit is contained in:
Dima Landiak 2022-07-12 12:33:16 +03:00
parent adab546ab9
commit d6244a8422
4 changed files with 127 additions and 11 deletions

View File

@ -53,6 +53,37 @@
<groupId>com.google.guava</groupId> <groupId>com.google.guava</groupId>
<artifactId>guava</artifactId> <artifactId>guava</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>log4j-over-slf4j</artifactId>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<build> <build>

View File

@ -45,6 +45,7 @@ import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.concurrent.DefaultPromise; import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.Promise;
import lombok.extern.slf4j.Slf4j;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
@ -60,6 +61,7 @@ import java.util.concurrent.atomic.AtomicInteger;
* Represents an MqttClientImpl connected to a single MQTT server. Will try to keep the connection going at all times * Represents an MqttClientImpl connected to a single MQTT server. Will try to keep the connection going at all times
*/ */
@SuppressWarnings({"WeakerAccess", "unused"}) @SuppressWarnings({"WeakerAccess", "unused"})
@Slf4j
final class MqttClientImpl implements MqttClient { final class MqttClientImpl implements MqttClient {
private final Set<String> serverSubscriptions = new HashSet<>(); private final Set<String> serverSubscriptions = new HashSet<>();
@ -131,6 +133,7 @@ final class MqttClientImpl implements MqttClient {
} }
private Future<MqttConnectResult> connect(String host, int port, boolean reconnect) { private Future<MqttConnectResult> connect(String host, int port, boolean reconnect) {
log.trace("[{}] Connecting to server, isReconnect - {}", channel != null ? channel.id() : "UNKNOWN", reconnect);
if (this.eventLoop == null) { if (this.eventLoop == null) {
this.eventLoop = new NioEventLoopGroup(); this.eventLoop = new NioEventLoopGroup();
} }
@ -147,10 +150,12 @@ final class MqttClientImpl implements MqttClient {
future.addListener((ChannelFutureListener) f -> { future.addListener((ChannelFutureListener) f -> {
if (f.isSuccess()) { if (f.isSuccess()) {
MqttClientImpl.this.channel = f.channel(); MqttClientImpl.this.channel = f.channel();
log.debug("[{}][{}] Connected successfully {}!", host, port, this.channel.id());
MqttClientImpl.this.channel.closeFuture().addListener((ChannelFutureListener) channelFuture -> { MqttClientImpl.this.channel.closeFuture().addListener((ChannelFutureListener) channelFuture -> {
if (isConnected()) { if (isConnected()) {
return; return;
} }
log.debug("[{}][{}] Channel is closed {}!", host, port, this.channel.id());
ChannelClosedException e = new ChannelClosedException("Channel is closed!"); ChannelClosedException e = new ChannelClosedException("Channel is closed!");
if (callback != null) { if (callback != null) {
callback.connectionLost(e); callback.connectionLost(e);
@ -169,6 +174,7 @@ final class MqttClientImpl implements MqttClient {
scheduleConnectIfRequired(host, port, true); scheduleConnectIfRequired(host, port, true);
}); });
} else { } else {
log.debug("[{}][{}] Connect failed, trying reconnect!", host, port);
scheduleConnectIfRequired(host, port, reconnect); scheduleConnectIfRequired(host, port, reconnect);
} }
}); });
@ -176,6 +182,7 @@ final class MqttClientImpl implements MqttClient {
} }
private void scheduleConnectIfRequired(String host, int port, boolean reconnect) { private void scheduleConnectIfRequired(String host, int port, boolean reconnect) {
log.trace("[{}] Scheduling connect to server, isReconnect - {}", channel != null ? channel.id() : "UNKNOWN", reconnect);
if (clientConfig.isReconnect() && !disconnected) { if (clientConfig.isReconnect() && !disconnected) {
if (reconnect) { if (reconnect) {
this.reconnect = true; this.reconnect = true;
@ -191,6 +198,7 @@ final class MqttClientImpl implements MqttClient {
@Override @Override
public Future<MqttConnectResult> reconnect() { public Future<MqttConnectResult> reconnect() {
log.trace("[{}] Reconnecting to server, isReconnect - {}", channel != null ? channel.id() : "UNKNOWN", reconnect);
if (host == null) { if (host == null) {
throw new IllegalStateException("Cannot reconnect. Call connect() first"); throw new IllegalStateException("Cannot reconnect. Call connect() first");
} }
@ -281,6 +289,7 @@ final class MqttClientImpl implements MqttClient {
*/ */
@Override @Override
public Future<Void> off(String topic, MqttHandler handler) { public Future<Void> off(String topic, MqttHandler handler) {
log.trace("[{}] Unsubscribing from {}", channel != null ? channel.id() : "UNKNOWN", topic);
Promise<Void> future = new DefaultPromise<>(this.eventLoop.next()); Promise<Void> future = new DefaultPromise<>(this.eventLoop.next());
for (MqttSubscription subscription : this.handlerToSubscription.get(handler)) { for (MqttSubscription subscription : this.handlerToSubscription.get(handler)) {
this.subscriptions.remove(topic, subscription); this.subscriptions.remove(topic, subscription);
@ -299,6 +308,7 @@ final class MqttClientImpl implements MqttClient {
*/ */
@Override @Override
public Future<Void> off(String topic) { public Future<Void> off(String topic) {
log.trace("[{}] Unsubscribing from {}", channel != null ? channel.id() : "UNKNOWN", topic);
Promise<Void> future = new DefaultPromise<>(this.eventLoop.next()); Promise<Void> future = new DefaultPromise<>(this.eventLoop.next());
ImmutableSet<MqttSubscription> subscriptions = ImmutableSet.copyOf(this.subscriptions.get(topic)); ImmutableSet<MqttSubscription> subscriptions = ImmutableSet.copyOf(this.subscriptions.get(topic));
for (MqttSubscription subscription : subscriptions) { for (MqttSubscription subscription : subscriptions) {
@ -360,6 +370,7 @@ final class MqttClientImpl implements MqttClient {
*/ */
@Override @Override
public Future<Void> publish(String topic, ByteBuf payload, MqttQoS qos, boolean retain) { public Future<Void> publish(String topic, ByteBuf payload, MqttQoS qos, boolean retain) {
log.trace("[{}] Publishing message to {}", channel != null ? channel.id() : "UNKNOWN", topic);
Promise<Void> future = new DefaultPromise<>(this.eventLoop.next()); Promise<Void> future = new DefaultPromise<>(this.eventLoop.next());
MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.PUBLISH, false, qos, retain, 0); MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.PUBLISH, false, qos, retain, 0);
MqttPublishVariableHeader variableHeader = new MqttPublishVariableHeader(topic, getNewMessageId().messageId()); MqttPublishVariableHeader variableHeader = new MqttPublishVariableHeader(topic, getNewMessageId().messageId());
@ -404,6 +415,7 @@ final class MqttClientImpl implements MqttClient {
@Override @Override
public void disconnect() { public void disconnect() {
log.trace("[{}] Disconnecting from server", channel != null ? channel.id() : "UNKNOWN");
disconnected = true; disconnected = true;
if (this.channel != null) { if (this.channel != null) {
MqttMessage message = new MqttMessage(new MqttFixedHeader(MqttMessageType.DISCONNECT, false, MqttQoS.AT_MOST_ONCE, false, 0)); MqttMessage message = new MqttMessage(new MqttFixedHeader(MqttMessageType.DISCONNECT, false, MqttQoS.AT_MOST_ONCE, false, 0));
@ -435,6 +447,7 @@ final class MqttClientImpl implements MqttClient {
return null; return null;
} }
if (this.channel.isActive()) { if (this.channel.isActive()) {
log.trace("[{}] Sending message {}", channel != null ? channel.id() : "UNKNOWN", message);
return this.channel.writeAndFlush(message); return this.channel.writeAndFlush(message);
} }
return this.channel.newFailedFuture(new ChannelClosedException("Channel is closed!")); return this.channel.newFailedFuture(new ChannelClosedException("Channel is closed!"));
@ -450,6 +463,7 @@ final class MqttClientImpl implements MqttClient {
} }
private Future<Void> createSubscription(String topic, MqttHandler handler, boolean once, MqttQoS qos) { private Future<Void> createSubscription(String topic, MqttHandler handler, boolean once, MqttQoS qos) {
log.trace("[{}] Creating subscription to {}", channel != null ? channel.id() : "UNKNOWN", topic);
if (this.pendingSubscribeTopics.contains(topic)) { if (this.pendingSubscribeTopics.contains(topic)) {
Optional<Map.Entry<Integer, MqttPendingSubscription>> subscriptionEntry = this.pendingSubscriptions.entrySet().stream().filter((e) -> e.getValue().getTopic().equals(topic)).findAny(); Optional<Map.Entry<Integer, MqttPendingSubscription>> subscriptionEntry = this.pendingSubscriptions.entrySet().stream().filter((e) -> e.getValue().getTopic().equals(topic)).findAny();
if (subscriptionEntry.isPresent()) { if (subscriptionEntry.isPresent()) {

View File

@ -26,9 +26,11 @@ import io.netty.handler.codec.mqtt.MqttQoS;
import io.netty.handler.timeout.IdleStateEvent; import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.ScheduledFuture; import io.netty.util.concurrent.ScheduledFuture;
import lombok.extern.slf4j.Slf4j;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@Slf4j
final class MqttPingHandler extends ChannelInboundHandlerAdapter { final class MqttPingHandler extends ChannelInboundHandlerAdapter {
private final int keepaliveSeconds; private final int keepaliveSeconds;
@ -46,11 +48,11 @@ final class MqttPingHandler extends ChannelInboundHandlerAdapter {
return; return;
} }
MqttMessage message = (MqttMessage) msg; MqttMessage message = (MqttMessage) msg;
if(message.fixedHeader().messageType() == MqttMessageType.PINGREQ){ if (message.fixedHeader().messageType() == MqttMessageType.PINGREQ) {
this.handlePingReq(ctx.channel()); this.handlePingReq(ctx.channel());
} else if(message.fixedHeader().messageType() == MqttMessageType.PINGRESP){ } else if (message.fixedHeader().messageType() == MqttMessageType.PINGRESP) {
this.handlePingResp(); this.handlePingResp(ctx.channel());
}else{ } else {
ctx.fireChannelRead(ReferenceCountUtil.retain(msg)); ctx.fireChannelRead(ReferenceCountUtil.retain(msg));
} }
} }
@ -59,23 +61,27 @@ final class MqttPingHandler extends ChannelInboundHandlerAdapter {
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
super.userEventTriggered(ctx, evt); super.userEventTriggered(ctx, evt);
if(evt instanceof IdleStateEvent){ if (evt instanceof IdleStateEvent) {
IdleStateEvent event = (IdleStateEvent) evt; IdleStateEvent event = (IdleStateEvent) evt;
switch(event.state()){ switch (event.state()) {
case READER_IDLE: case READER_IDLE:
log.debug("[{}] No reads were performed for specified period for channel {}", event.state(), ctx.channel().id());
this.sendPingReq(ctx.channel());
break; break;
case WRITER_IDLE: case WRITER_IDLE:
log.debug("[{}] No writes were performed for specified period for channel {}", event.state(), ctx.channel().id());
this.sendPingReq(ctx.channel()); this.sendPingReq(ctx.channel());
break; break;
} }
} }
} }
private void sendPingReq(Channel channel){ private void sendPingReq(Channel channel) {
log.trace("[{}] Sending ping request", channel.id());
MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.PINGREQ, false, MqttQoS.AT_MOST_ONCE, false, 0); MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.PINGREQ, false, MqttQoS.AT_MOST_ONCE, false, 0);
channel.writeAndFlush(new MqttMessage(fixedHeader)); channel.writeAndFlush(new MqttMessage(fixedHeader));
if(this.pingRespTimeout != null){ if (this.pingRespTimeout == null) {
this.pingRespTimeout = channel.eventLoop().schedule(() -> { this.pingRespTimeout = channel.eventLoop().schedule(() -> {
MqttFixedHeader fixedHeader2 = new MqttFixedHeader(MqttMessageType.DISCONNECT, false, MqttQoS.AT_MOST_ONCE, false, 0); MqttFixedHeader fixedHeader2 = new MqttFixedHeader(MqttMessageType.DISCONNECT, false, MqttQoS.AT_MOST_ONCE, false, 0);
channel.writeAndFlush(new MqttMessage(fixedHeader2)).addListener(ChannelFutureListener.CLOSE); channel.writeAndFlush(new MqttMessage(fixedHeader2)).addListener(ChannelFutureListener.CLOSE);
@ -84,13 +90,15 @@ final class MqttPingHandler extends ChannelInboundHandlerAdapter {
} }
} }
private void handlePingReq(Channel channel){ private void handlePingReq(Channel channel) {
log.trace("[{}] Handling ping request", channel.id());
MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.PINGRESP, false, MqttQoS.AT_MOST_ONCE, false, 0); MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.PINGRESP, false, MqttQoS.AT_MOST_ONCE, false, 0);
channel.writeAndFlush(new MqttMessage(fixedHeader)); channel.writeAndFlush(new MqttMessage(fixedHeader));
} }
private void handlePingResp(){ private void handlePingResp(Channel channel) {
if(this.pingRespTimeout != null && !this.pingRespTimeout.isCancelled() && !this.pingRespTimeout.isDone()){ log.trace("[{}] Handling ping response", channel.id());
if (this.pingRespTimeout != null && !this.pingRespTimeout.isCancelled() && !this.pingRespTimeout.isDone()) {
this.pingRespTimeout.cancel(true); this.pingRespTimeout.cancel(true);
this.pingRespTimeout = null; this.pingRespTimeout = null;
} }

View File

@ -0,0 +1,63 @@
/**
* Copyright © 2016-2022 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.mqtt;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.DefaultEventLoop;
import io.netty.handler.timeout.IdleStateEvent;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.concurrent.TimeUnit;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.after;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
class MqttPingHandlerTest {
static final int KEEP_ALIVE_SECONDS = 0;
static final int PROCESS_SEND_DISCONNECT_MSG_TIME_MS = 500;
MqttPingHandler mqttPingHandler;
@BeforeEach
void setUp() {
mqttPingHandler = new MqttPingHandler(KEEP_ALIVE_SECONDS);
}
@Test
void givenChannelReaderIdleState_whenNoPingResponse_thenDisconnectClient() throws Exception {
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
Channel channel = mock(Channel.class);
when(ctx.channel()).thenReturn(channel);
when(channel.eventLoop()).thenReturn(new DefaultEventLoop());
ChannelFuture channelFuture = mock(ChannelFuture.class);
when(channel.writeAndFlush(any())).thenReturn(channelFuture);
mqttPingHandler.userEventTriggered(ctx, IdleStateEvent.FIRST_READER_IDLE_STATE_EVENT);
verify(
channelFuture,
after(TimeUnit.SECONDS.toMillis(KEEP_ALIVE_SECONDS) + PROCESS_SEND_DISCONNECT_MSG_TIME_MS)
).addListener(eq(ChannelFutureListener.CLOSE));
}
}