AI rule node: configure chat model on each message

This commit is contained in:
Dmytro Skarzhynets 2025-06-09 17:51:24 +03:00
parent a9633cf1ce
commit 19c234fcdf
No known key found for this signature in database
GPG Key ID: 2B51652F224037DF

View File

@ -35,6 +35,7 @@ import org.thingsboard.rule.engine.api.TbNodeConfiguration;
import org.thingsboard.rule.engine.api.TbNodeException; import org.thingsboard.rule.engine.api.TbNodeException;
import org.thingsboard.rule.engine.api.util.TbNodeUtils; import org.thingsboard.rule.engine.api.util.TbNodeUtils;
import org.thingsboard.rule.engine.external.TbAbstractExternalNode; import org.thingsboard.rule.engine.external.TbAbstractExternalNode;
import org.thingsboard.server.common.data.id.AiSettingsId;
import org.thingsboard.server.common.data.plugin.ComponentType; import org.thingsboard.server.common.data.plugin.ComponentType;
import org.thingsboard.server.common.msg.TbMsg; import org.thingsboard.server.common.msg.TbMsg;
import org.thingsboard.server.dao.exception.DataValidationException; import org.thingsboard.server.dao.exception.DataValidationException;
@ -59,7 +60,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
private SystemMessage systemMessage; private SystemMessage systemMessage;
private PromptTemplate userPromptTemplate; private PromptTemplate userPromptTemplate;
private ResponseFormat responseFormat; private ResponseFormat responseFormat;
private ChatModel chatModel; private AiSettingsId aiSettingsId;
@Override @Override
public void init(TbContext ctx, TbNodeConfiguration configuration) throws TbNodeException { public void init(TbContext ctx, TbNodeConfiguration configuration) throws TbNodeException {
@ -86,7 +87,8 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
Rule engine message type: {{msgType}}""" Rule engine message type: {{msgType}}"""
.formatted(config.getUserPrompt()) .formatted(config.getUserPrompt())
); );
chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), config.getAiSettingsId());
aiSettingsId = config.getAiSettingsId();
} }
private static JsonSchema getJsonSchema(ResponseFormatType responseFormatType, JsonNode jsonSchema) { private static JsonSchema getJsonSchema(ResponseFormatType responseFormatType, JsonNode jsonSchema) {
@ -112,7 +114,9 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
.responseFormat(responseFormat) .responseFormat(responseFormat)
.build(); .build();
addCallback(sendChatRequest(ctx, chatRequest), new FutureCallback<>() { ChatModel chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), aiSettingsId);
addCallback(sendChatRequest(ctx, chatModel, chatRequest), new FutureCallback<>() {
@Override @Override
public void onSuccess(String response) { public void onSuccess(String response) {
if (!isValidJsonObject(response)) { if (!isValidJsonObject(response)) {
@ -130,7 +134,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
}, directExecutor()); }, directExecutor());
} }
private ListenableFuture<String> sendChatRequest(TbContext ctx, ChatRequest chatRequest) { private ListenableFuture<String> sendChatRequest(TbContext ctx, ChatModel chatModel, ChatRequest chatRequest) {
return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text()); return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text());
} }
@ -153,7 +157,6 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
systemMessage = null; systemMessage = null;
userPromptTemplate = null; userPromptTemplate = null;
responseFormat = null; responseFormat = null;
chatModel = null;
} }
} }