Merge pull request #10063 from Rhyaldir/add_dtls_cid_support

Support DTLS Connection ID with configuration
This commit is contained in:
Andrew Shvayka 2024-02-15 17:54:33 +02:00 committed by GitHub
commit 921159d262
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 721 additions and 5 deletions

View File

@ -1044,6 +1044,8 @@ transport:
dtls:
# RFC7925_RETRANSMISSION_TIMEOUT_IN_MILLISECONDS = 9000
retransmission_timeout: "${LWM2M_DTLS_RETRANSMISSION_TIMEOUT_MS:9000}"
# "" disables connection id support, 0 enables support but not for incoming traffic, any value greater than 0 set the connection id size in bytes
connection_id_length: "${LWM2M_DTLS_CONNECTION_ID_LENGTH:6}"
server:
# LwM2M Server ID
id: "${LWM2M_SERVER_ID:123}"

View File

@ -94,6 +94,11 @@
<artifactId>junit-vintage-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>

View File

@ -38,6 +38,7 @@ import javax.annotation.PreDestroy;
import java.security.cert.X509Certificate;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.eclipse.californium.scandium.config.DtlsConfig.DTLS_CONNECTION_ID_LENGTH;
import static org.eclipse.californium.scandium.config.DtlsConfig.DTLS_RECOMMENDED_CIPHER_SUITES_ONLY;
import static org.eclipse.californium.scandium.config.DtlsConfig.DTLS_RECOMMENDED_CURVES_ONLY;
import static org.eclipse.californium.scandium.config.DtlsConfig.DTLS_RETRANSMISSION_TIMEOUT;
@ -95,6 +96,7 @@ public class LwM2MTransportBootstrapService {
dtlsConfig.set(DTLS_RECOMMENDED_CURVES_ONLY, serverConfig.isRecommendedSupportedGroups());
dtlsConfig.set(DTLS_RECOMMENDED_CIPHER_SUITES_ONLY, serverConfig.isRecommendedCiphers());
dtlsConfig.set(DTLS_RETRANSMISSION_TIMEOUT, serverConfig.getDtlsRetransmissionTimeout(), MILLISECONDS);
dtlsConfig.set(DTLS_CONNECTION_ID_LENGTH, serverConfig.getDtlsConnectionIdLength());
dtlsConfig.set(DTLS_ROLE, SERVER_ONLY);
setServerWithCredentials(builder, dtlsConfig);

View File

@ -41,6 +41,10 @@ public class LwM2MTransportServerConfig implements LwM2MSecureServerConfig {
@Value("${transport.lwm2m.dtls.retransmission_timeout:9000}")
private int dtlsRetransmissionTimeout;
@Getter
@Value("${transport.lwm2m.dtls.connection_id_length:6}")
private Integer dtlsConnectionIdLength;
@Getter
@Value("${transport.lwm2m.timeout:}")
private Long timeout;

View File

@ -43,6 +43,7 @@ import javax.annotation.PreDestroy;
import java.security.cert.X509Certificate;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.eclipse.californium.scandium.config.DtlsConfig.DTLS_CONNECTION_ID_LENGTH;
import static org.eclipse.californium.scandium.config.DtlsConfig.DTLS_RECOMMENDED_CIPHER_SUITES_ONLY;
import static org.eclipse.californium.scandium.config.DtlsConfig.DTLS_RECOMMENDED_CURVES_ONLY;
import static org.eclipse.californium.scandium.config.DtlsConfig.DTLS_RETRANSMISSION_TIMEOUT;
@ -139,6 +140,7 @@ public class DefaultLwM2mTransportService implements LwM2MTransportService {
dtlsConfig.set(DTLS_RECOMMENDED_CURVES_ONLY, config.isRecommendedSupportedGroups());
dtlsConfig.set(DTLS_RECOMMENDED_CIPHER_SUITES_ONLY, config.isRecommendedCiphers());
dtlsConfig.set(DTLS_RETRANSMISSION_TIMEOUT, config.getDtlsRetransmissionTimeout(), MILLISECONDS);
dtlsConfig.set(DTLS_CONNECTION_ID_LENGTH, config.getDtlsConnectionIdLength());
dtlsConfig.set(DTLS_ROLE, SERVER_ONLY);
/* Create credentials */

View File

@ -29,7 +29,6 @@ import org.eclipse.leshan.core.util.NamedThreadFactory;
import org.eclipse.leshan.core.util.Validate;
import org.eclipse.leshan.server.californium.registration.CaliforniumRegistrationStore;
import org.eclipse.leshan.server.redis.RedisRegistrationStore;
import org.eclipse.leshan.server.redis.serialization.IdentitySerDes;
import org.eclipse.leshan.server.redis.serialization.ObservationSerDes;
import org.eclipse.leshan.server.redis.serialization.RegistrationSerDes;
import org.eclipse.leshan.server.registration.Deregistration;
@ -45,6 +44,7 @@ import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.Cursor;
import org.springframework.data.redis.core.ScanOptions;
import org.springframework.integration.redis.util.RedisLockRegistry;
import org.thingsboard.server.transport.lwm2m.server.store.util.LwM2MIdentitySerDes;
import java.net.InetSocketAddress;
import java.util.ArrayList;
@ -110,12 +110,18 @@ public class TbLwM2mRedisRegistrationStore implements CaliforniumRegistrationSto
public TbLwM2mRedisRegistrationStore(RedisConnectionFactory connectionFactory, ScheduledExecutorService schedExecutor, long cleanPeriodInSec,
long lifetimeGracePeriodInSec, int cleanLimit) {
this(connectionFactory, schedExecutor, cleanPeriodInSec, lifetimeGracePeriodInSec, cleanLimit,
new RedisLockRegistry(connectionFactory, "Registration"));
}
public TbLwM2mRedisRegistrationStore(RedisConnectionFactory connectionFactory, ScheduledExecutorService schedExecutor, long cleanPeriodInSec,
long lifetimeGracePeriodInSec, int cleanLimit, RedisLockRegistry lockRegistry) {
this.connectionFactory = connectionFactory;
this.schedExecutor = schedExecutor;
this.cleanPeriod = cleanPeriodInSec;
this.cleanLimit = cleanLimit;
this.gracePeriod = lifetimeGracePeriodInSec;
this.redisLock = new RedisLockRegistry(connectionFactory, "Registration");
this.redisLock = lockRegistry;
}
/* *************** Redis Key utility function **************** */
@ -173,7 +179,7 @@ public class TbLwM2mRedisRegistrationStore implements CaliforniumRegistrationSto
if (!oldRegistration.getSocketAddress().equals(registration.getSocketAddress())) {
removeAddrIndex(connection, oldRegistration);
}
if (!oldRegistration.getIdentity().equals(registration.getIdentity())) {
if (registrationsHaveDifferentIdentities(oldRegistration, registration)) {
removeIdentityIndex(connection, oldRegistration);
}
// remove old observation
@ -231,7 +237,7 @@ public class TbLwM2mRedisRegistrationStore implements CaliforniumRegistrationSto
if (!r.getSocketAddress().equals(updatedRegistration.getSocketAddress())) {
removeAddrIndex(connection, r);
}
if (!r.getIdentity().equals(updatedRegistration.getIdentity())) {
if (registrationsHaveDifferentIdentities(r, updatedRegistration)) {
removeIdentityIndex(connection, r);
}
@ -402,6 +408,12 @@ public class TbLwM2mRedisRegistrationStore implements CaliforniumRegistrationSto
connection.zRem(EXP_EP, registration.getEndpoint().getBytes(UTF_8));
}
private boolean registrationsHaveDifferentIdentities(Registration first, Registration second){
var first_identity_string = LwM2MIdentitySerDes.serialize(first.getIdentity()).toString();
var second_identity_string = LwM2MIdentitySerDes.serialize(second.getIdentity()).toString();
return !first_identity_string.equals(second_identity_string);
}
private byte[] toRegIdKey(String registrationId) {
return toKey(REG_EP_REGID_IDX, registrationId);
}
@ -411,7 +423,7 @@ public class TbLwM2mRedisRegistrationStore implements CaliforniumRegistrationSto
}
private byte[] toRegIdentityKey(Identity identity) {
return toKey(REG_EP_IDENTITY, IdentitySerDes.serialize(identity).toString());
return toKey(REG_EP_IDENTITY, LwM2MIdentitySerDes.serialize(identity).toString());
}
private byte[] toEndpointKey(String endpoint) {

View File

@ -0,0 +1,63 @@
/**
* Copyright © 2016-2024 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.lwm2m.server.store.util;
import com.eclipsesource.json.Json;
import com.eclipsesource.json.JsonObject;
import org.apache.commons.lang3.NotImplementedException;
import org.eclipse.leshan.core.request.Identity;
import org.eclipse.leshan.core.util.Hex;
import java.security.PublicKey;
public class LwM2MIdentitySerDes {
private static final String KEY_ADDRESS = "address";
private static final String KEY_PORT = "port";
private static final String KEY_ID = "id";
private static final String KEY_CN = "cn";
private static final String KEY_RPK = "rpk";
protected static final String KEY_LWM2MIDENTITY_TYPE = "type";
protected static final String LWM2MIDENTITY_TYPE_UNSECURE = "unsecure";
protected static final String LWM2MIDENTITY_TYPE_PSK = "psk";
protected static final String LWM2MIDENTITY_TYPE_X509 = "x509";
protected static final String LWM2MIDENTITY_TYPE_RPK = "rpk";
public static JsonObject serialize(Identity identity) {
JsonObject o = Json.object();
if (identity.isPSK()) {
o.set(KEY_LWM2MIDENTITY_TYPE, LWM2MIDENTITY_TYPE_PSK);
o.set(KEY_ID, identity.getPskIdentity());
} else if (identity.isRPK()) {
o.set(KEY_LWM2MIDENTITY_TYPE, LWM2MIDENTITY_TYPE_RPK);
PublicKey publicKey = identity.getRawPublicKey();
o.set(KEY_RPK, Hex.encodeHexString(publicKey.getEncoded()));
} else if (identity.isX509()) {
o.set(KEY_LWM2MIDENTITY_TYPE, LWM2MIDENTITY_TYPE_X509);
o.set(KEY_CN, identity.getX509CommonName());
} else {
o.set(KEY_LWM2MIDENTITY_TYPE, LWM2MIDENTITY_TYPE_UNSECURE);
o.set(KEY_ADDRESS, identity.getPeerAddress().getHostString());
o.set(KEY_PORT, identity.getPeerAddress().getPort());
}
return o;
}
public static Identity deserialize(JsonObject peer) {
throw new NotImplementedException();
}
}

View File

@ -0,0 +1,105 @@
/**
* Copyright © 2016-2024 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.lwm2m.bootstrap;
import org.eclipse.californium.core.network.CoapEndpoint;
import org.eclipse.californium.scandium.config.DtlsConnectorConfig;
import org.eclipse.leshan.server.californium.LeshanServer;
import org.eclipse.leshan.server.californium.bootstrap.LeshanBootstrapServer;
import org.eclipse.leshan.server.californium.registration.CaliforniumRegistrationStore;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.util.ReflectionTestUtils;
import org.thingsboard.server.cache.ota.OtaPackageDataCache;
import org.thingsboard.server.common.transport.TransportService;
import org.thingsboard.server.transport.lwm2m.bootstrap.secure.TbLwM2MDtlsBootstrapCertificateVerifier;
import org.thingsboard.server.transport.lwm2m.bootstrap.store.LwM2MBootstrapSecurityStore;
import org.thingsboard.server.transport.lwm2m.bootstrap.store.LwM2MInMemoryBootstrapConfigStore;
import org.thingsboard.server.transport.lwm2m.config.LwM2MTransportBootstrapConfig;
import org.thingsboard.server.transport.lwm2m.config.LwM2MTransportServerConfig;
import org.thingsboard.server.transport.lwm2m.secure.TbLwM2MAuthorizer;
import org.thingsboard.server.transport.lwm2m.secure.TbLwM2MDtlsCertificateVerifier;
import org.thingsboard.server.transport.lwm2m.server.store.TbSecurityStore;
import org.thingsboard.server.transport.lwm2m.server.uplink.LwM2mUplinkMsgHandler;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.when;
@ExtendWith(MockitoExtension.class)
public class LwM2MTransportBootstrapServiceTest {
@Mock
private LwM2MTransportServerConfig serverConfig;
@Mock
private LwM2MTransportBootstrapConfig bootstrapConfig;
@Mock
private LwM2MBootstrapSecurityStore lwM2MBootstrapSecurityStore;
@Mock
private LwM2MInMemoryBootstrapConfigStore lwM2MInMemoryBootstrapConfigStore;
@Mock
private TransportService transportService;
@Mock
private TbLwM2MDtlsBootstrapCertificateVerifier certificateVerifier;
@Test
public void getLHServer_creates_ConnectionIdGenerator_when_connection_id_length_not_null(){
final Integer CONNECTION_ID_LENGTH = 6;
when(serverConfig.getDtlsConnectionIdLength()).thenReturn(CONNECTION_ID_LENGTH);
var lwM2MBootstrapService = createLwM2MBootstrapService();
var server = lwM2MBootstrapService.getLhBootstrapServer();
var securedEndpoint = (CoapEndpoint) ReflectionTestUtils.getField(server, "securedEndpoint");
assertThat(securedEndpoint).isNotNull();
var config = (DtlsConnectorConfig) ReflectionTestUtils.getField(securedEndpoint.getConnector(), "config");
assertThat(config).isNotNull();
assertThat(config.getConnectionIdGenerator()).isNotNull();
assertThat((Integer) ReflectionTestUtils.getField(config.getConnectionIdGenerator(), "connectionIdLength"))
.isEqualTo(CONNECTION_ID_LENGTH);
}
@Test
public void getLHServer_creates_no_ConnectionIdGenerator_when_connection_id_length_is_null(){
when(serverConfig.getDtlsConnectionIdLength()).thenReturn(null);
var lwM2MBootstrapService = createLwM2MBootstrapService();
var server = lwM2MBootstrapService.getLhBootstrapServer();
var securedEndpoint = (CoapEndpoint) ReflectionTestUtils.getField(server, "securedEndpoint");
assertThat(securedEndpoint).isNotNull();
var config = (DtlsConnectorConfig) ReflectionTestUtils.getField(securedEndpoint.getConnector(), "config");
assertThat(config).isNotNull();
assertThat(config.getConnectionIdGenerator()).isNull();
}
private LwM2MTransportBootstrapService createLwM2MBootstrapService() {
setDefaultConfigVariables();
return new LwM2MTransportBootstrapService(serverConfig, bootstrapConfig, lwM2MBootstrapSecurityStore,
lwM2MInMemoryBootstrapConfigStore, transportService, certificateVerifier);
}
private void setDefaultConfigVariables(){
when(bootstrapConfig.getPort()).thenReturn(5683);
when(bootstrapConfig.getSecurePort()).thenReturn(5684);
when(serverConfig.isRecommendedCiphers()).thenReturn(false);
when(serverConfig.getDtlsRetransmissionTimeout()).thenReturn(9000);
}
}

View File

@ -0,0 +1,61 @@
/**
* Copyright © 2016-2024 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.lwm2m.config;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.test.context.SpringBootContextLoader;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.TestPropertySource;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.thingsboard.server.common.transport.config.ssl.SslCredentialsConfig;
import static org.assertj.core.api.Assertions.assertThat;
@ExtendWith(SpringExtension.class)
@EnableConfigurationProperties(value = LwM2MTransportServerConfig.class)
@ContextConfiguration(classes = {LwM2MTransportServerConfig.class}, loader = SpringBootContextLoader.class)
@TestPropertySource(properties = {
"transport.sessions.report_timeout=10",
"transport.lwm2m.security.recommended_ciphers=true",
"transport.lwm2m.security.recommended_supported_groups=true",
"transport.lwm2m.downlink_pool_size=10",
"transport.lwm2m.uplink_pool_size=10",
"transport.lwm2m.ota_pool_size=10",
"transport.lwm2m.clean_period_in_sec=2",
"transport.lwm2m.dtls.connection_id_length="
})
class LwM2MTransportServerConfigTest {
@MockBean(name = "lwm2mServerCredentials")
private SslCredentialsConfig credentialsConfig;
@MockBean(name = "lwm2mTrustCredentials")
private SslCredentialsConfig trustCredentialsConfig;
@Autowired
private LwM2MTransportServerConfig serverConfig;
@Test
void getDtlsConnectionIdLength_return_null_is_property_is_empty() {
// note: transport.lwm2m.dtls.connect_id_length is set in TestPropertySource
assertThat(serverConfig.getDtlsConnectionIdLength()).isNull();
}
}

View File

@ -0,0 +1,109 @@
/**
* Copyright © 2016-2024 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.lwm2m.server;
import org.eclipse.californium.core.network.CoapEndpoint;
import org.eclipse.californium.scandium.config.DtlsConnectorConfig;
import org.eclipse.leshan.server.californium.LeshanServer;
import org.eclipse.leshan.server.californium.registration.CaliforniumRegistrationStore;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.util.ReflectionTestUtils;
import org.thingsboard.server.cache.ota.OtaPackageDataCache;
import org.thingsboard.server.transport.lwm2m.config.LwM2MTransportServerConfig;
import org.thingsboard.server.transport.lwm2m.secure.TbLwM2MAuthorizer;
import org.thingsboard.server.transport.lwm2m.secure.TbLwM2MDtlsCertificateVerifier;
import org.thingsboard.server.transport.lwm2m.server.store.TbSecurityStore;
import org.thingsboard.server.transport.lwm2m.server.uplink.LwM2mUplinkMsgHandler;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.when;
@ExtendWith(MockitoExtension.class)
public class DefaultLwM2mTransportServiceTest {
@Mock
private LwM2mTransportContext context;
@Mock
private LwM2MTransportServerConfig config;
@Mock
private OtaPackageDataCache otaPackageDataCache;
@Mock
private LwM2mUplinkMsgHandler handler;
@Mock
private CaliforniumRegistrationStore registrationStore;
@Mock
private TbSecurityStore securityStore;
@Mock
private TbLwM2MDtlsCertificateVerifier certificateVerifier;
@Mock
private TbLwM2MAuthorizer authorizer;
@Mock
private LwM2mVersionedModelProvider modelProvider;
@Test
public void getLHServer_creates_ConnectionIdGenerator_when_connection_id_length_not_null(){
final Integer CONNECTION_ID_LENGTH = 6;
when(config.getDtlsConnectionIdLength()).thenReturn(CONNECTION_ID_LENGTH);
var lwm2mService = createLwM2MService();
LeshanServer server = ReflectionTestUtils.invokeMethod(lwm2mService, "getLhServer");
assertThat(server).isNotNull();
var securedEndpoint = (CoapEndpoint) ReflectionTestUtils.getField(server, "securedEndpoint");
assertThat(securedEndpoint).isNotNull();
var config = (DtlsConnectorConfig) ReflectionTestUtils.getField(securedEndpoint.getConnector(), "config");
assertThat(config).isNotNull();
assertThat(config.getConnectionIdGenerator()).isNotNull();
assertThat((Integer) ReflectionTestUtils.getField(config.getConnectionIdGenerator(), "connectionIdLength"))
.isEqualTo(CONNECTION_ID_LENGTH);
}
@Test
public void getLHServer_creates_no_ConnectionIdGenerator_when_connection_id_length_is_null(){
when(config.getDtlsConnectionIdLength()).thenReturn(null);
var lwm2mService = createLwM2MService();
LeshanServer server = ReflectionTestUtils.invokeMethod(lwm2mService, "getLhServer");
assertThat(server).isNotNull();
var securedEndpoint = (CoapEndpoint) ReflectionTestUtils.getField(server, "securedEndpoint");
assertThat(securedEndpoint).isNotNull();
var config = (DtlsConnectorConfig) ReflectionTestUtils.getField(securedEndpoint.getConnector(), "config");
assertThat(config).isNotNull();
assertThat(config.getConnectionIdGenerator()).isNull();
}
private DefaultLwM2mTransportService createLwM2MService() {
setDefaultConfigVariables();
return new DefaultLwM2mTransportService(context, config, otaPackageDataCache, handler, registrationStore,
securityStore, certificateVerifier, authorizer, modelProvider);
}
private void setDefaultConfigVariables(){
when(config.getPort()).thenReturn(5683);
when(config.getSecurePort()).thenReturn(5684);
when(config.isRecommendedCiphers()).thenReturn(false);
when(config.getDtlsRetransmissionTimeout()).thenReturn(9000);
}
}

View File

@ -0,0 +1,265 @@
/**
* Copyright © 2016-2024 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.lwm2m.server.store;
import org.eclipse.leshan.core.link.Link;
import org.eclipse.leshan.core.request.Identity;
import org.eclipse.leshan.core.util.NamedThreadFactory;
import org.eclipse.leshan.server.redis.serialization.RegistrationSerDes;
import org.eclipse.leshan.server.registration.Registration;
import org.eclipse.leshan.server.registration.RegistrationUpdate;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.integration.redis.util.RedisLockRegistry;
import org.springframework.test.util.ReflectionTestUtils;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.locks.Lock;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.thingsboard.server.transport.lwm2m.server.store.TbLwM2mRedisRegistrationStore.DEFAULT_CLEAN_LIMIT;
import static org.thingsboard.server.transport.lwm2m.server.store.TbLwM2mRedisRegistrationStore.DEFAULT_CLEAN_PERIOD;
import static org.thingsboard.server.transport.lwm2m.server.store.TbLwM2mRedisRegistrationStore.DEFAULT_GRACE_PERIOD;
@ExtendWith(MockitoExtension.class)
class TbLwM2mRedisRegistrationStoreTest {
RedisConnectionFactory connectionFactory;
RedisConnection connection;
RedisLockRegistry lockRegistry;
TbLwM2mRedisRegistrationStore registrationStore;
@BeforeEach
void setUp() {
lockRegistry = mock(RedisLockRegistry.class);
lenient().when(lockRegistry.obtain(any())).thenReturn(mock(Lock.class));
connection = mock(RedisConnection.class);
//when(connection.set(any(byte[].class), any(byte[].class))).
connectionFactory = mock(RedisConnectionFactory.class);
lenient().when(connectionFactory.getConnection()).thenReturn(connection);
ScheduledExecutorService executorService = Executors.newScheduledThreadPool(1,
new NamedThreadFactory(String.format("RedisRegistrationStore Cleaner (%ds)", DEFAULT_CLEAN_PERIOD)));
registrationStore = new TbLwM2mRedisRegistrationStore(connectionFactory, executorService,
DEFAULT_CLEAN_PERIOD, DEFAULT_GRACE_PERIOD, DEFAULT_CLEAN_LIMIT, lockRegistry);
}
@Test
void testAddRegistrationWithNoOldRegistration() {
setOldRegistration(null);
Registration registration = buildRegistration();
assertThat(registrationStore.addRegistration(registration)).isNull();
byte[] endpoint = registration.getEndpoint().getBytes(UTF_8);
verify(connection, times(1)).set(getRegIdKey(registration), endpoint);
verify(connection, times(1)).set(getRegAddrKey(registration), endpoint);
verify(connection, times(1)).set(getRegIdentityKey(registration), endpoint);
verify(connection, times(3)).set(any(byte[].class), any(byte[].class));
verify(connection, times(0)).del(any(byte[].class));
}
@Test
void testAddRegistrationWithOldRegistrationEqualToCurrent(){
var oldRegistration = buildRegistration();
setOldRegistration(oldRegistration);
Registration registration = buildRegistration();
var deregistration = registrationStore.addRegistration(registration);
assertThat(deregistration.getRegistration()).isEqualTo(oldRegistration);
byte[] endpoint = registration.getEndpoint().getBytes(UTF_8);
verify(connection, times(1)).set(getRegIdKey(registration), endpoint);
verify(connection, times(1)).set(getRegAddrKey(registration), endpoint);
verify(connection, times(1)).set(getRegIdentityKey(registration), endpoint);
verify(connection, times(3)).set(any(byte[].class), any(byte[].class));
verify(connection, times(1)).del(getTknsRegIdKey(oldRegistration));
verify(connection, times(1)).del(any(byte[].class));
}
@Test
void testAddRegistrationRemovesIndexes(){
var oldRegistration = buildRegistration(Identity.unsecure(getTestAddress(1234)));
setOldRegistration(oldRegistration);
var registration = buildRegistration(Identity.unsecure(getTestAddress(2345)));
var deregistration = registrationStore.addRegistration(registration);
assertThat(deregistration.getRegistration()).isEqualTo(oldRegistration);
byte[] endpoint = registration.getEndpoint().getBytes(UTF_8);
verify(connection, times(1)).set(getRegIdKey(registration), endpoint);
verify(connection, times(1)).set(getRegAddrKey(registration), endpoint);
verify(connection, times(1)).set(getRegIdentityKey(registration), endpoint);
verify(connection, times(3)).set(any(byte[].class), any(byte[].class));
verify(connection, times(1)).del(getRegAddrKey(oldRegistration));
verify(connection, times(1)).del(getRegIdentityKey(oldRegistration));
verify(connection, times(1)).del(getTknsRegIdKey(oldRegistration));
verify(connection, times(3)).del(any(byte[].class));
}
@Test
void testUpdateRegistrationWhenNoRegistrationFound() {
setOldRegistration(null);
Registration registration = buildRegistration();
RegistrationUpdate update = createUpdateFromRegistration(registration);
assertThat(registrationStore.updateRegistration(update)).isNull();
verify(connection, times(1)).get(getRegIdKey(registration));
verify(connection, times(1)).get(any(byte[].class));
verify(connection, times(0)).del(any(byte[].class));
}
@Test
void testUpdateRegistrationWithSameRegistration() {
Registration registration = buildRegistration();
setOldRegistration(registration);
RegistrationUpdate update = createUpdateFromRegistration(registration);
assertThat(registrationStore.updateRegistration(update)).isNotNull();
var endpoint = registration.getEndpoint().getBytes(UTF_8);
// check registration and addressIndex here updated
verify(connection, times(1)).set(eq(getEndpointKey(endpoint)), any(byte[].class));
verify(connection, times(1)).set(getRegAddrKey(registration), endpoint);
verify(connection, times(2)).set(any(byte[].class), any(byte[].class));
verify(connection, times(0)).del(any(byte[].class));
}
@Test
void testUpdateRegistrationWithRegistrationFromSecureIdentitiesWithDifferentAddress() {
Registration oldRegistration = buildRegistration(Identity.psk(getTestAddress(1234), "my:psk"));
setOldRegistration(oldRegistration);
Registration newRegistration = buildRegistration(Identity.psk(getTestAddress(2345), "my:psk"));
RegistrationUpdate update = createUpdateFromRegistration(newRegistration);
assertThat(oldRegistration.getEndpoint()).isEqualTo(newRegistration.getEndpoint());
assertThat(registrationStore.updateRegistration(update)).isNotNull();
var endpoint = newRegistration.getEndpoint().getBytes(UTF_8);
// check registration and addressIndex here updated
verify(connection, times(1)).set(eq(getEndpointKey(endpoint)), any(byte[].class));
verify(connection, times(1)).set(getRegAddrKey(newRegistration), endpoint);
// check old AddrIndex has been removed
verify(connection, times(1)).del(getRegAddrKey(oldRegistration));
// check identityIndex has not been removed
verify(connection, times(0)).del(getRegIdentityKey(oldRegistration));
// check only one key (AddrIndex) in total was removed
verify(connection, times(1)).del(any(byte[].class));
}
@Test
void testGetRegistrationByIdentityReturnsRegistrationForSecureIdentityWithDifferentAddress() {
Registration registration = buildRegistration(Identity.psk(getTestAddress(1234), "my:psk"));
setOldRegistration(registration);
Identity sameIdentityWithDifferentAddress = Identity.psk(getTestAddress(2345), "my:psk");
Registration retrievedRegistration = registrationStore.getRegistrationByIdentity(sameIdentityWithDifferentAddress);
assertThat(retrievedRegistration).isEqualTo(registration);
}
private void setOldRegistration(Registration oldRegistration){
byte[] serializedRegistration = null;
if (oldRegistration != null){
byte[] endpoint = oldRegistration.getEndpoint().getBytes(UTF_8);
// set the AddrIndex
byte[] regAddrKey = getRegAddrKey(oldRegistration);
lenient().when(connection.get(eq(regAddrKey))).thenReturn(endpoint);
// set the IdentityIndex
byte[] regIdentityKey = getRegIdentityKey(oldRegistration);
lenient().when(connection.get(eq(regIdentityKey))).thenReturn(endpoint);
// set the IdIndex
byte[] regIdKey = getRegIdKey(oldRegistration);
lenient().when(connection.get(eq(regIdKey))).thenReturn(endpoint);
// set the registration
serializedRegistration = RegistrationSerDes.bSerialize(oldRegistration);
lenient().when(connection.get(eq(getEndpointKey(endpoint)))).thenReturn(serializedRegistration);
}
lenient().when(connection.getSet(any(byte[].class), any(byte[].class))).thenReturn(serializedRegistration);
}
private byte[] getRegAddrKey(Registration registration){
return ReflectionTestUtils.invokeMethod(registrationStore, "toRegAddrKey", registration.getSocketAddress());
}
private byte[] getRegIdentityKey(Registration registration){
return ReflectionTestUtils.invokeMethod(registrationStore, "toRegIdentityKey", registration.getIdentity());
}
private byte[] getRegIdKey(Registration registration){
return ReflectionTestUtils.invokeMethod(registrationStore, "toRegIdKey", registration.getId());
}
private byte[] getEndpointKey(byte[] endpoint){
return ReflectionTestUtils.invokeMethod(registrationStore, "toEndpointKey", endpoint);
}
private byte[] getTknsRegIdKey(Registration registration){
return ReflectionTestUtils.invokeMethod(registrationStore, "toKey", "TKNS:REGID:", registration.getId());
}
private static Registration buildRegistration() {
return buildRegistration(Identity.psk(getTestAddress(), "my:psk"));
}
private static Registration buildRegistration(Identity identity){
return new Registration.Builder("my_reg_id", "abcde", identity)
.objectLinks(new Link[]{})
.build();
}
private static RegistrationUpdate createUpdateFromRegistration(Registration registration){
return new RegistrationUpdate(
registration.getId(),
registration.getIdentity(),
registration.getLifeTimeInSec(),
registration.getSmsNumber(),
registration.getBindingMode(),
registration.getObjectLinks(),
registration.getAdditionalRegistrationAttributes()
);
}
private static InetSocketAddress getTestAddress() {
return getTestAddress(5684);
}
private static InetSocketAddress getTestAddress(int port) {
try {
return new InetSocketAddress(InetAddress.getByName("1.2.3.4"), port);
} catch (UnknownHostException e) {
throw new AssertionError("Cannot create test address");
}
}
}

View File

@ -0,0 +1,77 @@
/**
* Copyright © 2016-2024 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.lwm2m.server.store.util;
import com.eclipsesource.json.JsonObject;
import org.apache.commons.lang3.NotImplementedException;
import org.eclipse.leshan.core.request.Identity;
import org.junit.jupiter.api.Test;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.security.PublicKey;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
class LwM2MIdentitySerDesTest {
@Test
void serializePskIdentity() {
assertThat(LwM2MIdentitySerDes.serialize(Identity.psk(getTestAddress(), "my:psk")).toString())
.isEqualTo("{\"type\":\"psk\",\"id\":\"my:psk\"}");
}
@Test
void serializeRpkIdentity() {
var public_key = mock(PublicKey.class);
when(public_key.getEncoded()).thenReturn(new byte[]{1,2,3,4,5,6,7,8,9});
assertThat(LwM2MIdentitySerDes.serialize(Identity.rpk(getTestAddress(), public_key)).toString())
.isEqualTo("{\"type\":\"rpk\",\"rpk\":\"010203040506070809\"}");
}
@Test
void serializeX509Identity() {
assertThat(LwM2MIdentitySerDes.serialize(Identity.x509(getTestAddress(), "MyCommonName")).toString())
.isEqualTo("{\"type\":\"x509\",\"cn\":\"MyCommonName\"}");
}
@Test
void serializeUnsecureIdentity() {
assertThat(LwM2MIdentitySerDes.serialize(Identity.unsecure(getTestAddress())).toString())
.isEqualTo("{\"type\":\"unsecure\",\"address\":\"1.2.3.4\",\"port\":5684}");
}
@Test
void deserialize() {
assertThatThrownBy(() -> LwM2MIdentitySerDes.deserialize(mock(JsonObject.class)))
.isInstanceOf(NotImplementedException.class);
}
private static InetSocketAddress getTestAddress() {
try {
return new InetSocketAddress(InetAddress.getByName("1.2.3.4"), 5684);
} catch (UnknownHostException e) {
throw new AssertionError("Cannot create test address");
}
}
}

View File

@ -131,6 +131,7 @@
<dbunit.version>2.7.2</dbunit.version>
<java-websocket.version>1.5.2</java-websocket.version>
<jupiter.version>5.8.2</jupiter.version> <!-- keep the same version as spring-boot-starter-test depend on jupiter-->
<mockito.version>4.5.1</mockito.version>
<json-path.version>2.6.0</json-path.version>
<mock-server.version>5.13.1</mock-server.version>
<spring-test-dbunit.version>1.3.0</spring-test-dbunit.version> <!-- 2016 -->
@ -1649,6 +1650,12 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>

View File

@ -155,6 +155,8 @@ transport:
dtls:
# RFC7925_RETRANSMISSION_TIMEOUT_IN_MILLISECONDS = 9000
retransmission_timeout: "${LWM2M_DTLS_RETRANSMISSION_TIMEOUT_MS:9000}"
# "" disables connection id support, 0 enables support but not for incoming traffic, any value greater than 0 set the connection id size in bytes
connection_id_length: "${LWM2M_DTLS_CONNECTION_ID_LENGTH:6}"
server:
# LwM2M Server ID
id: "${LWM2M_SERVER_ID:123}"