DefaultTransportApiService validations refactored to a sync style to avoid excessive context switching
This commit is contained in:
		
							parent
							
								
									f1fa6e6a01
								
							
						
					
					
						commit
						b30fe6ce26
					
				@ -179,18 +179,18 @@ public class DefaultTransportApiService implements TransportApiService {
 | 
			
		||||
        if (transportApiRequestMsg.hasValidateTokenRequestMsg()) {
 | 
			
		||||
            ValidateDeviceTokenRequestMsg msg = transportApiRequestMsg.getValidateTokenRequestMsg();
 | 
			
		||||
            final String token = msg.getToken();
 | 
			
		||||
            result = Futures.transformAsync(handlerExecutor.submit(() -> validateCredentials(token, DeviceCredentialsType.ACCESS_TOKEN)), future -> future, MoreExecutors.directExecutor());
 | 
			
		||||
            result = handlerExecutor.submit(() -> validateCredentials(token, DeviceCredentialsType.ACCESS_TOKEN));
 | 
			
		||||
        } else if (transportApiRequestMsg.hasValidateBasicMqttCredRequestMsg()) {
 | 
			
		||||
            TransportProtos.ValidateBasicMqttCredRequestMsg msg = transportApiRequestMsg.getValidateBasicMqttCredRequestMsg();
 | 
			
		||||
            result = Futures.transformAsync(handlerExecutor.submit(() -> validateCredentials(msg)), future -> future, MoreExecutors.directExecutor());
 | 
			
		||||
            result = handlerExecutor.submit(() -> validateCredentials(msg));
 | 
			
		||||
        } else if (transportApiRequestMsg.hasValidateX509CertRequestMsg()) {
 | 
			
		||||
            ValidateDeviceX509CertRequestMsg msg = transportApiRequestMsg.getValidateX509CertRequestMsg();
 | 
			
		||||
            final String hash = msg.getHash();
 | 
			
		||||
            result = Futures.transformAsync(handlerExecutor.submit(() -> validateCredentials(hash, DeviceCredentialsType.X509_CERTIFICATE)), future -> future, MoreExecutors.directExecutor());
 | 
			
		||||
            result = handlerExecutor.submit(() -> validateCredentials(hash, DeviceCredentialsType.X509_CERTIFICATE));
 | 
			
		||||
        } else if (transportApiRequestMsg.hasValidateOrCreateX509CertRequestMsg()) {
 | 
			
		||||
            TransportProtos.ValidateOrCreateDeviceX509CertRequestMsg msg = transportApiRequestMsg.getValidateOrCreateX509CertRequestMsg();
 | 
			
		||||
            final String certChain = msg.getCertificateChain();
 | 
			
		||||
            result = Futures.transformAsync(handlerExecutor.submit(() -> validateOrCreateDeviceX509Certificate(certChain)), future -> future, MoreExecutors.directExecutor());
 | 
			
		||||
            result = handlerExecutor.submit(() -> validateOrCreateDeviceX509Certificate(certChain));
 | 
			
		||||
        } else if (transportApiRequestMsg.hasGetOrCreateDeviceRequestMsg()) {
 | 
			
		||||
            result = handle(transportApiRequestMsg.getGetOrCreateDeviceRequestMsg());
 | 
			
		||||
        } else if (transportApiRequestMsg.hasEntityProfileRequestMsg()) {
 | 
			
		||||
@ -200,7 +200,7 @@ public class DefaultTransportApiService implements TransportApiService {
 | 
			
		||||
        } else if (transportApiRequestMsg.hasValidateDeviceLwM2MCredentialsRequestMsg()) {
 | 
			
		||||
            ValidateDeviceLwM2MCredentialsRequestMsg msg = transportApiRequestMsg.getValidateDeviceLwM2MCredentialsRequestMsg();
 | 
			
		||||
            final String credentialsId = msg.getCredentialsId();
 | 
			
		||||
            result = Futures.transformAsync(handlerExecutor.submit(() -> validateCredentials(credentialsId, DeviceCredentialsType.LWM2M_CREDENTIALS)), future -> future, MoreExecutors.directExecutor());
 | 
			
		||||
            result = handlerExecutor.submit(() -> validateCredentials(credentialsId, DeviceCredentialsType.LWM2M_CREDENTIALS));
 | 
			
		||||
        } else if (transportApiRequestMsg.hasProvisionDeviceRequestMsg()) {
 | 
			
		||||
            result = handle(transportApiRequestMsg.getProvisionDeviceRequestMsg());
 | 
			
		||||
        } else if (transportApiRequestMsg.hasResourceRequestMsg()) {
 | 
			
		||||
@ -222,24 +222,24 @@ public class DefaultTransportApiService implements TransportApiService {
 | 
			
		||||
                MoreExecutors.directExecutor());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private ListenableFuture<TransportApiResponseMsg> validateCredentials(String credentialsId, DeviceCredentialsType credentialsType) {
 | 
			
		||||
    private TransportApiResponseMsg validateCredentials(String credentialsId, DeviceCredentialsType credentialsType) {
 | 
			
		||||
        //TODO: Make async and enable caching
 | 
			
		||||
        DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(credentialsId);
 | 
			
		||||
        if (credentials != null && credentials.getCredentialsType() == credentialsType) {
 | 
			
		||||
            return getDeviceInfo(credentials);
 | 
			
		||||
        } else {
 | 
			
		||||
            return getEmptyTransportApiResponseFuture();
 | 
			
		||||
            return getEmptyTransportApiResponse();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private ListenableFuture<TransportApiResponseMsg> validateCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) {
 | 
			
		||||
    private TransportApiResponseMsg validateCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) {
 | 
			
		||||
        DeviceCredentials credentials;
 | 
			
		||||
        if (StringUtils.isEmpty(mqtt.getUserName())) {
 | 
			
		||||
            credentials = checkMqttCredentials(mqtt, EncryptionUtil.getSha3Hash(mqtt.getClientId()));
 | 
			
		||||
            if (credentials != null) {
 | 
			
		||||
                return getDeviceInfo(credentials);
 | 
			
		||||
            } else {
 | 
			
		||||
                return getEmptyTransportApiResponseFuture();
 | 
			
		||||
                return getEmptyTransportApiResponse();
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(
 | 
			
		||||
@ -249,7 +249,7 @@ public class DefaultTransportApiService implements TransportApiService {
 | 
			
		||||
                if (VALID.equals(validationResult)) {
 | 
			
		||||
                    return getDeviceInfo(credentials);
 | 
			
		||||
                } else if (PASSWORD_MISMATCH.equals(validationResult)) {
 | 
			
		||||
                    return getEmptyTransportApiResponseFuture();
 | 
			
		||||
                    return getEmptyTransportApiResponse();
 | 
			
		||||
                } else {
 | 
			
		||||
                    return validateUserNameCredentials(mqtt);
 | 
			
		||||
                }
 | 
			
		||||
@ -259,7 +259,7 @@ public class DefaultTransportApiService implements TransportApiService {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    protected ListenableFuture<TransportApiResponseMsg> validateOrCreateDeviceX509Certificate(String certificateChain) {
 | 
			
		||||
    protected TransportApiResponseMsg validateOrCreateDeviceX509Certificate(String certificateChain) {
 | 
			
		||||
        List<String> chain = X509_CERTIFICATE_TRIM_CHAIN_PATTERN.matcher(certificateChain).results().map(match ->
 | 
			
		||||
                EncryptionUtil.certTrimNewLines(match.group())).collect(Collectors.toList());
 | 
			
		||||
        for (String certificateValue : chain) {
 | 
			
		||||
@ -279,16 +279,16 @@ public class DefaultTransportApiService implements TransportApiService {
 | 
			
		||||
                    }
 | 
			
		||||
                } catch (ProvisionFailedException e) {
 | 
			
		||||
                    log.debug("[{}][{}] Failed to provision device with cert chain: {}", deviceProfile.getTenantId(), deviceProfile.getId(), provisionRequest, e);
 | 
			
		||||
                    return getEmptyTransportApiResponseFuture();
 | 
			
		||||
                    return getEmptyTransportApiResponse();
 | 
			
		||||
                }
 | 
			
		||||
            } else if (deviceProfile != null) {
 | 
			
		||||
                log.warn("[{}][{}] Device Profile provision configuration mismatched: expected {}, actual {}", deviceProfile.getTenantId(), deviceProfile.getId(), DeviceProfileProvisionType.X509_CERTIFICATE_CHAIN, deviceProfile.getProvisionType());
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return getEmptyTransportApiResponseFuture();
 | 
			
		||||
        return getEmptyTransportApiResponse();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private ListenableFuture<TransportApiResponseMsg> validateUserNameCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) {
 | 
			
		||||
    private TransportApiResponseMsg validateUserNameCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) {
 | 
			
		||||
        DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(mqtt.getUserName());
 | 
			
		||||
        if (credentials != null) {
 | 
			
		||||
            switch (credentials.getCredentialsType()) {
 | 
			
		||||
@ -298,11 +298,11 @@ public class DefaultTransportApiService implements TransportApiService {
 | 
			
		||||
                    if (VALID.equals(validateMqttCredentials(mqtt, credentials))) {
 | 
			
		||||
                        return getDeviceInfo(credentials);
 | 
			
		||||
                    } else {
 | 
			
		||||
                        return getEmptyTransportApiResponseFuture();
 | 
			
		||||
                        return getEmptyTransportApiResponse();
 | 
			
		||||
                    }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return getEmptyTransportApiResponseFuture();
 | 
			
		||||
        return getEmptyTransportApiResponse();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private DeviceCredentials checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, String credId) {
 | 
			
		||||
@ -543,31 +543,30 @@ public class DefaultTransportApiService implements TransportApiService {
 | 
			
		||||
                .build());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private ListenableFuture<TransportApiResponseMsg> getDeviceInfo(DeviceCredentials credentials) {
 | 
			
		||||
        return Futures.transform(deviceService.findDeviceByIdAsync(TenantId.SYS_TENANT_ID, credentials.getDeviceId()), device -> {
 | 
			
		||||
            if (device == null) {
 | 
			
		||||
                log.trace("[{}] Failed to lookup device by id", credentials.getDeviceId());
 | 
			
		||||
                return getEmptyTransportApiResponse();
 | 
			
		||||
    TransportApiResponseMsg getDeviceInfo(DeviceCredentials credentials) {
 | 
			
		||||
        Device device = deviceService.findDeviceById(TenantId.SYS_TENANT_ID, credentials.getDeviceId());
 | 
			
		||||
        if (device == null) {
 | 
			
		||||
            log.trace("[{}] Failed to lookup device by id", credentials.getDeviceId());
 | 
			
		||||
            return getEmptyTransportApiResponse();
 | 
			
		||||
        }
 | 
			
		||||
        try {
 | 
			
		||||
            ValidateDeviceCredentialsResponseMsg.Builder builder = ValidateDeviceCredentialsResponseMsg.newBuilder();
 | 
			
		||||
            builder.setDeviceInfo(getDeviceInfoProto(device));
 | 
			
		||||
            DeviceProfile deviceProfile = deviceProfileCache.get(device.getTenantId(), device.getDeviceProfileId());
 | 
			
		||||
            if (deviceProfile != null) {
 | 
			
		||||
                builder.setProfileBody(ByteString.copyFrom(dataDecodingEncodingService.encode(deviceProfile)));
 | 
			
		||||
            } else {
 | 
			
		||||
                log.warn("[{}] Failed to find device profile [{}] for device. ", device.getId(), device.getDeviceProfileId());
 | 
			
		||||
            }
 | 
			
		||||
            try {
 | 
			
		||||
                ValidateDeviceCredentialsResponseMsg.Builder builder = ValidateDeviceCredentialsResponseMsg.newBuilder();
 | 
			
		||||
                builder.setDeviceInfo(getDeviceInfoProto(device));
 | 
			
		||||
                DeviceProfile deviceProfile = deviceProfileCache.get(device.getTenantId(), device.getDeviceProfileId());
 | 
			
		||||
                if (deviceProfile != null) {
 | 
			
		||||
                    builder.setProfileBody(ByteString.copyFrom(dataDecodingEncodingService.encode(deviceProfile)));
 | 
			
		||||
                } else {
 | 
			
		||||
                    log.warn("[{}] Failed to find device profile [{}] for device. ", device.getId(), device.getDeviceProfileId());
 | 
			
		||||
                }
 | 
			
		||||
                if (!StringUtils.isEmpty(credentials.getCredentialsValue())) {
 | 
			
		||||
                    builder.setCredentialsBody(credentials.getCredentialsValue());
 | 
			
		||||
                }
 | 
			
		||||
                return TransportApiResponseMsg.newBuilder()
 | 
			
		||||
                        .setValidateCredResponseMsg(builder.build()).build();
 | 
			
		||||
            } catch (JsonProcessingException e) {
 | 
			
		||||
                log.warn("[{}] Failed to lookup device by id", credentials.getDeviceId(), e);
 | 
			
		||||
                return getEmptyTransportApiResponse();
 | 
			
		||||
            if (!StringUtils.isEmpty(credentials.getCredentialsValue())) {
 | 
			
		||||
                builder.setCredentialsBody(credentials.getCredentialsValue());
 | 
			
		||||
            }
 | 
			
		||||
        }, MoreExecutors.directExecutor());
 | 
			
		||||
            return TransportApiResponseMsg.newBuilder()
 | 
			
		||||
                    .setValidateCredResponseMsg(builder.build()).build();
 | 
			
		||||
        } catch (JsonProcessingException e) {
 | 
			
		||||
            log.warn("[{}] Failed to lookup device by id", credentials.getDeviceId(), e);
 | 
			
		||||
            return getEmptyTransportApiResponse();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private DeviceInfoProto getDeviceInfoProto(Device device) throws JsonProcessingException {
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,6 @@
 | 
			
		||||
package org.thingsboard.server.service.transport;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import com.google.common.util.concurrent.Futures;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import org.junit.Before;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
@ -46,6 +45,7 @@ import org.thingsboard.server.dao.ota.OtaPackageService;
 | 
			
		||||
import org.thingsboard.server.dao.queue.QueueService;
 | 
			
		||||
import org.thingsboard.server.dao.relation.RelationService;
 | 
			
		||||
import org.thingsboard.server.dao.tenant.TbTenantProfileCache;
 | 
			
		||||
import org.thingsboard.server.gen.transport.TransportProtos;
 | 
			
		||||
import org.thingsboard.server.queue.util.DataDecodingEncodingService;
 | 
			
		||||
import org.thingsboard.server.service.apiusage.TbApiUsageStateService;
 | 
			
		||||
import org.thingsboard.server.service.executors.DbCallbackExecutorService;
 | 
			
		||||
@ -62,6 +62,8 @@ import java.util.regex.Matcher;
 | 
			
		||||
import java.util.regex.Pattern;
 | 
			
		||||
 | 
			
		||||
import static org.mockito.ArgumentMatchers.any;
 | 
			
		||||
import static org.mockito.BDDMockito.willReturn;
 | 
			
		||||
import static org.mockito.Mockito.mock;
 | 
			
		||||
import static org.mockito.Mockito.times;
 | 
			
		||||
import static org.mockito.Mockito.verify;
 | 
			
		||||
import static org.mockito.Mockito.when;
 | 
			
		||||
@ -123,11 +125,13 @@ public class DefaultTransportApiServiceTest {
 | 
			
		||||
    @Test
 | 
			
		||||
    public void validateExistingDeviceByX509CertificateStrategy() {
 | 
			
		||||
        var device = createDevice();
 | 
			
		||||
        when(deviceService.findDeviceByIdAsync(any(), any())).thenReturn(Futures.immediateFuture(device));
 | 
			
		||||
 | 
			
		||||
        var deviceCredentials = createDeviceCredentials(chain[0], device.getId());
 | 
			
		||||
        when(deviceCredentialsService.findDeviceCredentialsByCredentialsId(any())).thenReturn(deviceCredentials);
 | 
			
		||||
 | 
			
		||||
        TransportProtos.TransportApiResponseMsg response = mock(TransportProtos.TransportApiResponseMsg.class);
 | 
			
		||||
        willReturn(response).given(service).getDeviceInfo(deviceCredentials);
 | 
			
		||||
 | 
			
		||||
        service.validateOrCreateDeviceX509Certificate(certificateChain);
 | 
			
		||||
        verify(deviceCredentialsService, times(1)).findDeviceCredentialsByCredentialsId(any());
 | 
			
		||||
    }
 | 
			
		||||
@ -139,7 +143,6 @@ public class DefaultTransportApiServiceTest {
 | 
			
		||||
 | 
			
		||||
        var device = createDevice();
 | 
			
		||||
        when(deviceService.findDeviceByTenantIdAndName(any(), any())).thenReturn(device);
 | 
			
		||||
        when(deviceService.findDeviceByIdAsync(any(), any())).thenReturn(Futures.immediateFuture(device));
 | 
			
		||||
 | 
			
		||||
        var deviceCredentials = createDeviceCredentials(chain[0], device.getId());
 | 
			
		||||
        when(deviceCredentialsService.findDeviceCredentialsByCredentialsId(any())).thenReturn(null);
 | 
			
		||||
@ -148,9 +151,12 @@ public class DefaultTransportApiServiceTest {
 | 
			
		||||
        var provisionResponse = createProvisionResponse(deviceCredentials);
 | 
			
		||||
        when(deviceProvisionService.provisionDeviceViaX509Chain(any(), any())).thenReturn(provisionResponse);
 | 
			
		||||
 | 
			
		||||
        TransportProtos.TransportApiResponseMsg response = mock(TransportProtos.TransportApiResponseMsg.class);
 | 
			
		||||
        willReturn(response).given(service).getDeviceInfo(deviceCredentials);
 | 
			
		||||
 | 
			
		||||
        service.validateOrCreateDeviceX509Certificate(certificateChain);
 | 
			
		||||
        verify(deviceProfileService, times(1)).findDeviceProfileByProvisionDeviceKey(any());
 | 
			
		||||
        verify(deviceService, times(1)).findDeviceByIdAsync(any(), any());
 | 
			
		||||
        verify(service, times(1)).getDeviceInfo(any());
 | 
			
		||||
        verify(deviceCredentialsService, times(1)).findDeviceCredentialsByCredentialsId(any());
 | 
			
		||||
        verify(deviceProvisionService, times(1)).provisionDeviceViaX509Chain(any(), any());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user