AI rule node: add timeout support
This commit is contained in:
parent
d5c6ed1f61
commit
cb106760c1
@ -23,13 +23,16 @@ import lombok.RequiredArgsConstructor;
|
|||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.thingsboard.rule.engine.api.RuleEngineAiService;
|
import org.thingsboard.rule.engine.api.RuleEngineAiService;
|
||||||
import org.thingsboard.server.common.data.ai.AiSettings;
|
import org.thingsboard.server.common.data.ai.AiSettings;
|
||||||
|
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
|
||||||
import org.thingsboard.server.common.data.ai.model.GoogleAiGeminiChatModelConfig;
|
import org.thingsboard.server.common.data.ai.model.GoogleAiGeminiChatModelConfig;
|
||||||
import org.thingsboard.server.common.data.ai.model.MistralAiChatModelConfig;
|
import org.thingsboard.server.common.data.ai.model.MistralAiChatModelConfig;
|
||||||
import org.thingsboard.server.common.data.ai.model.OpenAiChatModelConfig;
|
import org.thingsboard.server.common.data.ai.model.OpenAiChatModelConfig;
|
||||||
|
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
|
||||||
import org.thingsboard.server.common.data.id.AiSettingsId;
|
import org.thingsboard.server.common.data.id.AiSettingsId;
|
||||||
import org.thingsboard.server.common.data.id.TenantId;
|
import org.thingsboard.server.common.data.id.TenantId;
|
||||||
import org.thingsboard.server.dao.ai.AiSettingsService;
|
import org.thingsboard.server.dao.ai.AiSettingsService;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
import java.util.NoSuchElementException;
|
import java.util.NoSuchElementException;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
|
||||||
@ -46,37 +49,53 @@ class AiServiceImpl implements RuleEngineAiService {
|
|||||||
throw new NoSuchElementException("AI settings with ID: " + aiSettingsId + " were not found");
|
throw new NoSuchElementException("AI settings with ID: " + aiSettingsId + " were not found");
|
||||||
}
|
}
|
||||||
var aiSettings = aiSettingsOpt.get();
|
var aiSettings = aiSettingsOpt.get();
|
||||||
|
return configureChatModel(aiSettings.getProviderConfig(), aiSettings.getModelConfig());
|
||||||
|
}
|
||||||
|
|
||||||
return switch (aiSettings.getProvider()) {
|
@Override
|
||||||
|
public ChatModel configureChatModel(AiProviderConfig providerConfig, AiModelConfig modelConfig) {
|
||||||
|
return switch (providerConfig.getProvider()) {
|
||||||
case OPENAI -> {
|
case OPENAI -> {
|
||||||
var modelBuilder = OpenAiChatModel.builder()
|
var modelBuilder = OpenAiChatModel.builder()
|
||||||
.apiKey(aiSettings.getProviderConfig().getApiKey())
|
.apiKey(providerConfig.getApiKey())
|
||||||
.modelName(aiSettings.getModel());
|
.modelName(modelConfig.getModel());
|
||||||
|
|
||||||
if (aiSettings.getModelConfig() instanceof OpenAiChatModelConfig config) {
|
if (modelConfig instanceof OpenAiChatModelConfig config) {
|
||||||
modelBuilder.temperature(config.getTemperature());
|
modelBuilder.temperature(config.getTemperature());
|
||||||
|
if (config.getTimeoutSeconds() != null) {
|
||||||
|
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
|
||||||
|
}
|
||||||
|
modelBuilder.maxRetries(config.getMaxRetries());
|
||||||
}
|
}
|
||||||
|
|
||||||
yield modelBuilder.build();
|
yield modelBuilder.build();
|
||||||
}
|
}
|
||||||
case MISTRAL_AI -> {
|
case MISTRAL_AI -> {
|
||||||
var modelBuilder = MistralAiChatModel.builder()
|
var modelBuilder = MistralAiChatModel.builder()
|
||||||
.apiKey(aiSettings.getProviderConfig().getApiKey())
|
.apiKey(providerConfig.getApiKey())
|
||||||
.modelName(aiSettings.getModel());
|
.modelName(modelConfig.getModel());
|
||||||
|
|
||||||
if (aiSettings.getModelConfig() instanceof MistralAiChatModelConfig config) {
|
if (modelConfig instanceof MistralAiChatModelConfig config) {
|
||||||
modelBuilder.temperature(config.getTemperature());
|
modelBuilder.temperature(config.getTemperature());
|
||||||
|
if (config.getTimeoutSeconds() != null) {
|
||||||
|
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
|
||||||
|
}
|
||||||
|
modelBuilder.maxRetries(config.getMaxRetries());
|
||||||
}
|
}
|
||||||
|
|
||||||
yield modelBuilder.build();
|
yield modelBuilder.build();
|
||||||
}
|
}
|
||||||
case GOOGLE_AI_GEMINI -> {
|
case GOOGLE_AI_GEMINI -> {
|
||||||
var modelBuilder = GoogleAiGeminiChatModel.builder()
|
var modelBuilder = GoogleAiGeminiChatModel.builder()
|
||||||
.apiKey(aiSettings.getProviderConfig().getApiKey())
|
.apiKey(providerConfig.getApiKey())
|
||||||
.modelName(aiSettings.getModel());
|
.modelName(modelConfig.getModel());
|
||||||
|
|
||||||
if (aiSettings.getModelConfig() instanceof GoogleAiGeminiChatModelConfig config) {
|
if (modelConfig instanceof GoogleAiGeminiChatModelConfig config) {
|
||||||
modelBuilder.temperature(config.getTemperature());
|
modelBuilder.temperature(config.getTemperature());
|
||||||
|
if (config.getTimeoutSeconds() != null) {
|
||||||
|
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
|
||||||
|
}
|
||||||
|
modelBuilder.maxRetries(config.getMaxRetries());
|
||||||
}
|
}
|
||||||
|
|
||||||
yield modelBuilder.build();
|
yield modelBuilder.build();
|
||||||
|
|||||||
@ -44,4 +44,13 @@ public abstract class AiModelConfig {
|
|||||||
)
|
)
|
||||||
private String model;
|
private String model;
|
||||||
|
|
||||||
|
public abstract Integer getTimeoutSeconds();
|
||||||
|
|
||||||
|
public abstract void setTimeoutSeconds(Integer timeoutSeconds);
|
||||||
|
|
||||||
|
public abstract Integer getMaxRetries();
|
||||||
|
|
||||||
|
public abstract void setMaxRetries(Integer timeoutSeconds);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -47,4 +47,16 @@ public final class GoogleAiGeminiChatModelConfig extends AiModelConfig {
|
|||||||
)
|
)
|
||||||
private Double temperature;
|
private Double temperature;
|
||||||
|
|
||||||
|
@Schema(
|
||||||
|
accessMode = Schema.AccessMode.READ_WRITE,
|
||||||
|
description = "Timeout (in seconds) for establishing HTTP connection"
|
||||||
|
)
|
||||||
|
private Integer timeoutSeconds;
|
||||||
|
|
||||||
|
@Schema(
|
||||||
|
accessMode = Schema.AccessMode.READ_WRITE,
|
||||||
|
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
|
||||||
|
)
|
||||||
|
private Integer maxRetries;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -47,4 +47,16 @@ public final class MistralAiChatModelConfig extends AiModelConfig {
|
|||||||
)
|
)
|
||||||
private Double temperature;
|
private Double temperature;
|
||||||
|
|
||||||
|
@Schema(
|
||||||
|
accessMode = Schema.AccessMode.READ_WRITE,
|
||||||
|
description = "Timeout (in seconds) for the entire HTTP call: applied to connect, read, and write operations"
|
||||||
|
)
|
||||||
|
private Integer timeoutSeconds;
|
||||||
|
|
||||||
|
@Schema(
|
||||||
|
accessMode = Schema.AccessMode.READ_WRITE,
|
||||||
|
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
|
||||||
|
)
|
||||||
|
private Integer maxRetries;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -47,4 +47,16 @@ public final class OpenAiChatModelConfig extends AiModelConfig {
|
|||||||
)
|
)
|
||||||
private Double temperature;
|
private Double temperature;
|
||||||
|
|
||||||
|
@Schema(
|
||||||
|
accessMode = Schema.AccessMode.READ_WRITE,
|
||||||
|
description = "Timeout (in seconds) for both establishing HTTP connection and receiving a response"
|
||||||
|
)
|
||||||
|
private Integer timeoutSeconds;
|
||||||
|
|
||||||
|
@Schema(
|
||||||
|
accessMode = Schema.AccessMode.READ_WRITE,
|
||||||
|
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
|
||||||
|
)
|
||||||
|
private Integer maxRetries;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,6 +16,8 @@
|
|||||||
package org.thingsboard.rule.engine.api;
|
package org.thingsboard.rule.engine.api;
|
||||||
|
|
||||||
import dev.langchain4j.model.chat.ChatModel;
|
import dev.langchain4j.model.chat.ChatModel;
|
||||||
|
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
|
||||||
|
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
|
||||||
import org.thingsboard.server.common.data.id.AiSettingsId;
|
import org.thingsboard.server.common.data.id.AiSettingsId;
|
||||||
import org.thingsboard.server.common.data.id.TenantId;
|
import org.thingsboard.server.common.data.id.TenantId;
|
||||||
|
|
||||||
@ -23,4 +25,6 @@ public interface RuleEngineAiService {
|
|||||||
|
|
||||||
ChatModel configureChatModel(TenantId tenantId, AiSettingsId aiSettingsId);
|
ChatModel configureChatModel(TenantId tenantId, AiSettingsId aiSettingsId);
|
||||||
|
|
||||||
|
ChatModel configureChatModel(AiProviderConfig providerConfig, AiModelConfig modelConfig);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,6 +35,9 @@ import org.thingsboard.rule.engine.api.TbNodeConfiguration;
|
|||||||
import org.thingsboard.rule.engine.api.TbNodeException;
|
import org.thingsboard.rule.engine.api.TbNodeException;
|
||||||
import org.thingsboard.rule.engine.api.util.TbNodeUtils;
|
import org.thingsboard.rule.engine.api.util.TbNodeUtils;
|
||||||
import org.thingsboard.rule.engine.external.TbAbstractExternalNode;
|
import org.thingsboard.rule.engine.external.TbAbstractExternalNode;
|
||||||
|
import org.thingsboard.server.common.data.ai.AiSettings;
|
||||||
|
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
|
||||||
|
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
|
||||||
import org.thingsboard.server.common.data.id.AiSettingsId;
|
import org.thingsboard.server.common.data.id.AiSettingsId;
|
||||||
import org.thingsboard.server.common.data.plugin.ComponentType;
|
import org.thingsboard.server.common.data.plugin.ComponentType;
|
||||||
import org.thingsboard.server.common.msg.TbMsg;
|
import org.thingsboard.server.common.msg.TbMsg;
|
||||||
@ -42,6 +45,7 @@ import org.thingsboard.server.dao.exception.DataValidationException;
|
|||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
import static com.google.common.util.concurrent.Futures.addCallback;
|
import static com.google.common.util.concurrent.Futures.addCallback;
|
||||||
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
|
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
|
||||||
@ -60,6 +64,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
|||||||
private SystemMessage systemMessage;
|
private SystemMessage systemMessage;
|
||||||
private PromptTemplate userPromptTemplate;
|
private PromptTemplate userPromptTemplate;
|
||||||
private ResponseFormat responseFormat;
|
private ResponseFormat responseFormat;
|
||||||
|
private int timeoutSeconds;
|
||||||
private AiSettingsId aiSettingsId;
|
private AiSettingsId aiSettingsId;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -88,6 +93,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
|||||||
.formatted(config.getUserPrompt())
|
.formatted(config.getUserPrompt())
|
||||||
);
|
);
|
||||||
|
|
||||||
|
timeoutSeconds = config.getTimeoutSeconds();
|
||||||
aiSettingsId = config.getAiSettingsId();
|
aiSettingsId = config.getAiSettingsId();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,11 +101,11 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
|||||||
if (responseFormatType == ResponseFormatType.TEXT) {
|
if (responseFormatType == ResponseFormatType.TEXT) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
return responseFormatType == ResponseFormatType.JSON && jsonSchema != null ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null;
|
return responseFormatType == ResponseFormatType.JSON && jsonSchema != null && !jsonSchema.isNull() ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onMsg(TbContext ctx, TbMsg msg) {
|
public void onMsg(TbContext ctx, TbMsg msg) throws TbNodeException {
|
||||||
var ackedMsg = ackIfNeeded(ctx, msg);
|
var ackedMsg = ackIfNeeded(ctx, msg);
|
||||||
|
|
||||||
Map<String, Object> variables = Map.of(
|
Map<String, Object> variables = Map.of(
|
||||||
@ -114,7 +120,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
|||||||
.responseFormat(responseFormat)
|
.responseFormat(responseFormat)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
ChatModel chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), aiSettingsId);
|
ChatModel chatModel = configureChatModel(ctx);
|
||||||
|
|
||||||
addCallback(sendChatRequest(ctx, chatModel, chatRequest), new FutureCallback<>() {
|
addCallback(sendChatRequest(ctx, chatModel, chatRequest), new FutureCallback<>() {
|
||||||
@Override
|
@Override
|
||||||
@ -134,6 +140,21 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
|||||||
}, directExecutor());
|
}, directExecutor());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private ChatModel configureChatModel(TbContext ctx) throws TbNodeException {
|
||||||
|
Optional<AiSettings> aiSettingsOpt = ctx.getAiSettingsService().findAiSettingsByTenantIdAndId(ctx.getTenantId(), aiSettingsId);
|
||||||
|
if (aiSettingsOpt.isEmpty()) {
|
||||||
|
throw new TbNodeException("AI settings with ID: " + aiSettingsId + " were not found", true);
|
||||||
|
}
|
||||||
|
|
||||||
|
AiProviderConfig providerConfig = aiSettingsOpt.get().getProviderConfig();
|
||||||
|
AiModelConfig modelConfig = aiSettingsOpt.get().getModelConfig();
|
||||||
|
|
||||||
|
modelConfig.setTimeoutSeconds(timeoutSeconds);
|
||||||
|
modelConfig.setMaxRetries(0); // disable retries to respect timeout set in rule node config
|
||||||
|
|
||||||
|
return ctx.getAiService().configureChatModel(providerConfig, modelConfig);
|
||||||
|
}
|
||||||
|
|
||||||
private ListenableFuture<String> sendChatRequest(TbContext ctx, ChatModel chatModel, ChatRequest chatRequest) {
|
private ListenableFuture<String> sendChatRequest(TbContext ctx, ChatModel chatModel, ChatRequest chatRequest) {
|
||||||
return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text());
|
return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text());
|
||||||
}
|
}
|
||||||
@ -157,6 +178,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
|||||||
systemMessage = null;
|
systemMessage = null;
|
||||||
userPromptTemplate = null;
|
userPromptTemplate = null;
|
||||||
responseFormat = null;
|
responseFormat = null;
|
||||||
|
aiSettingsId = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,6 +19,8 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
|
|||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
||||||
import jakarta.validation.constraints.AssertTrue;
|
import jakarta.validation.constraints.AssertTrue;
|
||||||
|
import jakarta.validation.constraints.Max;
|
||||||
|
import jakarta.validation.constraints.Min;
|
||||||
import jakarta.validation.constraints.NotBlank;
|
import jakarta.validation.constraints.NotBlank;
|
||||||
import jakarta.validation.constraints.NotNull;
|
import jakarta.validation.constraints.NotNull;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@ -46,10 +48,14 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
|
|||||||
|
|
||||||
private JsonNode jsonSchema;
|
private JsonNode jsonSchema;
|
||||||
|
|
||||||
|
@Min(value = 1, message = "must be at least 1 second")
|
||||||
|
@Max(value = 600, message = "cannot exceed 600 seconds (10 minutes)")
|
||||||
|
private int timeoutSeconds;
|
||||||
|
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
@AssertTrue(message = "provided JSON Schema must conform to the Draft 2020-12 meta-schema")
|
@AssertTrue(message = "provided JSON Schema must conform to the Draft 2020-12 meta-schema")
|
||||||
public boolean isJsonSchemaValid() {
|
public boolean isJsonSchemaValid() {
|
||||||
return jsonSchema == null || JsonSchemaUtils.isValidJsonSchema(jsonSchema);
|
return jsonSchema == null || jsonSchema.isNull() || JsonSchemaUtils.isValidJsonSchema(jsonSchema);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -58,6 +64,7 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
|
|||||||
configuration.setSystemPrompt("You are helpful assistant. Your response must be in JSON format.");
|
configuration.setSystemPrompt("You are helpful assistant. Your response must be in JSON format.");
|
||||||
configuration.setUserPrompt("Tell me a joke.");
|
configuration.setUserPrompt("Tell me a joke.");
|
||||||
configuration.setResponseFormatType(ResponseFormatType.JSON);
|
configuration.setResponseFormatType(ResponseFormatType.JSON);
|
||||||
|
configuration.setTimeoutSeconds(60);
|
||||||
return configuration;
|
return configuration;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user