AI rule node: add timeout support
This commit is contained in:
parent
d5c6ed1f61
commit
cb106760c1
@ -23,13 +23,16 @@ import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.thingsboard.rule.engine.api.RuleEngineAiService;
|
||||
import org.thingsboard.server.common.data.ai.AiSettings;
|
||||
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
|
||||
import org.thingsboard.server.common.data.ai.model.GoogleAiGeminiChatModelConfig;
|
||||
import org.thingsboard.server.common.data.ai.model.MistralAiChatModelConfig;
|
||||
import org.thingsboard.server.common.data.ai.model.OpenAiChatModelConfig;
|
||||
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
|
||||
import org.thingsboard.server.common.data.id.AiSettingsId;
|
||||
import org.thingsboard.server.common.data.id.TenantId;
|
||||
import org.thingsboard.server.dao.ai.AiSettingsService;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.Optional;
|
||||
|
||||
@ -46,37 +49,53 @@ class AiServiceImpl implements RuleEngineAiService {
|
||||
throw new NoSuchElementException("AI settings with ID: " + aiSettingsId + " were not found");
|
||||
}
|
||||
var aiSettings = aiSettingsOpt.get();
|
||||
return configureChatModel(aiSettings.getProviderConfig(), aiSettings.getModelConfig());
|
||||
}
|
||||
|
||||
return switch (aiSettings.getProvider()) {
|
||||
@Override
|
||||
public ChatModel configureChatModel(AiProviderConfig providerConfig, AiModelConfig modelConfig) {
|
||||
return switch (providerConfig.getProvider()) {
|
||||
case OPENAI -> {
|
||||
var modelBuilder = OpenAiChatModel.builder()
|
||||
.apiKey(aiSettings.getProviderConfig().getApiKey())
|
||||
.modelName(aiSettings.getModel());
|
||||
.apiKey(providerConfig.getApiKey())
|
||||
.modelName(modelConfig.getModel());
|
||||
|
||||
if (aiSettings.getModelConfig() instanceof OpenAiChatModelConfig config) {
|
||||
if (modelConfig instanceof OpenAiChatModelConfig config) {
|
||||
modelBuilder.temperature(config.getTemperature());
|
||||
if (config.getTimeoutSeconds() != null) {
|
||||
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
|
||||
}
|
||||
modelBuilder.maxRetries(config.getMaxRetries());
|
||||
}
|
||||
|
||||
yield modelBuilder.build();
|
||||
}
|
||||
case MISTRAL_AI -> {
|
||||
var modelBuilder = MistralAiChatModel.builder()
|
||||
.apiKey(aiSettings.getProviderConfig().getApiKey())
|
||||
.modelName(aiSettings.getModel());
|
||||
.apiKey(providerConfig.getApiKey())
|
||||
.modelName(modelConfig.getModel());
|
||||
|
||||
if (aiSettings.getModelConfig() instanceof MistralAiChatModelConfig config) {
|
||||
if (modelConfig instanceof MistralAiChatModelConfig config) {
|
||||
modelBuilder.temperature(config.getTemperature());
|
||||
if (config.getTimeoutSeconds() != null) {
|
||||
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
|
||||
}
|
||||
modelBuilder.maxRetries(config.getMaxRetries());
|
||||
}
|
||||
|
||||
yield modelBuilder.build();
|
||||
}
|
||||
case GOOGLE_AI_GEMINI -> {
|
||||
var modelBuilder = GoogleAiGeminiChatModel.builder()
|
||||
.apiKey(aiSettings.getProviderConfig().getApiKey())
|
||||
.modelName(aiSettings.getModel());
|
||||
.apiKey(providerConfig.getApiKey())
|
||||
.modelName(modelConfig.getModel());
|
||||
|
||||
if (aiSettings.getModelConfig() instanceof GoogleAiGeminiChatModelConfig config) {
|
||||
if (modelConfig instanceof GoogleAiGeminiChatModelConfig config) {
|
||||
modelBuilder.temperature(config.getTemperature());
|
||||
if (config.getTimeoutSeconds() != null) {
|
||||
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
|
||||
}
|
||||
modelBuilder.maxRetries(config.getMaxRetries());
|
||||
}
|
||||
|
||||
yield modelBuilder.build();
|
||||
|
||||
@ -44,4 +44,13 @@ public abstract class AiModelConfig {
|
||||
)
|
||||
private String model;
|
||||
|
||||
public abstract Integer getTimeoutSeconds();
|
||||
|
||||
public abstract void setTimeoutSeconds(Integer timeoutSeconds);
|
||||
|
||||
public abstract Integer getMaxRetries();
|
||||
|
||||
public abstract void setMaxRetries(Integer timeoutSeconds);
|
||||
|
||||
|
||||
}
|
||||
|
||||
@ -47,4 +47,16 @@ public final class GoogleAiGeminiChatModelConfig extends AiModelConfig {
|
||||
)
|
||||
private Double temperature;
|
||||
|
||||
@Schema(
|
||||
accessMode = Schema.AccessMode.READ_WRITE,
|
||||
description = "Timeout (in seconds) for establishing HTTP connection"
|
||||
)
|
||||
private Integer timeoutSeconds;
|
||||
|
||||
@Schema(
|
||||
accessMode = Schema.AccessMode.READ_WRITE,
|
||||
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
|
||||
)
|
||||
private Integer maxRetries;
|
||||
|
||||
}
|
||||
|
||||
@ -47,4 +47,16 @@ public final class MistralAiChatModelConfig extends AiModelConfig {
|
||||
)
|
||||
private Double temperature;
|
||||
|
||||
@Schema(
|
||||
accessMode = Schema.AccessMode.READ_WRITE,
|
||||
description = "Timeout (in seconds) for the entire HTTP call: applied to connect, read, and write operations"
|
||||
)
|
||||
private Integer timeoutSeconds;
|
||||
|
||||
@Schema(
|
||||
accessMode = Schema.AccessMode.READ_WRITE,
|
||||
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
|
||||
)
|
||||
private Integer maxRetries;
|
||||
|
||||
}
|
||||
|
||||
@ -47,4 +47,16 @@ public final class OpenAiChatModelConfig extends AiModelConfig {
|
||||
)
|
||||
private Double temperature;
|
||||
|
||||
@Schema(
|
||||
accessMode = Schema.AccessMode.READ_WRITE,
|
||||
description = "Timeout (in seconds) for both establishing HTTP connection and receiving a response"
|
||||
)
|
||||
private Integer timeoutSeconds;
|
||||
|
||||
@Schema(
|
||||
accessMode = Schema.AccessMode.READ_WRITE,
|
||||
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
|
||||
)
|
||||
private Integer maxRetries;
|
||||
|
||||
}
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
package org.thingsboard.rule.engine.api;
|
||||
|
||||
import dev.langchain4j.model.chat.ChatModel;
|
||||
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
|
||||
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
|
||||
import org.thingsboard.server.common.data.id.AiSettingsId;
|
||||
import org.thingsboard.server.common.data.id.TenantId;
|
||||
|
||||
@ -23,4 +25,6 @@ public interface RuleEngineAiService {
|
||||
|
||||
ChatModel configureChatModel(TenantId tenantId, AiSettingsId aiSettingsId);
|
||||
|
||||
ChatModel configureChatModel(AiProviderConfig providerConfig, AiModelConfig modelConfig);
|
||||
|
||||
}
|
||||
|
||||
@ -35,6 +35,9 @@ import org.thingsboard.rule.engine.api.TbNodeConfiguration;
|
||||
import org.thingsboard.rule.engine.api.TbNodeException;
|
||||
import org.thingsboard.rule.engine.api.util.TbNodeUtils;
|
||||
import org.thingsboard.rule.engine.external.TbAbstractExternalNode;
|
||||
import org.thingsboard.server.common.data.ai.AiSettings;
|
||||
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
|
||||
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
|
||||
import org.thingsboard.server.common.data.id.AiSettingsId;
|
||||
import org.thingsboard.server.common.data.plugin.ComponentType;
|
||||
import org.thingsboard.server.common.msg.TbMsg;
|
||||
@ -42,6 +45,7 @@ import org.thingsboard.server.dao.exception.DataValidationException;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
import static com.google.common.util.concurrent.Futures.addCallback;
|
||||
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
|
||||
@ -60,6 +64,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
||||
private SystemMessage systemMessage;
|
||||
private PromptTemplate userPromptTemplate;
|
||||
private ResponseFormat responseFormat;
|
||||
private int timeoutSeconds;
|
||||
private AiSettingsId aiSettingsId;
|
||||
|
||||
@Override
|
||||
@ -88,6 +93,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
||||
.formatted(config.getUserPrompt())
|
||||
);
|
||||
|
||||
timeoutSeconds = config.getTimeoutSeconds();
|
||||
aiSettingsId = config.getAiSettingsId();
|
||||
}
|
||||
|
||||
@ -95,11 +101,11 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
||||
if (responseFormatType == ResponseFormatType.TEXT) {
|
||||
return null;
|
||||
}
|
||||
return responseFormatType == ResponseFormatType.JSON && jsonSchema != null ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null;
|
||||
return responseFormatType == ResponseFormatType.JSON && jsonSchema != null && !jsonSchema.isNull() ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMsg(TbContext ctx, TbMsg msg) {
|
||||
public void onMsg(TbContext ctx, TbMsg msg) throws TbNodeException {
|
||||
var ackedMsg = ackIfNeeded(ctx, msg);
|
||||
|
||||
Map<String, Object> variables = Map.of(
|
||||
@ -114,7 +120,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
||||
.responseFormat(responseFormat)
|
||||
.build();
|
||||
|
||||
ChatModel chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), aiSettingsId);
|
||||
ChatModel chatModel = configureChatModel(ctx);
|
||||
|
||||
addCallback(sendChatRequest(ctx, chatModel, chatRequest), new FutureCallback<>() {
|
||||
@Override
|
||||
@ -134,6 +140,21 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
||||
}, directExecutor());
|
||||
}
|
||||
|
||||
private ChatModel configureChatModel(TbContext ctx) throws TbNodeException {
|
||||
Optional<AiSettings> aiSettingsOpt = ctx.getAiSettingsService().findAiSettingsByTenantIdAndId(ctx.getTenantId(), aiSettingsId);
|
||||
if (aiSettingsOpt.isEmpty()) {
|
||||
throw new TbNodeException("AI settings with ID: " + aiSettingsId + " were not found", true);
|
||||
}
|
||||
|
||||
AiProviderConfig providerConfig = aiSettingsOpt.get().getProviderConfig();
|
||||
AiModelConfig modelConfig = aiSettingsOpt.get().getModelConfig();
|
||||
|
||||
modelConfig.setTimeoutSeconds(timeoutSeconds);
|
||||
modelConfig.setMaxRetries(0); // disable retries to respect timeout set in rule node config
|
||||
|
||||
return ctx.getAiService().configureChatModel(providerConfig, modelConfig);
|
||||
}
|
||||
|
||||
private ListenableFuture<String> sendChatRequest(TbContext ctx, ChatModel chatModel, ChatRequest chatRequest) {
|
||||
return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text());
|
||||
}
|
||||
@ -157,6 +178,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
||||
systemMessage = null;
|
||||
userPromptTemplate = null;
|
||||
responseFormat = null;
|
||||
aiSettingsId = null;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -19,6 +19,8 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
||||
import jakarta.validation.constraints.AssertTrue;
|
||||
import jakarta.validation.constraints.Max;
|
||||
import jakarta.validation.constraints.Min;
|
||||
import jakarta.validation.constraints.NotBlank;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
@ -46,10 +48,14 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
|
||||
|
||||
private JsonNode jsonSchema;
|
||||
|
||||
@Min(value = 1, message = "must be at least 1 second")
|
||||
@Max(value = 600, message = "cannot exceed 600 seconds (10 minutes)")
|
||||
private int timeoutSeconds;
|
||||
|
||||
@JsonIgnore
|
||||
@AssertTrue(message = "provided JSON Schema must conform to the Draft 2020-12 meta-schema")
|
||||
public boolean isJsonSchemaValid() {
|
||||
return jsonSchema == null || JsonSchemaUtils.isValidJsonSchema(jsonSchema);
|
||||
return jsonSchema == null || jsonSchema.isNull() || JsonSchemaUtils.isValidJsonSchema(jsonSchema);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -58,6 +64,7 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
|
||||
configuration.setSystemPrompt("You are helpful assistant. Your response must be in JSON format.");
|
||||
configuration.setUserPrompt("Tell me a joke.");
|
||||
configuration.setResponseFormatType(ResponseFormatType.JSON);
|
||||
configuration.setTimeoutSeconds(60);
|
||||
return configuration;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user