From d44bbe4dd8985dab843e3d1589981d7719566c14 Mon Sep 17 00:00:00 2001 From: Dmytro Skarzhynets Date: Wed, 21 May 2025 16:04:31 +0300 Subject: [PATCH] AI rule node: support text and JSON Schema response formats --- .../ai/Langchain4jJsonSchemaAdapter.java | 135 ++++++++++++++++++ .../thingsboard/rule/engine/ai/TbAiNode.java | 34 ++++- .../rule/engine/ai/TbAiNodeConfiguration.java | 5 + 3 files changed, 173 insertions(+), 1 deletion(-) create mode 100644 rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/Langchain4jJsonSchemaAdapter.java diff --git a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/Langchain4jJsonSchemaAdapter.java b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/Langchain4jJsonSchemaAdapter.java new file mode 100644 index 0000000000..9ad745f3ed --- /dev/null +++ b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/Langchain4jJsonSchemaAdapter.java @@ -0,0 +1,135 @@ +/** + * Copyright © 2016-2025 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.rule.engine.ai; + +import com.fasterxml.jackson.databind.JsonNode; +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; +import dev.langchain4j.model.chat.request.json.JsonEnumSchema; +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonNullSchema; +import dev.langchain4j.model.chat.request.json.JsonNumberSchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; + +import java.util.ArrayList; +import java.util.List; + +/** + * Converts a Jackson {@link JsonNode} JSON Schema into a Langchain4j {@link JsonSchema} model. + */ +final class Langchain4jJsonSchemaAdapter { + + private Langchain4jJsonSchemaAdapter() { + throw new AssertionError("Can't instantiate utility class"); + } + + /** + * Creates a Langchain4j {@link JsonSchema} from the given root JSON Schema node. + * + * @param rootSchemaNode a valid JSON Schema as a Jackson {@link JsonNode} + * @return the corresponding Langchain4j {@link JsonSchema} + */ + public static JsonSchema fromJsonNode(JsonNode rootSchemaNode) { + return JsonSchema.builder() + .name(rootSchemaNode.get("title").textValue()) + .rootElement(parse(rootSchemaNode)) + .build(); + } + + private static JsonSchemaElement parse(JsonNode schemaNode) { + String description = schemaNode.hasNonNull("description") ? schemaNode.get("description").textValue() : null; + + if (schemaNode.has("enum")) { // enum schemas can be defined without 'type' + return parseEnum(schemaNode).description(description).build(); + } + + String type = schemaNode.get("type").textValue(); + + return switch (type) { + case "string" -> JsonStringSchema.builder().description(description).build(); + case "integer" -> JsonIntegerSchema.builder().description(description).build(); + case "boolean" -> JsonBooleanSchema.builder().description(description).build(); + case "number" -> JsonNumberSchema.builder().description(description).build(); + case "null" -> new JsonNullSchema(); + case "object" -> parseObject(schemaNode).description(description).build(); + case "array" -> parseArray(schemaNode).description(description).build(); + default -> throw new IllegalArgumentException("Unsupported JSON Schema type: " + type); + }; + } + + private static JsonEnumSchema.Builder parseEnum(JsonNode enumSchema) { + var builder = new JsonEnumSchema.Builder(); + + List enumValues = new ArrayList<>(); + for (JsonNode element : enumSchema.get("enum")) { + if (!element.isTextual()) { + throw new IllegalArgumentException("Expected each 'enum' element to be a string, but found: " + element.getNodeType()); + } + enumValues.add(element.textValue()); + } + builder.enumValues(enumValues); + + return builder; + } + + private static JsonObjectSchema.Builder parseObject(JsonNode objectSchema) { + var builder = new JsonObjectSchema.Builder(); + + JsonNode propertiesNode = objectSchema.get("properties"); + if (propertiesNode != null) { + propertiesNode.fields().forEachRemaining(entry -> { + String key = entry.getKey(); + JsonNode value = entry.getValue(); + builder.addProperty(key, parse(value)); + }); + } + + List required = new ArrayList<>(); + JsonNode requiredNode = objectSchema.get("required"); + if (requiredNode != null) { + for (JsonNode value : requiredNode) { + required.add(value.textValue()); + } + } + builder.required(required); + + boolean additionalProperties = true; // default value if 'additionalProperties' is not set + JsonNode additionalPropertiesNode = objectSchema.get("additionalProperties"); + if (additionalPropertiesNode != null) { + if (!additionalPropertiesNode.isBoolean()) { + throw new IllegalArgumentException("Expected 'additionalProperties' to be a boolean, but found: " + additionalPropertiesNode.getNodeType()); + } + additionalProperties = additionalPropertiesNode.booleanValue(); + } + builder.additionalProperties(additionalProperties); + + return builder; + } + + private static JsonArraySchema.Builder parseArray(JsonNode arraySchema) { + var builder = new JsonArraySchema.Builder(); + + if (arraySchema.hasNonNull("items")) { + builder.items(parse(arraySchema.get("items"))); + } + + return builder; + } + +} diff --git a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNode.java b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNode.java index e3e9bdbf61..315a02a3b9 100644 --- a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNode.java +++ b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNode.java @@ -15,6 +15,7 @@ */ package org.thingsboard.rule.engine.ai; +import com.fasterxml.jackson.databind.JsonNode; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.ListenableFuture; import dev.langchain4j.data.message.SystemMessage; @@ -22,6 +23,8 @@ import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.request.ChatRequest; import dev.langchain4j.model.chat.request.ResponseFormat; +import dev.langchain4j.model.chat.request.ResponseFormatType; +import dev.langchain4j.model.chat.request.json.JsonSchema; import dev.langchain4j.model.input.PromptTemplate; import org.checkerframework.checker.nullness.qual.NonNull; import org.thingsboard.common.util.JacksonUtil; @@ -55,6 +58,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode { private SystemMessage systemMessage; private PromptTemplate userPromptTemplate; + private ResponseFormat responseFormat; private ChatModel chatModel; @Override @@ -69,6 +73,11 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode { throw new TbNodeException(e, true); } + responseFormat = ResponseFormat.builder() + .type(config.getResponseFormatType()) + .jsonSchema(getJsonSchema(config.getResponseFormatType(), config.getJsonSchema())) + .build(); + systemMessage = SystemMessage.from(config.getSystemPrompt()); userPromptTemplate = PromptTemplate.from(""" User-provided task or question: %s @@ -80,6 +89,13 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode { chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), config.getAiSettingsId()); } + private static JsonSchema getJsonSchema(ResponseFormatType responseFormatType, JsonNode jsonSchema) { + if (responseFormatType == ResponseFormatType.TEXT) { + return null; + } + return responseFormatType == ResponseFormatType.JSON && jsonSchema != null ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null; + } + @Override public void onMsg(TbContext ctx, TbMsg msg) { var ackedMsg = ackIfNeeded(ctx, msg); @@ -93,12 +109,15 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode { var chatRequest = ChatRequest.builder() .messages(List.of(systemMessage, userMessage)) - .responseFormat(ResponseFormat.JSON) + .responseFormat(responseFormat) .build(); addCallback(sendChatRequest(ctx, chatRequest), new FutureCallback<>() { @Override public void onSuccess(String response) { + if (!isValidJson(response)) { + response = wrapInJsonObject(response); + } tellSuccess(ctx, ackedMsg.transform() .data(response) .build()); @@ -115,11 +134,24 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode { return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text()); } + private static boolean isValidJson(String jsonString) { + try { + return JacksonUtil.toJsonNode(jsonString) != null; + } catch (IllegalArgumentException e) { + return false; + } + } + + private static String wrapInJsonObject(String response) { + return JacksonUtil.newObjectNode().put("response", response).toString(); + } + @Override public void destroy() { super.destroy(); systemMessage = null; userPromptTemplate = null; + responseFormat = null; chatModel = null; } diff --git a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNodeConfiguration.java b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNodeConfiguration.java index 380e5ccd55..c3234f61bb 100644 --- a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNodeConfiguration.java +++ b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/ai/TbAiNodeConfiguration.java @@ -16,6 +16,7 @@ package org.thingsboard.rule.engine.ai; import com.fasterxml.jackson.databind.JsonNode; +import dev.langchain4j.model.chat.request.ResponseFormatType; import jakarta.validation.constraints.AssertTrue; import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotNull; @@ -39,6 +40,9 @@ public class TbAiNodeConfiguration implements NodeConfiguration