Merge pull request #5614 from thingsboard/bug/mqtt-credentials

[3.3.3] Fix corner case when access token matches user name in credentials
This commit is contained in:
Andrew Shvayka 2021-11-24 12:14:10 +02:00 committed by GitHub
commit 4c08744817
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 310 additions and 41 deletions

View File

@ -0,0 +1,18 @@
/**
* Copyright © 2016-2021 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.service.transport;
enum BasicCredentialsValidationResult {HASH_MISMATCH, PASSWORD_MISMATCH, VALID}

View File

@ -25,7 +25,6 @@ import com.google.protobuf.ByteString;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import org.thingsboard.common.util.JacksonUtil;
import org.thingsboard.server.cache.ota.OtaPackageDataCache;
import org.thingsboard.server.common.data.ApiUsageState;
@ -37,6 +36,7 @@ import org.thingsboard.server.common.data.EntityType;
import org.thingsboard.server.common.data.OtaPackage;
import org.thingsboard.server.common.data.OtaPackageInfo;
import org.thingsboard.server.common.data.ResourceType;
import org.thingsboard.server.common.data.StringUtils;
import org.thingsboard.server.common.data.TbResource;
import org.thingsboard.server.common.data.TenantProfile;
import org.thingsboard.server.common.data.device.credentials.BasicMqttCredentials;
@ -107,6 +107,9 @@ import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import static org.thingsboard.server.service.transport.BasicCredentialsValidationResult.PASSWORD_MISMATCH;
import static org.thingsboard.server.service.transport.BasicCredentialsValidationResult.VALID;
/**
* Created by ashvayka on 05.10.18.
*/
@ -181,71 +184,89 @@ public class DefaultTransportApiService implements TransportApiService {
//TODO: Make async and enable caching
DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(credentialsId);
if (credentials != null && credentials.getCredentialsType() == credentialsType) {
return getDeviceInfo(credentials.getDeviceId(), credentials);
return getDeviceInfo(credentials);
} else {
return getEmptyTransportApiResponseFuture();
}
}
private ListenableFuture<TransportApiResponseMsg> validateCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) {
DeviceCredentials credentials = null;
if (!StringUtils.isEmpty(mqtt.getUserName())) {
credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(mqtt.getUserName());
if (credentials != null) {
if (credentials.getCredentialsType() == DeviceCredentialsType.ACCESS_TOKEN) {
return getDeviceInfo(credentials.getDeviceId(), credentials);
} else if (credentials.getCredentialsType() == DeviceCredentialsType.MQTT_BASIC) {
if (!checkMqttCredentials(mqtt, credentials)) {
credentials = null;
}
} else {
return getEmptyTransportApiResponseFuture();
}
}
if (credentials == null) {
credentials = checkMqttCredentials(mqtt, EncryptionUtil.getSha3Hash("|", mqtt.getClientId(), mqtt.getUserName()));
}
}
if (credentials == null) {
DeviceCredentials credentials;
if (StringUtils.isEmpty(mqtt.getUserName())) {
credentials = checkMqttCredentials(mqtt, EncryptionUtil.getSha3Hash(mqtt.getClientId()));
}
if (credentials != null) {
return getDeviceInfo(credentials.getDeviceId(), credentials);
if (credentials != null) {
return getDeviceInfo(credentials);
} else {
return getEmptyTransportApiResponseFuture();
}
} else {
return getEmptyTransportApiResponseFuture();
credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(
EncryptionUtil.getSha3Hash("|", mqtt.getClientId(), mqtt.getUserName()));
if (checkIsMqttCredentials(credentials)) {
var validationResult = validateMqttCredentials(mqtt, credentials);
if (VALID.equals(validationResult)) {
return getDeviceInfo(credentials);
} else if (PASSWORD_MISMATCH.equals(validationResult)) {
return getEmptyTransportApiResponseFuture();
} else {
return validateUserNameCredentials(mqtt);
}
} else {
return validateUserNameCredentials(mqtt);
}
}
}
private ListenableFuture<TransportApiResponseMsg> validateUserNameCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) {
DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(mqtt.getUserName());
if (credentials != null) {
switch (credentials.getCredentialsType()) {
case ACCESS_TOKEN:
return getDeviceInfo(credentials);
case MQTT_BASIC:
if (VALID.equals(validateMqttCredentials(mqtt, credentials))) {
return getDeviceInfo(credentials);
} else {
return getEmptyTransportApiResponseFuture();
}
}
}
return getEmptyTransportApiResponseFuture();
}
private static boolean checkIsMqttCredentials(DeviceCredentials credentials) {
return credentials != null && DeviceCredentialsType.MQTT_BASIC.equals(credentials.getCredentialsType());
}
private DeviceCredentials checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, String credId) {
DeviceCredentials deviceCredentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(credId);
return checkMqttCredentials(clientCred, deviceCredentialsService.findDeviceCredentialsByCredentialsId(credId));
}
private DeviceCredentials checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, DeviceCredentials deviceCredentials) {
if (deviceCredentials != null && deviceCredentials.getCredentialsType() == DeviceCredentialsType.MQTT_BASIC) {
if (!checkMqttCredentials(clientCred, deviceCredentials)) {
return null;
} else {
if (VALID.equals(validateMqttCredentials(clientCred, deviceCredentials))) {
return deviceCredentials;
}
}
return null;
}
private boolean checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, DeviceCredentials deviceCredentials) {
private BasicCredentialsValidationResult validateMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, DeviceCredentials deviceCredentials) {
BasicMqttCredentials dbCred = JacksonUtil.fromString(deviceCredentials.getCredentialsValue(), BasicMqttCredentials.class);
if (!StringUtils.isEmpty(dbCred.getClientId()) && !dbCred.getClientId().equals(clientCred.getClientId())) {
return false;
return BasicCredentialsValidationResult.HASH_MISMATCH;
}
if (!StringUtils.isEmpty(dbCred.getUserName()) && !dbCred.getUserName().equals(clientCred.getUserName())) {
return false;
return BasicCredentialsValidationResult.HASH_MISMATCH;
}
if (!StringUtils.isEmpty(dbCred.getPassword())) {
if (StringUtils.isEmpty(clientCred.getPassword())) {
return false;
return BasicCredentialsValidationResult.PASSWORD_MISMATCH;
} else {
if (!dbCred.getPassword().equals(clientCred.getPassword())) {
return false;
}
return dbCred.getPassword().equals(clientCred.getPassword()) ? VALID : BasicCredentialsValidationResult.PASSWORD_MISMATCH;
}
}
return true;
return VALID;
}
private ListenableFuture<TransportApiResponseMsg> handle(GetOrCreateDeviceFromGatewayRequestMsg requestMsg) {
@ -437,10 +458,10 @@ public class DefaultTransportApiService implements TransportApiService {
.build());
}
private ListenableFuture<TransportApiResponseMsg> getDeviceInfo(DeviceId deviceId, DeviceCredentials credentials) {
return Futures.transform(deviceService.findDeviceByIdAsync(TenantId.SYS_TENANT_ID, deviceId), device -> {
private ListenableFuture<TransportApiResponseMsg> getDeviceInfo(DeviceCredentials credentials) {
return Futures.transform(deviceService.findDeviceByIdAsync(TenantId.SYS_TENANT_ID, credentials.getDeviceId()), device -> {
if (device == null) {
log.trace("[{}] Failed to lookup device by id", deviceId);
log.trace("[{}] Failed to lookup device by id", credentials.getDeviceId());
return getEmptyTransportApiResponse();
}
try {
@ -458,7 +479,7 @@ public class DefaultTransportApiService implements TransportApiService {
return TransportApiResponseMsg.newBuilder()
.setValidateCredResponseMsg(builder.build()).build();
} catch (JsonProcessingException e) {
log.warn("[{}] Failed to lookup device by id", deviceId, e);
log.warn("[{}] Failed to lookup device by id", credentials.getDeviceId(), e);
return getEmptyTransportApiResponse();
}
}, MoreExecutors.directExecutor());

View File

@ -33,6 +33,7 @@ import java.util.Arrays;
"org.thingsboard.server.transport.*.attributes.request.sql.*Test",
"org.thingsboard.server.transport.*.claim.sql.*Test",
"org.thingsboard.server.transport.*.provision.sql.*Test",
"org.thingsboard.server.transport.*.credentials.sql.*Test",
"org.thingsboard.server.transport.lwm2m.sql.*Test"
})
public class TransportSqlTestSuite {

View File

@ -0,0 +1,226 @@
/**
* Copyright © 2016-2021 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.mqtt.credentials.sql;
import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.paho.client.mqttv3.MqttAsyncClient;
import org.eclipse.paho.client.mqttv3.MqttConnectOptions;
import org.eclipse.paho.client.mqttv3.MqttException;
import org.eclipse.paho.client.mqttv3.MqttSecurityException;
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.thingsboard.common.util.JacksonUtil;
import org.thingsboard.server.common.data.Device;
import org.thingsboard.server.common.data.StringUtils;
import org.thingsboard.server.common.data.Tenant;
import org.thingsboard.server.common.data.User;
import org.thingsboard.server.common.data.device.credentials.BasicMqttCredentials;
import org.thingsboard.server.common.data.device.profile.MqttTopics;
import org.thingsboard.server.common.data.security.Authority;
import org.thingsboard.server.common.data.security.DeviceCredentials;
import org.thingsboard.server.common.data.security.DeviceCredentialsType;
import org.thingsboard.server.dao.service.DaoSqlTest;
import org.thingsboard.server.transport.mqtt.AbstractMqttIntegrationTest;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@DaoSqlTest
public class BasicMqttCredentialsTest extends AbstractMqttIntegrationTest {
public static final String CLIENT_ID = "ClientId";
public static final String USER_NAME1 = "UserName1";
public static final String USER_NAME2 = "UserName2";
public static final String USER_NAME3 = "UserName3";
public static final String PASSWORD = "secret";
private Device clientIdDevice;
private Device clientIdAndUserNameDevice1;
private Device clientIdAndUserNameAndPasswordDevice2;
private Device clientIdAndUserNameAndPasswordDevice3;
private Device accessTokenDevice;
private Device accessToken2Device;
@Before
public void before() throws Exception {
loginSysAdmin();
Tenant tenant = new Tenant();
tenant.setTitle("My tenant");
savedTenant = doPost("/api/tenant", tenant, Tenant.class);
Assert.assertNotNull(savedTenant);
tenantAdmin = new User();
tenantAdmin.setAuthority(Authority.TENANT_ADMIN);
tenantAdmin.setTenantId(savedTenant.getId());
tenantAdmin.setEmail("tenant" + atomicInteger.getAndIncrement() + "@thingsboard.org");
tenantAdmin.setFirstName("Joe");
tenantAdmin.setLastName("Downs");
tenantAdmin = createUserAndLogin(tenantAdmin, "testPassword1");
BasicMqttCredentials credValue = new BasicMqttCredentials();
credValue.setClientId(CLIENT_ID);
clientIdDevice = createDevice("clientIdDevice", credValue);
credValue = new BasicMqttCredentials();
credValue.setClientId(CLIENT_ID);
credValue.setUserName(USER_NAME1);
clientIdAndUserNameDevice1 = createDevice("clientIdAndUserNameDevice", credValue);
credValue = new BasicMqttCredentials();
credValue.setClientId(CLIENT_ID);
credValue.setUserName(USER_NAME2);
credValue.setPassword(PASSWORD);
clientIdAndUserNameAndPasswordDevice2 = createDevice("clientIdAndUserNameAndPasswordDevice", credValue);
credValue = new BasicMqttCredentials();
credValue.setClientId(CLIENT_ID);
credValue.setUserName(USER_NAME3);
credValue.setPassword(PASSWORD);
clientIdAndUserNameAndPasswordDevice3 = createDevice("clientIdAndUserNameAndPasswordDevice2", credValue);
accessTokenDevice = createDevice("accessTokenDevice", USER_NAME1);
accessToken2Device = createDevice("accessToken2Device", USER_NAME2);
}
@Test
public void testCorrectCredentials() throws Exception {
// Check that correct devices receive telemetry
testTelemetryIsDelivered(accessTokenDevice, getMqttAsyncClient(null, USER_NAME1, null));
testTelemetryIsDelivered(clientIdDevice, getMqttAsyncClient(CLIENT_ID, null, null));
testTelemetryIsDelivered(clientIdAndUserNameDevice1, getMqttAsyncClient(CLIENT_ID, USER_NAME1, null));
testTelemetryIsDelivered(clientIdAndUserNameAndPasswordDevice2, getMqttAsyncClient(CLIENT_ID, USER_NAME2, PASSWORD));
// Also correct. Random clientId and password, but matches access token
testTelemetryIsDelivered(accessToken2Device, getMqttAsyncClient(RandomStringUtils.randomAlphanumeric(10), USER_NAME2, RandomStringUtils.randomAlphanumeric(10)));
}
@Test(expected = MqttSecurityException.class)
public void testCorrectClientIdAndUserNameButWrongPassword() throws Exception {
// Not correct. Correct clientId and username, but wrong password
testTelemetryIsNotDelivered(clientIdAndUserNameAndPasswordDevice3, getMqttAsyncClient(CLIENT_ID, USER_NAME3, "WRONG PASSWORD"));
}
private void testTelemetryIsDelivered(Device device, MqttAsyncClient client) throws Exception {
testTelemetryIsDelivered(device, client, true);
}
private void testTelemetryIsNotDelivered(Device device, MqttAsyncClient client) throws Exception {
testTelemetryIsDelivered(device, client, false);
}
private void testTelemetryIsDelivered(Device device, MqttAsyncClient client, boolean ok) throws Exception {
String randomKey = RandomStringUtils.randomAlphanumeric(10);
List<String> expectedKeys = Arrays.asList(randomKey);
publishMqttMsg(client, JacksonUtil.toString(JacksonUtil.newObjectNode().put(randomKey, true)).getBytes(), MqttTopics.DEVICE_TELEMETRY_TOPIC);
String deviceId = device.getId().getId().toString();
long start = System.currentTimeMillis();
long end = System.currentTimeMillis() + 5000;
List<String> actualKeys = null;
while (start <= end) {
actualKeys = doGetAsyncTyped("/api/plugins/telemetry/DEVICE/" + deviceId + "/keys/timeseries", new TypeReference<>() {
});
if (actualKeys.size() == expectedKeys.size()) {
break;
}
Thread.sleep(100);
start += 100;
}
if (ok) {
assertNotNull(actualKeys);
Set<String> actualKeySet = new HashSet<>(actualKeys);
Set<String> expectedKeySet = new HashSet<>(expectedKeys);
assertEquals(expectedKeySet, actualKeySet);
} else {
assertNull(actualKeys);
}
client.disconnect().waitForCompletion();
}
@After
public void after() throws Exception {
processAfterTest();
}
protected MqttAsyncClient getMqttAsyncClient(String clientId, String username, String password) throws MqttException {
if (StringUtils.isEmpty(clientId)) {
clientId = MqttAsyncClient.generateClientId();
}
MqttAsyncClient client = new MqttAsyncClient(MQTT_URL, clientId, new MemoryPersistence());
MqttConnectOptions options = new MqttConnectOptions();
if (StringUtils.isNotEmpty(username)) {
options.setUserName(username);
}
if (StringUtils.isNotEmpty(password)) {
options.setPassword(password.toCharArray());
}
client.connect(options).waitForCompletion();
return client;
}
private Device createDevice(String deviceName, BasicMqttCredentials clientIdCredValue) throws Exception {
Device device = new Device();
device.setName(deviceName);
device.setType("default");
device = doPost("/api/device", device, Device.class);
DeviceCredentials clientIdCred =
doGet("/api/device/" + device.getId().getId().toString() + "/credentials", DeviceCredentials.class);
clientIdCred.setCredentialsType(DeviceCredentialsType.MQTT_BASIC);
clientIdCred.setCredentialsValue(JacksonUtil.toString(clientIdCredValue));
doPost("/api/device/credentials", clientIdCred).andExpect(status().isOk());
return device;
}
private Device createDevice(String deviceName, String accessToken) throws Exception {
Device device = new Device();
device.setName(deviceName);
device.setType("default");
device = doPost("/api/device", device, Device.class);
DeviceCredentials clientIdCred =
doGet("/api/device/" + device.getId().getId().toString() + "/credentials", DeviceCredentials.class);
clientIdCred.setCredentialsType(DeviceCredentialsType.ACCESS_TOKEN);
clientIdCred.setCredentialsId(accessToken);
doPost("/api/device/credentials", clientIdCred).andExpect(status().isOk());
return device;
}
}

View File

@ -19,6 +19,8 @@ import org.thingsboard.server.common.data.id.DeviceId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.data.security.DeviceCredentials;
import java.util.List;
public interface DeviceCredentialsService {
DeviceCredentials findDeviceCredentialsByDeviceId(TenantId tenantId, DeviceId deviceId);
@ -32,4 +34,5 @@ public interface DeviceCredentialsService {
void formatCredentials(DeviceCredentials deviceCredentials);
void deleteDeviceCredentials(TenantId tenantId, DeviceCredentials deviceCredentials);
}