MQTT Rate Limits Draft

This commit is contained in:
Andrew Shvayka 2022-01-13 08:42:48 +02:00
parent db6c1075eb
commit d8097d2b76
11 changed files with 184 additions and 6 deletions

View File

@ -619,6 +619,13 @@ transport:
log:
enabled: "${TB_TRANSPORT_LOG_ENABLED:true}"
max_length: "${TB_TRANSPORT_LOG_MAX_LENGTH:1024}"
rate_limits:
# Maximum number of simultaneous connections from a single ip address
max_connections_per_ip: "${TB_TRANSPORT_MAX_CONNECTIONS_PER_IP:50}"
# Maximum number of connect attempts with invalid credentials
max_wrong_credentials_per_ip: "${TB_TRANSPORT_MAX_WRONG_CREDENTIALS_PER_IP:10}"
# Timeout to expire block IP addresses
ip_block_timeout: "${TB_TRANSPORT_IP_BLOCK_TIMEOUT:10000}"
# Local HTTP transport parameters
http:
enabled: "${HTTP_ENABLED:true}"
@ -630,6 +637,9 @@ transport:
enabled: "${MQTT_ENABLED:true}"
bind_address: "${MQTT_BIND_ADDRESS:0.0.0.0}"
bind_port: "${MQTT_BIND_PORT:1883}"
# Enable proxy protocol support. Disabled by default. If enabled, supports both v1 and v2.
# Useful to get the real IP address of the client in the logs and for rate limits.
proxy_enabled: "${MQTT_PROXY_PROTOCOL_ENABLED:false}"
timeout: "${MQTT_TIMEOUT:10000}"
msg_queue_size_per_device_limit: "${MQTT_MSG_QUEUE_SIZE_PER_DEVICE_LIMIT:100}" # messages await in the queue before device connected state. This limit works on low level before TenantProfileLimits mechanism
netty:

View File

@ -30,6 +30,7 @@ import org.thingsboard.server.transport.mqtt.adaptors.ProtoMqttAdaptor;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.net.InetSocketAddress;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
@ -73,6 +74,10 @@ public class MqttTransportContext extends TransportContext {
@Value("${transport.mqtt.timeout:10000}")
private long timeout;
@Getter
@Value("${transport.mqtt.proxy_enabled:false}")
private boolean proxyEnabled;
private final AtomicInteger connectionsCounter = new AtomicInteger();
@PostConstruct
@ -88,4 +93,13 @@ public class MqttTransportContext extends TransportContext {
public void channelUnregistered() {
connectionsCounter.decrementAndGet();
}
public boolean checkAddress(InetSocketAddress address){
return rateLimitService.checkAddress(address);
}
public void onAuthFailed(InetSocketAddress address){
rateLimitService.onAuthFailed(address);
}
}

View File

@ -20,6 +20,7 @@ import com.google.gson.JsonParseException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.mqtt.MqttConnAckMessage;
import io.netty.handler.codec.mqtt.MqttConnAckVariableHeader;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
@ -164,6 +165,9 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
log.trace("[{}] Processing msg: {}", sessionId, msg);
if (address == null) {
address = getAddress(ctx);
}
try {
if (msg instanceof MqttMessage) {
MqttMessage message = (MqttMessage) msg;
@ -182,8 +186,11 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
}
}
InetSocketAddress getAddress(ChannelHandlerContext ctx) {
return ctx.channel().attr(MqttTransportService.ADDRESS).get();
}
void processMqttMsg(ChannelHandlerContext ctx, MqttMessage msg) {
address = getAddress(ctx);
if (msg.fixedHeader() == null) {
log.info("[{}:{}] Invalid message received", address.getHostName(), address.getPort());
ctx.close();
@ -199,10 +206,6 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
}
}
InetSocketAddress getAddress(ChannelHandlerContext ctx) {
return (InetSocketAddress) ctx.channel().remoteAddress();
}
private void processProvisionSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) {
switch (msg.fixedHeader().messageType()) {
case PUBLISH:
@ -771,7 +774,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
private void processAuthTokenConnect(ChannelHandlerContext ctx, MqttConnectMessage connectMessage) {
String userName = connectMessage.payload().userName();
log.debug("[{}] Processing connect msg for client with user name: {}!", sessionId, userName);
log.debug("[{}][{}] Processing connect msg for client with user name: {}!", address, sessionId, userName);
TransportProtos.ValidateBasicMqttCredRequestMsg.Builder request = TransportProtos.ValidateBasicMqttCredRequestMsg.newBuilder()
.setClientId(connectMessage.payload().clientIdentifier());
if (userName != null) {
@ -820,6 +823,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
}
});
} catch (Exception e) {
context.onAuthFailed(address);
ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage));
log.trace("[{}] X509 auth failure: {}", sessionId, address, e);
ctx.close();
@ -931,6 +935,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
private void onValidateDeviceResponse(ValidateDeviceCredentialsResponse msg, ChannelHandlerContext ctx, MqttConnectMessage connectMessage) {
if (!msg.hasDeviceInfo()) {
context.onAuthFailed(address);
ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage));
ctx.close();
} else {

View File

@ -18,9 +18,12 @@ package org.thingsboard.server.transport.mqtt;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
import io.netty.handler.codec.mqtt.MqttDecoder;
import io.netty.handler.codec.mqtt.MqttEncoder;
import io.netty.handler.ssl.SslHandler;
import org.thingsboard.server.transport.mqtt.limits.IpFilter;
import org.thingsboard.server.transport.mqtt.limits.ProxyIpFilter;
/**
* @author Andrew Shvayka
@ -39,6 +42,12 @@ public class MqttTransportServerInitializer extends ChannelInitializer<SocketCha
public void initChannel(SocketChannel ch) {
ChannelPipeline pipeline = ch.pipeline();
SslHandler sslHandler = null;
if (context.isProxyEnabled()) {
pipeline.addLast("proxy", new HAProxyMessageDecoder());
pipeline.addLast("ipFilter", new ProxyIpFilter(context));
} else {
pipeline.addLast("ipFilter", new IpFilter(context));
}
if (sslEnabled && context.getSslHandlerProvider() != null) {
sslHandler = context.getSslHandlerProvider().getSslHandler();
pipeline.addLast(sslHandler);

View File

@ -21,6 +21,7 @@ import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.AttributeKey;
import io.netty.util.ResourceLeakDetector;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
@ -33,6 +34,8 @@ import org.thingsboard.server.common.data.TbTransportService;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.net.InetAddress;
import java.net.InetSocketAddress;
/**
* @author Andrew Shvayka
@ -42,6 +45,8 @@ import javax.annotation.PreDestroy;
@Slf4j
public class MqttTransportService implements TbTransportService {
public static AttributeKey<InetSocketAddress> ADDRESS = AttributeKey.newInstance("SRC_ADDRESS");
@Value("${transport.mqtt.bind_address}")
private String host;
@Value("${transport.mqtt.bind_port}")

View File

@ -0,0 +1,46 @@
/**
* Copyright © 2016-2021 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.mqtt.limits;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.ipfilter.AbstractRemoteAddressFilter;
import lombok.extern.slf4j.Slf4j;
import org.thingsboard.server.transport.mqtt.MqttTransportContext;
import org.thingsboard.server.transport.mqtt.MqttTransportService;
import java.net.InetSocketAddress;
@Slf4j
public class IpFilter extends AbstractRemoteAddressFilter<InetSocketAddress> {
private MqttTransportContext context;
public IpFilter(MqttTransportContext context) {
this.context = context;
}
@Override
protected boolean accept(ChannelHandlerContext ctx, InetSocketAddress remoteAddress) throws Exception {
if(context.checkAddress(remoteAddress)){
ctx.channel().attr(MqttTransportService.ADDRESS).set(remoteAddress);
return true;
} else {
return false;
}
}
}

View File

@ -0,0 +1,60 @@
/**
* Copyright © 2016-2021 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.mqtt.limits;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.util.AttributeKey;
import lombok.extern.slf4j.Slf4j;
import org.thingsboard.server.transport.mqtt.MqttTransportContext;
import org.thingsboard.server.transport.mqtt.MqttTransportService;
import java.net.InetAddress;
import java.net.InetSocketAddress;
@Slf4j
public class ProxyIpFilter extends ChannelInboundHandlerAdapter {
private MqttTransportContext context;
public ProxyIpFilter(MqttTransportContext context) {
this.context = context;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if(msg instanceof HAProxyMessage){
HAProxyMessage proxyMsg = (HAProxyMessage) msg;
if(proxyMsg.sourceAddress() != null && proxyMsg.sourcePort() > 0) {
InetSocketAddress address = new InetSocketAddress(proxyMsg.sourceAddress(), proxyMsg.sourcePort());
if(!context.checkAddress(address)){
ctx.close();
} else {
ctx.channel().attr(MqttTransportService.ADDRESS).set(address);
// We no longer need this channel in the pipeline. Similar to HAProxyMessageDecoder
ctx.pipeline().remove(this);
}
} else {
log.debug("Received local health-check connection message: {}", proxyMsg);
ctx.close();
}
} else {
super.channelRead(ctx, msg);
}
}
}

View File

@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.thingsboard.common.util.ThingsBoardExecutors;
import org.thingsboard.server.cache.ota.OtaPackageDataCache;
import org.thingsboard.server.common.transport.limits.TransportRateLimitService;
import org.thingsboard.server.queue.discovery.TbServiceInfoProvider;
import org.thingsboard.server.queue.scheduler.SchedulerComponent;
@ -57,6 +58,9 @@ public abstract class TransportContext {
@Autowired
private TransportResourceCache transportResourceCache;
@Autowired
protected TransportRateLimitService rateLimitService;
@PostConstruct
public void init() {
executor = ThingsBoardExecutors.newWorkStealingPool(50, getClass());
@ -73,4 +77,6 @@ public abstract class TransportContext {
return serviceInfoProvider.getServiceId();
}
}

View File

@ -30,6 +30,8 @@ import org.thingsboard.server.common.transport.TransportTenantProfileCache;
import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult;
import org.thingsboard.server.queue.util.TbTransportComponent;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
@ -116,6 +118,18 @@ public class DefaultTransportRateLimitService implements TransportRateLimitServi
tenantAllowed.put(tenantId, allowed);
}
private Set<InetAddress> blockedAddresses = new HashSet<>();
@Override
public boolean checkAddress(InetSocketAddress address) {
return !blockedAddresses.contains(address.getAddress());
}
@Override
public void onAuthFailed(InetSocketAddress address) {
blockedAddresses.add(address.getAddress());
}
private <T extends EntityId> void mergeLimits(T entityId, EntityTransportRateLimits newRateLimits,
Function<T, EntityTransportRateLimits> getFunction,
BiConsumer<T, EntityTransportRateLimits> putFunction) {

View File

@ -20,6 +20,8 @@ import org.thingsboard.server.common.data.id.DeviceId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult;
import java.net.InetSocketAddress;
public interface TransportRateLimitService {
EntityType checkLimits(TenantId tenantId, DeviceId deviceId, int dataPoints);
@ -33,4 +35,8 @@ public interface TransportRateLimitService {
void remove(DeviceId deviceId);
void update(TenantId tenantId, boolean transportEnabled);
boolean checkAddress(InetSocketAddress address);
void onAuthFailed(InetSocketAddress address);
}

View File

@ -88,6 +88,9 @@ transport:
mqtt:
bind_address: "${MQTT_BIND_ADDRESS:0.0.0.0}"
bind_port: "${MQTT_BIND_PORT:1883}"
# Enable proxy protocol support. Disabled by default. If enabled, supports both v1 and v2.
# Useful to get the real IP address of the client in the logs and for rate limits.
proxy_enabled: "${MQTT_PROXY_PROTOCOL_ENABLED:false}"
timeout: "${MQTT_TIMEOUT:10000}"
msg_queue_size_per_device_limit: "${MQTT_MSG_QUEUE_SIZE_PER_DEVICE_LIMIT:100}" # messages await in the queue before device connected state. This limit works on low level before TenantProfileLimits mechanism
netty: