AI rule node: support node patterns in prompts
This commit is contained in:
parent
d2d22a44c2
commit
c096397010
@ -27,7 +27,6 @@ import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.thingsboard.common.util.JacksonUtil;
|
||||
import org.thingsboard.rule.engine.api.RuleNode;
|
||||
@ -46,11 +45,9 @@ import org.thingsboard.server.common.msg.TbMsg;
|
||||
import org.thingsboard.server.dao.exception.DataValidationException;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.NoSuchElementException;
|
||||
|
||||
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
|
||||
import static java.util.Objects.requireNonNullElse;
|
||||
import static org.thingsboard.server.dao.service.ConstraintValidator.validateFields;
|
||||
|
||||
@RuleNode(
|
||||
@ -63,8 +60,8 @@ import static org.thingsboard.server.dao.service.ConstraintValidator.validateFie
|
||||
)
|
||||
public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
||||
|
||||
private SystemMessage systemMessage;
|
||||
private PromptTemplate userPromptTemplate;
|
||||
private String systemPrompt;
|
||||
private String userPrompt;
|
||||
private ResponseFormat responseFormat;
|
||||
private int timeoutSeconds;
|
||||
private AiSettingsId aiSettingsId;
|
||||
@ -86,15 +83,8 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
||||
.jsonSchema(getJsonSchema(config.getResponseFormatType(), config.getJsonSchema()))
|
||||
.build();
|
||||
|
||||
systemMessage = SystemMessage.from(config.getSystemPrompt());
|
||||
userPromptTemplate = PromptTemplate.from("""
|
||||
User-provided task or question: %s
|
||||
Rule engine message payload: {{msgPayload}}
|
||||
Rule engine message metadata: {{msgMetadata}}
|
||||
Rule engine message type: {{msgType}}"""
|
||||
.formatted(config.getUserPrompt())
|
||||
);
|
||||
|
||||
systemPrompt = config.getSystemPrompt();
|
||||
userPrompt = config.getUserPrompt();
|
||||
timeoutSeconds = config.getTimeoutSeconds();
|
||||
|
||||
if (!aiSettingsExist(ctx, config.getAiSettingsId())) {
|
||||
@ -118,12 +108,8 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
||||
public void onMsg(TbContext ctx, TbMsg msg) throws TbNodeException {
|
||||
var ackedMsg = ackIfNeeded(ctx, msg);
|
||||
|
||||
Map<String, Object> variables = Map.of(
|
||||
"msgPayload", msg.getData(),
|
||||
"msgMetadata", requireNonNullElse(JacksonUtil.toString(msg.getMetaData().getData()), "{}"),
|
||||
"msgType", msg.getType()
|
||||
);
|
||||
UserMessage userMessage = userPromptTemplate.apply(variables).toUserMessage();
|
||||
var systemMessage = SystemMessage.from(TbNodeUtils.processPattern(systemPrompt, ackedMsg));
|
||||
var userMessage = UserMessage.from(TbNodeUtils.processPattern(userPrompt, ackedMsg));
|
||||
|
||||
var chatRequest = ChatRequest.builder()
|
||||
.messages(List.of(systemMessage, userMessage))
|
||||
@ -183,8 +169,8 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
||||
@Override
|
||||
public void destroy() {
|
||||
super.destroy();
|
||||
systemMessage = null;
|
||||
userPromptTemplate = null;
|
||||
systemPrompt = null;
|
||||
userPrompt = null;
|
||||
responseFormat = null;
|
||||
aiSettingsId = null;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user