Merge pull request #13701 from AndriiLandiak/mqtt-client-id-length

MQTT: validate client id length based on protocol version
This commit is contained in:
Viacheslav Klimov 2025-07-16 17:34:00 +03:00 committed by GitHub
commit 21016025f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 119 additions and 238 deletions

View File

@ -15,9 +15,6 @@
*/ */
package org.thingsboard.server.common.data; package org.thingsboard.server.common.data;
/**
* @author Andrew Shvayka
*/
public class DataConstants { public class DataConstants {
public static final String TENANT = "TENANT"; public static final String TENANT = "TENANT";

View File

@ -15,11 +15,11 @@
*/ */
package org.thingsboard.mqtt; package org.thingsboard.mqtt;
/** import java.io.Serial;
* Created by Valerii Sosliuk on 12/26/2017.
*/
public class ChannelClosedException extends RuntimeException { public class ChannelClosedException extends RuntimeException {
@Serial
private static final long serialVersionUID = 6266638352424706909L; private static final long serialVersionUID = 6266638352424706909L;
public ChannelClosedException() { public ChannelClosedException() {
@ -40,4 +40,5 @@ public class ChannelClosedException extends RuntimeException {
public ChannelClosedException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) { public ChannelClosedException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace); super(message, cause, enableSuppression, writableStackTrace);
} }
} }

View File

@ -21,9 +21,6 @@ import io.netty.handler.codec.mqtt.MqttPubAckMessage;
import io.netty.handler.codec.mqtt.MqttSubAckMessage; import io.netty.handler.codec.mqtt.MqttSubAckMessage;
import io.netty.handler.codec.mqtt.MqttUnsubAckMessage; import io.netty.handler.codec.mqtt.MqttUnsubAckMessage;
/**
* Created by Valerii Sosliuk on 12/30/2017.
*/
public interface MqttClientCallback { public interface MqttClientCallback {
/** /**
@ -53,4 +50,5 @@ public interface MqttClientCallback {
default void onDisconnect(MqttMessage mqttDisconnectMessage) { default void onDisconnect(MqttMessage mqttDisconnectMessage) {
} }
} }

View File

@ -28,23 +28,46 @@ import java.util.Random;
@SuppressWarnings({"WeakerAccess", "unused"}) @SuppressWarnings({"WeakerAccess", "unused"})
public final class MqttClientConfig { public final class MqttClientConfig {
@Getter
private final SslContext sslContext; private final SslContext sslContext;
private final String randomClientId; private final String randomClientId;
@Getter @Getter
@Setter @Setter
private String ownerId; // [TenantId][IntegrationId] or [TenantId][RuleNodeId] for exceptions logging purposes private String ownerId; // [TenantId][IntegrationId] or [TenantId][RuleNodeId] for exceptions logging purposes
@Nonnull
@Getter
private String clientId; private String clientId;
@Getter
private int timeoutSeconds = 60; private int timeoutSeconds = 60;
@Getter
private MqttVersion protocolVersion = MqttVersion.MQTT_3_1; private MqttVersion protocolVersion = MqttVersion.MQTT_3_1;
@Nullable private String username = null; @Nullable
@Nullable private String password = null; @Getter
@Setter
private String username = null;
@Nullable
@Getter
@Setter
private String password = null;
@Getter
@Setter
private boolean cleanSession = true; private boolean cleanSession = true;
@Nullable private MqttLastWill lastWill; @Nullable
@Getter
@Setter
private MqttLastWill lastWill;
@Setter
@Getter
private Class<? extends Channel> channelClass = NioSocketChannel.class; private Class<? extends Channel> channelClass = NioSocketChannel.class;
@Getter
@Setter
private boolean reconnect = true; private boolean reconnect = true;
@Getter
private long reconnectDelay = 1L; private long reconnectDelay = 1L;
@Getter
private int maxBytesInMessage = 8092; private int maxBytesInMessage = 8092;
@Getter @Getter
@ -74,109 +97,37 @@ public final class MqttClientConfig {
public MqttClientConfig(SslContext sslContext) { public MqttClientConfig(SslContext sslContext) {
this.sslContext = sslContext; this.sslContext = sslContext;
Random random = new Random(); Random random = new Random();
String id = "netty-mqtt/"; StringBuilder id = new StringBuilder("netty-mqtt/");
String[] options = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".split(""); String[] options = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".split("");
for(int i = 0; i < 8; i++){ for (int i = 0; i < 8; i++) {
id += options[random.nextInt(options.length)]; id.append(options[random.nextInt(options.length)]);
} }
this.clientId = id; this.clientId = id.toString();
this.randomClientId = id; this.randomClientId = id.toString();
}
@Nonnull
public String getClientId() {
return clientId;
} }
public void setClientId(@Nullable String clientId) { public void setClientId(@Nullable String clientId) {
if(clientId == null){ if (clientId == null) {
this.clientId = randomClientId; this.clientId = randomClientId;
}else{ } else {
this.clientId = clientId; this.clientId = clientId;
} }
} }
public int getTimeoutSeconds() {
return timeoutSeconds;
}
public void setTimeoutSeconds(int timeoutSeconds) { public void setTimeoutSeconds(int timeoutSeconds) {
if(timeoutSeconds != -1 && timeoutSeconds <= 0){ if (timeoutSeconds != -1 && timeoutSeconds <= 0) {
throw new IllegalArgumentException("timeoutSeconds must be > 0 or -1"); throw new IllegalArgumentException("timeoutSeconds must be > 0 or -1");
} }
this.timeoutSeconds = timeoutSeconds; this.timeoutSeconds = timeoutSeconds;
} }
public MqttVersion getProtocolVersion() {
return protocolVersion;
}
public void setProtocolVersion(MqttVersion protocolVersion) { public void setProtocolVersion(MqttVersion protocolVersion) {
if(protocolVersion == null){ if (protocolVersion == null) {
throw new NullPointerException("protocolVersion"); throw new NullPointerException("protocolVersion");
} }
this.protocolVersion = protocolVersion; this.protocolVersion = protocolVersion;
} }
@Nullable
public String getUsername() {
return username;
}
public void setUsername(@Nullable String username) {
this.username = username;
}
@Nullable
public String getPassword() {
return password;
}
public void setPassword(@Nullable String password) {
this.password = password;
}
public boolean isCleanSession() {
return cleanSession;
}
public void setCleanSession(boolean cleanSession) {
this.cleanSession = cleanSession;
}
@Nullable
public MqttLastWill getLastWill() {
return lastWill;
}
public void setLastWill(@Nullable MqttLastWill lastWill) {
this.lastWill = lastWill;
}
public Class<? extends Channel> getChannelClass() {
return channelClass;
}
public void setChannelClass(Class<? extends Channel> channelClass) {
this.channelClass = channelClass;
}
public SslContext getSslContext() {
return sslContext;
}
public boolean isReconnect() {
return reconnect;
}
public void setReconnect(boolean reconnect) {
this.reconnect = reconnect;
}
public long getReconnectDelay() {
return reconnectDelay;
}
/** /**
* Sets the reconnect delay in seconds. Defaults to 1 second. * Sets the reconnect delay in seconds. Defaults to 1 second.
* @param reconnectDelay * @param reconnectDelay
@ -189,10 +140,6 @@ public final class MqttClientConfig {
this.reconnectDelay = reconnectDelay; this.reconnectDelay = reconnectDelay;
} }
public int getMaxBytesInMessage() {
return maxBytesInMessage;
}
/** /**
* Sets the maximum number of bytes in the message for the {@link io.netty.handler.codec.mqtt.MqttDecoder}. * Sets the maximum number of bytes in the message for the {@link io.netty.handler.codec.mqtt.MqttDecoder}.
* Default value is 8092 as specified by Netty. The absolute maximum size is 256MB as set by the MQTT spec. * Default value is 8092 as specified by Netty. The absolute maximum size is 256MB as set by the MQTT spec.
@ -206,4 +153,5 @@ public final class MqttClientConfig {
} }
this.maxBytesInMessage = maxBytesInMessage; this.maxBytesInMessage = maxBytesInMessage;
} }
} }

View File

@ -46,7 +46,9 @@ import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.concurrent.DefaultPromise; import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.Promise;
import lombok.AccessLevel;
import lombok.Getter; import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.thingsboard.common.util.ListeningExecutor; import org.thingsboard.common.util.ListeningExecutor;
@ -67,18 +69,28 @@ import java.util.concurrent.atomic.AtomicInteger;
@Slf4j @Slf4j
final class MqttClientImpl implements MqttClient { final class MqttClientImpl implements MqttClient {
@Getter(AccessLevel.PACKAGE)
private final Set<String> serverSubscriptions = new HashSet<>(); private final Set<String> serverSubscriptions = new HashSet<>();
@Getter(AccessLevel.PACKAGE)
private final ConcurrentMap<Integer, MqttPendingUnsubscription> pendingServerUnsubscribes = new ConcurrentHashMap<>(); private final ConcurrentMap<Integer, MqttPendingUnsubscription> pendingServerUnsubscribes = new ConcurrentHashMap<>();
@Getter(AccessLevel.PACKAGE)
private final ConcurrentMap<Integer, MqttIncomingQos2Publish> qos2PendingIncomingPublishes = new ConcurrentHashMap<>(); private final ConcurrentMap<Integer, MqttIncomingQos2Publish> qos2PendingIncomingPublishes = new ConcurrentHashMap<>();
@Getter(AccessLevel.PACKAGE)
private final ConcurrentMap<Integer, MqttPendingPublish> pendingPublishes = new ConcurrentHashMap<>(); private final ConcurrentMap<Integer, MqttPendingPublish> pendingPublishes = new ConcurrentHashMap<>();
@Getter(AccessLevel.PACKAGE)
private final HashMultimap<String, MqttSubscription> subscriptions = HashMultimap.create(); private final HashMultimap<String, MqttSubscription> subscriptions = HashMultimap.create();
@Getter(AccessLevel.PACKAGE)
private final ConcurrentMap<Integer, MqttPendingSubscription> pendingSubscriptions = new ConcurrentHashMap<>(); private final ConcurrentMap<Integer, MqttPendingSubscription> pendingSubscriptions = new ConcurrentHashMap<>();
@Getter(AccessLevel.PACKAGE)
private final Set<String> pendingSubscribeTopics = new HashSet<>(); private final Set<String> pendingSubscribeTopics = new HashSet<>();
@Getter(AccessLevel.PACKAGE)
private final HashMultimap<MqttHandler, MqttSubscription> handlerToSubscription = HashMultimap.create(); private final HashMultimap<MqttHandler, MqttSubscription> handlerToSubscription = HashMultimap.create();
private final AtomicInteger nextMessageId = new AtomicInteger(1); private final AtomicInteger nextMessageId = new AtomicInteger(1);
@Getter
private final MqttClientConfig clientConfig; private final MqttClientConfig clientConfig;
@Getter(AccessLevel.PACKAGE)
private final MqttHandler defaultHandler; private final MqttHandler defaultHandler;
private final ReconnectStrategy reconnectStrategy; private final ReconnectStrategy reconnectStrategy;
@ -88,12 +100,15 @@ final class MqttClientImpl implements MqttClient {
private volatile Channel channel; private volatile Channel channel;
private volatile boolean disconnected = false; private volatile boolean disconnected = false;
@Getter
private volatile boolean reconnect = false; private volatile boolean reconnect = false;
private String host; private String host;
private int port; private int port;
@Getter @Getter
@Setter
private MqttClientCallback callback; private MqttClientCallback callback;
@Getter
private final ListeningExecutor handlerExecutor; private final ListeningExecutor handlerExecutor;
private final static int DISCONNECT_FALLBACK_DELAY_SECS = 1; private final static int DISCONNECT_FALLBACK_DELAY_SECS = 1;
@ -240,11 +255,6 @@ final class MqttClientImpl implements MqttClient {
this.eventLoop = eventLoop; this.eventLoop = eventLoop;
} }
@Override
public ListeningExecutor getHandlerExecutor() {
return this.handlerExecutor;
}
/** /**
* Subscribe on the given topic. When a message is received, MqttClient will invoke the {@link MqttHandler#onMessage(String, ByteBuf)} function of the given handler * Subscribe on the given topic. When a message is received, MqttClient will invoke the {@link MqttHandler#onMessage(String, ByteBuf)} function of the given handler
* *
@ -446,16 +456,6 @@ final class MqttClientImpl implements MqttClient {
return future; return future;
} }
/**
* Retrieve the MqttClient configuration
*
* @return The {@link MqttClientConfig} instance we use
*/
@Override
public MqttClientConfig getClientConfig() {
return clientConfig;
}
@Override @Override
public void disconnect() { public void disconnect() {
if (disconnected) { if (disconnected) {
@ -480,25 +480,15 @@ final class MqttClientImpl implements MqttClient {
} }
} }
@Override
public void setCallback(MqttClientCallback callback) {
this.callback = callback;
}
///////////////////////////////////////////// PRIVATE API ///////////////////////////////////////////// ///////////////////////////////////////////// PRIVATE API /////////////////////////////////////////////
public boolean isReconnect() {
return reconnect;
}
public void onSuccessfulReconnect() { public void onSuccessfulReconnect() {
if (callback != null) { if (callback != null) {
callback.onSuccessfulReconnect(); callback.onSuccessfulReconnect();
} }
} }
ChannelFuture sendAndFlushPacket(Object message) { ChannelFuture sendAndFlushPacket(Object message) {
if (this.channel == null) { if (this.channel == null) {
return null; return null;
@ -576,7 +566,7 @@ final class MqttClientImpl implements MqttClient {
} }
private void checkSubscriptions(String topic, Promise<Void> promise) { private void checkSubscriptions(String topic, Promise<Void> promise) {
if (!(this.subscriptions.containsKey(topic) && this.subscriptions.get(topic).size() != 0) && this.serverSubscriptions.contains(topic)) { if (!(this.subscriptions.containsKey(topic) && !this.subscriptions.get(topic).isEmpty()) && this.serverSubscriptions.contains(topic)) {
MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.UNSUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, false, 0); MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.UNSUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader = getNewMessageId(); MqttMessageIdVariableHeader variableHeader = getNewMessageId();
MqttUnsubscribePayload payload = new MqttUnsubscribePayload(Collections.singletonList(topic)); MqttUnsubscribePayload payload = new MqttUnsubscribePayload(Collections.singletonList(topic));
@ -614,38 +604,6 @@ final class MqttClientImpl implements MqttClient {
} }
} }
ConcurrentMap<Integer, MqttPendingSubscription> getPendingSubscriptions() {
return pendingSubscriptions;
}
HashMultimap<String, MqttSubscription> getSubscriptions() {
return subscriptions;
}
Set<String> getPendingSubscribeTopics() {
return pendingSubscribeTopics;
}
HashMultimap<MqttHandler, MqttSubscription> getHandlerToSubscription() {
return handlerToSubscription;
}
Set<String> getServerSubscriptions() {
return serverSubscriptions;
}
ConcurrentMap<Integer, MqttPendingUnsubscription> getPendingServerUnsubscribes() {
return pendingServerUnsubscribes;
}
ConcurrentMap<Integer, MqttPendingPublish> getPendingPublishes() {
return pendingPublishes;
}
ConcurrentMap<Integer, MqttIncomingQos2Publish> getQos2PendingIncomingPublishes() {
return qos2PendingIncomingPublishes;
}
private class MqttChannelInitializer extends ChannelInitializer<SocketChannel> { private class MqttChannelInitializer extends ChannelInitializer<SocketChannel> {
private final Promise<MqttConnectResult> connectFuture; private final Promise<MqttConnectResult> connectFuture;
@ -673,10 +631,7 @@ final class MqttClientImpl implements MqttClient {
ch.pipeline().addLast("mqttPingHandler", new MqttPingHandler(MqttClientImpl.this.clientConfig.getTimeoutSeconds())); ch.pipeline().addLast("mqttPingHandler", new MqttPingHandler(MqttClientImpl.this.clientConfig.getTimeoutSeconds()));
ch.pipeline().addLast("mqttHandler", new MqttChannelHandler(MqttClientImpl.this, connectFuture)); ch.pipeline().addLast("mqttHandler", new MqttChannelHandler(MqttClientImpl.this, connectFuture));
} }
}
MqttHandler getDefaultHandler() {
return defaultHandler;
} }
} }

View File

@ -17,14 +17,18 @@ package org.thingsboard.mqtt;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.handler.codec.mqtt.MqttConnectReturnCode; import io.netty.handler.codec.mqtt.MqttConnectReturnCode;
import lombok.Getter;
import lombok.ToString; import lombok.ToString;
@ToString @ToString
@SuppressWarnings({"WeakerAccess", "unused"}) @SuppressWarnings({"WeakerAccess", "unused"})
public final class MqttConnectResult { public final class MqttConnectResult {
@Getter
private final boolean success; private final boolean success;
@Getter
private final MqttConnectReturnCode returnCode; private final MqttConnectReturnCode returnCode;
@Getter
private final ChannelFuture closeFuture; private final ChannelFuture closeFuture;
MqttConnectResult(boolean success, MqttConnectReturnCode returnCode, ChannelFuture closeFuture) { MqttConnectResult(boolean success, MqttConnectReturnCode returnCode, ChannelFuture closeFuture) {
@ -33,16 +37,4 @@ public final class MqttConnectResult {
this.closeFuture = closeFuture; this.closeFuture = closeFuture;
} }
public boolean isSuccess() {
return success;
}
public MqttConnectReturnCode getReturnCode() {
return returnCode;
}
public ChannelFuture getCloseFuture() {
return closeFuture;
}
} }

View File

@ -15,16 +15,23 @@
*/ */
package org.thingsboard.mqtt; package org.thingsboard.mqtt;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import java.util.regex.Pattern; import java.util.regex.Pattern;
final class MqttSubscription { final class MqttSubscription {
@Getter(AccessLevel.PACKAGE)
private final String topic; private final String topic;
private final Pattern topicRegex; private final Pattern topicRegex;
@Getter
private final MqttHandler handler; private final MqttHandler handler;
@Getter(AccessLevel.PACKAGE)
private final boolean once; private final boolean once;
@Getter(AccessLevel.PACKAGE)
@Setter(AccessLevel.PACKAGE)
private volatile boolean called; private volatile boolean called;
MqttSubscription(String topic, MqttHandler handler, boolean once) { MqttSubscription(String topic, MqttHandler handler, boolean once) {
@ -40,22 +47,6 @@ final class MqttSubscription {
this.topicRegex = Pattern.compile(topic.replace("+", "[^/]+").replace("#", ".+") + "$"); this.topicRegex = Pattern.compile(topic.replace("+", "[^/]+").replace("#", ".+") + "$");
} }
String getTopic() {
return topic;
}
public MqttHandler getHandler() {
return handler;
}
boolean isOnce() {
return once;
}
boolean isCalled() {
return called;
}
boolean matches(String topic) { boolean matches(String topic) {
return this.topicRegex.matcher(topic).matches(); return this.topicRegex.matcher(topic).matches();
} }
@ -78,7 +69,4 @@ final class MqttSubscription {
return result; return result;
} }
void setCalled(boolean called) {
this.called = called;
}
} }

View File

@ -45,7 +45,6 @@ import org.thingsboard.server.common.msg.TbMsg;
import org.thingsboard.server.common.msg.TbMsgMetaData; import org.thingsboard.server.common.msg.TbMsgMetaData;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
@ -64,12 +63,10 @@ import java.util.concurrent.TimeoutException;
) )
public class TbMqttNode extends TbAbstractExternalNode { public class TbMqttNode extends TbAbstractExternalNode {
private static final Charset UTF8 = StandardCharsets.UTF_8; private static final int MQTT_3_MAX_CLIENT_ID_LENGTH = 23;
private static final int MQTT_5_MAX_CLIENT_ID_LENGTH = 256;
private static final String ERROR = "error";
protected TbMqttNodeConfiguration mqttNodeConfiguration; protected TbMqttNodeConfiguration mqttNodeConfiguration;
protected MqttClient mqttClient; protected MqttClient mqttClient;
@Override @Override
@ -87,9 +84,9 @@ public class TbMqttNode extends TbAbstractExternalNode {
@Override @Override
public void onMsg(TbContext ctx, TbMsg msg) { public void onMsg(TbContext ctx, TbMsg msg) {
String topic = TbNodeUtils.processPattern(this.mqttNodeConfiguration.getTopicPattern(), msg); String topic = TbNodeUtils.processPattern(mqttNodeConfiguration.getTopicPattern(), msg);
var tbMsg = ackIfNeeded(ctx, msg); var tbMsg = ackIfNeeded(ctx, msg);
this.mqttClient.publish(topic, Unpooled.wrappedBuffer(getData(tbMsg, mqttNodeConfiguration.isParseToPlainText()).getBytes(UTF8)), this.mqttClient.publish(topic, Unpooled.wrappedBuffer(getData(tbMsg, mqttNodeConfiguration.isParseToPlainText()).getBytes(StandardCharsets.UTF_8)),
MqttQoS.AT_LEAST_ONCE, mqttNodeConfiguration.isRetainedMessage()) MqttQoS.AT_LEAST_ONCE, mqttNodeConfiguration.isRetainedMessage())
.addListener(future -> { .addListener(future -> {
if (future.isSuccess()) { if (future.isSuccess()) {
@ -103,7 +100,7 @@ public class TbMqttNode extends TbAbstractExternalNode {
private TbMsg processException(TbMsg origMsg, Throwable e) { private TbMsg processException(TbMsg origMsg, Throwable e) {
TbMsgMetaData metaData = origMsg.getMetaData().copy(); TbMsgMetaData metaData = origMsg.getMetaData().copy();
metaData.putValue(ERROR, e.getClass() + ": " + e.getMessage()); metaData.putValue("error", e.getClass() + ": " + e.getMessage());
return origMsg.transform() return origMsg.transform()
.metaData(metaData) .metaData(metaData)
.build(); .build();
@ -111,8 +108,8 @@ public class TbMqttNode extends TbAbstractExternalNode {
@Override @Override
public void destroy() { public void destroy() {
if (this.mqttClient != null) { if (mqttClient != null) {
this.mqttClient.disconnect(); mqttClient.disconnect();
} }
} }
@ -123,11 +120,11 @@ public class TbMqttNode extends TbAbstractExternalNode {
protected MqttClient initClient(TbContext ctx) throws Exception { protected MqttClient initClient(TbContext ctx) throws Exception {
MqttClientConfig config = new MqttClientConfig(getSslContext()); MqttClientConfig config = new MqttClientConfig(getSslContext());
config.setOwnerId(getOwnerId(ctx)); config.setOwnerId(getOwnerId(ctx));
if (!StringUtils.isEmpty(this.mqttNodeConfiguration.getClientId())) { if (!StringUtils.isEmpty(mqttNodeConfiguration.getClientId())) {
config.setClientId(getClientId(ctx)); config.setClientId(getClientId(ctx));
} }
config.setCleanSession(this.mqttNodeConfiguration.isCleanSession()); config.setCleanSession(mqttNodeConfiguration.isCleanSession());
config.setProtocolVersion(this.mqttNodeConfiguration.getProtocolVersion()); config.setProtocolVersion(mqttNodeConfiguration.getProtocolVersion());
MqttClientSettings mqttClientSettings = ctx.getMqttClientSettings(); MqttClientSettings mqttClientSettings = ctx.getMqttClientSettings();
config.setRetransmissionConfig(new MqttClientConfig.RetransmissionConfig( config.setRetransmissionConfig(new MqttClientConfig.RetransmissionConfig(
@ -139,32 +136,32 @@ public class TbMqttNode extends TbAbstractExternalNode {
prepareMqttClientConfig(config); prepareMqttClientConfig(config);
MqttClient client = getMqttClient(ctx, config); MqttClient client = getMqttClient(ctx, config);
client.setEventLoop(ctx.getSharedEventLoop()); client.setEventLoop(ctx.getSharedEventLoop());
Promise<MqttConnectResult> connectFuture = client.connect(this.mqttNodeConfiguration.getHost(), this.mqttNodeConfiguration.getPort()); Promise<MqttConnectResult> connectFuture = client.connect(mqttNodeConfiguration.getHost(), mqttNodeConfiguration.getPort());
MqttConnectResult result; MqttConnectResult result;
try { try {
result = connectFuture.get(this.mqttNodeConfiguration.getConnectTimeoutSec(), TimeUnit.SECONDS); result = connectFuture.get(mqttNodeConfiguration.getConnectTimeoutSec(), TimeUnit.SECONDS);
} catch (TimeoutException ex) { } catch (TimeoutException ex) {
connectFuture.cancel(true); connectFuture.cancel(true);
client.disconnect(); client.disconnect();
String hostPort = this.mqttNodeConfiguration.getHost() + ":" + this.mqttNodeConfiguration.getPort(); String hostPort = mqttNodeConfiguration.getHost() + ":" + mqttNodeConfiguration.getPort();
throw new RuntimeException(String.format("Failed to connect to MQTT broker at %s.", hostPort)); throw new RuntimeException(String.format("Failed to connect to MQTT broker at %s.", hostPort));
} }
if (!result.isSuccess()) { if (!result.isSuccess()) {
connectFuture.cancel(true); connectFuture.cancel(true);
client.disconnect(); client.disconnect();
String hostPort = this.mqttNodeConfiguration.getHost() + ":" + this.mqttNodeConfiguration.getPort(); String hostPort = mqttNodeConfiguration.getHost() + ":" + mqttNodeConfiguration.getPort();
throw new RuntimeException(String.format("Failed to connect to MQTT broker at %s. Result code is: %s", hostPort, result.getReturnCode())); throw new RuntimeException(String.format("Failed to connect to MQTT broker at %s. Result code is: %s", hostPort, result.getReturnCode()));
} }
return client; return client;
} }
private String getClientId(TbContext ctx) throws TbNodeException { private String getClientId(TbContext ctx) throws TbNodeException {
String clientId = this.mqttNodeConfiguration.isAppendClientIdSuffix() ? String clientId = mqttNodeConfiguration.isAppendClientIdSuffix() ?
this.mqttNodeConfiguration.getClientId() + "_" + ctx.getServiceId() : mqttNodeConfiguration.getClientId() + "_" + ctx.getServiceId() :
this.mqttNodeConfiguration.getClientId(); mqttNodeConfiguration.getClientId();
if (clientId.length() > 23) { int maxLength = mqttNodeConfiguration.getProtocolVersion() == MqttVersion.MQTT_3_1 ? MQTT_3_MAX_CLIENT_ID_LENGTH : MQTT_5_MAX_CLIENT_ID_LENGTH;
throw new TbNodeException("Client ID is too long '" + clientId + "'. " + if (clientId.length() > maxLength) {
"The length of Client ID cannot be longer than 23, but current length is " + clientId.length() + ".", true); throw new TbNodeException("The length of Client ID cannot be longer than " + maxLength + ", but current length is " + clientId.length() + ".", true);
} }
return clientId; return clientId;
} }
@ -174,7 +171,7 @@ public class TbMqttNode extends TbAbstractExternalNode {
} }
protected void prepareMqttClientConfig(MqttClientConfig config) { protected void prepareMqttClientConfig(MqttClientConfig config) {
ClientCredentials credentials = this.mqttNodeConfiguration.getCredentials(); ClientCredentials credentials = mqttNodeConfiguration.getCredentials();
if (credentials.getType() == CredentialsType.BASIC) { if (credentials.getType() == CredentialsType.BASIC) {
BasicCredentials basicCredentials = (BasicCredentials) credentials; BasicCredentials basicCredentials = (BasicCredentials) credentials;
config.setUsername(basicCredentials.getUsername()); config.setUsername(basicCredentials.getUsername());
@ -183,7 +180,7 @@ public class TbMqttNode extends TbAbstractExternalNode {
} }
private SslContext getSslContext() throws SSLException { private SslContext getSslContext() throws SSLException {
return this.mqttNodeConfiguration.isSsl() ? this.mqttNodeConfiguration.getCredentials().initSslContext() : null; return mqttNodeConfiguration.isSsl() ? mqttNodeConfiguration.getCredentials().initSslContext() : null;
} }
private String getData(TbMsg tbMsg, boolean parseToPlainText) { private String getData(TbMsg tbMsg, boolean parseToPlainText) {

View File

@ -212,40 +212,45 @@ public class TbMqttNodeTest extends AbstractRuleNodeUpgradeTest {
assertThatNoException().isThrownBy(() -> mqttNode.init(ctxMock, new TbNodeConfiguration(JacksonUtil.valueToTree(mqttNodeConfig)))); assertThatNoException().isThrownBy(() -> mqttNode.init(ctxMock, new TbNodeConfiguration(JacksonUtil.valueToTree(mqttNodeConfig))));
} }
@Test @ParameterizedTest
public void givenClientIdIsTooLong_whenInit_thenThrowsException() { @MethodSource("provideInvalidClientIdScenarios")
String invalidClientId = "vhfrbeb38ygwfwrgfwefgterhytjytj"; public void givenInvalidClientId_whenInit_thenThrowsException(MqttVersion version, int maxLength, int repeat, String serviceId, boolean appendSuffix) {
mqttNodeConfig.setClientId(invalidClientId); String baseClientId = "x".repeat(repeat);
mqttNodeConfig.setClientId(baseClientId);
mqttNodeConfig.setAppendClientIdSuffix(appendSuffix);
mqttNodeConfig.setProtocolVersion(version);
given(ctxMock.getTenantId()).willReturn(TENANT_ID); given(ctxMock.getTenantId()).willReturn(TENANT_ID);
given(ctxMock.getSelf()).willReturn(new RuleNode(RULE_NODE_ID)); given(ctxMock.getSelf()).willReturn(new RuleNode(RULE_NODE_ID));
String clientId = appendSuffix ? baseClientId + "_" + serviceId : baseClientId;
if (appendSuffix) {
given(ctxMock.getServiceId()).willReturn(serviceId);
}
String expectedMessage = "The length of Client ID cannot be longer than " + maxLength + ", but current length is " + clientId.length() + ".";
assertThatThrownBy(() -> mqttNode.init(ctxMock, new TbNodeConfiguration(JacksonUtil.valueToTree(mqttNodeConfig)))) assertThatThrownBy(() -> mqttNode.init(ctxMock, new TbNodeConfiguration(JacksonUtil.valueToTree(mqttNodeConfig))))
.isInstanceOf(TbNodeException.class) .isInstanceOf(TbNodeException.class)
.hasMessage("Client ID is too long '" + invalidClientId + "'. " + .hasMessage(expectedMessage)
"The length of Client ID cannot be longer than 23, but current length is " + invalidClientId.length() + ".")
.extracting(e -> ((TbNodeException) e).isUnrecoverable()) .extracting(e -> ((TbNodeException) e).isUnrecoverable())
.isEqualTo(true); .isEqualTo(true);
} }
@Test private static Stream<Arguments> provideInvalidClientIdScenarios() {
public void givenClientIdIsOkAndAppendClientIdSuffixIsTrue_whenInit_thenClientIdBecomesInvalidAndThrowsException() { return Stream.of(
String validClientId = "fertjnhnjj4ge"; // MQTT_5, too long clientId
mqttNodeConfig.setClientId("fertjnhnjj4ge"); Arguments.of(MqttVersion.MQTT_5, 256, 257, null, false),
mqttNodeConfig.setAppendClientIdSuffix(true);
given(ctxMock.getTenantId()).willReturn(TENANT_ID); // MQTT_5, base + suffix exceeds
given(ctxMock.getSelf()).willReturn(new RuleNode(RULE_NODE_ID)); Arguments.of(MqttVersion.MQTT_5, 256, 250, "test-service", true),
String serviceId = "test-service";
given(ctxMock.getServiceId()).willReturn(serviceId);
String resultedClientId = validClientId + "_" + serviceId; // MQTT_3_1, too long clientId
assertThatThrownBy(() -> mqttNode.init(ctxMock, new TbNodeConfiguration(JacksonUtil.valueToTree(mqttNodeConfig)))) Arguments.of(MqttVersion.MQTT_3_1, 23, 24, null, false),
.isInstanceOf(TbNodeException.class)
.hasMessage("Client ID is too long '" + resultedClientId + "'. " + // MQTT_3_1, base + suffix exceeds
"The length of Client ID cannot be longer than 23, but current length is " + resultedClientId.length() + ".") Arguments.of(MqttVersion.MQTT_3_1, 23, 5, "verylongservicename", true)
.extracting(e -> ((TbNodeException) e).isUnrecoverable()) );
.isEqualTo(true);
} }
@Test @Test