AI rule node: support timeout for Vertex AI

This commit is contained in:
Dmytro Skarzhynets 2025-06-27 18:41:21 +03:00
parent 3f58ff01c3
commit afb0259010
No known key found for this signature in database
GPG Key ID: 2B51652F224037DF
2 changed files with 44 additions and 6 deletions

View File

@ -15,10 +15,14 @@
*/
package org.thingsboard.server.service.ai;
import com.google.api.gax.core.FixedCredentialsProvider;
import com.google.api.gax.retrying.RetrySettings;
import com.google.auth.oauth2.ServiceAccountCredentials;
import com.google.cloud.vertexai.Transport;
import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import com.google.cloud.vertexai.generativeai.GenerativeModel;
import dev.langchain4j.model.bedrock.BedrockChatModel;
import dev.langchain4j.model.chat.ChatModel;
@ -96,17 +100,44 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
// construct service account credentials using service account key JSON
ServiceAccountCredentials serviceAccountCredentials;
try {
serviceAccountCredentials = ServiceAccountCredentials
.fromStream(new ByteArrayInputStream(JacksonUtil.writeValueAsBytes(providerConfig.serviceAccountKey())));
serviceAccountCredentials = ServiceAccountCredentials.fromStream(
new ByteArrayInputStream(JacksonUtil.writeValueAsBytes(providerConfig.serviceAccountKey()))
);
} catch (IOException e) {
throw new RuntimeException("Failed to parse service account key JSON", e);
}
PredictionServiceSettings predictionServiceClientSettings;
try {
// create prediction service settings for REST transport with service account key credentials
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newHttpJsonBuilder()
.setCredentialsProvider(FixedCredentialsProvider.create(serviceAccountCredentials));
// get the retry settings that control request timeout for generateContent RPC
RetrySettings.Builder retrySettings = settingsBuilder
.generateContentSettings()
.getRetrySettings()
.toBuilder();
// set request timeout from model config
if (modelConfig.timeoutSeconds() != null) {
retrySettings.setTotalTimeout(org.threeten.bp.Duration.ofSeconds(modelConfig.timeoutSeconds()));
}
// set updated retry settings
settingsBuilder.generateContentSettings().setRetrySettings(retrySettings.build());
// build the client settings
predictionServiceClientSettings = settingsBuilder.build();
} catch (IOException e) {
throw new RuntimeException("Failed to create prediction service client settings", e);
}
// construct Vertex AI instance
var vertexAI = new VertexAI.Builder()
.setProjectId(providerConfig.projectId())
.setLocation(providerConfig.location())
.setCredentials(serviceAccountCredentials)
.setPredictionClientSupplier(() -> createPredictionServiceClient(predictionServiceClientSettings))
.setTransport(Transport.REST) // GRPC also possible, but likely does not work with service account keys
.build();
@ -121,12 +152,19 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
var generationConfig = generationConfigBuilder.build();
// construct generative model instance
var generativeModel = new GenerativeModel(modelConfig.modelId(), vertexAI)
.withGenerationConfig(generationConfig);
var generativeModel = new GenerativeModel(modelConfig.modelId(), vertexAI).withGenerationConfig(generationConfig);
return new VertexAiGeminiChatModel(generativeModel, generationConfig, modelConfig.maxRetries());
}
private static PredictionServiceClient createPredictionServiceClient(PredictionServiceSettings settings) {
try {
return PredictionServiceClient.create(settings);
} catch (IOException e) {
throw new RuntimeException("Failed to create prediction service client", e);
}
}
@Override
public ChatModel configureChatModel(MistralAiChatModel chatModel) {
MistralAiChatModel.Config modelConfig = chatModel.modelConfig();

View File

@ -31,7 +31,7 @@ public record GoogleVertexAiGeminiChatModel(
String modelId,
Double temperature,
Double topP,
Integer timeoutSeconds, // TODO: not supported by Vertex AI
Integer timeoutSeconds,
Integer maxRetries
) implements AiChatModelConfig<GoogleVertexAiGeminiChatModel.Config> {}