diff --git a/application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java b/application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java index 97863a0a4a..156949fdc5 100644 --- a/application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java +++ b/application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java @@ -15,10 +15,14 @@ */ package org.thingsboard.server.service.ai; +import com.google.api.gax.core.FixedCredentialsProvider; +import com.google.api.gax.retrying.RetrySettings; import com.google.auth.oauth2.ServiceAccountCredentials; import com.google.cloud.vertexai.Transport; import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.api.GenerationConfig; +import com.google.cloud.vertexai.api.PredictionServiceClient; +import com.google.cloud.vertexai.api.PredictionServiceSettings; import com.google.cloud.vertexai.generativeai.GenerativeModel; import dev.langchain4j.model.bedrock.BedrockChatModel; import dev.langchain4j.model.chat.ChatModel; @@ -96,17 +100,44 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur // construct service account credentials using service account key JSON ServiceAccountCredentials serviceAccountCredentials; try { - serviceAccountCredentials = ServiceAccountCredentials - .fromStream(new ByteArrayInputStream(JacksonUtil.writeValueAsBytes(providerConfig.serviceAccountKey()))); + serviceAccountCredentials = ServiceAccountCredentials.fromStream( + new ByteArrayInputStream(JacksonUtil.writeValueAsBytes(providerConfig.serviceAccountKey())) + ); } catch (IOException e) { throw new RuntimeException("Failed to parse service account key JSON", e); } + PredictionServiceSettings predictionServiceClientSettings; + try { + // create prediction service settings for REST transport with service account key credentials + PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newHttpJsonBuilder() + .setCredentialsProvider(FixedCredentialsProvider.create(serviceAccountCredentials)); + + // get the retry settings that control request timeout for generateContent RPC + RetrySettings.Builder retrySettings = settingsBuilder + .generateContentSettings() + .getRetrySettings() + .toBuilder(); + + // set request timeout from model config + if (modelConfig.timeoutSeconds() != null) { + retrySettings.setTotalTimeout(org.threeten.bp.Duration.ofSeconds(modelConfig.timeoutSeconds())); + } + + // set updated retry settings + settingsBuilder.generateContentSettings().setRetrySettings(retrySettings.build()); + + // build the client settings + predictionServiceClientSettings = settingsBuilder.build(); + } catch (IOException e) { + throw new RuntimeException("Failed to create prediction service client settings", e); + } + // construct Vertex AI instance var vertexAI = new VertexAI.Builder() .setProjectId(providerConfig.projectId()) .setLocation(providerConfig.location()) - .setCredentials(serviceAccountCredentials) + .setPredictionClientSupplier(() -> createPredictionServiceClient(predictionServiceClientSettings)) .setTransport(Transport.REST) // GRPC also possible, but likely does not work with service account keys .build(); @@ -121,12 +152,19 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur var generationConfig = generationConfigBuilder.build(); // construct generative model instance - var generativeModel = new GenerativeModel(modelConfig.modelId(), vertexAI) - .withGenerationConfig(generationConfig); + var generativeModel = new GenerativeModel(modelConfig.modelId(), vertexAI).withGenerationConfig(generationConfig); return new VertexAiGeminiChatModel(generativeModel, generationConfig, modelConfig.maxRetries()); } + private static PredictionServiceClient createPredictionServiceClient(PredictionServiceSettings settings) { + try { + return PredictionServiceClient.create(settings); + } catch (IOException e) { + throw new RuntimeException("Failed to create prediction service client", e); + } + } + @Override public ChatModel configureChatModel(MistralAiChatModel chatModel) { MistralAiChatModel.Config modelConfig = chatModel.modelConfig(); diff --git a/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/GoogleVertexAiGeminiChatModel.java b/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/GoogleVertexAiGeminiChatModel.java index 5659aed918..67a852e33c 100644 --- a/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/GoogleVertexAiGeminiChatModel.java +++ b/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/GoogleVertexAiGeminiChatModel.java @@ -31,7 +31,7 @@ public record GoogleVertexAiGeminiChatModel( String modelId, Double temperature, Double topP, - Integer timeoutSeconds, // TODO: not supported by Vertex AI + Integer timeoutSeconds, Integer maxRetries ) implements AiChatModelConfig {}