DefaultTransportApiService validations refactored to a sync style to avoid excessive context switching

This commit is contained in:
Sergey Matvienko 2023-11-01 10:51:09 +01:00
parent f1fa6e6a01
commit b30fe6ce26
2 changed files with 48 additions and 43 deletions

View File

@ -179,18 +179,18 @@ public class DefaultTransportApiService implements TransportApiService {
if (transportApiRequestMsg.hasValidateTokenRequestMsg()) { if (transportApiRequestMsg.hasValidateTokenRequestMsg()) {
ValidateDeviceTokenRequestMsg msg = transportApiRequestMsg.getValidateTokenRequestMsg(); ValidateDeviceTokenRequestMsg msg = transportApiRequestMsg.getValidateTokenRequestMsg();
final String token = msg.getToken(); final String token = msg.getToken();
result = Futures.transformAsync(handlerExecutor.submit(() -> validateCredentials(token, DeviceCredentialsType.ACCESS_TOKEN)), future -> future, MoreExecutors.directExecutor()); result = handlerExecutor.submit(() -> validateCredentials(token, DeviceCredentialsType.ACCESS_TOKEN));
} else if (transportApiRequestMsg.hasValidateBasicMqttCredRequestMsg()) { } else if (transportApiRequestMsg.hasValidateBasicMqttCredRequestMsg()) {
TransportProtos.ValidateBasicMqttCredRequestMsg msg = transportApiRequestMsg.getValidateBasicMqttCredRequestMsg(); TransportProtos.ValidateBasicMqttCredRequestMsg msg = transportApiRequestMsg.getValidateBasicMqttCredRequestMsg();
result = Futures.transformAsync(handlerExecutor.submit(() -> validateCredentials(msg)), future -> future, MoreExecutors.directExecutor()); result = handlerExecutor.submit(() -> validateCredentials(msg));
} else if (transportApiRequestMsg.hasValidateX509CertRequestMsg()) { } else if (transportApiRequestMsg.hasValidateX509CertRequestMsg()) {
ValidateDeviceX509CertRequestMsg msg = transportApiRequestMsg.getValidateX509CertRequestMsg(); ValidateDeviceX509CertRequestMsg msg = transportApiRequestMsg.getValidateX509CertRequestMsg();
final String hash = msg.getHash(); final String hash = msg.getHash();
result = Futures.transformAsync(handlerExecutor.submit(() -> validateCredentials(hash, DeviceCredentialsType.X509_CERTIFICATE)), future -> future, MoreExecutors.directExecutor()); result = handlerExecutor.submit(() -> validateCredentials(hash, DeviceCredentialsType.X509_CERTIFICATE));
} else if (transportApiRequestMsg.hasValidateOrCreateX509CertRequestMsg()) { } else if (transportApiRequestMsg.hasValidateOrCreateX509CertRequestMsg()) {
TransportProtos.ValidateOrCreateDeviceX509CertRequestMsg msg = transportApiRequestMsg.getValidateOrCreateX509CertRequestMsg(); TransportProtos.ValidateOrCreateDeviceX509CertRequestMsg msg = transportApiRequestMsg.getValidateOrCreateX509CertRequestMsg();
final String certChain = msg.getCertificateChain(); final String certChain = msg.getCertificateChain();
result = Futures.transformAsync(handlerExecutor.submit(() -> validateOrCreateDeviceX509Certificate(certChain)), future -> future, MoreExecutors.directExecutor()); result = handlerExecutor.submit(() -> validateOrCreateDeviceX509Certificate(certChain));
} else if (transportApiRequestMsg.hasGetOrCreateDeviceRequestMsg()) { } else if (transportApiRequestMsg.hasGetOrCreateDeviceRequestMsg()) {
result = handle(transportApiRequestMsg.getGetOrCreateDeviceRequestMsg()); result = handle(transportApiRequestMsg.getGetOrCreateDeviceRequestMsg());
} else if (transportApiRequestMsg.hasEntityProfileRequestMsg()) { } else if (transportApiRequestMsg.hasEntityProfileRequestMsg()) {
@ -200,7 +200,7 @@ public class DefaultTransportApiService implements TransportApiService {
} else if (transportApiRequestMsg.hasValidateDeviceLwM2MCredentialsRequestMsg()) { } else if (transportApiRequestMsg.hasValidateDeviceLwM2MCredentialsRequestMsg()) {
ValidateDeviceLwM2MCredentialsRequestMsg msg = transportApiRequestMsg.getValidateDeviceLwM2MCredentialsRequestMsg(); ValidateDeviceLwM2MCredentialsRequestMsg msg = transportApiRequestMsg.getValidateDeviceLwM2MCredentialsRequestMsg();
final String credentialsId = msg.getCredentialsId(); final String credentialsId = msg.getCredentialsId();
result = Futures.transformAsync(handlerExecutor.submit(() -> validateCredentials(credentialsId, DeviceCredentialsType.LWM2M_CREDENTIALS)), future -> future, MoreExecutors.directExecutor()); result = handlerExecutor.submit(() -> validateCredentials(credentialsId, DeviceCredentialsType.LWM2M_CREDENTIALS));
} else if (transportApiRequestMsg.hasProvisionDeviceRequestMsg()) { } else if (transportApiRequestMsg.hasProvisionDeviceRequestMsg()) {
result = handle(transportApiRequestMsg.getProvisionDeviceRequestMsg()); result = handle(transportApiRequestMsg.getProvisionDeviceRequestMsg());
} else if (transportApiRequestMsg.hasResourceRequestMsg()) { } else if (transportApiRequestMsg.hasResourceRequestMsg()) {
@ -222,24 +222,24 @@ public class DefaultTransportApiService implements TransportApiService {
MoreExecutors.directExecutor()); MoreExecutors.directExecutor());
} }
private ListenableFuture<TransportApiResponseMsg> validateCredentials(String credentialsId, DeviceCredentialsType credentialsType) { private TransportApiResponseMsg validateCredentials(String credentialsId, DeviceCredentialsType credentialsType) {
//TODO: Make async and enable caching //TODO: Make async and enable caching
DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(credentialsId); DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(credentialsId);
if (credentials != null && credentials.getCredentialsType() == credentialsType) { if (credentials != null && credentials.getCredentialsType() == credentialsType) {
return getDeviceInfo(credentials); return getDeviceInfo(credentials);
} else { } else {
return getEmptyTransportApiResponseFuture(); return getEmptyTransportApiResponse();
} }
} }
private ListenableFuture<TransportApiResponseMsg> validateCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) { private TransportApiResponseMsg validateCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) {
DeviceCredentials credentials; DeviceCredentials credentials;
if (StringUtils.isEmpty(mqtt.getUserName())) { if (StringUtils.isEmpty(mqtt.getUserName())) {
credentials = checkMqttCredentials(mqtt, EncryptionUtil.getSha3Hash(mqtt.getClientId())); credentials = checkMqttCredentials(mqtt, EncryptionUtil.getSha3Hash(mqtt.getClientId()));
if (credentials != null) { if (credentials != null) {
return getDeviceInfo(credentials); return getDeviceInfo(credentials);
} else { } else {
return getEmptyTransportApiResponseFuture(); return getEmptyTransportApiResponse();
} }
} else { } else {
credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId( credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(
@ -249,7 +249,7 @@ public class DefaultTransportApiService implements TransportApiService {
if (VALID.equals(validationResult)) { if (VALID.equals(validationResult)) {
return getDeviceInfo(credentials); return getDeviceInfo(credentials);
} else if (PASSWORD_MISMATCH.equals(validationResult)) { } else if (PASSWORD_MISMATCH.equals(validationResult)) {
return getEmptyTransportApiResponseFuture(); return getEmptyTransportApiResponse();
} else { } else {
return validateUserNameCredentials(mqtt); return validateUserNameCredentials(mqtt);
} }
@ -259,7 +259,7 @@ public class DefaultTransportApiService implements TransportApiService {
} }
} }
protected ListenableFuture<TransportApiResponseMsg> validateOrCreateDeviceX509Certificate(String certificateChain) { protected TransportApiResponseMsg validateOrCreateDeviceX509Certificate(String certificateChain) {
List<String> chain = X509_CERTIFICATE_TRIM_CHAIN_PATTERN.matcher(certificateChain).results().map(match -> List<String> chain = X509_CERTIFICATE_TRIM_CHAIN_PATTERN.matcher(certificateChain).results().map(match ->
EncryptionUtil.certTrimNewLines(match.group())).collect(Collectors.toList()); EncryptionUtil.certTrimNewLines(match.group())).collect(Collectors.toList());
for (String certificateValue : chain) { for (String certificateValue : chain) {
@ -279,16 +279,16 @@ public class DefaultTransportApiService implements TransportApiService {
} }
} catch (ProvisionFailedException e) { } catch (ProvisionFailedException e) {
log.debug("[{}][{}] Failed to provision device with cert chain: {}", deviceProfile.getTenantId(), deviceProfile.getId(), provisionRequest, e); log.debug("[{}][{}] Failed to provision device with cert chain: {}", deviceProfile.getTenantId(), deviceProfile.getId(), provisionRequest, e);
return getEmptyTransportApiResponseFuture(); return getEmptyTransportApiResponse();
} }
} else if (deviceProfile != null) { } else if (deviceProfile != null) {
log.warn("[{}][{}] Device Profile provision configuration mismatched: expected {}, actual {}", deviceProfile.getTenantId(), deviceProfile.getId(), DeviceProfileProvisionType.X509_CERTIFICATE_CHAIN, deviceProfile.getProvisionType()); log.warn("[{}][{}] Device Profile provision configuration mismatched: expected {}, actual {}", deviceProfile.getTenantId(), deviceProfile.getId(), DeviceProfileProvisionType.X509_CERTIFICATE_CHAIN, deviceProfile.getProvisionType());
} }
} }
return getEmptyTransportApiResponseFuture(); return getEmptyTransportApiResponse();
} }
private ListenableFuture<TransportApiResponseMsg> validateUserNameCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) { private TransportApiResponseMsg validateUserNameCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) {
DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(mqtt.getUserName()); DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(mqtt.getUserName());
if (credentials != null) { if (credentials != null) {
switch (credentials.getCredentialsType()) { switch (credentials.getCredentialsType()) {
@ -298,11 +298,11 @@ public class DefaultTransportApiService implements TransportApiService {
if (VALID.equals(validateMqttCredentials(mqtt, credentials))) { if (VALID.equals(validateMqttCredentials(mqtt, credentials))) {
return getDeviceInfo(credentials); return getDeviceInfo(credentials);
} else { } else {
return getEmptyTransportApiResponseFuture(); return getEmptyTransportApiResponse();
} }
} }
} }
return getEmptyTransportApiResponseFuture(); return getEmptyTransportApiResponse();
} }
private DeviceCredentials checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, String credId) { private DeviceCredentials checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, String credId) {
@ -543,31 +543,30 @@ public class DefaultTransportApiService implements TransportApiService {
.build()); .build());
} }
private ListenableFuture<TransportApiResponseMsg> getDeviceInfo(DeviceCredentials credentials) { TransportApiResponseMsg getDeviceInfo(DeviceCredentials credentials) {
return Futures.transform(deviceService.findDeviceByIdAsync(TenantId.SYS_TENANT_ID, credentials.getDeviceId()), device -> { Device device = deviceService.findDeviceById(TenantId.SYS_TENANT_ID, credentials.getDeviceId());
if (device == null) { if (device == null) {
log.trace("[{}] Failed to lookup device by id", credentials.getDeviceId()); log.trace("[{}] Failed to lookup device by id", credentials.getDeviceId());
return getEmptyTransportApiResponse(); return getEmptyTransportApiResponse();
}
try {
ValidateDeviceCredentialsResponseMsg.Builder builder = ValidateDeviceCredentialsResponseMsg.newBuilder();
builder.setDeviceInfo(getDeviceInfoProto(device));
DeviceProfile deviceProfile = deviceProfileCache.get(device.getTenantId(), device.getDeviceProfileId());
if (deviceProfile != null) {
builder.setProfileBody(ByteString.copyFrom(dataDecodingEncodingService.encode(deviceProfile)));
} else {
log.warn("[{}] Failed to find device profile [{}] for device. ", device.getId(), device.getDeviceProfileId());
} }
try { if (!StringUtils.isEmpty(credentials.getCredentialsValue())) {
ValidateDeviceCredentialsResponseMsg.Builder builder = ValidateDeviceCredentialsResponseMsg.newBuilder(); builder.setCredentialsBody(credentials.getCredentialsValue());
builder.setDeviceInfo(getDeviceInfoProto(device));
DeviceProfile deviceProfile = deviceProfileCache.get(device.getTenantId(), device.getDeviceProfileId());
if (deviceProfile != null) {
builder.setProfileBody(ByteString.copyFrom(dataDecodingEncodingService.encode(deviceProfile)));
} else {
log.warn("[{}] Failed to find device profile [{}] for device. ", device.getId(), device.getDeviceProfileId());
}
if (!StringUtils.isEmpty(credentials.getCredentialsValue())) {
builder.setCredentialsBody(credentials.getCredentialsValue());
}
return TransportApiResponseMsg.newBuilder()
.setValidateCredResponseMsg(builder.build()).build();
} catch (JsonProcessingException e) {
log.warn("[{}] Failed to lookup device by id", credentials.getDeviceId(), e);
return getEmptyTransportApiResponse();
} }
}, MoreExecutors.directExecutor()); return TransportApiResponseMsg.newBuilder()
.setValidateCredResponseMsg(builder.build()).build();
} catch (JsonProcessingException e) {
log.warn("[{}] Failed to lookup device by id", credentials.getDeviceId(), e);
return getEmptyTransportApiResponse();
}
} }
private DeviceInfoProto getDeviceInfoProto(Device device) throws JsonProcessingException { private DeviceInfoProto getDeviceInfoProto(Device device) throws JsonProcessingException {

View File

@ -16,7 +16,6 @@
package org.thingsboard.server.service.transport; package org.thingsboard.server.service.transport;
import com.google.common.util.concurrent.Futures;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -46,6 +45,7 @@ import org.thingsboard.server.dao.ota.OtaPackageService;
import org.thingsboard.server.dao.queue.QueueService; import org.thingsboard.server.dao.queue.QueueService;
import org.thingsboard.server.dao.relation.RelationService; import org.thingsboard.server.dao.relation.RelationService;
import org.thingsboard.server.dao.tenant.TbTenantProfileCache; import org.thingsboard.server.dao.tenant.TbTenantProfileCache;
import org.thingsboard.server.gen.transport.TransportProtos;
import org.thingsboard.server.queue.util.DataDecodingEncodingService; import org.thingsboard.server.queue.util.DataDecodingEncodingService;
import org.thingsboard.server.service.apiusage.TbApiUsageStateService; import org.thingsboard.server.service.apiusage.TbApiUsageStateService;
import org.thingsboard.server.service.executors.DbCallbackExecutorService; import org.thingsboard.server.service.executors.DbCallbackExecutorService;
@ -62,6 +62,8 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.willReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -123,11 +125,13 @@ public class DefaultTransportApiServiceTest {
@Test @Test
public void validateExistingDeviceByX509CertificateStrategy() { public void validateExistingDeviceByX509CertificateStrategy() {
var device = createDevice(); var device = createDevice();
when(deviceService.findDeviceByIdAsync(any(), any())).thenReturn(Futures.immediateFuture(device));
var deviceCredentials = createDeviceCredentials(chain[0], device.getId()); var deviceCredentials = createDeviceCredentials(chain[0], device.getId());
when(deviceCredentialsService.findDeviceCredentialsByCredentialsId(any())).thenReturn(deviceCredentials); when(deviceCredentialsService.findDeviceCredentialsByCredentialsId(any())).thenReturn(deviceCredentials);
TransportProtos.TransportApiResponseMsg response = mock(TransportProtos.TransportApiResponseMsg.class);
willReturn(response).given(service).getDeviceInfo(deviceCredentials);
service.validateOrCreateDeviceX509Certificate(certificateChain); service.validateOrCreateDeviceX509Certificate(certificateChain);
verify(deviceCredentialsService, times(1)).findDeviceCredentialsByCredentialsId(any()); verify(deviceCredentialsService, times(1)).findDeviceCredentialsByCredentialsId(any());
} }
@ -139,7 +143,6 @@ public class DefaultTransportApiServiceTest {
var device = createDevice(); var device = createDevice();
when(deviceService.findDeviceByTenantIdAndName(any(), any())).thenReturn(device); when(deviceService.findDeviceByTenantIdAndName(any(), any())).thenReturn(device);
when(deviceService.findDeviceByIdAsync(any(), any())).thenReturn(Futures.immediateFuture(device));
var deviceCredentials = createDeviceCredentials(chain[0], device.getId()); var deviceCredentials = createDeviceCredentials(chain[0], device.getId());
when(deviceCredentialsService.findDeviceCredentialsByCredentialsId(any())).thenReturn(null); when(deviceCredentialsService.findDeviceCredentialsByCredentialsId(any())).thenReturn(null);
@ -148,9 +151,12 @@ public class DefaultTransportApiServiceTest {
var provisionResponse = createProvisionResponse(deviceCredentials); var provisionResponse = createProvisionResponse(deviceCredentials);
when(deviceProvisionService.provisionDeviceViaX509Chain(any(), any())).thenReturn(provisionResponse); when(deviceProvisionService.provisionDeviceViaX509Chain(any(), any())).thenReturn(provisionResponse);
TransportProtos.TransportApiResponseMsg response = mock(TransportProtos.TransportApiResponseMsg.class);
willReturn(response).given(service).getDeviceInfo(deviceCredentials);
service.validateOrCreateDeviceX509Certificate(certificateChain); service.validateOrCreateDeviceX509Certificate(certificateChain);
verify(deviceProfileService, times(1)).findDeviceProfileByProvisionDeviceKey(any()); verify(deviceProfileService, times(1)).findDeviceProfileByProvisionDeviceKey(any());
verify(deviceService, times(1)).findDeviceByIdAsync(any(), any()); verify(service, times(1)).getDeviceInfo(any());
verify(deviceCredentialsService, times(1)).findDeviceCredentialsByCredentialsId(any()); verify(deviceCredentialsService, times(1)).findDeviceCredentialsByCredentialsId(any());
verify(deviceProvisionService, times(1)).provisionDeviceViaX509Chain(any(), any()); verify(deviceProvisionService, times(1)).provisionDeviceViaX509Chain(any(), any());
} }