AI rule node: support timeout for Vertex AI
This commit is contained in:
parent
3f58ff01c3
commit
afb0259010
@ -15,10 +15,14 @@
|
|||||||
*/
|
*/
|
||||||
package org.thingsboard.server.service.ai;
|
package org.thingsboard.server.service.ai;
|
||||||
|
|
||||||
|
import com.google.api.gax.core.FixedCredentialsProvider;
|
||||||
|
import com.google.api.gax.retrying.RetrySettings;
|
||||||
import com.google.auth.oauth2.ServiceAccountCredentials;
|
import com.google.auth.oauth2.ServiceAccountCredentials;
|
||||||
import com.google.cloud.vertexai.Transport;
|
import com.google.cloud.vertexai.Transport;
|
||||||
import com.google.cloud.vertexai.VertexAI;
|
import com.google.cloud.vertexai.VertexAI;
|
||||||
import com.google.cloud.vertexai.api.GenerationConfig;
|
import com.google.cloud.vertexai.api.GenerationConfig;
|
||||||
|
import com.google.cloud.vertexai.api.PredictionServiceClient;
|
||||||
|
import com.google.cloud.vertexai.api.PredictionServiceSettings;
|
||||||
import com.google.cloud.vertexai.generativeai.GenerativeModel;
|
import com.google.cloud.vertexai.generativeai.GenerativeModel;
|
||||||
import dev.langchain4j.model.bedrock.BedrockChatModel;
|
import dev.langchain4j.model.bedrock.BedrockChatModel;
|
||||||
import dev.langchain4j.model.chat.ChatModel;
|
import dev.langchain4j.model.chat.ChatModel;
|
||||||
@ -96,17 +100,44 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
|
|||||||
// construct service account credentials using service account key JSON
|
// construct service account credentials using service account key JSON
|
||||||
ServiceAccountCredentials serviceAccountCredentials;
|
ServiceAccountCredentials serviceAccountCredentials;
|
||||||
try {
|
try {
|
||||||
serviceAccountCredentials = ServiceAccountCredentials
|
serviceAccountCredentials = ServiceAccountCredentials.fromStream(
|
||||||
.fromStream(new ByteArrayInputStream(JacksonUtil.writeValueAsBytes(providerConfig.serviceAccountKey())));
|
new ByteArrayInputStream(JacksonUtil.writeValueAsBytes(providerConfig.serviceAccountKey()))
|
||||||
|
);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException("Failed to parse service account key JSON", e);
|
throw new RuntimeException("Failed to parse service account key JSON", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PredictionServiceSettings predictionServiceClientSettings;
|
||||||
|
try {
|
||||||
|
// create prediction service settings for REST transport with service account key credentials
|
||||||
|
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newHttpJsonBuilder()
|
||||||
|
.setCredentialsProvider(FixedCredentialsProvider.create(serviceAccountCredentials));
|
||||||
|
|
||||||
|
// get the retry settings that control request timeout for generateContent RPC
|
||||||
|
RetrySettings.Builder retrySettings = settingsBuilder
|
||||||
|
.generateContentSettings()
|
||||||
|
.getRetrySettings()
|
||||||
|
.toBuilder();
|
||||||
|
|
||||||
|
// set request timeout from model config
|
||||||
|
if (modelConfig.timeoutSeconds() != null) {
|
||||||
|
retrySettings.setTotalTimeout(org.threeten.bp.Duration.ofSeconds(modelConfig.timeoutSeconds()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// set updated retry settings
|
||||||
|
settingsBuilder.generateContentSettings().setRetrySettings(retrySettings.build());
|
||||||
|
|
||||||
|
// build the client settings
|
||||||
|
predictionServiceClientSettings = settingsBuilder.build();
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException("Failed to create prediction service client settings", e);
|
||||||
|
}
|
||||||
|
|
||||||
// construct Vertex AI instance
|
// construct Vertex AI instance
|
||||||
var vertexAI = new VertexAI.Builder()
|
var vertexAI = new VertexAI.Builder()
|
||||||
.setProjectId(providerConfig.projectId())
|
.setProjectId(providerConfig.projectId())
|
||||||
.setLocation(providerConfig.location())
|
.setLocation(providerConfig.location())
|
||||||
.setCredentials(serviceAccountCredentials)
|
.setPredictionClientSupplier(() -> createPredictionServiceClient(predictionServiceClientSettings))
|
||||||
.setTransport(Transport.REST) // GRPC also possible, but likely does not work with service account keys
|
.setTransport(Transport.REST) // GRPC also possible, but likely does not work with service account keys
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
@ -121,12 +152,19 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
|
|||||||
var generationConfig = generationConfigBuilder.build();
|
var generationConfig = generationConfigBuilder.build();
|
||||||
|
|
||||||
// construct generative model instance
|
// construct generative model instance
|
||||||
var generativeModel = new GenerativeModel(modelConfig.modelId(), vertexAI)
|
var generativeModel = new GenerativeModel(modelConfig.modelId(), vertexAI).withGenerationConfig(generationConfig);
|
||||||
.withGenerationConfig(generationConfig);
|
|
||||||
|
|
||||||
return new VertexAiGeminiChatModel(generativeModel, generationConfig, modelConfig.maxRetries());
|
return new VertexAiGeminiChatModel(generativeModel, generationConfig, modelConfig.maxRetries());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static PredictionServiceClient createPredictionServiceClient(PredictionServiceSettings settings) {
|
||||||
|
try {
|
||||||
|
return PredictionServiceClient.create(settings);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException("Failed to create prediction service client", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatModel configureChatModel(MistralAiChatModel chatModel) {
|
public ChatModel configureChatModel(MistralAiChatModel chatModel) {
|
||||||
MistralAiChatModel.Config modelConfig = chatModel.modelConfig();
|
MistralAiChatModel.Config modelConfig = chatModel.modelConfig();
|
||||||
|
|||||||
@ -31,7 +31,7 @@ public record GoogleVertexAiGeminiChatModel(
|
|||||||
String modelId,
|
String modelId,
|
||||||
Double temperature,
|
Double temperature,
|
||||||
Double topP,
|
Double topP,
|
||||||
Integer timeoutSeconds, // TODO: not supported by Vertex AI
|
Integer timeoutSeconds,
|
||||||
Integer maxRetries
|
Integer maxRetries
|
||||||
) implements AiChatModelConfig<GoogleVertexAiGeminiChatModel.Config> {}
|
) implements AiChatModelConfig<GoogleVertexAiGeminiChatModel.Config> {}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user