diff --git a/application/src/main/java/org/thingsboard/server/service/transport/BasicCredentialsValidationResult.java b/application/src/main/java/org/thingsboard/server/service/transport/BasicCredentialsValidationResult.java new file mode 100644 index 0000000000..56d7b776e8 --- /dev/null +++ b/application/src/main/java/org/thingsboard/server/service/transport/BasicCredentialsValidationResult.java @@ -0,0 +1,18 @@ +/** + * Copyright © 2016-2021 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.service.transport; + +enum BasicCredentialsValidationResult {HASH_MISMATCH, PASSWORD_MISMATCH, VALID} diff --git a/application/src/main/java/org/thingsboard/server/service/transport/DefaultTransportApiService.java b/application/src/main/java/org/thingsboard/server/service/transport/DefaultTransportApiService.java index dc5f6c436d..1cefefd5d1 100644 --- a/application/src/main/java/org/thingsboard/server/service/transport/DefaultTransportApiService.java +++ b/application/src/main/java/org/thingsboard/server/service/transport/DefaultTransportApiService.java @@ -25,7 +25,6 @@ import com.google.protobuf.ByteString; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; -import org.springframework.util.StringUtils; import org.thingsboard.common.util.JacksonUtil; import org.thingsboard.server.cache.ota.OtaPackageDataCache; import org.thingsboard.server.common.data.ApiUsageState; @@ -37,6 +36,7 @@ import org.thingsboard.server.common.data.EntityType; import org.thingsboard.server.common.data.OtaPackage; import org.thingsboard.server.common.data.OtaPackageInfo; import org.thingsboard.server.common.data.ResourceType; +import org.thingsboard.server.common.data.StringUtils; import org.thingsboard.server.common.data.TbResource; import org.thingsboard.server.common.data.TenantProfile; import org.thingsboard.server.common.data.device.credentials.BasicMqttCredentials; @@ -107,6 +107,9 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.stream.Collectors; +import static org.thingsboard.server.service.transport.BasicCredentialsValidationResult.PASSWORD_MISMATCH; +import static org.thingsboard.server.service.transport.BasicCredentialsValidationResult.VALID; + /** * Created by ashvayka on 05.10.18. */ @@ -181,71 +184,89 @@ public class DefaultTransportApiService implements TransportApiService { //TODO: Make async and enable caching DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(credentialsId); if (credentials != null && credentials.getCredentialsType() == credentialsType) { - return getDeviceInfo(credentials.getDeviceId(), credentials); + return getDeviceInfo(credentials); } else { return getEmptyTransportApiResponseFuture(); } } private ListenableFuture validateCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) { - DeviceCredentials credentials = null; - if (!StringUtils.isEmpty(mqtt.getUserName())) { - credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(mqtt.getUserName()); - if (credentials != null) { - if (credentials.getCredentialsType() == DeviceCredentialsType.ACCESS_TOKEN) { - return getDeviceInfo(credentials.getDeviceId(), credentials); - } else if (credentials.getCredentialsType() == DeviceCredentialsType.MQTT_BASIC) { - if (!checkMqttCredentials(mqtt, credentials)) { - credentials = null; - } - } else { - return getEmptyTransportApiResponseFuture(); - } - } - if (credentials == null) { - credentials = checkMqttCredentials(mqtt, EncryptionUtil.getSha3Hash("|", mqtt.getClientId(), mqtt.getUserName())); - } - } - if (credentials == null) { + DeviceCredentials credentials; + if (StringUtils.isEmpty(mqtt.getUserName())) { credentials = checkMqttCredentials(mqtt, EncryptionUtil.getSha3Hash(mqtt.getClientId())); - } - if (credentials != null) { - return getDeviceInfo(credentials.getDeviceId(), credentials); + if (credentials != null) { + return getDeviceInfo(credentials); + } else { + return getEmptyTransportApiResponseFuture(); + } } else { - return getEmptyTransportApiResponseFuture(); + credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId( + EncryptionUtil.getSha3Hash("|", mqtt.getClientId(), mqtt.getUserName())); + if (checkIsMqttCredentials(credentials)) { + var validationResult = validateMqttCredentials(mqtt, credentials); + if (VALID.equals(validationResult)) { + return getDeviceInfo(credentials); + } else if (PASSWORD_MISMATCH.equals(validationResult)) { + return getEmptyTransportApiResponseFuture(); + } else { + return validateUserNameCredentials(mqtt); + } + } else { + return validateUserNameCredentials(mqtt); + } } } + private ListenableFuture validateUserNameCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) { + DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(mqtt.getUserName()); + if (credentials != null) { + switch (credentials.getCredentialsType()) { + case ACCESS_TOKEN: + return getDeviceInfo(credentials); + case MQTT_BASIC: + if (VALID.equals(validateMqttCredentials(mqtt, credentials))) { + return getDeviceInfo(credentials); + } else { + return getEmptyTransportApiResponseFuture(); + } + } + } + return getEmptyTransportApiResponseFuture(); + } + + private static boolean checkIsMqttCredentials(DeviceCredentials credentials) { + return credentials != null && DeviceCredentialsType.MQTT_BASIC.equals(credentials.getCredentialsType()); + } + private DeviceCredentials checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, String credId) { - DeviceCredentials deviceCredentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(credId); + return checkMqttCredentials(clientCred, deviceCredentialsService.findDeviceCredentialsByCredentialsId(credId)); + } + + private DeviceCredentials checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, DeviceCredentials deviceCredentials) { if (deviceCredentials != null && deviceCredentials.getCredentialsType() == DeviceCredentialsType.MQTT_BASIC) { - if (!checkMqttCredentials(clientCred, deviceCredentials)) { - return null; - } else { + if (VALID.equals(validateMqttCredentials(clientCred, deviceCredentials))) { return deviceCredentials; } } return null; } - private boolean checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, DeviceCredentials deviceCredentials) { + private BasicCredentialsValidationResult validateMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, DeviceCredentials deviceCredentials) { BasicMqttCredentials dbCred = JacksonUtil.fromString(deviceCredentials.getCredentialsValue(), BasicMqttCredentials.class); if (!StringUtils.isEmpty(dbCred.getClientId()) && !dbCred.getClientId().equals(clientCred.getClientId())) { - return false; + return BasicCredentialsValidationResult.HASH_MISMATCH; } if (!StringUtils.isEmpty(dbCred.getUserName()) && !dbCred.getUserName().equals(clientCred.getUserName())) { - return false; + return BasicCredentialsValidationResult.HASH_MISMATCH; } if (!StringUtils.isEmpty(dbCred.getPassword())) { if (StringUtils.isEmpty(clientCred.getPassword())) { - return false; + return BasicCredentialsValidationResult.PASSWORD_MISMATCH; } else { - if (!dbCred.getPassword().equals(clientCred.getPassword())) { - return false; - } + return dbCred.getPassword().equals(clientCred.getPassword()) ? VALID : BasicCredentialsValidationResult.PASSWORD_MISMATCH; } } - return true; + return VALID; } private ListenableFuture handle(GetOrCreateDeviceFromGatewayRequestMsg requestMsg) { @@ -437,10 +458,10 @@ public class DefaultTransportApiService implements TransportApiService { .build()); } - private ListenableFuture getDeviceInfo(DeviceId deviceId, DeviceCredentials credentials) { - return Futures.transform(deviceService.findDeviceByIdAsync(TenantId.SYS_TENANT_ID, deviceId), device -> { + private ListenableFuture getDeviceInfo(DeviceCredentials credentials) { + return Futures.transform(deviceService.findDeviceByIdAsync(TenantId.SYS_TENANT_ID, credentials.getDeviceId()), device -> { if (device == null) { - log.trace("[{}] Failed to lookup device by id", deviceId); + log.trace("[{}] Failed to lookup device by id", credentials.getDeviceId()); return getEmptyTransportApiResponse(); } try { @@ -458,7 +479,7 @@ public class DefaultTransportApiService implements TransportApiService { return TransportApiResponseMsg.newBuilder() .setValidateCredResponseMsg(builder.build()).build(); } catch (JsonProcessingException e) { - log.warn("[{}] Failed to lookup device by id", deviceId, e); + log.warn("[{}] Failed to lookup device by id", credentials.getDeviceId(), e); return getEmptyTransportApiResponse(); } }, MoreExecutors.directExecutor()); diff --git a/application/src/test/java/org/thingsboard/server/transport/TransportSqlTestSuite.java b/application/src/test/java/org/thingsboard/server/transport/TransportSqlTestSuite.java index 0fa04d0611..8df7f22867 100644 --- a/application/src/test/java/org/thingsboard/server/transport/TransportSqlTestSuite.java +++ b/application/src/test/java/org/thingsboard/server/transport/TransportSqlTestSuite.java @@ -33,6 +33,7 @@ import java.util.Arrays; "org.thingsboard.server.transport.*.attributes.request.sql.*Test", "org.thingsboard.server.transport.*.claim.sql.*Test", "org.thingsboard.server.transport.*.provision.sql.*Test", + "org.thingsboard.server.transport.*.credentials.sql.*Test", "org.thingsboard.server.transport.lwm2m.sql.*Test" }) public class TransportSqlTestSuite { diff --git a/application/src/test/java/org/thingsboard/server/transport/mqtt/credentials/sql/BasicMqttCredentialsTest.java b/application/src/test/java/org/thingsboard/server/transport/mqtt/credentials/sql/BasicMqttCredentialsTest.java new file mode 100644 index 0000000000..6bb8949f73 --- /dev/null +++ b/application/src/test/java/org/thingsboard/server/transport/mqtt/credentials/sql/BasicMqttCredentialsTest.java @@ -0,0 +1,226 @@ +/** + * Copyright © 2016-2021 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.transport.mqtt.credentials.sql; + +import com.fasterxml.jackson.core.type.TypeReference; +import org.apache.commons.lang3.RandomStringUtils; +import org.eclipse.paho.client.mqttv3.MqttAsyncClient; +import org.eclipse.paho.client.mqttv3.MqttConnectOptions; +import org.eclipse.paho.client.mqttv3.MqttException; +import org.eclipse.paho.client.mqttv3.MqttSecurityException; +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.thingsboard.common.util.JacksonUtil; +import org.thingsboard.server.common.data.Device; +import org.thingsboard.server.common.data.StringUtils; +import org.thingsboard.server.common.data.Tenant; +import org.thingsboard.server.common.data.User; +import org.thingsboard.server.common.data.device.credentials.BasicMqttCredentials; +import org.thingsboard.server.common.data.device.profile.MqttTopics; +import org.thingsboard.server.common.data.security.Authority; +import org.thingsboard.server.common.data.security.DeviceCredentials; +import org.thingsboard.server.common.data.security.DeviceCredentialsType; +import org.thingsboard.server.dao.service.DaoSqlTest; +import org.thingsboard.server.transport.mqtt.AbstractMqttIntegrationTest; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +@DaoSqlTest +public class BasicMqttCredentialsTest extends AbstractMqttIntegrationTest { + + public static final String CLIENT_ID = "ClientId"; + public static final String USER_NAME1 = "UserName1"; + public static final String USER_NAME2 = "UserName2"; + public static final String USER_NAME3 = "UserName3"; + public static final String PASSWORD = "secret"; + + private Device clientIdDevice; + private Device clientIdAndUserNameDevice1; + private Device clientIdAndUserNameAndPasswordDevice2; + private Device clientIdAndUserNameAndPasswordDevice3; + private Device accessTokenDevice; + private Device accessToken2Device; + + + @Before + public void before() throws Exception { + loginSysAdmin(); + + Tenant tenant = new Tenant(); + tenant.setTitle("My tenant"); + savedTenant = doPost("/api/tenant", tenant, Tenant.class); + Assert.assertNotNull(savedTenant); + + tenantAdmin = new User(); + tenantAdmin.setAuthority(Authority.TENANT_ADMIN); + tenantAdmin.setTenantId(savedTenant.getId()); + tenantAdmin.setEmail("tenant" + atomicInteger.getAndIncrement() + "@thingsboard.org"); + tenantAdmin.setFirstName("Joe"); + tenantAdmin.setLastName("Downs"); + + tenantAdmin = createUserAndLogin(tenantAdmin, "testPassword1"); + + BasicMqttCredentials credValue = new BasicMqttCredentials(); + credValue.setClientId(CLIENT_ID); + clientIdDevice = createDevice("clientIdDevice", credValue); + + credValue = new BasicMqttCredentials(); + credValue.setClientId(CLIENT_ID); + credValue.setUserName(USER_NAME1); + clientIdAndUserNameDevice1 = createDevice("clientIdAndUserNameDevice", credValue); + + credValue = new BasicMqttCredentials(); + credValue.setClientId(CLIENT_ID); + credValue.setUserName(USER_NAME2); + credValue.setPassword(PASSWORD); + clientIdAndUserNameAndPasswordDevice2 = createDevice("clientIdAndUserNameAndPasswordDevice", credValue); + + credValue = new BasicMqttCredentials(); + credValue.setClientId(CLIENT_ID); + credValue.setUserName(USER_NAME3); + credValue.setPassword(PASSWORD); + clientIdAndUserNameAndPasswordDevice3 = createDevice("clientIdAndUserNameAndPasswordDevice2", credValue); + + accessTokenDevice = createDevice("accessTokenDevice", USER_NAME1); + accessToken2Device = createDevice("accessToken2Device", USER_NAME2); + } + + @Test + public void testCorrectCredentials() throws Exception { + // Check that correct devices receive telemetry + testTelemetryIsDelivered(accessTokenDevice, getMqttAsyncClient(null, USER_NAME1, null)); + testTelemetryIsDelivered(clientIdDevice, getMqttAsyncClient(CLIENT_ID, null, null)); + testTelemetryIsDelivered(clientIdAndUserNameDevice1, getMqttAsyncClient(CLIENT_ID, USER_NAME1, null)); + testTelemetryIsDelivered(clientIdAndUserNameAndPasswordDevice2, getMqttAsyncClient(CLIENT_ID, USER_NAME2, PASSWORD)); + + // Also correct. Random clientId and password, but matches access token + testTelemetryIsDelivered(accessToken2Device, getMqttAsyncClient(RandomStringUtils.randomAlphanumeric(10), USER_NAME2, RandomStringUtils.randomAlphanumeric(10))); + } + + @Test(expected = MqttSecurityException.class) + public void testCorrectClientIdAndUserNameButWrongPassword() throws Exception { + // Not correct. Correct clientId and username, but wrong password + testTelemetryIsNotDelivered(clientIdAndUserNameAndPasswordDevice3, getMqttAsyncClient(CLIENT_ID, USER_NAME3, "WRONG PASSWORD")); + } + + private void testTelemetryIsDelivered(Device device, MqttAsyncClient client) throws Exception { + testTelemetryIsDelivered(device, client, true); + } + + private void testTelemetryIsNotDelivered(Device device, MqttAsyncClient client) throws Exception { + testTelemetryIsDelivered(device, client, false); + } + + private void testTelemetryIsDelivered(Device device, MqttAsyncClient client, boolean ok) throws Exception { + String randomKey = RandomStringUtils.randomAlphanumeric(10); + List expectedKeys = Arrays.asList(randomKey); + publishMqttMsg(client, JacksonUtil.toString(JacksonUtil.newObjectNode().put(randomKey, true)).getBytes(), MqttTopics.DEVICE_TELEMETRY_TOPIC); + + String deviceId = device.getId().getId().toString(); + + long start = System.currentTimeMillis(); + long end = System.currentTimeMillis() + 5000; + + List actualKeys = null; + while (start <= end) { + actualKeys = doGetAsyncTyped("/api/plugins/telemetry/DEVICE/" + deviceId + "/keys/timeseries", new TypeReference<>() { + }); + if (actualKeys.size() == expectedKeys.size()) { + break; + } + Thread.sleep(100); + start += 100; + } + if (ok) { + assertNotNull(actualKeys); + + Set actualKeySet = new HashSet<>(actualKeys); + Set expectedKeySet = new HashSet<>(expectedKeys); + + assertEquals(expectedKeySet, actualKeySet); + } else { + assertNull(actualKeys); + } + client.disconnect().waitForCompletion(); + } + + @After + public void after() throws Exception { + processAfterTest(); + } + + protected MqttAsyncClient getMqttAsyncClient(String clientId, String username, String password) throws MqttException { + if (StringUtils.isEmpty(clientId)) { + clientId = MqttAsyncClient.generateClientId(); + } + MqttAsyncClient client = new MqttAsyncClient(MQTT_URL, clientId, new MemoryPersistence()); + + MqttConnectOptions options = new MqttConnectOptions(); + if (StringUtils.isNotEmpty(username)) { + options.setUserName(username); + } + if (StringUtils.isNotEmpty(password)) { + options.setPassword(password.toCharArray()); + } + client.connect(options).waitForCompletion(); + return client; + } + + private Device createDevice(String deviceName, BasicMqttCredentials clientIdCredValue) throws Exception { + Device device = new Device(); + device.setName(deviceName); + device.setType("default"); + + device = doPost("/api/device", device, Device.class); + + DeviceCredentials clientIdCred = + doGet("/api/device/" + device.getId().getId().toString() + "/credentials", DeviceCredentials.class); + + clientIdCred.setCredentialsType(DeviceCredentialsType.MQTT_BASIC); + + + clientIdCred.setCredentialsValue(JacksonUtil.toString(clientIdCredValue)); + doPost("/api/device/credentials", clientIdCred).andExpect(status().isOk()); + return device; + } + + private Device createDevice(String deviceName, String accessToken) throws Exception { + Device device = new Device(); + device.setName(deviceName); + device.setType("default"); + + device = doPost("/api/device", device, Device.class); + + DeviceCredentials clientIdCred = + doGet("/api/device/" + device.getId().getId().toString() + "/credentials", DeviceCredentials.class); + + clientIdCred.setCredentialsType(DeviceCredentialsType.ACCESS_TOKEN); + clientIdCred.setCredentialsId(accessToken); + doPost("/api/device/credentials", clientIdCred).andExpect(status().isOk()); + return device; + } +} diff --git a/common/dao-api/src/main/java/org/thingsboard/server/dao/device/DeviceCredentialsService.java b/common/dao-api/src/main/java/org/thingsboard/server/dao/device/DeviceCredentialsService.java index 29572bf51a..9afc735213 100644 --- a/common/dao-api/src/main/java/org/thingsboard/server/dao/device/DeviceCredentialsService.java +++ b/common/dao-api/src/main/java/org/thingsboard/server/dao/device/DeviceCredentialsService.java @@ -19,6 +19,8 @@ import org.thingsboard.server.common.data.id.DeviceId; import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.common.data.security.DeviceCredentials; +import java.util.List; + public interface DeviceCredentialsService { DeviceCredentials findDeviceCredentialsByDeviceId(TenantId tenantId, DeviceId deviceId); @@ -32,4 +34,5 @@ public interface DeviceCredentialsService { void formatCredentials(DeviceCredentials deviceCredentials); void deleteDeviceCredentials(TenantId tenantId, DeviceCredentials deviceCredentials); + }