AI rule node: support node patterns in prompts

This commit is contained in:
Dmytro Skarzhynets 2025-06-17 18:39:16 +03:00
parent d2d22a44c2
commit c096397010
No known key found for this signature in database
GPG Key ID: 2B51652F224037DF

View File

@ -27,7 +27,6 @@ import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType; import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonSchema; import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.input.PromptTemplate;
import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.NonNull;
import org.thingsboard.common.util.JacksonUtil; import org.thingsboard.common.util.JacksonUtil;
import org.thingsboard.rule.engine.api.RuleNode; import org.thingsboard.rule.engine.api.RuleNode;
@ -46,11 +45,9 @@ import org.thingsboard.server.common.msg.TbMsg;
import org.thingsboard.server.dao.exception.DataValidationException; import org.thingsboard.server.dao.exception.DataValidationException;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.util.Objects.requireNonNullElse;
import static org.thingsboard.server.dao.service.ConstraintValidator.validateFields; import static org.thingsboard.server.dao.service.ConstraintValidator.validateFields;
@RuleNode( @RuleNode(
@ -63,8 +60,8 @@ import static org.thingsboard.server.dao.service.ConstraintValidator.validateFie
) )
public final class TbAiNode extends TbAbstractExternalNode implements TbNode { public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
private SystemMessage systemMessage; private String systemPrompt;
private PromptTemplate userPromptTemplate; private String userPrompt;
private ResponseFormat responseFormat; private ResponseFormat responseFormat;
private int timeoutSeconds; private int timeoutSeconds;
private AiSettingsId aiSettingsId; private AiSettingsId aiSettingsId;
@ -86,15 +83,8 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
.jsonSchema(getJsonSchema(config.getResponseFormatType(), config.getJsonSchema())) .jsonSchema(getJsonSchema(config.getResponseFormatType(), config.getJsonSchema()))
.build(); .build();
systemMessage = SystemMessage.from(config.getSystemPrompt()); systemPrompt = config.getSystemPrompt();
userPromptTemplate = PromptTemplate.from(""" userPrompt = config.getUserPrompt();
User-provided task or question: %s
Rule engine message payload: {{msgPayload}}
Rule engine message metadata: {{msgMetadata}}
Rule engine message type: {{msgType}}"""
.formatted(config.getUserPrompt())
);
timeoutSeconds = config.getTimeoutSeconds(); timeoutSeconds = config.getTimeoutSeconds();
if (!aiSettingsExist(ctx, config.getAiSettingsId())) { if (!aiSettingsExist(ctx, config.getAiSettingsId())) {
@ -118,12 +108,8 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
public void onMsg(TbContext ctx, TbMsg msg) throws TbNodeException { public void onMsg(TbContext ctx, TbMsg msg) throws TbNodeException {
var ackedMsg = ackIfNeeded(ctx, msg); var ackedMsg = ackIfNeeded(ctx, msg);
Map<String, Object> variables = Map.of( var systemMessage = SystemMessage.from(TbNodeUtils.processPattern(systemPrompt, ackedMsg));
"msgPayload", msg.getData(), var userMessage = UserMessage.from(TbNodeUtils.processPattern(userPrompt, ackedMsg));
"msgMetadata", requireNonNullElse(JacksonUtil.toString(msg.getMetaData().getData()), "{}"),
"msgType", msg.getType()
);
UserMessage userMessage = userPromptTemplate.apply(variables).toUserMessage();
var chatRequest = ChatRequest.builder() var chatRequest = ChatRequest.builder()
.messages(List.of(systemMessage, userMessage)) .messages(List.of(systemMessage, userMessage))
@ -183,8 +169,8 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
@Override @Override
public void destroy() { public void destroy() {
super.destroy(); super.destroy();
systemMessage = null; systemPrompt = null;
userPromptTemplate = null; userPrompt = null;
responseFormat = null; responseFormat = null;
aiSettingsId = null; aiSettingsId = null;
} }