From 17adf8bd8b330527056f05cdeec7c0015eec60e5 Mon Sep 17 00:00:00 2001 From: Igor Kulikov Date: Mon, 4 Jun 2018 11:06:42 +0300 Subject: [PATCH] Improved MQTT client and added default handler to handle reconnect subscriptions --- .../thingsboard/mqtt/MqttChannelHandler.java | 45 +++---- .../java/org/thingsboard/mqtt/MqttClient.java | 12 +- .../thingsboard/mqtt/MqttClientCallback.java | 10 +- .../org/thingsboard/mqtt/MqttClientImpl.java | 121 ++++++++++++------ ...ubscribtion.java => MqttSubscription.java} | 6 +- .../rule/engine/mqtt/TbMqttNode.java | 2 +- 6 files changed, 118 insertions(+), 78 deletions(-) rename netty-mqtt/src/main/java/org/thingsboard/mqtt/{MqttSubscribtion.java => MqttSubscription.java} (93%) diff --git a/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttChannelHandler.java b/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttChannelHandler.java index ef5e7a531d..024e15a2c3 100644 --- a/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttChannelHandler.java +++ b/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttChannelHandler.java @@ -98,33 +98,25 @@ final class MqttChannelHandler extends SimpleChannelInboundHandler } private void invokeHandlersForIncomingPublish(MqttPublishMessage message) { - for (MqttSubscribtion subscribtion : ImmutableSet.copyOf(this.client.getSubscriptions().values())) { - if (subscribtion.matches(message.variableHeader().topicName())) { - if (subscribtion.isOnce() && subscribtion.isCalled()) { + boolean handlerInvoked = false; + for (MqttSubscription subscription : ImmutableSet.copyOf(this.client.getSubscriptions().values())) { + if (subscription.matches(message.variableHeader().topicName())) { + if (subscription.isOnce() && subscription.isCalled()) { continue; } message.payload().markReaderIndex(); - subscribtion.setCalled(true); - subscribtion.getHandler().onMessage(message.variableHeader().topicName(), message.payload()); - if (subscribtion.isOnce()) { - this.client.off(subscribtion.getTopic(), subscribtion.getHandler()); + subscription.setCalled(true); + subscription.getHandler().onMessage(message.variableHeader().topicName(), message.payload()); + if (subscription.isOnce()) { + this.client.off(subscription.getTopic(), subscription.getHandler()); } message.payload().resetReaderIndex(); + handlerInvoked = true; } } - /*Set subscribtions = ImmutableSet.copyOf(this.client.getSubscriptions().get(message.variableHeader().topicName())); - for (MqttSubscribtion subscribtion : subscribtions) { - if(subscribtion.isOnce() && subscribtion.isCalled()){ - continue; - } - message.payload().markReaderIndex(); - subscribtion.setCalled(true); - subscribtion.getHandler().onMessage(message.variableHeader().topicName(), message.payload()); - if(subscribtion.isOnce()){ - this.client.off(subscribtion.getTopic(), subscribtion.getHandler()); - } - message.payload().resetReaderIndex(); - }*/ + if (!handlerInvoked && client.getDefaultHandler() != null) { + client.getDefaultHandler().onMessage(message.variableHeader().topicName(), message.payload()); + } message.payload().release(); } @@ -133,7 +125,7 @@ final class MqttChannelHandler extends SimpleChannelInboundHandler case CONNECTION_ACCEPTED: this.connectFuture.setSuccess(new MqttConnectResult(true, MqttConnectReturnCode.CONNECTION_ACCEPTED, channel.closeFuture())); - this.client.getPendingSubscribtions().entrySet().stream().filter((e) -> !e.getValue().isSent()).forEach((e) -> { + this.client.getPendingSubscriptions().entrySet().stream().filter((e) -> !e.getValue().isSent()).forEach((e) -> { channel.write(e.getValue().getSubscribeMessage()); e.getValue().setSent(true); }); @@ -148,6 +140,9 @@ final class MqttChannelHandler extends SimpleChannelInboundHandler } }); channel.flush(); + if (this.client.isReconnect()) { + this.client.onSuccessfulReconnect(); + } break; case CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD: @@ -163,19 +158,19 @@ final class MqttChannelHandler extends SimpleChannelInboundHandler } private void handleSubAck(MqttSubAckMessage message) { - MqttPendingSubscribtion pendingSubscription = this.client.getPendingSubscribtions().remove(message.variableHeader().messageId()); + MqttPendingSubscribtion pendingSubscription = this.client.getPendingSubscriptions().remove(message.variableHeader().messageId()); if (pendingSubscription == null) { return; } pendingSubscription.onSubackReceived(); for (MqttPendingSubscribtion.MqttPendingHandler handler : pendingSubscription.getHandlers()) { - MqttSubscribtion subscribtion = new MqttSubscribtion(pendingSubscription.getTopic(), handler.getHandler(), handler.isOnce()); + MqttSubscription subscribtion = new MqttSubscription(pendingSubscription.getTopic(), handler.getHandler(), handler.isOnce()); this.client.getSubscriptions().put(pendingSubscription.getTopic(), subscribtion); this.client.getHandlerToSubscribtion().put(handler.getHandler(), subscribtion); } this.client.getPendingSubscribeTopics().remove(pendingSubscription.getTopic()); - this.client.getServerSubscribtions().add(pendingSubscription.getTopic()); + this.client.getServerSubscriptions().add(pendingSubscription.getTopic()); if (!pendingSubscription.getFuture().isDone()) { pendingSubscription.getFuture().setSuccess(null); @@ -220,7 +215,7 @@ final class MqttChannelHandler extends SimpleChannelInboundHandler return; } unsubscribtion.onUnsubackReceived(); - this.client.getServerSubscribtions().remove(unsubscribtion.getTopic()); + this.client.getServerSubscriptions().remove(unsubscribtion.getTopic()); unsubscribtion.getFuture().setSuccess(null); this.client.getPendingServerUnsubscribes().remove(message.variableHeader().messageId()); } diff --git a/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClient.java b/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClient.java index 6563525783..7e57a0a00c 100644 --- a/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClient.java +++ b/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClient.java @@ -172,24 +172,18 @@ public interface MqttClient { */ MqttClientConfig getClientConfig(); - /** - * Construct the MqttClientImpl with default config - */ - static MqttClient create(){ - return new MqttClientImpl(); - } /** * Construct the MqttClientImpl with additional config. * This config can also be changed using the {@link #getClientConfig()} function * * @param config The config object to use while looking for settings + * @param defaultHandler The handler for incoming messages that do not match any topic subscriptions */ - static MqttClient create(MqttClientConfig config){ - return new MqttClientImpl(config); + static MqttClient create(MqttClientConfig config, MqttHandler defaultHandler){ + return new MqttClientImpl(config, defaultHandler); } - /** * Send disconnect and close channel * diff --git a/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClientCallback.java b/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClientCallback.java index d7f0a08409..9f86b8e6cd 100644 --- a/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClientCallback.java +++ b/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClientCallback.java @@ -15,6 +15,8 @@ */ package org.thingsboard.mqtt; +import io.netty.channel.ChannelId; + /** * Created by Valerii Sosliuk on 12/30/2017. */ @@ -25,5 +27,11 @@ public interface MqttClientCallback { * * @param cause the reason behind the loss of connection. */ - public void connectionLost(Throwable cause); + void connectionLost(Throwable cause); + + /** + * This method is called when the connection to the server is recovered. + * + */ + void onSuccessfulReconnect(); } diff --git a/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClientImpl.java b/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClientImpl.java index 391410525f..d43ce55746 100644 --- a/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClientImpl.java +++ b/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttClientImpl.java @@ -40,23 +40,26 @@ import java.util.concurrent.atomic.AtomicInteger; @SuppressWarnings({"WeakerAccess", "unused"}) final class MqttClientImpl implements MqttClient { - private final Set serverSubscribtions = new HashSet<>(); + private final Set serverSubscriptions = new HashSet<>(); private final IntObjectHashMap pendingServerUnsubscribes = new IntObjectHashMap<>(); private final IntObjectHashMap qos2PendingIncomingPublishes = new IntObjectHashMap<>(); private final IntObjectHashMap pendingPublishes = new IntObjectHashMap<>(); - private final HashMultimap subscriptions = HashMultimap.create(); - private final IntObjectHashMap pendingSubscribtions = new IntObjectHashMap<>(); + private final HashMultimap subscriptions = HashMultimap.create(); + private final IntObjectHashMap pendingSubscriptions = new IntObjectHashMap<>(); private final Set pendingSubscribeTopics = new HashSet<>(); - private final HashMultimap handlerToSubscribtion = HashMultimap.create(); + private final HashMultimap handlerToSubscribtion = HashMultimap.create(); private final AtomicInteger nextMessageId = new AtomicInteger(1); private final MqttClientConfig clientConfig; + private final MqttHandler defaultHandler; + private EventLoopGroup eventLoop; - private Channel channel; + private volatile Channel channel; - private boolean disconnected = false; + private volatile boolean disconnected = false; + private volatile boolean reconnect = false; private String host; private int port; private MqttClientCallback callback; @@ -65,8 +68,9 @@ final class MqttClientImpl implements MqttClient { /** * Construct the MqttClientImpl with default config */ - public MqttClientImpl() { + public MqttClientImpl(MqttHandler defaultHandler) { this.clientConfig = new MqttClientConfig(); + this.defaultHandler = defaultHandler; } /** @@ -75,8 +79,9 @@ final class MqttClientImpl implements MqttClient { * * @param clientConfig The config object to use while looking for settings */ - public MqttClientImpl(MqttClientConfig clientConfig) { + public MqttClientImpl(MqttClientConfig clientConfig, MqttHandler defaultHandler) { this.clientConfig = clientConfig; + this.defaultHandler = defaultHandler; } /** @@ -100,12 +105,15 @@ final class MqttClientImpl implements MqttClient { */ @Override public Future connect(String host, int port) { + return connect(host, port, false); + } + + private Future connect(String host, int port, boolean reconnect) { if (this.eventLoop == null) { this.eventLoop = new NioEventLoopGroup(); } this.host = host; this.port = port; - Promise connectFuture = new DefaultPromise<>(this.eventLoop.next()); Bootstrap bootstrap = new Bootstrap(); bootstrap.group(this.eventLoop); @@ -113,22 +121,47 @@ final class MqttClientImpl implements MqttClient { bootstrap.remoteAddress(host, port); bootstrap.handler(new MqttChannelInitializer(connectFuture, host, port, clientConfig.getSslContext())); ChannelFuture future = bootstrap.connect(); + future.addListener((ChannelFutureListener) f -> { if (f.isSuccess()) { MqttClientImpl.this.channel = f.channel(); - } else if (clientConfig.isReconnect() && !disconnected) { - eventLoop.schedule((Runnable) () -> connect(host, port), 1L, TimeUnit.SECONDS); + MqttClientImpl.this.channel.closeFuture().addListener((ChannelFutureListener) channelFuture -> { + if (isConnected()) { + return; + } + ChannelClosedException e = new ChannelClosedException("Channel is closed!"); + if (callback != null) { + callback.connectionLost(e); + } + pendingSubscriptions.clear(); + serverSubscriptions.clear(); + subscriptions.clear(); + pendingServerUnsubscribes.clear(); + qos2PendingIncomingPublishes.clear(); + pendingPublishes.clear(); + pendingSubscribeTopics.clear(); + handlerToSubscribtion.clear(); + scheduleConnectIfRequired(host, port, true); + }); + } else { + scheduleConnectIfRequired(host, port, reconnect); } }); return connectFuture; } + private void scheduleConnectIfRequired(String host, int port, boolean reconnect) { + if (clientConfig.isReconnect() && !disconnected) { + if (reconnect) { + this.reconnect = true; + } + eventLoop.schedule((Runnable) () -> connect(host, port, reconnect), 1L, TimeUnit.SECONDS); + } + } + @Override public boolean isConnected() { - if (!disconnected) { - return channel == null ? false : channel.isActive(); - }; - return false; + return !disconnected && channel != null && channel.isActive(); } @Override @@ -183,7 +216,7 @@ final class MqttClientImpl implements MqttClient { */ @Override public Future on(String topic, MqttHandler handler, MqttQoS qos) { - return createSubscribtion(topic, handler, false, qos); + return createSubscription(topic, handler, false, qos); } /** @@ -210,7 +243,7 @@ final class MqttClientImpl implements MqttClient { */ @Override public Future once(String topic, MqttHandler handler, MqttQoS qos) { - return createSubscribtion(topic, handler, true, qos); + return createSubscription(topic, handler, true, qos); } /** @@ -224,7 +257,7 @@ final class MqttClientImpl implements MqttClient { @Override public Future off(String topic, MqttHandler handler) { Promise future = new DefaultPromise<>(this.eventLoop.next()); - for (MqttSubscribtion subscribtion : this.handlerToSubscribtion.get(handler)) { + for (MqttSubscription subscribtion : this.handlerToSubscribtion.get(handler)) { this.subscriptions.remove(topic, subscribtion); } this.handlerToSubscribtion.removeAll(handler); @@ -242,9 +275,9 @@ final class MqttClientImpl implements MqttClient { @Override public Future off(String topic) { Promise future = new DefaultPromise<>(this.eventLoop.next()); - ImmutableSet subscribtions = ImmutableSet.copyOf(this.subscriptions.get(topic)); - for (MqttSubscribtion subscribtion : subscribtions) { - for (MqttSubscribtion handSub : this.handlerToSubscribtion.get(subscribtion.getHandler())) { + ImmutableSet subscribtions = ImmutableSet.copyOf(this.subscriptions.get(topic)); + for (MqttSubscription subscribtion : subscribtions) { + for (MqttSubscription handSub : this.handlerToSubscribtion.get(subscribtion.getHandler())) { this.subscriptions.remove(topic, handSub); } this.handlerToSubscribtion.remove(subscribtion.getHandler(), subscribtion); @@ -310,7 +343,7 @@ final class MqttClientImpl implements MqttClient { ChannelFuture channelFuture = this.sendAndFlushPacket(message); if (channelFuture != null) { - pendingPublish.setSent(channelFuture != null); + pendingPublish.setSent(true); if (channelFuture.cause() != null) { future.setFailure(channelFuture.cause()); return future; @@ -352,6 +385,15 @@ final class MqttClientImpl implements MqttClient { ///////////////////////////////////////////// PRIVATE API ///////////////////////////////////////////// + public boolean isReconnect() { + return reconnect; + } + + public void onSuccessfulReconnect() { + callback.onSuccessfulReconnect(); + } + + ChannelFuture sendAndFlushPacket(Object message) { if (this.channel == null) { return null; @@ -359,11 +401,7 @@ final class MqttClientImpl implements MqttClient { if (this.channel.isActive()) { return this.channel.writeAndFlush(message); } - ChannelClosedException e = new ChannelClosedException("Channel is closed"); - if (callback != null) { - callback.connectionLost(e); - } - return this.channel.newFailedFuture(e); + return this.channel.newFailedFuture(new ChannelClosedException("Channel is closed!")); } private MqttMessageIdVariableHeader getNewMessageId() { @@ -371,16 +409,16 @@ final class MqttClientImpl implements MqttClient { return MqttMessageIdVariableHeader.from(this.nextMessageId.getAndIncrement()); } - private Future createSubscribtion(String topic, MqttHandler handler, boolean once, MqttQoS qos) { + private Future createSubscription(String topic, MqttHandler handler, boolean once, MqttQoS qos) { if (this.pendingSubscribeTopics.contains(topic)) { - Optional> subscribtionEntry = this.pendingSubscribtions.entrySet().stream().filter((e) -> e.getValue().getTopic().equals(topic)).findAny(); + Optional> subscribtionEntry = this.pendingSubscriptions.entrySet().stream().filter((e) -> e.getValue().getTopic().equals(topic)).findAny(); if (subscribtionEntry.isPresent()) { subscribtionEntry.get().getValue().addHandler(handler, once); return subscribtionEntry.get().getValue().getFuture(); } } - if (this.serverSubscribtions.contains(topic)) { - MqttSubscribtion subscribtion = new MqttSubscribtion(topic, handler, once); + if (this.serverSubscriptions.contains(topic)) { + MqttSubscription subscribtion = new MqttSubscription(topic, handler, once); this.subscriptions.put(topic, subscribtion); this.handlerToSubscribtion.put(handler, subscribtion); return this.channel.newSucceededFuture(); @@ -395,7 +433,7 @@ final class MqttClientImpl implements MqttClient { final MqttPendingSubscribtion pendingSubscribtion = new MqttPendingSubscribtion(future, topic, message); pendingSubscribtion.addHandler(handler, once); - this.pendingSubscribtions.put(variableHeader.messageId(), pendingSubscribtion); + this.pendingSubscriptions.put(variableHeader.messageId(), pendingSubscribtion); this.pendingSubscribeTopics.add(topic); pendingSubscribtion.setSent(this.sendAndFlushPacket(message) != null); //If not sent, we will send it when the connection is opened @@ -405,7 +443,7 @@ final class MqttClientImpl implements MqttClient { } private void checkSubscribtions(String topic, Promise promise) { - if (!(this.subscriptions.containsKey(topic) && this.subscriptions.get(topic).size() != 0) && this.serverSubscribtions.contains(topic)) { + if (!(this.subscriptions.containsKey(topic) && this.subscriptions.get(topic).size() != 0) && this.serverSubscriptions.contains(topic)) { MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.UNSUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, false, 0); MqttMessageIdVariableHeader variableHeader = getNewMessageId(); MqttUnsubscribePayload payload = new MqttUnsubscribePayload(Collections.singletonList(topic)); @@ -421,11 +459,11 @@ final class MqttClientImpl implements MqttClient { } } - IntObjectHashMap getPendingSubscribtions() { - return pendingSubscribtions; + IntObjectHashMap getPendingSubscriptions() { + return pendingSubscriptions; } - HashMultimap getSubscriptions() { + HashMultimap getSubscriptions() { return subscriptions; } @@ -433,12 +471,12 @@ final class MqttClientImpl implements MqttClient { return pendingSubscribeTopics; } - HashMultimap getHandlerToSubscribtion() { + HashMultimap getHandlerToSubscribtion() { return handlerToSubscribtion; } - Set getServerSubscribtions() { - return serverSubscribtions; + Set getServerSubscriptions() { + return serverSubscriptions; } IntObjectHashMap getPendingServerUnsubscribes() { @@ -481,4 +519,9 @@ final class MqttClientImpl implements MqttClient { ch.pipeline().addLast("mqttHandler", new MqttChannelHandler(MqttClientImpl.this, connectFuture)); } } + + MqttHandler getDefaultHandler() { + return defaultHandler; + } + } diff --git a/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttSubscribtion.java b/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttSubscription.java similarity index 93% rename from netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttSubscribtion.java rename to netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttSubscription.java index 27f4cb969e..88761054b2 100644 --- a/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttSubscribtion.java +++ b/netty-mqtt/src/main/java/org/thingsboard/mqtt/MqttSubscription.java @@ -17,7 +17,7 @@ package org.thingsboard.mqtt; import java.util.regex.Pattern; -final class MqttSubscribtion { +final class MqttSubscription { private final String topic; private final Pattern topicRegex; @@ -27,7 +27,7 @@ final class MqttSubscribtion { private boolean called; - MqttSubscribtion(String topic, MqttHandler handler, boolean once) { + MqttSubscription(String topic, MqttHandler handler, boolean once) { if(topic == null){ throw new NullPointerException("topic"); } @@ -65,7 +65,7 @@ final class MqttSubscribtion { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - MqttSubscribtion that = (MqttSubscribtion) o; + MqttSubscription that = (MqttSubscription) o; return once == that.once && topic.equals(that.topic) && handler.equals(that.handler); } diff --git a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/mqtt/TbMqttNode.java b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/mqtt/TbMqttNode.java index c0a813cef4..6b05b7c526 100644 --- a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/mqtt/TbMqttNode.java +++ b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/mqtt/TbMqttNode.java @@ -113,7 +113,7 @@ public class TbMqttNode implements TbNode { } config.setCleanSession(this.config.isCleanSession()); this.config.getCredentials().configure(config); - MqttClient client = MqttClient.create(config); + MqttClient client = MqttClient.create(config, null); client.setEventLoop(this.eventLoopGroup); Future connectFuture = client.connect(this.config.getHost(), this.config.getPort()); MqttConnectResult result;