diff --git a/common/data/src/main/java/org/thingsboard/server/common/data/oauth2/OAuth2ClientsDomainParams.java b/common/data/src/main/java/org/thingsboard/server/common/data/oauth2/OAuth2ClientsDomainParams.java new file mode 100644 index 0000000000..d9dc07d8ea --- /dev/null +++ b/common/data/src/main/java/org/thingsboard/server/common/data/oauth2/OAuth2ClientsDomainParams.java @@ -0,0 +1,33 @@ +/** + * Copyright © 2016-2020 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.common.data.oauth2; + +import lombok.*; + +import java.util.List; + +@EqualsAndHashCode +@Data +@ToString +@Builder(toBuilder = true) +@NoArgsConstructor +@AllArgsConstructor +public class OAuth2ClientsDomainParams { + private String domainName; + private String adminSettingsId; + + private List clientRegistrations; +} \ No newline at end of file diff --git a/common/data/src/main/java/org/thingsboard/server/common/data/oauth2/OAuth2ClientsParams.java b/common/data/src/main/java/org/thingsboard/server/common/data/oauth2/OAuth2ClientsParams.java index 3c0da5a9a0..ef08a892a1 100644 --- a/common/data/src/main/java/org/thingsboard/server/common/data/oauth2/OAuth2ClientsParams.java +++ b/common/data/src/main/java/org/thingsboard/server/common/data/oauth2/OAuth2ClientsParams.java @@ -26,8 +26,5 @@ import java.util.List; @NoArgsConstructor @AllArgsConstructor public class OAuth2ClientsParams { - private String domainName; - private String adminSettingsId; - - private List clientRegistrations; + private List clientsDomainsParams; } \ No newline at end of file diff --git a/dao/src/main/java/org/thingsboard/server/dao/oauth2/OAuth2ServiceImpl.java b/dao/src/main/java/org/thingsboard/server/dao/oauth2/OAuth2ServiceImpl.java index 62436bab96..a3bd1c90b8 100644 --- a/dao/src/main/java/org/thingsboard/server/dao/oauth2/OAuth2ServiceImpl.java +++ b/dao/src/main/java/org/thingsboard/server/dao/oauth2/OAuth2ServiceImpl.java @@ -47,6 +47,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; import java.util.stream.Collectors; +import java.util.stream.Stream; @Slf4j @Service @@ -79,7 +80,6 @@ public class OAuth2ServiceImpl implements OAuth2Service { return environment.acceptsProfiles("install"); } - // TODO do I need to add a field that invalidates cache in case write to cache fails after successful saving in DB? @PostConstruct public void init() { if (isInstall()) return; @@ -94,7 +94,7 @@ public class OAuth2ServiceImpl implements OAuth2Service { return clientsParams.entrySet().stream() .map(entry -> { TenantId tenantId = entry.getKey(); - OAuth2ClientRegistration clientRegistration = entry.getValue().getClientRegistrations().stream() + OAuth2ClientRegistration clientRegistration = toClientRegistrationStream(entry.getValue()) .filter(registration -> registrationId.equals(registration.getRegistrationId())) .findFirst() .orElse(null); @@ -115,9 +115,9 @@ public class OAuth2ServiceImpl implements OAuth2Service { @Override public List getOAuth2Clients(String domainName) { - OAuth2ClientsParams oAuth2ClientsParams = getMergedOAuth2ClientsParams(domainName); - return oAuth2ClientsParams != null && oAuth2ClientsParams.getClientRegistrations() != null ? - oAuth2ClientsParams.getClientRegistrations().stream() + OAuth2ClientsDomainParams oAuth2ClientsDomainParams = getMergedOAuth2ClientsParams(domainName); + return oAuth2ClientsDomainParams != null && oAuth2ClientsDomainParams.getClientRegistrations() != null ? + oAuth2ClientsDomainParams.getClientRegistrations().stream() .map(this::toClientInfo) .collect(Collectors.toList()) : Collections.emptyList() @@ -150,15 +150,19 @@ public class OAuth2ServiceImpl implements OAuth2Service { @Override public OAuth2ClientsParams saveTenantOAuth2ClientsParams(TenantId tenantId, OAuth2ClientsParams oAuth2ClientsParams) { - // TODO what if tenant saves config for several different domain names, do we need to check it + if (oAuth2ClientsParams.getClientsDomainsParams().size() != 1) { + throw new DataValidationException("Tenant can configure OAuth2 only for one domain!"); + } validate(oAuth2ClientsParams); validateRegistrationIdUniqueness(oAuth2ClientsParams, tenantId); cacheWriteLock.lock(); try { validateRegistrationIdUniqueness(oAuth2ClientsParams, tenantId); - String adminSettingsId = processTenantAdminSettings(tenantId, oAuth2ClientsParams.getDomainName(), oAuth2ClientsParams.getAdminSettingsId()); - oAuth2ClientsParams.setAdminSettingsId(adminSettingsId); + + OAuth2ClientsDomainParams oAuth2ClientsDomainParams = oAuth2ClientsParams.getClientsDomainsParams().get(0); + String adminSettingsId = processTenantAdminSettings(tenantId, oAuth2ClientsDomainParams.getDomainName(), oAuth2ClientsDomainParams.getAdminSettingsId()); + oAuth2ClientsDomainParams.setAdminSettingsId(adminSettingsId); List attributes = createOAuth2ClientsParamsAttributes(oAuth2ClientsParams); try { @@ -247,11 +251,20 @@ public class OAuth2ServiceImpl implements OAuth2Service { } private void validateRegistrationIdUniqueness(OAuth2ClientsParams inputOAuth2ClientsParams, TenantId tenantId) { - inputOAuth2ClientsParams.getClientRegistrations().stream() + long distinctRegistrationIds = toClientRegistrationStream(inputOAuth2ClientsParams) + .map(OAuth2ClientRegistration::getRegistrationId) + .distinct() + .count(); + long actualRegistrationIds = toClientRegistrationStream(inputOAuth2ClientsParams).count(); + if (distinctRegistrationIds != actualRegistrationIds) { + throw new DataValidationException("All registration IDs should be unique!"); + } + + toClientRegistrationStream(inputOAuth2ClientsParams) .map(OAuth2ClientRegistration::getRegistrationId) .forEach(registrationId -> { clientsParams.forEach((paramsTenantId, oAuth2ClientsParams) -> { - boolean registrationExists = oAuth2ClientsParams.getClientRegistrations().stream() + boolean registrationExists = toClientRegistrationStream(oAuth2ClientsParams) .map(OAuth2ClientRegistration::getRegistrationId) .anyMatch(registrationId::equals); if (registrationExists && !tenantId.equals(paramsTenantId)) { @@ -263,8 +276,27 @@ public class OAuth2ServiceImpl implements OAuth2Service { } private void validate(OAuth2ClientsParams oAuth2ClientsParams) { - for (OAuth2ClientRegistration clientRegistration : oAuth2ClientsParams.getClientRegistrations()) { - validator.accept(clientRegistration); + validateDomainNames(oAuth2ClientsParams); + + toClientRegistrationStream(oAuth2ClientsParams) + .forEach(validator); + } + + private void validateDomainNames(OAuth2ClientsParams oAuth2ClientsParams) { + oAuth2ClientsParams.getClientsDomainsParams() + .forEach(oAuth2ClientsDomainParams -> { + if (StringUtils.isEmpty(oAuth2ClientsDomainParams.getDomainName())) { + throw new DataValidationException("Domain name should be specified!"); + } + }); + + long distinctDomainNames = oAuth2ClientsParams.getClientsDomainsParams().stream() + .map(OAuth2ClientsDomainParams::getDomainName) + .distinct() + .count(); + long actualDomainNames = oAuth2ClientsParams.getClientsDomainsParams().size(); + if (distinctDomainNames != actualDomainNames) { + throw new DataValidationException("All domain names should be unique!"); } } @@ -300,7 +332,9 @@ public class OAuth2ServiceImpl implements OAuth2Service { entry -> constructOAuth2ClientsParams(entry.getValue()) )) : new HashMap<>(); - tenantClientParams.put(TenantId.SYS_TENANT_ID, systemOAuth2ClientsParams); + if (systemOAuth2ClientsParams.getClientsDomainsParams() != null) { + tenantClientParams.put(TenantId.SYS_TENANT_ID, systemOAuth2ClientsParams); + } return tenantClientParams; }, MoreExecutors.directExecutor() @@ -314,12 +348,12 @@ public class OAuth2ServiceImpl implements OAuth2Service { @Override public void deleteTenantOAuth2ClientsParams(TenantId tenantId) { OAuth2ClientsParams params = getTenantOAuth2ClientsParams(tenantId); - if (!StringUtils.isEmpty(params.getDomainName())) { - String settingsKey = constructAdminSettingsDomainKey(params.getDomainName()); - adminSettingsService.deleteAdminSettingsByKey(tenantId, settingsKey); - attributesService.removeAll(tenantId, tenantId, DataConstants.SERVER_SCOPE, Collections.singletonList(OAUTH2_CLIENT_REGISTRATIONS_PARAMS)); - clientsParams.remove(tenantId); - } + if (params == null) return; + OAuth2ClientsDomainParams domainParams = params.getClientsDomainsParams().get(0); + String settingsKey = constructAdminSettingsDomainKey(domainParams.getDomainName()); + adminSettingsService.deleteAdminSettingsByKey(tenantId, settingsKey); + attributesService.removeAll(tenantId, tenantId, DataConstants.SERVER_SCOPE, Collections.singletonList(OAUTH2_CLIENT_REGISTRATIONS_PARAMS)); + clientsParams.remove(tenantId); } @Override @@ -351,9 +385,17 @@ public class OAuth2ServiceImpl implements OAuth2Service { }, MoreExecutors.directExecutor()); } - private OAuth2ClientsParams getMergedOAuth2ClientsParams(String domainName) { + private OAuth2ClientsDomainParams getMergedOAuth2ClientsParams(String domainName) { AdminSettings oauth2ClientsSettings = adminSettingsService.findAdminSettingsByKey(TenantId.SYS_TENANT_ID, constructAdminSettingsDomainKey(domainName)); - OAuth2ClientsParams result; + OAuth2ClientsDomainParams result; + + OAuth2ClientsParams systemOAuth2ClientsParams = getSystemOAuth2ClientsParams(TenantId.SYS_TENANT_ID); + OAuth2ClientsDomainParams systemOAuth2ClientsDomainParams = systemOAuth2ClientsParams != null ? + systemOAuth2ClientsParams.getClientsDomainsParams().stream() + .filter(oAuth2ClientsDomainParams -> domainName.equals(oAuth2ClientsDomainParams.getDomainName())) + .findFirst() + .orElse(null) + : null; if (oauth2ClientsSettings != null) { String strEntityType = oauth2ClientsSettings.getJsonValue().get("entityType").asText(); String strEntityId = oauth2ClientsSettings.getJsonValue().get("entityId").asText(); @@ -363,17 +405,16 @@ public class OAuth2ServiceImpl implements OAuth2Service { throw new IllegalStateException("Only tenant can configure OAuth2 for certain domain!"); } TenantId tenantId = (TenantId) entityId; - result = getTenantOAuth2ClientsParams(tenantId); - OAuth2ClientsParams systemOAuth2ClientsParams = getSystemOAuth2ClientsParams(TenantId.SYS_TENANT_ID); - if (systemOAuth2ClientsParams != null) { + result = getTenantOAuth2ClientsParams(tenantId).getClientsDomainsParams().get(0); + if (systemOAuth2ClientsDomainParams != null) { ArrayList tenantClientRegistrations = new ArrayList<>(result.getClientRegistrations()); - tenantClientRegistrations.addAll(systemOAuth2ClientsParams.getClientRegistrations()); + tenantClientRegistrations.addAll(systemOAuth2ClientsDomainParams.getClientRegistrations()); result = result.toBuilder() .clientRegistrations(tenantClientRegistrations) .build(); } } else { - result = getSystemOAuth2ClientsParams(TenantId.SYS_TENANT_ID); + result = systemOAuth2ClientsDomainParams; } return result; } @@ -423,6 +464,11 @@ public class OAuth2ServiceImpl implements OAuth2Service { return client; } + private Stream toClientRegistrationStream(OAuth2ClientsParams oAuth2ClientsParams) { + return oAuth2ClientsParams.getClientsDomainsParams().stream() + .flatMap(oAuth2ClientsDomainParams -> oAuth2ClientsDomainParams.getClientRegistrations().stream()); + } + private final Consumer validator = clientRegistration -> { if (StringUtils.isEmpty(clientRegistration.getRegistrationId())) { throw new DataValidationException("Registration ID should be specified!");