diff --git a/application/src/main/java/org/thingsboard/server/service/ai/AiServiceImpl.java b/application/src/main/java/org/thingsboard/server/service/ai/AiServiceImpl.java index 60d7dca7a8..eeddb5e6d2 100644 --- a/application/src/main/java/org/thingsboard/server/service/ai/AiServiceImpl.java +++ b/application/src/main/java/org/thingsboard/server/service/ai/AiServiceImpl.java @@ -23,13 +23,16 @@ import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Service; import org.thingsboard.rule.engine.api.RuleEngineAiService; import org.thingsboard.server.common.data.ai.AiSettings; +import org.thingsboard.server.common.data.ai.model.AiModelConfig; import org.thingsboard.server.common.data.ai.model.GoogleAiGeminiChatModelConfig; import org.thingsboard.server.common.data.ai.model.MistralAiChatModelConfig; import org.thingsboard.server.common.data.ai.model.OpenAiChatModelConfig; +import org.thingsboard.server.common.data.ai.provider.AiProviderConfig; import org.thingsboard.server.common.data.id.AiSettingsId; import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.dao.ai.AiSettingsService; +import java.time.Duration; import java.util.NoSuchElementException; import java.util.Optional; @@ -46,37 +49,53 @@ class AiServiceImpl implements RuleEngineAiService { throw new NoSuchElementException("AI settings with ID: " + aiSettingsId + " were not found"); } var aiSettings = aiSettingsOpt.get(); + return configureChatModel(aiSettings.getProviderConfig(), aiSettings.getModelConfig()); + } - return switch (aiSettings.getProvider()) { + @Override + public ChatModel configureChatModel(AiProviderConfig providerConfig, AiModelConfig modelConfig) { + return switch (providerConfig.getProvider()) { case OPENAI -> { var modelBuilder = OpenAiChatModel.builder() - .apiKey(aiSettings.getProviderConfig().getApiKey()) - .modelName(aiSettings.getModel()); + .apiKey(providerConfig.getApiKey()) + .modelName(modelConfig.getModel()); - if (aiSettings.getModelConfig() instanceof OpenAiChatModelConfig config) { + if (modelConfig instanceof OpenAiChatModelConfig config) { modelBuilder.temperature(config.getTemperature()); + if (config.getTimeoutSeconds() != null) { + modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds())); + } + modelBuilder.maxRetries(config.getMaxRetries()); } yield modelBuilder.build(); } case MISTRAL_AI -> { var modelBuilder = MistralAiChatModel.builder() - .apiKey(aiSettings.getProviderConfig().getApiKey()) - .modelName(aiSettings.getModel()); + .apiKey(providerConfig.getApiKey()) + .modelName(modelConfig.getModel()); - if (aiSettings.getModelConfig() instanceof MistralAiChatModelConfig config) { + if (modelConfig instanceof MistralAiChatModelConfig config) { modelBuilder.temperature(config.getTemperature()); + if (config.getTimeoutSeconds() != null) { + modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds())); + } + modelBuilder.maxRetries(config.getMaxRetries()); } yield modelBuilder.build(); } case GOOGLE_AI_GEMINI -> { var modelBuilder = GoogleAiGeminiChatModel.builder() - .apiKey(aiSettings.getProviderConfig().getApiKey()) - .modelName(aiSettings.getModel()); + .apiKey(providerConfig.getApiKey()) + .modelName(modelConfig.getModel()); - if (aiSettings.getModelConfig() instanceof GoogleAiGeminiChatModelConfig config) { + if (modelConfig instanceof GoogleAiGeminiChatModelConfig config) { modelBuilder.temperature(config.getTemperature()); + if (config.getTimeoutSeconds() != null) { + modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds())); + } + modelBuilder.maxRetries(config.getMaxRetries()); } yield modelBuilder.build(); diff --git a/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/AiModelConfig.java b/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/AiModelConfig.java index 5112df3213..779b75812f 100644 --- a/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/AiModelConfig.java +++ b/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/AiModelConfig.java @@ -44,4 +44,13 @@ public abstract class AiModelConfig { ) private String model; + public abstract Integer getTimeoutSeconds(); + + public abstract void setTimeoutSeconds(Integer timeoutSeconds); + + public abstract Integer getMaxRetries(); + + public abstract void setMaxRetries(Integer timeoutSeconds); + + } diff --git a/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/GoogleAiGeminiChatModelConfig.java b/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/GoogleAiGeminiChatModelConfig.java index 761ec998cc..1d85115f31 100644 --- a/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/GoogleAiGeminiChatModelConfig.java +++ b/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/GoogleAiGeminiChatModelConfig.java @@ -47,4 +47,16 @@ public final class GoogleAiGeminiChatModelConfig extends AiModelConfig { ) private Double temperature; + @Schema( + accessMode = Schema.AccessMode.READ_WRITE, + description = "Timeout (in seconds) for establishing HTTP connection" + ) + private Integer timeoutSeconds; + + @Schema( + accessMode = Schema.AccessMode.READ_WRITE, + description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)" + ) + private Integer maxRetries; + } diff --git a/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/MistralAiChatModelConfig.java b/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/MistralAiChatModelConfig.java index 20a8cfea7c..d333eb6283 100644 --- a/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/MistralAiChatModelConfig.java +++ b/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/MistralAiChatModelConfig.java @@ -47,4 +47,16 @@ public final class MistralAiChatModelConfig extends AiModelConfig { ) private Double temperature; + @Schema( + accessMode = Schema.AccessMode.READ_WRITE, + description = "Timeout (in seconds) for the entire HTTP call: applied to connect, read, and write operations" + ) + private Integer timeoutSeconds; + + @Schema( + accessMode = Schema.AccessMode.READ_WRITE, + description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)" + ) + private Integer maxRetries; + } diff --git a/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/OpenAiChatModelConfig.java b/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/OpenAiChatModelConfig.java index a7bc3725ef..accf5c92ec 100644 --- a/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/OpenAiChatModelConfig.java +++ b/common/data/src/main/java/org/thingsboard/server/common/data/ai/model/OpenAiChatModelConfig.java @@ -47,4 +47,16 @@ public final class OpenAiChatModelConfig extends AiModelConfig { ) private Double temperature; + @Schema( + accessMode = Schema.AccessMode.READ_WRITE, + description = "Timeout (in seconds) for both establishing HTTP connection and receiving a response" + ) + private Integer timeoutSeconds; + + @Schema( + accessMode = Schema.AccessMode.READ_WRITE, + description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)" + ) + private Integer maxRetries; + } diff --git a/rule-engine/rule-engine-api/src/main/java/org/thingsboard/rule/engine/api/RuleEngineAiService.java b/rule-engine/rule-engine-api/src/main/java/org/thingsboard/rule/engine/api/RuleEngineAiService.java index ae455d3b51..d493c5599a 100644 --- a/rule-engine/rule-engine-api/src/main/java/org/thingsboard/rule/engine/api/RuleEngineAiService.java +++ b/rule-engine/rule-engine-api/src/main/java/org/thingsboard/rule/engine/api/RuleEngineAiService.java @@ -16,6 +16,8 @@ package org.thingsboard.rule.engine.api; import dev.langchain4j.model.chat.ChatModel; +import org.thingsboard.server.common.data.ai.model.AiModelConfig; +import org.thingsboard.server.common.data.ai.provider.AiProviderConfig; import org.thingsboard.server.common.data.id.AiSettingsId; import org.thingsboard.server.common.data.id.TenantId; @@ -23,4 +25,6 @@ public interface RuleEngineAiService { ChatModel configureChatModel(TenantId tenantId, AiSettingsId aiSettingsId); + ChatModel configureChatModel(AiProviderConfig providerConfig, AiModelConfig modelConfig); + } diff --git a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNode.java b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNode.java index 6b0bb3803b..7ee354d581 100644 --- a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNode.java +++ b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNode.java @@ -35,6 +35,9 @@ import org.thingsboard.rule.engine.api.TbNodeConfiguration; import org.thingsboard.rule.engine.api.TbNodeException; import org.thingsboard.rule.engine.api.util.TbNodeUtils; import org.thingsboard.rule.engine.external.TbAbstractExternalNode; +import org.thingsboard.server.common.data.ai.AiSettings; +import org.thingsboard.server.common.data.ai.model.AiModelConfig; +import org.thingsboard.server.common.data.ai.provider.AiProviderConfig; import org.thingsboard.server.common.data.id.AiSettingsId; import org.thingsboard.server.common.data.plugin.ComponentType; import org.thingsboard.server.common.msg.TbMsg; @@ -42,6 +45,7 @@ import org.thingsboard.server.dao.exception.DataValidationException; import java.util.List; import java.util.Map; +import java.util.Optional; import static com.google.common.util.concurrent.Futures.addCallback; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; @@ -60,6 +64,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode { private SystemMessage systemMessage; private PromptTemplate userPromptTemplate; private ResponseFormat responseFormat; + private int timeoutSeconds; private AiSettingsId aiSettingsId; @Override @@ -88,6 +93,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode { .formatted(config.getUserPrompt()) ); + timeoutSeconds = config.getTimeoutSeconds(); aiSettingsId = config.getAiSettingsId(); } @@ -95,11 +101,11 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode { if (responseFormatType == ResponseFormatType.TEXT) { return null; } - return responseFormatType == ResponseFormatType.JSON && jsonSchema != null ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null; + return responseFormatType == ResponseFormatType.JSON && jsonSchema != null && !jsonSchema.isNull() ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null; } @Override - public void onMsg(TbContext ctx, TbMsg msg) { + public void onMsg(TbContext ctx, TbMsg msg) throws TbNodeException { var ackedMsg = ackIfNeeded(ctx, msg); Map variables = Map.of( @@ -114,7 +120,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode { .responseFormat(responseFormat) .build(); - ChatModel chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), aiSettingsId); + ChatModel chatModel = configureChatModel(ctx); addCallback(sendChatRequest(ctx, chatModel, chatRequest), new FutureCallback<>() { @Override @@ -134,6 +140,21 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode { }, directExecutor()); } + private ChatModel configureChatModel(TbContext ctx) throws TbNodeException { + Optional aiSettingsOpt = ctx.getAiSettingsService().findAiSettingsByTenantIdAndId(ctx.getTenantId(), aiSettingsId); + if (aiSettingsOpt.isEmpty()) { + throw new TbNodeException("AI settings with ID: " + aiSettingsId + " were not found", true); + } + + AiProviderConfig providerConfig = aiSettingsOpt.get().getProviderConfig(); + AiModelConfig modelConfig = aiSettingsOpt.get().getModelConfig(); + + modelConfig.setTimeoutSeconds(timeoutSeconds); + modelConfig.setMaxRetries(0); // disable retries to respect timeout set in rule node config + + return ctx.getAiService().configureChatModel(providerConfig, modelConfig); + } + private ListenableFuture sendChatRequest(TbContext ctx, ChatModel chatModel, ChatRequest chatRequest) { return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text()); } @@ -157,6 +178,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode { systemMessage = null; userPromptTemplate = null; responseFormat = null; + aiSettingsId = null; } } diff --git a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNodeConfiguration.java b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNodeConfiguration.java index f9341ae982..7d2b9745ab 100644 --- a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNodeConfiguration.java +++ b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNodeConfiguration.java @@ -19,6 +19,8 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.databind.JsonNode; import dev.langchain4j.model.chat.request.ResponseFormatType; import jakarta.validation.constraints.AssertTrue; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotNull; import lombok.Data; @@ -46,10 +48,14 @@ public class TbAiNodeConfiguration implements NodeConfiguration