Correct close and cleanup of the MQTT session context

This commit is contained in:
Andrii Shvaika 2021-08-05 17:22:40 +03:00
parent c1d8aa1370
commit daac250c2e
3 changed files with 84 additions and 74 deletions

View File

@ -81,6 +81,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
@ -153,10 +154,11 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
if (message.decoderResult().isSuccess()) {
processMqttMsg(ctx, message);
} else {
log.error("[{}] Message processing failed: {}", sessionId, message.decoderResult().cause().getMessage());
log.error("[{}] Message decoding failed: {}", sessionId, message.decoderResult().cause().getMessage());
ctx.close();
}
} else {
log.debug("[{}] Received non mqtt message: {}", sessionId, msg.getClass().getSimpleName());
ctx.close();
}
} finally {
@ -168,7 +170,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
address = getAddress(ctx);
if (msg.fixedHeader() == null) {
log.info("[{}:{}] Invalid message received", address.getHostName(), address.getPort());
processDisconnect(ctx);
ctx.close();
return;
}
deviceSessionCtx.setChannel(ctx);
@ -208,8 +210,8 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
}
}
} else {
log.debug("[{}] Unsupported topic for provisioning requests: {}!", sessionId, topicName);
ctx.close();
throw new RuntimeException("Unsupported topic for provisioning requests!");
}
} catch (RuntimeException | AdaptorException e) {
log.warn("[{}] Failed to process publish msg [{}][{}]", sessionId, topicName, msgId, e);
@ -220,48 +222,30 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
ctx.writeAndFlush(new MqttMessage(new MqttFixedHeader(PINGRESP, false, AT_MOST_ONCE, false, 0)));
break;
case DISCONNECT:
if (checkConnected(ctx, msg)) {
processDisconnect(ctx);
}
ctx.close();
break;
}
}
void enqueueRegularSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) {
final int queueSize = deviceSessionCtx.getMsgQueueSize().incrementAndGet();
if (queueSize > context.getMessageQueueSizePerDeviceLimit()) {
log.warn("Closing current session because msq queue size for device {} exceed limit {} with msgQueueSize counter {} and actual queue size {}",
deviceSessionCtx.getDeviceId(), context.getMessageQueueSizePerDeviceLimit(), queueSize, deviceSessionCtx.getMsgQueue().size());
final int queueSize = deviceSessionCtx.getMsgQueueSize();
if (queueSize >= context.getMessageQueueSizePerDeviceLimit()) {
log.info("Closing current session because msq queue size for device {} exceed limit {} with msgQueueSize counter {} and actual queue size {}",
deviceSessionCtx.getDeviceId(), context.getMessageQueueSizePerDeviceLimit(), queueSize, deviceSessionCtx.getMsgQueueSize());
ctx.close();
return;
}
deviceSessionCtx.getMsgQueue().add(msg);
ReferenceCountUtil.retain(msg);
deviceSessionCtx.addToQueue(msg);
processMsgQueue(ctx); //Under the normal conditions the msg queue will contain 0 messages. Many messages will be processed on device connect event in separate thread pool
}
void processMsgQueue(ChannelHandlerContext ctx) {
if (!deviceSessionCtx.isConnected()) {
log.trace("[{}][{}] Postpone processing msg due to device is not connected. Msg queue size is {}", sessionId, deviceSessionCtx.getDeviceId(), deviceSessionCtx.getMsgQueue().size());
log.trace("[{}][{}] Postpone processing msg due to device is not connected. Msg queue size is {}", sessionId, deviceSessionCtx.getDeviceId(), deviceSessionCtx.getMsgQueueSize());
return;
}
while (!deviceSessionCtx.getMsgQueue().isEmpty()) {
if (deviceSessionCtx.getMsgQueueProcessorLock().tryLock()) {
try {
MqttMessage msg;
while ((msg = deviceSessionCtx.getMsgQueue().poll()) != null) {
deviceSessionCtx.getMsgQueueSize().decrementAndGet();
processRegularSessionMsg(ctx, msg);
ReferenceCountUtil.safeRelease(msg);
}
} finally {
deviceSessionCtx.getMsgQueueProcessorLock().unlock();
}
} else {
return;
}
}
deviceSessionCtx.tryProcessQueuedMsgs(msg -> processRegularSessionMsg(ctx, msg));
}
void processRegularSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) {
@ -282,9 +266,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
}
break;
case DISCONNECT:
if (checkConnected(ctx, msg)) {
processDisconnect(ctx);
}
ctx.close();
break;
case PUBACK:
int msgId = ((MqttPubAckMessage) msg).variableHeader().messageId();
@ -438,7 +420,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override
public void onError(Throwable e) {
log.trace("[{}] Failed to publish msg: {}", sessionId, msg, e);
processDisconnect(ctx);
ctx.close();
}
};
}
@ -464,7 +446,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
} else {
deviceSessionCtx.getContext().getProtoMqttAdaptor().convertToPublish(deviceSessionCtx, provisionResponseMsg).ifPresent(deviceSessionCtx.getChannel()::writeAndFlush);
}
scheduler.schedule(() -> processDisconnect(ctx), 60, TimeUnit.SECONDS);
scheduler.schedule((Callable<ChannelFuture>) ctx::close, 60, TimeUnit.SECONDS);
} catch (Exception e) {
log.trace("[{}] Failed to convert device attributes response to MQTT msg", sessionId, e);
}
@ -473,7 +455,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override
public void onError(Throwable e) {
log.trace("[{}] Failed to publish msg: {}", sessionId, msg, e);
processDisconnect(ctx);
ctx.close();
}
}
@ -508,7 +490,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override
public void onError(Throwable e) {
log.trace("[{}] Failed to get firmware: {}", sessionId, msg, e);
processDisconnect(ctx);
ctx.close();
}
}
@ -530,7 +512,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
deviceSessionCtx.getChannel().writeAndFlush(deviceSessionCtx
.getPayloadAdaptor()
.createMqttPublishMsg(deviceSessionCtx, MqttTopics.DEVICE_FIRMWARE_ERROR_TOPIC, error.getBytes()));
processDisconnect(ctx);
ctx.close();
}
private void processSubscribe(ChannelHandlerContext ctx, MqttSubscribeMessage mqttMsg) {
@ -699,6 +681,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
});
} catch (Exception e) {
ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage));
log.trace("[{}] X509 auth failure: {}", sessionId, address, e);
ctx.close();
}
}
@ -716,12 +699,6 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
return null;
}
void processDisconnect(ChannelHandlerContext ctx) {
ctx.close();
log.info("[{}] Client disconnected!", sessionId);
doDisconnect();
}
private MqttConnAckMessage createMqttConnAckMsg(MqttConnectReturnCode returnCode, MqttConnectMessage msg) {
MqttFixedHeader mqttFixedHeader =
new MqttFixedHeader(CONNACK, false, AT_MOST_ONCE, false, 0);
@ -766,7 +743,6 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
return true;
} else {
log.info("[{}] Closing current session due to invalid msg order: {}", sessionId, msg);
ctx.close();
return false;
}
}
@ -791,11 +767,13 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override
public void operationComplete(Future<? super Void> future) throws Exception {
log.trace("[{}] Channel closed!", sessionId);
doDisconnect();
}
private void doDisconnect() {
public void doDisconnect() {
if (deviceSessionCtx.isConnected()) {
log.info("[{}] Client disconnected!", sessionId);
transportService.process(deviceSessionCtx.getSessionInfo(), DefaultTransportService.getSessionEventMsg(SessionEvent.CLOSED), null);
transportService.deregisterSession(deviceSessionCtx.getSessionInfo());
if (gatewaySessionHandler != null) {
@ -803,11 +781,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
}
deviceSessionCtx.setDisconnected();
}
if (!deviceSessionCtx.getMsgQueue().isEmpty()) {
log.warn("doDisconnect for device {} but unprocessed messages {} left in the msg queue", deviceSessionCtx.getDeviceId(), deviceSessionCtx.getMsgQueue().size());
deviceSessionCtx.getMsgQueue().clear();
}
deviceSessionCtx.release();
}
@ -866,7 +840,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override
public void onRemoteSessionCloseCommand(UUID sessionId, TransportProtos.SessionCloseNotificationProto sessionCloseNotification) {
log.trace("[{}] Received the remote command to close the session: {}", sessionId, sessionCloseNotification.getMessage());
processDisconnect(deviceSessionCtx.getChannel());
deviceSessionCtx.getChannel().close();
}
@Override

View File

@ -19,6 +19,7 @@ import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.util.ReferenceCountUtil;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
@ -35,12 +36,16 @@ import org.thingsboard.server.transport.mqtt.adaptors.MqttTransportAdaptor;
import org.thingsboard.server.transport.mqtt.util.MqttTopicFilter;
import org.thingsboard.server.transport.mqtt.util.MqttTopicFilterFactory;
import java.util.Collection;
import java.util.Collections;
import java.util.Queue;
import java.util.UUID;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
/**
* @author Andrew Shvayka
@ -57,13 +62,11 @@ public class DeviceSessionCtx extends MqttDeviceAwareSessionContext {
private final AtomicInteger msgIdSeq = new AtomicInteger(0);
@Getter
private final ConcurrentLinkedQueue<MqttMessage> msgQueue = new ConcurrentLinkedQueue<>();
@Getter
private final Lock msgQueueProcessorLock = new ReentrantLock();
@Getter
private final AtomicInteger msgQueueSize = new AtomicInteger(0);
@Getter
@ -160,4 +163,49 @@ public class DeviceSessionCtx extends MqttDeviceAwareSessionContext {
rpcResponseDynamicMessageDescriptor = protoTransportPayloadConfig.getRpcResponseDynamicMessageDescriptor(protoTransportPayloadConfig.getDeviceRpcResponseProtoSchema());
rpcRequestDynamicMessageBuilder = protoTransportPayloadConfig.getRpcRequestDynamicMessageBuilder(protoTransportPayloadConfig.getDeviceRpcRequestProtoSchema());
}
public void addToQueue(MqttMessage msg) {
msgQueueSize.incrementAndGet();
ReferenceCountUtil.retain(msg);
msgQueue.add(msg);
}
public void tryProcessQueuedMsgs(Consumer<MqttMessage> msgProcessor) {
while (!msgQueue.isEmpty()) {
if (msgQueueProcessorLock.tryLock()) {
try {
MqttMessage msg;
while ((msg = msgQueue.poll()) != null) {
try {
msgQueueSize.decrementAndGet();
msgProcessor.accept(msg);
} finally {
ReferenceCountUtil.safeRelease(msg);
}
}
} finally {
msgQueueProcessorLock.unlock();
}
} else {
return;
}
}
}
public int getMsgQueueSize() {
return msgQueueSize.get();
}
public void release() {
if (!msgQueue.isEmpty()) {
log.warn("doDisconnect for device {} but unprocessed messages {} left in the msg queue", getDeviceId(), msgQueue.size());
msgQueue.forEach(ReferenceCountUtil::safeRelease);
msgQueue.clear();
}
}
public Collection<MqttMessage> getMsgQueueSnapshot(){
return Collections.unmodifiableCollection(msgQueue);
}
}

View File

@ -111,18 +111,6 @@ public class MqttTransportHandlerTest {
return new MqttPublishMessage(mqttFixedHeader, variableHeader, payload);
}
@Test
public void givenMessageWithoutFixedHeader_whenProcessMqttMsg_thenProcessDisconnect() {
MqttFixedHeader mqttFixedHeader = null;
MqttMessage msg = new MqttMessage(mqttFixedHeader);
willDoNothing().given(handler).processDisconnect(ctx);
handler.processMqttMsg(ctx, msg);
assertThat(handler.address, is(IP_ADDR));
verify(handler, times(1)).processDisconnect(ctx);
}
@Test
public void givenMqttConnectMessage_whenProcessMqttMsg_thenProcessConnect() {
MqttConnectMessage msg = getMqttConnectMessage();
@ -132,7 +120,7 @@ public class MqttTransportHandlerTest {
assertThat(handler.address, is(IP_ADDR));
assertThat(handler.deviceSessionCtx.getChannel(), is(ctx));
verify(handler, never()).processDisconnect(any());
verify(handler, never()).doDisconnect();
verify(handler, times(1)).processConnect(ctx, msg);
}
@ -140,8 +128,8 @@ public class MqttTransportHandlerTest {
public void givenQueueLimit_whenEnqueueRegularSessionMsgOverLimit_thenOK() {
List<MqttPublishMessage> messages = Stream.generate(this::getMqttPublishMessage).limit(MSG_QUEUE_LIMIT).collect(Collectors.toList());
messages.forEach(msg -> handler.enqueueRegularSessionMsg(ctx, msg));
assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(MSG_QUEUE_LIMIT));
assertThat(handler.deviceSessionCtx.getMsgQueue(), contains(messages.toArray()));
assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(MSG_QUEUE_LIMIT));
assertThat(handler.deviceSessionCtx.getMsgQueueSnapshot(), contains(messages.toArray()));
}
@Test
@ -152,7 +140,7 @@ public class MqttTransportHandlerTest {
messages.forEach((msg) -> handler.enqueueRegularSessionMsg(ctx, msg));
assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(limit));
assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(MSG_QUEUE_LIMIT));
verify(handler, times(limit)).enqueueRegularSessionMsg(any(), any());
verify(handler, times(MSG_QUEUE_LIMIT)).processMsgQueue(any());
verify(ctx, times(1)).close();
@ -169,9 +157,9 @@ public class MqttTransportHandlerTest {
assertThat(handler.address, is(IP_ADDR));
assertThat(handler.deviceSessionCtx.getChannel(), is(ctx));
assertThat(handler.deviceSessionCtx.isConnected(), is(false));
assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(MSG_QUEUE_LIMIT));
assertThat(handler.deviceSessionCtx.getMsgQueue(), contains(messages.toArray()));
verify(handler, never()).processDisconnect(any());
assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(MSG_QUEUE_LIMIT));
assertThat(handler.deviceSessionCtx.getMsgQueueSnapshot(), contains(messages.toArray()));
verify(handler, never()).doDisconnect();
verify(handler, times(1)).processConnect(any(), any());
verify(handler, times(MSG_QUEUE_LIMIT)).enqueueRegularSessionMsg(any(), any());
verify(handler, never()).processRegularSessionMsg(any(), any());
@ -212,8 +200,8 @@ public class MqttTransportHandlerTest {
assertThat(finishLatch.await(TIMEOUT, TimeUnit.SECONDS), is(true));
//then
assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(0));
assertThat(handler.deviceSessionCtx.getMsgQueue(), empty());
assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(0));
assertThat(handler.deviceSessionCtx.getMsgQueueSnapshot(), empty());
verify(handler, times(MSG_QUEUE_LIMIT)).processRegularSessionMsg(any(), any());
messages.forEach((msg) -> verify(handler, times(1)).processRegularSessionMsg(ctx, msg));
}