AI rule node: add timeout support

This commit is contained in:
Dmytro Skarzhynets 2025-06-11 17:42:03 +03:00
parent d5c6ed1f61
commit cb106760c1
No known key found for this signature in database
GPG Key ID: 2B51652F224037DF
8 changed files with 111 additions and 14 deletions

View File

@ -23,13 +23,16 @@ import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.thingsboard.rule.engine.api.RuleEngineAiService;
import org.thingsboard.server.common.data.ai.AiSettings;
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
import org.thingsboard.server.common.data.ai.model.GoogleAiGeminiChatModelConfig;
import org.thingsboard.server.common.data.ai.model.MistralAiChatModelConfig;
import org.thingsboard.server.common.data.ai.model.OpenAiChatModelConfig;
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
import org.thingsboard.server.common.data.id.AiSettingsId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.dao.ai.AiSettingsService;
import java.time.Duration;
import java.util.NoSuchElementException;
import java.util.Optional;
@ -46,37 +49,53 @@ class AiServiceImpl implements RuleEngineAiService {
throw new NoSuchElementException("AI settings with ID: " + aiSettingsId + " were not found");
}
var aiSettings = aiSettingsOpt.get();
return configureChatModel(aiSettings.getProviderConfig(), aiSettings.getModelConfig());
}
return switch (aiSettings.getProvider()) {
@Override
public ChatModel configureChatModel(AiProviderConfig providerConfig, AiModelConfig modelConfig) {
return switch (providerConfig.getProvider()) {
case OPENAI -> {
var modelBuilder = OpenAiChatModel.builder()
.apiKey(aiSettings.getProviderConfig().getApiKey())
.modelName(aiSettings.getModel());
.apiKey(providerConfig.getApiKey())
.modelName(modelConfig.getModel());
if (aiSettings.getModelConfig() instanceof OpenAiChatModelConfig config) {
if (modelConfig instanceof OpenAiChatModelConfig config) {
modelBuilder.temperature(config.getTemperature());
if (config.getTimeoutSeconds() != null) {
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
}
modelBuilder.maxRetries(config.getMaxRetries());
}
yield modelBuilder.build();
}
case MISTRAL_AI -> {
var modelBuilder = MistralAiChatModel.builder()
.apiKey(aiSettings.getProviderConfig().getApiKey())
.modelName(aiSettings.getModel());
.apiKey(providerConfig.getApiKey())
.modelName(modelConfig.getModel());
if (aiSettings.getModelConfig() instanceof MistralAiChatModelConfig config) {
if (modelConfig instanceof MistralAiChatModelConfig config) {
modelBuilder.temperature(config.getTemperature());
if (config.getTimeoutSeconds() != null) {
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
}
modelBuilder.maxRetries(config.getMaxRetries());
}
yield modelBuilder.build();
}
case GOOGLE_AI_GEMINI -> {
var modelBuilder = GoogleAiGeminiChatModel.builder()
.apiKey(aiSettings.getProviderConfig().getApiKey())
.modelName(aiSettings.getModel());
.apiKey(providerConfig.getApiKey())
.modelName(modelConfig.getModel());
if (aiSettings.getModelConfig() instanceof GoogleAiGeminiChatModelConfig config) {
if (modelConfig instanceof GoogleAiGeminiChatModelConfig config) {
modelBuilder.temperature(config.getTemperature());
if (config.getTimeoutSeconds() != null) {
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
}
modelBuilder.maxRetries(config.getMaxRetries());
}
yield modelBuilder.build();

View File

@ -44,4 +44,13 @@ public abstract class AiModelConfig {
)
private String model;
public abstract Integer getTimeoutSeconds();
public abstract void setTimeoutSeconds(Integer timeoutSeconds);
public abstract Integer getMaxRetries();
public abstract void setMaxRetries(Integer timeoutSeconds);
}

View File

@ -47,4 +47,16 @@ public final class GoogleAiGeminiChatModelConfig extends AiModelConfig {
)
private Double temperature;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Timeout (in seconds) for establishing HTTP connection"
)
private Integer timeoutSeconds;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
)
private Integer maxRetries;
}

View File

@ -47,4 +47,16 @@ public final class MistralAiChatModelConfig extends AiModelConfig {
)
private Double temperature;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Timeout (in seconds) for the entire HTTP call: applied to connect, read, and write operations"
)
private Integer timeoutSeconds;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
)
private Integer maxRetries;
}

View File

@ -47,4 +47,16 @@ public final class OpenAiChatModelConfig extends AiModelConfig {
)
private Double temperature;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Timeout (in seconds) for both establishing HTTP connection and receiving a response"
)
private Integer timeoutSeconds;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
)
private Integer maxRetries;
}

View File

@ -16,6 +16,8 @@
package org.thingsboard.rule.engine.api;
import dev.langchain4j.model.chat.ChatModel;
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
import org.thingsboard.server.common.data.id.AiSettingsId;
import org.thingsboard.server.common.data.id.TenantId;
@ -23,4 +25,6 @@ public interface RuleEngineAiService {
ChatModel configureChatModel(TenantId tenantId, AiSettingsId aiSettingsId);
ChatModel configureChatModel(AiProviderConfig providerConfig, AiModelConfig modelConfig);
}

View File

@ -35,6 +35,9 @@ import org.thingsboard.rule.engine.api.TbNodeConfiguration;
import org.thingsboard.rule.engine.api.TbNodeException;
import org.thingsboard.rule.engine.api.util.TbNodeUtils;
import org.thingsboard.rule.engine.external.TbAbstractExternalNode;
import org.thingsboard.server.common.data.ai.AiSettings;
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
import org.thingsboard.server.common.data.id.AiSettingsId;
import org.thingsboard.server.common.data.plugin.ComponentType;
import org.thingsboard.server.common.msg.TbMsg;
@ -42,6 +45,7 @@ import org.thingsboard.server.dao.exception.DataValidationException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static com.google.common.util.concurrent.Futures.addCallback;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
@ -60,6 +64,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
private SystemMessage systemMessage;
private PromptTemplate userPromptTemplate;
private ResponseFormat responseFormat;
private int timeoutSeconds;
private AiSettingsId aiSettingsId;
@Override
@ -88,6 +93,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
.formatted(config.getUserPrompt())
);
timeoutSeconds = config.getTimeoutSeconds();
aiSettingsId = config.getAiSettingsId();
}
@ -95,11 +101,11 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
if (responseFormatType == ResponseFormatType.TEXT) {
return null;
}
return responseFormatType == ResponseFormatType.JSON && jsonSchema != null ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null;
return responseFormatType == ResponseFormatType.JSON && jsonSchema != null && !jsonSchema.isNull() ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null;
}
@Override
public void onMsg(TbContext ctx, TbMsg msg) {
public void onMsg(TbContext ctx, TbMsg msg) throws TbNodeException {
var ackedMsg = ackIfNeeded(ctx, msg);
Map<String, Object> variables = Map.of(
@ -114,7 +120,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
.responseFormat(responseFormat)
.build();
ChatModel chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), aiSettingsId);
ChatModel chatModel = configureChatModel(ctx);
addCallback(sendChatRequest(ctx, chatModel, chatRequest), new FutureCallback<>() {
@Override
@ -134,6 +140,21 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
}, directExecutor());
}
private ChatModel configureChatModel(TbContext ctx) throws TbNodeException {
Optional<AiSettings> aiSettingsOpt = ctx.getAiSettingsService().findAiSettingsByTenantIdAndId(ctx.getTenantId(), aiSettingsId);
if (aiSettingsOpt.isEmpty()) {
throw new TbNodeException("AI settings with ID: " + aiSettingsId + " were not found", true);
}
AiProviderConfig providerConfig = aiSettingsOpt.get().getProviderConfig();
AiModelConfig modelConfig = aiSettingsOpt.get().getModelConfig();
modelConfig.setTimeoutSeconds(timeoutSeconds);
modelConfig.setMaxRetries(0); // disable retries to respect timeout set in rule node config
return ctx.getAiService().configureChatModel(providerConfig, modelConfig);
}
private ListenableFuture<String> sendChatRequest(TbContext ctx, ChatModel chatModel, ChatRequest chatRequest) {
return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text());
}
@ -157,6 +178,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
systemMessage = null;
userPromptTemplate = null;
responseFormat = null;
aiSettingsId = null;
}
}

View File

@ -19,6 +19,8 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.JsonNode;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@ -46,10 +48,14 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
private JsonNode jsonSchema;
@Min(value = 1, message = "must be at least 1 second")
@Max(value = 600, message = "cannot exceed 600 seconds (10 minutes)")
private int timeoutSeconds;
@JsonIgnore
@AssertTrue(message = "provided JSON Schema must conform to the Draft 2020-12 meta-schema")
public boolean isJsonSchemaValid() {
return jsonSchema == null || JsonSchemaUtils.isValidJsonSchema(jsonSchema);
return jsonSchema == null || jsonSchema.isNull() || JsonSchemaUtils.isValidJsonSchema(jsonSchema);
}
@Override
@ -58,6 +64,7 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
configuration.setSystemPrompt("You are helpful assistant. Your response must be in JSON format.");
configuration.setUserPrompt("Tell me a joke.");
configuration.setResponseFormatType(ResponseFormatType.JSON);
configuration.setTimeoutSeconds(60);
return configuration;
}