AI rule node: support timeout for Vertex AI
This commit is contained in:
parent
3f58ff01c3
commit
afb0259010
@ -15,10 +15,14 @@
|
||||
*/
|
||||
package org.thingsboard.server.service.ai;
|
||||
|
||||
import com.google.api.gax.core.FixedCredentialsProvider;
|
||||
import com.google.api.gax.retrying.RetrySettings;
|
||||
import com.google.auth.oauth2.ServiceAccountCredentials;
|
||||
import com.google.cloud.vertexai.Transport;
|
||||
import com.google.cloud.vertexai.VertexAI;
|
||||
import com.google.cloud.vertexai.api.GenerationConfig;
|
||||
import com.google.cloud.vertexai.api.PredictionServiceClient;
|
||||
import com.google.cloud.vertexai.api.PredictionServiceSettings;
|
||||
import com.google.cloud.vertexai.generativeai.GenerativeModel;
|
||||
import dev.langchain4j.model.bedrock.BedrockChatModel;
|
||||
import dev.langchain4j.model.chat.ChatModel;
|
||||
@ -96,17 +100,44 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
|
||||
// construct service account credentials using service account key JSON
|
||||
ServiceAccountCredentials serviceAccountCredentials;
|
||||
try {
|
||||
serviceAccountCredentials = ServiceAccountCredentials
|
||||
.fromStream(new ByteArrayInputStream(JacksonUtil.writeValueAsBytes(providerConfig.serviceAccountKey())));
|
||||
serviceAccountCredentials = ServiceAccountCredentials.fromStream(
|
||||
new ByteArrayInputStream(JacksonUtil.writeValueAsBytes(providerConfig.serviceAccountKey()))
|
||||
);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to parse service account key JSON", e);
|
||||
}
|
||||
|
||||
PredictionServiceSettings predictionServiceClientSettings;
|
||||
try {
|
||||
// create prediction service settings for REST transport with service account key credentials
|
||||
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newHttpJsonBuilder()
|
||||
.setCredentialsProvider(FixedCredentialsProvider.create(serviceAccountCredentials));
|
||||
|
||||
// get the retry settings that control request timeout for generateContent RPC
|
||||
RetrySettings.Builder retrySettings = settingsBuilder
|
||||
.generateContentSettings()
|
||||
.getRetrySettings()
|
||||
.toBuilder();
|
||||
|
||||
// set request timeout from model config
|
||||
if (modelConfig.timeoutSeconds() != null) {
|
||||
retrySettings.setTotalTimeout(org.threeten.bp.Duration.ofSeconds(modelConfig.timeoutSeconds()));
|
||||
}
|
||||
|
||||
// set updated retry settings
|
||||
settingsBuilder.generateContentSettings().setRetrySettings(retrySettings.build());
|
||||
|
||||
// build the client settings
|
||||
predictionServiceClientSettings = settingsBuilder.build();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to create prediction service client settings", e);
|
||||
}
|
||||
|
||||
// construct Vertex AI instance
|
||||
var vertexAI = new VertexAI.Builder()
|
||||
.setProjectId(providerConfig.projectId())
|
||||
.setLocation(providerConfig.location())
|
||||
.setCredentials(serviceAccountCredentials)
|
||||
.setPredictionClientSupplier(() -> createPredictionServiceClient(predictionServiceClientSettings))
|
||||
.setTransport(Transport.REST) // GRPC also possible, but likely does not work with service account keys
|
||||
.build();
|
||||
|
||||
@ -121,12 +152,19 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
|
||||
var generationConfig = generationConfigBuilder.build();
|
||||
|
||||
// construct generative model instance
|
||||
var generativeModel = new GenerativeModel(modelConfig.modelId(), vertexAI)
|
||||
.withGenerationConfig(generationConfig);
|
||||
var generativeModel = new GenerativeModel(modelConfig.modelId(), vertexAI).withGenerationConfig(generationConfig);
|
||||
|
||||
return new VertexAiGeminiChatModel(generativeModel, generationConfig, modelConfig.maxRetries());
|
||||
}
|
||||
|
||||
private static PredictionServiceClient createPredictionServiceClient(PredictionServiceSettings settings) {
|
||||
try {
|
||||
return PredictionServiceClient.create(settings);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to create prediction service client", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatModel configureChatModel(MistralAiChatModel chatModel) {
|
||||
MistralAiChatModel.Config modelConfig = chatModel.modelConfig();
|
||||
|
||||
@ -31,7 +31,7 @@ public record GoogleVertexAiGeminiChatModel(
|
||||
String modelId,
|
||||
Double temperature,
|
||||
Double topP,
|
||||
Integer timeoutSeconds, // TODO: not supported by Vertex AI
|
||||
Integer timeoutSeconds,
|
||||
Integer maxRetries
|
||||
) implements AiChatModelConfig<GoogleVertexAiGeminiChatModel.Config> {}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user