Correct close and cleanup of the MQTT session context

This commit is contained in:
Andrii Shvaika 2021-08-05 17:22:40 +03:00
parent c1d8aa1370
commit daac250c2e
3 changed files with 84 additions and 74 deletions

View File

@ -81,6 +81,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -153,10 +154,11 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
if (message.decoderResult().isSuccess()) { if (message.decoderResult().isSuccess()) {
processMqttMsg(ctx, message); processMqttMsg(ctx, message);
} else { } else {
log.error("[{}] Message processing failed: {}", sessionId, message.decoderResult().cause().getMessage()); log.error("[{}] Message decoding failed: {}", sessionId, message.decoderResult().cause().getMessage());
ctx.close(); ctx.close();
} }
} else { } else {
log.debug("[{}] Received non mqtt message: {}", sessionId, msg.getClass().getSimpleName());
ctx.close(); ctx.close();
} }
} finally { } finally {
@ -168,7 +170,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
address = getAddress(ctx); address = getAddress(ctx);
if (msg.fixedHeader() == null) { if (msg.fixedHeader() == null) {
log.info("[{}:{}] Invalid message received", address.getHostName(), address.getPort()); log.info("[{}:{}] Invalid message received", address.getHostName(), address.getPort());
processDisconnect(ctx); ctx.close();
return; return;
} }
deviceSessionCtx.setChannel(ctx); deviceSessionCtx.setChannel(ctx);
@ -208,8 +210,8 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
} }
} }
} else { } else {
log.debug("[{}] Unsupported topic for provisioning requests: {}!", sessionId, topicName);
ctx.close(); ctx.close();
throw new RuntimeException("Unsupported topic for provisioning requests!");
} }
} catch (RuntimeException | AdaptorException e) { } catch (RuntimeException | AdaptorException e) {
log.warn("[{}] Failed to process publish msg [{}][{}]", sessionId, topicName, msgId, e); log.warn("[{}] Failed to process publish msg [{}][{}]", sessionId, topicName, msgId, e);
@ -220,48 +222,30 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
ctx.writeAndFlush(new MqttMessage(new MqttFixedHeader(PINGRESP, false, AT_MOST_ONCE, false, 0))); ctx.writeAndFlush(new MqttMessage(new MqttFixedHeader(PINGRESP, false, AT_MOST_ONCE, false, 0)));
break; break;
case DISCONNECT: case DISCONNECT:
if (checkConnected(ctx, msg)) { ctx.close();
processDisconnect(ctx);
}
break; break;
} }
} }
void enqueueRegularSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) { void enqueueRegularSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) {
final int queueSize = deviceSessionCtx.getMsgQueueSize().incrementAndGet(); final int queueSize = deviceSessionCtx.getMsgQueueSize();
if (queueSize > context.getMessageQueueSizePerDeviceLimit()) { if (queueSize >= context.getMessageQueueSizePerDeviceLimit()) {
log.warn("Closing current session because msq queue size for device {} exceed limit {} with msgQueueSize counter {} and actual queue size {}", log.info("Closing current session because msq queue size for device {} exceed limit {} with msgQueueSize counter {} and actual queue size {}",
deviceSessionCtx.getDeviceId(), context.getMessageQueueSizePerDeviceLimit(), queueSize, deviceSessionCtx.getMsgQueue().size()); deviceSessionCtx.getDeviceId(), context.getMessageQueueSizePerDeviceLimit(), queueSize, deviceSessionCtx.getMsgQueueSize());
ctx.close(); ctx.close();
return; return;
} }
deviceSessionCtx.getMsgQueue().add(msg); deviceSessionCtx.addToQueue(msg);
ReferenceCountUtil.retain(msg);
processMsgQueue(ctx); //Under the normal conditions the msg queue will contain 0 messages. Many messages will be processed on device connect event in separate thread pool processMsgQueue(ctx); //Under the normal conditions the msg queue will contain 0 messages. Many messages will be processed on device connect event in separate thread pool
} }
void processMsgQueue(ChannelHandlerContext ctx) { void processMsgQueue(ChannelHandlerContext ctx) {
if (!deviceSessionCtx.isConnected()) { if (!deviceSessionCtx.isConnected()) {
log.trace("[{}][{}] Postpone processing msg due to device is not connected. Msg queue size is {}", sessionId, deviceSessionCtx.getDeviceId(), deviceSessionCtx.getMsgQueue().size()); log.trace("[{}][{}] Postpone processing msg due to device is not connected. Msg queue size is {}", sessionId, deviceSessionCtx.getDeviceId(), deviceSessionCtx.getMsgQueueSize());
return; return;
} }
while (!deviceSessionCtx.getMsgQueue().isEmpty()) { deviceSessionCtx.tryProcessQueuedMsgs(msg -> processRegularSessionMsg(ctx, msg));
if (deviceSessionCtx.getMsgQueueProcessorLock().tryLock()) {
try {
MqttMessage msg;
while ((msg = deviceSessionCtx.getMsgQueue().poll()) != null) {
deviceSessionCtx.getMsgQueueSize().decrementAndGet();
processRegularSessionMsg(ctx, msg);
ReferenceCountUtil.safeRelease(msg);
}
} finally {
deviceSessionCtx.getMsgQueueProcessorLock().unlock();
}
} else {
return;
}
}
} }
void processRegularSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) { void processRegularSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) {
@ -282,9 +266,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
} }
break; break;
case DISCONNECT: case DISCONNECT:
if (checkConnected(ctx, msg)) { ctx.close();
processDisconnect(ctx);
}
break; break;
case PUBACK: case PUBACK:
int msgId = ((MqttPubAckMessage) msg).variableHeader().messageId(); int msgId = ((MqttPubAckMessage) msg).variableHeader().messageId();
@ -438,7 +420,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override @Override
public void onError(Throwable e) { public void onError(Throwable e) {
log.trace("[{}] Failed to publish msg: {}", sessionId, msg, e); log.trace("[{}] Failed to publish msg: {}", sessionId, msg, e);
processDisconnect(ctx); ctx.close();
} }
}; };
} }
@ -464,7 +446,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
} else { } else {
deviceSessionCtx.getContext().getProtoMqttAdaptor().convertToPublish(deviceSessionCtx, provisionResponseMsg).ifPresent(deviceSessionCtx.getChannel()::writeAndFlush); deviceSessionCtx.getContext().getProtoMqttAdaptor().convertToPublish(deviceSessionCtx, provisionResponseMsg).ifPresent(deviceSessionCtx.getChannel()::writeAndFlush);
} }
scheduler.schedule(() -> processDisconnect(ctx), 60, TimeUnit.SECONDS); scheduler.schedule((Callable<ChannelFuture>) ctx::close, 60, TimeUnit.SECONDS);
} catch (Exception e) { } catch (Exception e) {
log.trace("[{}] Failed to convert device attributes response to MQTT msg", sessionId, e); log.trace("[{}] Failed to convert device attributes response to MQTT msg", sessionId, e);
} }
@ -473,7 +455,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override @Override
public void onError(Throwable e) { public void onError(Throwable e) {
log.trace("[{}] Failed to publish msg: {}", sessionId, msg, e); log.trace("[{}] Failed to publish msg: {}", sessionId, msg, e);
processDisconnect(ctx); ctx.close();
} }
} }
@ -508,7 +490,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override @Override
public void onError(Throwable e) { public void onError(Throwable e) {
log.trace("[{}] Failed to get firmware: {}", sessionId, msg, e); log.trace("[{}] Failed to get firmware: {}", sessionId, msg, e);
processDisconnect(ctx); ctx.close();
} }
} }
@ -530,7 +512,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
deviceSessionCtx.getChannel().writeAndFlush(deviceSessionCtx deviceSessionCtx.getChannel().writeAndFlush(deviceSessionCtx
.getPayloadAdaptor() .getPayloadAdaptor()
.createMqttPublishMsg(deviceSessionCtx, MqttTopics.DEVICE_FIRMWARE_ERROR_TOPIC, error.getBytes())); .createMqttPublishMsg(deviceSessionCtx, MqttTopics.DEVICE_FIRMWARE_ERROR_TOPIC, error.getBytes()));
processDisconnect(ctx); ctx.close();
} }
private void processSubscribe(ChannelHandlerContext ctx, MqttSubscribeMessage mqttMsg) { private void processSubscribe(ChannelHandlerContext ctx, MqttSubscribeMessage mqttMsg) {
@ -699,6 +681,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
}); });
} catch (Exception e) { } catch (Exception e) {
ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage)); ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage));
log.trace("[{}] X509 auth failure: {}", sessionId, address, e);
ctx.close(); ctx.close();
} }
} }
@ -716,12 +699,6 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
return null; return null;
} }
void processDisconnect(ChannelHandlerContext ctx) {
ctx.close();
log.info("[{}] Client disconnected!", sessionId);
doDisconnect();
}
private MqttConnAckMessage createMqttConnAckMsg(MqttConnectReturnCode returnCode, MqttConnectMessage msg) { private MqttConnAckMessage createMqttConnAckMsg(MqttConnectReturnCode returnCode, MqttConnectMessage msg) {
MqttFixedHeader mqttFixedHeader = MqttFixedHeader mqttFixedHeader =
new MqttFixedHeader(CONNACK, false, AT_MOST_ONCE, false, 0); new MqttFixedHeader(CONNACK, false, AT_MOST_ONCE, false, 0);
@ -766,7 +743,6 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
return true; return true;
} else { } else {
log.info("[{}] Closing current session due to invalid msg order: {}", sessionId, msg); log.info("[{}] Closing current session due to invalid msg order: {}", sessionId, msg);
ctx.close();
return false; return false;
} }
} }
@ -791,11 +767,13 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override @Override
public void operationComplete(Future<? super Void> future) throws Exception { public void operationComplete(Future<? super Void> future) throws Exception {
log.trace("[{}] Channel closed!", sessionId);
doDisconnect(); doDisconnect();
} }
private void doDisconnect() { public void doDisconnect() {
if (deviceSessionCtx.isConnected()) { if (deviceSessionCtx.isConnected()) {
log.info("[{}] Client disconnected!", sessionId);
transportService.process(deviceSessionCtx.getSessionInfo(), DefaultTransportService.getSessionEventMsg(SessionEvent.CLOSED), null); transportService.process(deviceSessionCtx.getSessionInfo(), DefaultTransportService.getSessionEventMsg(SessionEvent.CLOSED), null);
transportService.deregisterSession(deviceSessionCtx.getSessionInfo()); transportService.deregisterSession(deviceSessionCtx.getSessionInfo());
if (gatewaySessionHandler != null) { if (gatewaySessionHandler != null) {
@ -803,11 +781,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
} }
deviceSessionCtx.setDisconnected(); deviceSessionCtx.setDisconnected();
} }
deviceSessionCtx.release();
if (!deviceSessionCtx.getMsgQueue().isEmpty()) {
log.warn("doDisconnect for device {} but unprocessed messages {} left in the msg queue", deviceSessionCtx.getDeviceId(), deviceSessionCtx.getMsgQueue().size());
deviceSessionCtx.getMsgQueue().clear();
}
} }
@ -866,7 +840,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override @Override
public void onRemoteSessionCloseCommand(UUID sessionId, TransportProtos.SessionCloseNotificationProto sessionCloseNotification) { public void onRemoteSessionCloseCommand(UUID sessionId, TransportProtos.SessionCloseNotificationProto sessionCloseNotification) {
log.trace("[{}] Received the remote command to close the session: {}", sessionId, sessionCloseNotification.getMessage()); log.trace("[{}] Received the remote command to close the session: {}", sessionId, sessionCloseNotification.getMessage());
processDisconnect(deviceSessionCtx.getChannel()); deviceSessionCtx.getChannel().close();
} }
@Override @Override

View File

@ -19,6 +19,7 @@ import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage; import com.google.protobuf.DynamicMessage;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.mqtt.MqttMessage; import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.util.ReferenceCountUtil;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -35,12 +36,16 @@ import org.thingsboard.server.transport.mqtt.adaptors.MqttTransportAdaptor;
import org.thingsboard.server.transport.mqtt.util.MqttTopicFilter; import org.thingsboard.server.transport.mqtt.util.MqttTopicFilter;
import org.thingsboard.server.transport.mqtt.util.MqttTopicFilterFactory; import org.thingsboard.server.transport.mqtt.util.MqttTopicFilterFactory;
import java.util.Collection;
import java.util.Collections;
import java.util.Queue;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
/** /**
* @author Andrew Shvayka * @author Andrew Shvayka
@ -57,13 +62,11 @@ public class DeviceSessionCtx extends MqttDeviceAwareSessionContext {
private final AtomicInteger msgIdSeq = new AtomicInteger(0); private final AtomicInteger msgIdSeq = new AtomicInteger(0);
@Getter
private final ConcurrentLinkedQueue<MqttMessage> msgQueue = new ConcurrentLinkedQueue<>(); private final ConcurrentLinkedQueue<MqttMessage> msgQueue = new ConcurrentLinkedQueue<>();
@Getter @Getter
private final Lock msgQueueProcessorLock = new ReentrantLock(); private final Lock msgQueueProcessorLock = new ReentrantLock();
@Getter
private final AtomicInteger msgQueueSize = new AtomicInteger(0); private final AtomicInteger msgQueueSize = new AtomicInteger(0);
@Getter @Getter
@ -160,4 +163,49 @@ public class DeviceSessionCtx extends MqttDeviceAwareSessionContext {
rpcResponseDynamicMessageDescriptor = protoTransportPayloadConfig.getRpcResponseDynamicMessageDescriptor(protoTransportPayloadConfig.getDeviceRpcResponseProtoSchema()); rpcResponseDynamicMessageDescriptor = protoTransportPayloadConfig.getRpcResponseDynamicMessageDescriptor(protoTransportPayloadConfig.getDeviceRpcResponseProtoSchema());
rpcRequestDynamicMessageBuilder = protoTransportPayloadConfig.getRpcRequestDynamicMessageBuilder(protoTransportPayloadConfig.getDeviceRpcRequestProtoSchema()); rpcRequestDynamicMessageBuilder = protoTransportPayloadConfig.getRpcRequestDynamicMessageBuilder(protoTransportPayloadConfig.getDeviceRpcRequestProtoSchema());
} }
public void addToQueue(MqttMessage msg) {
msgQueueSize.incrementAndGet();
ReferenceCountUtil.retain(msg);
msgQueue.add(msg);
}
public void tryProcessQueuedMsgs(Consumer<MqttMessage> msgProcessor) {
while (!msgQueue.isEmpty()) {
if (msgQueueProcessorLock.tryLock()) {
try {
MqttMessage msg;
while ((msg = msgQueue.poll()) != null) {
try {
msgQueueSize.decrementAndGet();
msgProcessor.accept(msg);
} finally {
ReferenceCountUtil.safeRelease(msg);
}
}
} finally {
msgQueueProcessorLock.unlock();
}
} else {
return;
}
}
}
public int getMsgQueueSize() {
return msgQueueSize.get();
}
public void release() {
if (!msgQueue.isEmpty()) {
log.warn("doDisconnect for device {} but unprocessed messages {} left in the msg queue", getDeviceId(), msgQueue.size());
msgQueue.forEach(ReferenceCountUtil::safeRelease);
msgQueue.clear();
}
}
public Collection<MqttMessage> getMsgQueueSnapshot(){
return Collections.unmodifiableCollection(msgQueue);
}
} }

View File

@ -111,18 +111,6 @@ public class MqttTransportHandlerTest {
return new MqttPublishMessage(mqttFixedHeader, variableHeader, payload); return new MqttPublishMessage(mqttFixedHeader, variableHeader, payload);
} }
@Test
public void givenMessageWithoutFixedHeader_whenProcessMqttMsg_thenProcessDisconnect() {
MqttFixedHeader mqttFixedHeader = null;
MqttMessage msg = new MqttMessage(mqttFixedHeader);
willDoNothing().given(handler).processDisconnect(ctx);
handler.processMqttMsg(ctx, msg);
assertThat(handler.address, is(IP_ADDR));
verify(handler, times(1)).processDisconnect(ctx);
}
@Test @Test
public void givenMqttConnectMessage_whenProcessMqttMsg_thenProcessConnect() { public void givenMqttConnectMessage_whenProcessMqttMsg_thenProcessConnect() {
MqttConnectMessage msg = getMqttConnectMessage(); MqttConnectMessage msg = getMqttConnectMessage();
@ -132,7 +120,7 @@ public class MqttTransportHandlerTest {
assertThat(handler.address, is(IP_ADDR)); assertThat(handler.address, is(IP_ADDR));
assertThat(handler.deviceSessionCtx.getChannel(), is(ctx)); assertThat(handler.deviceSessionCtx.getChannel(), is(ctx));
verify(handler, never()).processDisconnect(any()); verify(handler, never()).doDisconnect();
verify(handler, times(1)).processConnect(ctx, msg); verify(handler, times(1)).processConnect(ctx, msg);
} }
@ -140,8 +128,8 @@ public class MqttTransportHandlerTest {
public void givenQueueLimit_whenEnqueueRegularSessionMsgOverLimit_thenOK() { public void givenQueueLimit_whenEnqueueRegularSessionMsgOverLimit_thenOK() {
List<MqttPublishMessage> messages = Stream.generate(this::getMqttPublishMessage).limit(MSG_QUEUE_LIMIT).collect(Collectors.toList()); List<MqttPublishMessage> messages = Stream.generate(this::getMqttPublishMessage).limit(MSG_QUEUE_LIMIT).collect(Collectors.toList());
messages.forEach(msg -> handler.enqueueRegularSessionMsg(ctx, msg)); messages.forEach(msg -> handler.enqueueRegularSessionMsg(ctx, msg));
assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(MSG_QUEUE_LIMIT)); assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(MSG_QUEUE_LIMIT));
assertThat(handler.deviceSessionCtx.getMsgQueue(), contains(messages.toArray())); assertThat(handler.deviceSessionCtx.getMsgQueueSnapshot(), contains(messages.toArray()));
} }
@Test @Test
@ -152,7 +140,7 @@ public class MqttTransportHandlerTest {
messages.forEach((msg) -> handler.enqueueRegularSessionMsg(ctx, msg)); messages.forEach((msg) -> handler.enqueueRegularSessionMsg(ctx, msg));
assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(limit)); assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(MSG_QUEUE_LIMIT));
verify(handler, times(limit)).enqueueRegularSessionMsg(any(), any()); verify(handler, times(limit)).enqueueRegularSessionMsg(any(), any());
verify(handler, times(MSG_QUEUE_LIMIT)).processMsgQueue(any()); verify(handler, times(MSG_QUEUE_LIMIT)).processMsgQueue(any());
verify(ctx, times(1)).close(); verify(ctx, times(1)).close();
@ -169,9 +157,9 @@ public class MqttTransportHandlerTest {
assertThat(handler.address, is(IP_ADDR)); assertThat(handler.address, is(IP_ADDR));
assertThat(handler.deviceSessionCtx.getChannel(), is(ctx)); assertThat(handler.deviceSessionCtx.getChannel(), is(ctx));
assertThat(handler.deviceSessionCtx.isConnected(), is(false)); assertThat(handler.deviceSessionCtx.isConnected(), is(false));
assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(MSG_QUEUE_LIMIT)); assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(MSG_QUEUE_LIMIT));
assertThat(handler.deviceSessionCtx.getMsgQueue(), contains(messages.toArray())); assertThat(handler.deviceSessionCtx.getMsgQueueSnapshot(), contains(messages.toArray()));
verify(handler, never()).processDisconnect(any()); verify(handler, never()).doDisconnect();
verify(handler, times(1)).processConnect(any(), any()); verify(handler, times(1)).processConnect(any(), any());
verify(handler, times(MSG_QUEUE_LIMIT)).enqueueRegularSessionMsg(any(), any()); verify(handler, times(MSG_QUEUE_LIMIT)).enqueueRegularSessionMsg(any(), any());
verify(handler, never()).processRegularSessionMsg(any(), any()); verify(handler, never()).processRegularSessionMsg(any(), any());
@ -212,8 +200,8 @@ public class MqttTransportHandlerTest {
assertThat(finishLatch.await(TIMEOUT, TimeUnit.SECONDS), is(true)); assertThat(finishLatch.await(TIMEOUT, TimeUnit.SECONDS), is(true));
//then //then
assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(0)); assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(0));
assertThat(handler.deviceSessionCtx.getMsgQueue(), empty()); assertThat(handler.deviceSessionCtx.getMsgQueueSnapshot(), empty());
verify(handler, times(MSG_QUEUE_LIMIT)).processRegularSessionMsg(any(), any()); verify(handler, times(MSG_QUEUE_LIMIT)).processRegularSessionMsg(any(), any());
messages.forEach((msg) -> verify(handler, times(1)).processRegularSessionMsg(ctx, msg)); messages.forEach((msg) -> verify(handler, times(1)).processRegularSessionMsg(ctx, msg));
} }