AI rule node: support timeout for Vertex AI

This commit is contained in:
Dmytro Skarzhynets 2025-06-27 18:41:21 +03:00
parent 3f58ff01c3
commit afb0259010
No known key found for this signature in database
GPG Key ID: 2B51652F224037DF
2 changed files with 44 additions and 6 deletions

View File

@ -15,10 +15,14 @@
*/ */
package org.thingsboard.server.service.ai; package org.thingsboard.server.service.ai;
import com.google.api.gax.core.FixedCredentialsProvider;
import com.google.api.gax.retrying.RetrySettings;
import com.google.auth.oauth2.ServiceAccountCredentials; import com.google.auth.oauth2.ServiceAccountCredentials;
import com.google.cloud.vertexai.Transport; import com.google.cloud.vertexai.Transport;
import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.GenerationConfig; import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import com.google.cloud.vertexai.generativeai.GenerativeModel; import com.google.cloud.vertexai.generativeai.GenerativeModel;
import dev.langchain4j.model.bedrock.BedrockChatModel; import dev.langchain4j.model.bedrock.BedrockChatModel;
import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.ChatModel;
@ -96,17 +100,44 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
// construct service account credentials using service account key JSON // construct service account credentials using service account key JSON
ServiceAccountCredentials serviceAccountCredentials; ServiceAccountCredentials serviceAccountCredentials;
try { try {
serviceAccountCredentials = ServiceAccountCredentials serviceAccountCredentials = ServiceAccountCredentials.fromStream(
.fromStream(new ByteArrayInputStream(JacksonUtil.writeValueAsBytes(providerConfig.serviceAccountKey()))); new ByteArrayInputStream(JacksonUtil.writeValueAsBytes(providerConfig.serviceAccountKey()))
);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException("Failed to parse service account key JSON", e); throw new RuntimeException("Failed to parse service account key JSON", e);
} }
PredictionServiceSettings predictionServiceClientSettings;
try {
// create prediction service settings for REST transport with service account key credentials
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newHttpJsonBuilder()
.setCredentialsProvider(FixedCredentialsProvider.create(serviceAccountCredentials));
// get the retry settings that control request timeout for generateContent RPC
RetrySettings.Builder retrySettings = settingsBuilder
.generateContentSettings()
.getRetrySettings()
.toBuilder();
// set request timeout from model config
if (modelConfig.timeoutSeconds() != null) {
retrySettings.setTotalTimeout(org.threeten.bp.Duration.ofSeconds(modelConfig.timeoutSeconds()));
}
// set updated retry settings
settingsBuilder.generateContentSettings().setRetrySettings(retrySettings.build());
// build the client settings
predictionServiceClientSettings = settingsBuilder.build();
} catch (IOException e) {
throw new RuntimeException("Failed to create prediction service client settings", e);
}
// construct Vertex AI instance // construct Vertex AI instance
var vertexAI = new VertexAI.Builder() var vertexAI = new VertexAI.Builder()
.setProjectId(providerConfig.projectId()) .setProjectId(providerConfig.projectId())
.setLocation(providerConfig.location()) .setLocation(providerConfig.location())
.setCredentials(serviceAccountCredentials) .setPredictionClientSupplier(() -> createPredictionServiceClient(predictionServiceClientSettings))
.setTransport(Transport.REST) // GRPC also possible, but likely does not work with service account keys .setTransport(Transport.REST) // GRPC also possible, but likely does not work with service account keys
.build(); .build();
@ -121,12 +152,19 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
var generationConfig = generationConfigBuilder.build(); var generationConfig = generationConfigBuilder.build();
// construct generative model instance // construct generative model instance
var generativeModel = new GenerativeModel(modelConfig.modelId(), vertexAI) var generativeModel = new GenerativeModel(modelConfig.modelId(), vertexAI).withGenerationConfig(generationConfig);
.withGenerationConfig(generationConfig);
return new VertexAiGeminiChatModel(generativeModel, generationConfig, modelConfig.maxRetries()); return new VertexAiGeminiChatModel(generativeModel, generationConfig, modelConfig.maxRetries());
} }
private static PredictionServiceClient createPredictionServiceClient(PredictionServiceSettings settings) {
try {
return PredictionServiceClient.create(settings);
} catch (IOException e) {
throw new RuntimeException("Failed to create prediction service client", e);
}
}
@Override @Override
public ChatModel configureChatModel(MistralAiChatModel chatModel) { public ChatModel configureChatModel(MistralAiChatModel chatModel) {
MistralAiChatModel.Config modelConfig = chatModel.modelConfig(); MistralAiChatModel.Config modelConfig = chatModel.modelConfig();

View File

@ -31,7 +31,7 @@ public record GoogleVertexAiGeminiChatModel(
String modelId, String modelId,
Double temperature, Double temperature,
Double topP, Double topP,
Integer timeoutSeconds, // TODO: not supported by Vertex AI Integer timeoutSeconds,
Integer maxRetries Integer maxRetries
) implements AiChatModelConfig<GoogleVertexAiGeminiChatModel.Config> {} ) implements AiChatModelConfig<GoogleVertexAiGeminiChatModel.Config> {}