MQTT Rate Limits Draft

This commit is contained in:
Andrew Shvayka 2022-01-13 08:42:48 +02:00
parent db6c1075eb
commit d8097d2b76
11 changed files with 184 additions and 6 deletions

View File

@ -619,6 +619,13 @@ transport:
log: log:
enabled: "${TB_TRANSPORT_LOG_ENABLED:true}" enabled: "${TB_TRANSPORT_LOG_ENABLED:true}"
max_length: "${TB_TRANSPORT_LOG_MAX_LENGTH:1024}" max_length: "${TB_TRANSPORT_LOG_MAX_LENGTH:1024}"
rate_limits:
# Maximum number of simultaneous connections from a single ip address
max_connections_per_ip: "${TB_TRANSPORT_MAX_CONNECTIONS_PER_IP:50}"
# Maximum number of connect attempts with invalid credentials
max_wrong_credentials_per_ip: "${TB_TRANSPORT_MAX_WRONG_CREDENTIALS_PER_IP:10}"
# Timeout to expire block IP addresses
ip_block_timeout: "${TB_TRANSPORT_IP_BLOCK_TIMEOUT:10000}"
# Local HTTP transport parameters # Local HTTP transport parameters
http: http:
enabled: "${HTTP_ENABLED:true}" enabled: "${HTTP_ENABLED:true}"
@ -630,6 +637,9 @@ transport:
enabled: "${MQTT_ENABLED:true}" enabled: "${MQTT_ENABLED:true}"
bind_address: "${MQTT_BIND_ADDRESS:0.0.0.0}" bind_address: "${MQTT_BIND_ADDRESS:0.0.0.0}"
bind_port: "${MQTT_BIND_PORT:1883}" bind_port: "${MQTT_BIND_PORT:1883}"
# Enable proxy protocol support. Disabled by default. If enabled, supports both v1 and v2.
# Useful to get the real IP address of the client in the logs and for rate limits.
proxy_enabled: "${MQTT_PROXY_PROTOCOL_ENABLED:false}"
timeout: "${MQTT_TIMEOUT:10000}" timeout: "${MQTT_TIMEOUT:10000}"
msg_queue_size_per_device_limit: "${MQTT_MSG_QUEUE_SIZE_PER_DEVICE_LIMIT:100}" # messages await in the queue before device connected state. This limit works on low level before TenantProfileLimits mechanism msg_queue_size_per_device_limit: "${MQTT_MSG_QUEUE_SIZE_PER_DEVICE_LIMIT:100}" # messages await in the queue before device connected state. This limit works on low level before TenantProfileLimits mechanism
netty: netty:

View File

@ -30,6 +30,7 @@ import org.thingsboard.server.transport.mqtt.adaptors.ProtoMqttAdaptor;
import javax.annotation.PostConstruct; import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy; import javax.annotation.PreDestroy;
import java.net.InetSocketAddress;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -73,6 +74,10 @@ public class MqttTransportContext extends TransportContext {
@Value("${transport.mqtt.timeout:10000}") @Value("${transport.mqtt.timeout:10000}")
private long timeout; private long timeout;
@Getter
@Value("${transport.mqtt.proxy_enabled:false}")
private boolean proxyEnabled;
private final AtomicInteger connectionsCounter = new AtomicInteger(); private final AtomicInteger connectionsCounter = new AtomicInteger();
@PostConstruct @PostConstruct
@ -88,4 +93,13 @@ public class MqttTransportContext extends TransportContext {
public void channelUnregistered() { public void channelUnregistered() {
connectionsCounter.decrementAndGet(); connectionsCounter.decrementAndGet();
} }
public boolean checkAddress(InetSocketAddress address){
return rateLimitService.checkAddress(address);
}
public void onAuthFailed(InetSocketAddress address){
rateLimitService.onAuthFailed(address);
}
} }

View File

@ -20,6 +20,7 @@ import com.google.gson.JsonParseException;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.mqtt.MqttConnAckMessage; import io.netty.handler.codec.mqtt.MqttConnAckMessage;
import io.netty.handler.codec.mqtt.MqttConnAckVariableHeader; import io.netty.handler.codec.mqtt.MqttConnAckVariableHeader;
import io.netty.handler.codec.mqtt.MqttConnectMessage; import io.netty.handler.codec.mqtt.MqttConnectMessage;
@ -164,6 +165,9 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) { public void channelRead(ChannelHandlerContext ctx, Object msg) {
log.trace("[{}] Processing msg: {}", sessionId, msg); log.trace("[{}] Processing msg: {}", sessionId, msg);
if (address == null) {
address = getAddress(ctx);
}
try { try {
if (msg instanceof MqttMessage) { if (msg instanceof MqttMessage) {
MqttMessage message = (MqttMessage) msg; MqttMessage message = (MqttMessage) msg;
@ -182,8 +186,11 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
} }
} }
InetSocketAddress getAddress(ChannelHandlerContext ctx) {
return ctx.channel().attr(MqttTransportService.ADDRESS).get();
}
void processMqttMsg(ChannelHandlerContext ctx, MqttMessage msg) { void processMqttMsg(ChannelHandlerContext ctx, MqttMessage msg) {
address = getAddress(ctx);
if (msg.fixedHeader() == null) { if (msg.fixedHeader() == null) {
log.info("[{}:{}] Invalid message received", address.getHostName(), address.getPort()); log.info("[{}:{}] Invalid message received", address.getHostName(), address.getPort());
ctx.close(); ctx.close();
@ -199,10 +206,6 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
} }
} }
InetSocketAddress getAddress(ChannelHandlerContext ctx) {
return (InetSocketAddress) ctx.channel().remoteAddress();
}
private void processProvisionSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) { private void processProvisionSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) {
switch (msg.fixedHeader().messageType()) { switch (msg.fixedHeader().messageType()) {
case PUBLISH: case PUBLISH:
@ -771,7 +774,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
private void processAuthTokenConnect(ChannelHandlerContext ctx, MqttConnectMessage connectMessage) { private void processAuthTokenConnect(ChannelHandlerContext ctx, MqttConnectMessage connectMessage) {
String userName = connectMessage.payload().userName(); String userName = connectMessage.payload().userName();
log.debug("[{}] Processing connect msg for client with user name: {}!", sessionId, userName); log.debug("[{}][{}] Processing connect msg for client with user name: {}!", address, sessionId, userName);
TransportProtos.ValidateBasicMqttCredRequestMsg.Builder request = TransportProtos.ValidateBasicMqttCredRequestMsg.newBuilder() TransportProtos.ValidateBasicMqttCredRequestMsg.Builder request = TransportProtos.ValidateBasicMqttCredRequestMsg.newBuilder()
.setClientId(connectMessage.payload().clientIdentifier()); .setClientId(connectMessage.payload().clientIdentifier());
if (userName != null) { if (userName != null) {
@ -820,6 +823,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
} }
}); });
} catch (Exception e) { } catch (Exception e) {
context.onAuthFailed(address);
ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage)); ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage));
log.trace("[{}] X509 auth failure: {}", sessionId, address, e); log.trace("[{}] X509 auth failure: {}", sessionId, address, e);
ctx.close(); ctx.close();
@ -931,6 +935,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
private void onValidateDeviceResponse(ValidateDeviceCredentialsResponse msg, ChannelHandlerContext ctx, MqttConnectMessage connectMessage) { private void onValidateDeviceResponse(ValidateDeviceCredentialsResponse msg, ChannelHandlerContext ctx, MqttConnectMessage connectMessage) {
if (!msg.hasDeviceInfo()) { if (!msg.hasDeviceInfo()) {
context.onAuthFailed(address);
ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage)); ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage));
ctx.close(); ctx.close();
} else { } else {

View File

@ -18,9 +18,12 @@ package org.thingsboard.server.transport.mqtt;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
import io.netty.handler.codec.mqtt.MqttDecoder; import io.netty.handler.codec.mqtt.MqttDecoder;
import io.netty.handler.codec.mqtt.MqttEncoder; import io.netty.handler.codec.mqtt.MqttEncoder;
import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandler;
import org.thingsboard.server.transport.mqtt.limits.IpFilter;
import org.thingsboard.server.transport.mqtt.limits.ProxyIpFilter;
/** /**
* @author Andrew Shvayka * @author Andrew Shvayka
@ -39,6 +42,12 @@ public class MqttTransportServerInitializer extends ChannelInitializer<SocketCha
public void initChannel(SocketChannel ch) { public void initChannel(SocketChannel ch) {
ChannelPipeline pipeline = ch.pipeline(); ChannelPipeline pipeline = ch.pipeline();
SslHandler sslHandler = null; SslHandler sslHandler = null;
if (context.isProxyEnabled()) {
pipeline.addLast("proxy", new HAProxyMessageDecoder());
pipeline.addLast("ipFilter", new ProxyIpFilter(context));
} else {
pipeline.addLast("ipFilter", new IpFilter(context));
}
if (sslEnabled && context.getSslHandlerProvider() != null) { if (sslEnabled && context.getSslHandlerProvider() != null) {
sslHandler = context.getSslHandlerProvider().getSslHandler(); sslHandler = context.getSslHandlerProvider().getSslHandler();
pipeline.addLast(sslHandler); pipeline.addLast(sslHandler);

View File

@ -21,6 +21,7 @@ import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.AttributeKey;
import io.netty.util.ResourceLeakDetector; import io.netty.util.ResourceLeakDetector;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@ -33,6 +34,8 @@ import org.thingsboard.server.common.data.TbTransportService;
import javax.annotation.PostConstruct; import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy; import javax.annotation.PreDestroy;
import java.net.InetAddress;
import java.net.InetSocketAddress;
/** /**
* @author Andrew Shvayka * @author Andrew Shvayka
@ -42,6 +45,8 @@ import javax.annotation.PreDestroy;
@Slf4j @Slf4j
public class MqttTransportService implements TbTransportService { public class MqttTransportService implements TbTransportService {
public static AttributeKey<InetSocketAddress> ADDRESS = AttributeKey.newInstance("SRC_ADDRESS");
@Value("${transport.mqtt.bind_address}") @Value("${transport.mqtt.bind_address}")
private String host; private String host;
@Value("${transport.mqtt.bind_port}") @Value("${transport.mqtt.bind_port}")

View File

@ -0,0 +1,46 @@
/**
* Copyright © 2016-2021 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.mqtt.limits;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.ipfilter.AbstractRemoteAddressFilter;
import lombok.extern.slf4j.Slf4j;
import org.thingsboard.server.transport.mqtt.MqttTransportContext;
import org.thingsboard.server.transport.mqtt.MqttTransportService;
import java.net.InetSocketAddress;
@Slf4j
public class IpFilter extends AbstractRemoteAddressFilter<InetSocketAddress> {
private MqttTransportContext context;
public IpFilter(MqttTransportContext context) {
this.context = context;
}
@Override
protected boolean accept(ChannelHandlerContext ctx, InetSocketAddress remoteAddress) throws Exception {
if(context.checkAddress(remoteAddress)){
ctx.channel().attr(MqttTransportService.ADDRESS).set(remoteAddress);
return true;
} else {
return false;
}
}
}

View File

@ -0,0 +1,60 @@
/**
* Copyright © 2016-2021 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.mqtt.limits;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.util.AttributeKey;
import lombok.extern.slf4j.Slf4j;
import org.thingsboard.server.transport.mqtt.MqttTransportContext;
import org.thingsboard.server.transport.mqtt.MqttTransportService;
import java.net.InetAddress;
import java.net.InetSocketAddress;
@Slf4j
public class ProxyIpFilter extends ChannelInboundHandlerAdapter {
private MqttTransportContext context;
public ProxyIpFilter(MqttTransportContext context) {
this.context = context;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if(msg instanceof HAProxyMessage){
HAProxyMessage proxyMsg = (HAProxyMessage) msg;
if(proxyMsg.sourceAddress() != null && proxyMsg.sourcePort() > 0) {
InetSocketAddress address = new InetSocketAddress(proxyMsg.sourceAddress(), proxyMsg.sourcePort());
if(!context.checkAddress(address)){
ctx.close();
} else {
ctx.channel().attr(MqttTransportService.ADDRESS).set(address);
// We no longer need this channel in the pipeline. Similar to HAProxyMessageDecoder
ctx.pipeline().remove(this);
}
} else {
log.debug("Received local health-check connection message: {}", proxyMsg);
ctx.close();
}
} else {
super.channelRead(ctx, msg);
}
}
}

View File

@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.thingsboard.common.util.ThingsBoardExecutors; import org.thingsboard.common.util.ThingsBoardExecutors;
import org.thingsboard.server.cache.ota.OtaPackageDataCache; import org.thingsboard.server.cache.ota.OtaPackageDataCache;
import org.thingsboard.server.common.transport.limits.TransportRateLimitService;
import org.thingsboard.server.queue.discovery.TbServiceInfoProvider; import org.thingsboard.server.queue.discovery.TbServiceInfoProvider;
import org.thingsboard.server.queue.scheduler.SchedulerComponent; import org.thingsboard.server.queue.scheduler.SchedulerComponent;
@ -57,6 +58,9 @@ public abstract class TransportContext {
@Autowired @Autowired
private TransportResourceCache transportResourceCache; private TransportResourceCache transportResourceCache;
@Autowired
protected TransportRateLimitService rateLimitService;
@PostConstruct @PostConstruct
public void init() { public void init() {
executor = ThingsBoardExecutors.newWorkStealingPool(50, getClass()); executor = ThingsBoardExecutors.newWorkStealingPool(50, getClass());
@ -73,4 +77,6 @@ public abstract class TransportContext {
return serviceInfoProvider.getServiceId(); return serviceInfoProvider.getServiceId();
} }
} }

View File

@ -30,6 +30,8 @@ import org.thingsboard.server.common.transport.TransportTenantProfileCache;
import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult; import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult;
import org.thingsboard.server.queue.util.TbTransportComponent; import org.thingsboard.server.queue.util.TbTransportComponent;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
@ -116,6 +118,18 @@ public class DefaultTransportRateLimitService implements TransportRateLimitServi
tenantAllowed.put(tenantId, allowed); tenantAllowed.put(tenantId, allowed);
} }
private Set<InetAddress> blockedAddresses = new HashSet<>();
@Override
public boolean checkAddress(InetSocketAddress address) {
return !blockedAddresses.contains(address.getAddress());
}
@Override
public void onAuthFailed(InetSocketAddress address) {
blockedAddresses.add(address.getAddress());
}
private <T extends EntityId> void mergeLimits(T entityId, EntityTransportRateLimits newRateLimits, private <T extends EntityId> void mergeLimits(T entityId, EntityTransportRateLimits newRateLimits,
Function<T, EntityTransportRateLimits> getFunction, Function<T, EntityTransportRateLimits> getFunction,
BiConsumer<T, EntityTransportRateLimits> putFunction) { BiConsumer<T, EntityTransportRateLimits> putFunction) {

View File

@ -20,6 +20,8 @@ import org.thingsboard.server.common.data.id.DeviceId;
import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult; import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult;
import java.net.InetSocketAddress;
public interface TransportRateLimitService { public interface TransportRateLimitService {
EntityType checkLimits(TenantId tenantId, DeviceId deviceId, int dataPoints); EntityType checkLimits(TenantId tenantId, DeviceId deviceId, int dataPoints);
@ -33,4 +35,8 @@ public interface TransportRateLimitService {
void remove(DeviceId deviceId); void remove(DeviceId deviceId);
void update(TenantId tenantId, boolean transportEnabled); void update(TenantId tenantId, boolean transportEnabled);
boolean checkAddress(InetSocketAddress address);
void onAuthFailed(InetSocketAddress address);
} }

View File

@ -88,6 +88,9 @@ transport:
mqtt: mqtt:
bind_address: "${MQTT_BIND_ADDRESS:0.0.0.0}" bind_address: "${MQTT_BIND_ADDRESS:0.0.0.0}"
bind_port: "${MQTT_BIND_PORT:1883}" bind_port: "${MQTT_BIND_PORT:1883}"
# Enable proxy protocol support. Disabled by default. If enabled, supports both v1 and v2.
# Useful to get the real IP address of the client in the logs and for rate limits.
proxy_enabled: "${MQTT_PROXY_PROTOCOL_ENABLED:false}"
timeout: "${MQTT_TIMEOUT:10000}" timeout: "${MQTT_TIMEOUT:10000}"
msg_queue_size_per_device_limit: "${MQTT_MSG_QUEUE_SIZE_PER_DEVICE_LIMIT:100}" # messages await in the queue before device connected state. This limit works on low level before TenantProfileLimits mechanism msg_queue_size_per_device_limit: "${MQTT_MSG_QUEUE_SIZE_PER_DEVICE_LIMIT:100}" # messages await in the queue before device connected state. This limit works on low level before TenantProfileLimits mechanism
netty: netty: