AI rule node: add timeout support

This commit is contained in:
Dmytro Skarzhynets 2025-06-11 17:42:03 +03:00
parent d5c6ed1f61
commit cb106760c1
No known key found for this signature in database
GPG Key ID: 2B51652F224037DF
8 changed files with 111 additions and 14 deletions

View File

@ -23,13 +23,16 @@ import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.thingsboard.rule.engine.api.RuleEngineAiService; import org.thingsboard.rule.engine.api.RuleEngineAiService;
import org.thingsboard.server.common.data.ai.AiSettings; import org.thingsboard.server.common.data.ai.AiSettings;
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
import org.thingsboard.server.common.data.ai.model.GoogleAiGeminiChatModelConfig; import org.thingsboard.server.common.data.ai.model.GoogleAiGeminiChatModelConfig;
import org.thingsboard.server.common.data.ai.model.MistralAiChatModelConfig; import org.thingsboard.server.common.data.ai.model.MistralAiChatModelConfig;
import org.thingsboard.server.common.data.ai.model.OpenAiChatModelConfig; import org.thingsboard.server.common.data.ai.model.OpenAiChatModelConfig;
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
import org.thingsboard.server.common.data.id.AiSettingsId; import org.thingsboard.server.common.data.id.AiSettingsId;
import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.dao.ai.AiSettingsService; import org.thingsboard.server.dao.ai.AiSettingsService;
import java.time.Duration;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import java.util.Optional; import java.util.Optional;
@ -46,37 +49,53 @@ class AiServiceImpl implements RuleEngineAiService {
throw new NoSuchElementException("AI settings with ID: " + aiSettingsId + " were not found"); throw new NoSuchElementException("AI settings with ID: " + aiSettingsId + " were not found");
} }
var aiSettings = aiSettingsOpt.get(); var aiSettings = aiSettingsOpt.get();
return configureChatModel(aiSettings.getProviderConfig(), aiSettings.getModelConfig());
}
return switch (aiSettings.getProvider()) { @Override
public ChatModel configureChatModel(AiProviderConfig providerConfig, AiModelConfig modelConfig) {
return switch (providerConfig.getProvider()) {
case OPENAI -> { case OPENAI -> {
var modelBuilder = OpenAiChatModel.builder() var modelBuilder = OpenAiChatModel.builder()
.apiKey(aiSettings.getProviderConfig().getApiKey()) .apiKey(providerConfig.getApiKey())
.modelName(aiSettings.getModel()); .modelName(modelConfig.getModel());
if (aiSettings.getModelConfig() instanceof OpenAiChatModelConfig config) { if (modelConfig instanceof OpenAiChatModelConfig config) {
modelBuilder.temperature(config.getTemperature()); modelBuilder.temperature(config.getTemperature());
if (config.getTimeoutSeconds() != null) {
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
}
modelBuilder.maxRetries(config.getMaxRetries());
} }
yield modelBuilder.build(); yield modelBuilder.build();
} }
case MISTRAL_AI -> { case MISTRAL_AI -> {
var modelBuilder = MistralAiChatModel.builder() var modelBuilder = MistralAiChatModel.builder()
.apiKey(aiSettings.getProviderConfig().getApiKey()) .apiKey(providerConfig.getApiKey())
.modelName(aiSettings.getModel()); .modelName(modelConfig.getModel());
if (aiSettings.getModelConfig() instanceof MistralAiChatModelConfig config) { if (modelConfig instanceof MistralAiChatModelConfig config) {
modelBuilder.temperature(config.getTemperature()); modelBuilder.temperature(config.getTemperature());
if (config.getTimeoutSeconds() != null) {
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
}
modelBuilder.maxRetries(config.getMaxRetries());
} }
yield modelBuilder.build(); yield modelBuilder.build();
} }
case GOOGLE_AI_GEMINI -> { case GOOGLE_AI_GEMINI -> {
var modelBuilder = GoogleAiGeminiChatModel.builder() var modelBuilder = GoogleAiGeminiChatModel.builder()
.apiKey(aiSettings.getProviderConfig().getApiKey()) .apiKey(providerConfig.getApiKey())
.modelName(aiSettings.getModel()); .modelName(modelConfig.getModel());
if (aiSettings.getModelConfig() instanceof GoogleAiGeminiChatModelConfig config) { if (modelConfig instanceof GoogleAiGeminiChatModelConfig config) {
modelBuilder.temperature(config.getTemperature()); modelBuilder.temperature(config.getTemperature());
if (config.getTimeoutSeconds() != null) {
modelBuilder.timeout(Duration.ofSeconds(config.getTimeoutSeconds()));
}
modelBuilder.maxRetries(config.getMaxRetries());
} }
yield modelBuilder.build(); yield modelBuilder.build();

View File

@ -44,4 +44,13 @@ public abstract class AiModelConfig {
) )
private String model; private String model;
public abstract Integer getTimeoutSeconds();
public abstract void setTimeoutSeconds(Integer timeoutSeconds);
public abstract Integer getMaxRetries();
public abstract void setMaxRetries(Integer timeoutSeconds);
} }

View File

@ -47,4 +47,16 @@ public final class GoogleAiGeminiChatModelConfig extends AiModelConfig {
) )
private Double temperature; private Double temperature;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Timeout (in seconds) for establishing HTTP connection"
)
private Integer timeoutSeconds;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
)
private Integer maxRetries;
} }

View File

@ -47,4 +47,16 @@ public final class MistralAiChatModelConfig extends AiModelConfig {
) )
private Double temperature; private Double temperature;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Timeout (in seconds) for the entire HTTP call: applied to connect, read, and write operations"
)
private Integer timeoutSeconds;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
)
private Integer maxRetries;
} }

View File

@ -47,4 +47,16 @@ public final class OpenAiChatModelConfig extends AiModelConfig {
) )
private Double temperature; private Double temperature;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Timeout (in seconds) for both establishing HTTP connection and receiving a response"
)
private Integer timeoutSeconds;
@Schema(
accessMode = Schema.AccessMode.READ_WRITE,
description = "Maximum number of times to retry an LLM call upon exception (except for non-retriable ones like authentication or invalid request errors)"
)
private Integer maxRetries;
} }

View File

@ -16,6 +16,8 @@
package org.thingsboard.rule.engine.api; package org.thingsboard.rule.engine.api;
import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.ChatModel;
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
import org.thingsboard.server.common.data.id.AiSettingsId; import org.thingsboard.server.common.data.id.AiSettingsId;
import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.common.data.id.TenantId;
@ -23,4 +25,6 @@ public interface RuleEngineAiService {
ChatModel configureChatModel(TenantId tenantId, AiSettingsId aiSettingsId); ChatModel configureChatModel(TenantId tenantId, AiSettingsId aiSettingsId);
ChatModel configureChatModel(AiProviderConfig providerConfig, AiModelConfig modelConfig);
} }

View File

@ -35,6 +35,9 @@ import org.thingsboard.rule.engine.api.TbNodeConfiguration;
import org.thingsboard.rule.engine.api.TbNodeException; import org.thingsboard.rule.engine.api.TbNodeException;
import org.thingsboard.rule.engine.api.util.TbNodeUtils; import org.thingsboard.rule.engine.api.util.TbNodeUtils;
import org.thingsboard.rule.engine.external.TbAbstractExternalNode; import org.thingsboard.rule.engine.external.TbAbstractExternalNode;
import org.thingsboard.server.common.data.ai.AiSettings;
import org.thingsboard.server.common.data.ai.model.AiModelConfig;
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
import org.thingsboard.server.common.data.id.AiSettingsId; import org.thingsboard.server.common.data.id.AiSettingsId;
import org.thingsboard.server.common.data.plugin.ComponentType; import org.thingsboard.server.common.data.plugin.ComponentType;
import org.thingsboard.server.common.msg.TbMsg; import org.thingsboard.server.common.msg.TbMsg;
@ -42,6 +45,7 @@ import org.thingsboard.server.dao.exception.DataValidationException;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import static com.google.common.util.concurrent.Futures.addCallback; import static com.google.common.util.concurrent.Futures.addCallback;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
@ -60,6 +64,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
private SystemMessage systemMessage; private SystemMessage systemMessage;
private PromptTemplate userPromptTemplate; private PromptTemplate userPromptTemplate;
private ResponseFormat responseFormat; private ResponseFormat responseFormat;
private int timeoutSeconds;
private AiSettingsId aiSettingsId; private AiSettingsId aiSettingsId;
@Override @Override
@ -88,6 +93,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
.formatted(config.getUserPrompt()) .formatted(config.getUserPrompt())
); );
timeoutSeconds = config.getTimeoutSeconds();
aiSettingsId = config.getAiSettingsId(); aiSettingsId = config.getAiSettingsId();
} }
@ -95,11 +101,11 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
if (responseFormatType == ResponseFormatType.TEXT) { if (responseFormatType == ResponseFormatType.TEXT) {
return null; return null;
} }
return responseFormatType == ResponseFormatType.JSON && jsonSchema != null ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null; return responseFormatType == ResponseFormatType.JSON && jsonSchema != null && !jsonSchema.isNull() ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null;
} }
@Override @Override
public void onMsg(TbContext ctx, TbMsg msg) { public void onMsg(TbContext ctx, TbMsg msg) throws TbNodeException {
var ackedMsg = ackIfNeeded(ctx, msg); var ackedMsg = ackIfNeeded(ctx, msg);
Map<String, Object> variables = Map.of( Map<String, Object> variables = Map.of(
@ -114,7 +120,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
.responseFormat(responseFormat) .responseFormat(responseFormat)
.build(); .build();
ChatModel chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), aiSettingsId); ChatModel chatModel = configureChatModel(ctx);
addCallback(sendChatRequest(ctx, chatModel, chatRequest), new FutureCallback<>() { addCallback(sendChatRequest(ctx, chatModel, chatRequest), new FutureCallback<>() {
@Override @Override
@ -134,6 +140,21 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
}, directExecutor()); }, directExecutor());
} }
private ChatModel configureChatModel(TbContext ctx) throws TbNodeException {
Optional<AiSettings> aiSettingsOpt = ctx.getAiSettingsService().findAiSettingsByTenantIdAndId(ctx.getTenantId(), aiSettingsId);
if (aiSettingsOpt.isEmpty()) {
throw new TbNodeException("AI settings with ID: " + aiSettingsId + " were not found", true);
}
AiProviderConfig providerConfig = aiSettingsOpt.get().getProviderConfig();
AiModelConfig modelConfig = aiSettingsOpt.get().getModelConfig();
modelConfig.setTimeoutSeconds(timeoutSeconds);
modelConfig.setMaxRetries(0); // disable retries to respect timeout set in rule node config
return ctx.getAiService().configureChatModel(providerConfig, modelConfig);
}
private ListenableFuture<String> sendChatRequest(TbContext ctx, ChatModel chatModel, ChatRequest chatRequest) { private ListenableFuture<String> sendChatRequest(TbContext ctx, ChatModel chatModel, ChatRequest chatRequest) {
return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text()); return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text());
} }
@ -157,6 +178,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
systemMessage = null; systemMessage = null;
userPromptTemplate = null; userPromptTemplate = null;
responseFormat = null; responseFormat = null;
aiSettingsId = null;
} }
} }

View File

@ -19,6 +19,8 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.JsonNode;
import dev.langchain4j.model.chat.request.ResponseFormatType; import dev.langchain4j.model.chat.request.ResponseFormatType;
import jakarta.validation.constraints.AssertTrue; import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
@ -46,10 +48,14 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
private JsonNode jsonSchema; private JsonNode jsonSchema;
@Min(value = 1, message = "must be at least 1 second")
@Max(value = 600, message = "cannot exceed 600 seconds (10 minutes)")
private int timeoutSeconds;
@JsonIgnore @JsonIgnore
@AssertTrue(message = "provided JSON Schema must conform to the Draft 2020-12 meta-schema") @AssertTrue(message = "provided JSON Schema must conform to the Draft 2020-12 meta-schema")
public boolean isJsonSchemaValid() { public boolean isJsonSchemaValid() {
return jsonSchema == null || JsonSchemaUtils.isValidJsonSchema(jsonSchema); return jsonSchema == null || jsonSchema.isNull() || JsonSchemaUtils.isValidJsonSchema(jsonSchema);
} }
@Override @Override
@ -58,6 +64,7 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
configuration.setSystemPrompt("You are helpful assistant. Your response must be in JSON format."); configuration.setSystemPrompt("You are helpful assistant. Your response must be in JSON format.");
configuration.setUserPrompt("Tell me a joke."); configuration.setUserPrompt("Tell me a joke.");
configuration.setResponseFormatType(ResponseFormatType.JSON); configuration.setResponseFormatType(ResponseFormatType.JSON);
configuration.setTimeoutSeconds(60);
return configuration; return configuration;
} }