AI rule node: support text and JSON Schema response formats

This commit is contained in:
Dmytro Skarzhynets 2025-05-21 16:04:31 +03:00
parent b64b5795a3
commit d44bbe4dd8
No known key found for this signature in database
GPG Key ID: 2B51652F224037DF
3 changed files with 173 additions and 1 deletions

View File

@ -0,0 +1,135 @@
/**
* Copyright © 2016-2025 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.rule.engine.ai;
import com.fasterxml.jackson.databind.JsonNode;
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonNullSchema;
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import java.util.ArrayList;
import java.util.List;
/**
* Converts a Jackson {@link JsonNode} JSON Schema into a Langchain4j {@link JsonSchema} model.
*/
final class Langchain4jJsonSchemaAdapter {
private Langchain4jJsonSchemaAdapter() {
throw new AssertionError("Can't instantiate utility class");
}
/**
* Creates a Langchain4j {@link JsonSchema} from the given root JSON Schema node.
*
* @param rootSchemaNode a valid JSON Schema as a Jackson {@link JsonNode}
* @return the corresponding Langchain4j {@link JsonSchema}
*/
public static JsonSchema fromJsonNode(JsonNode rootSchemaNode) {
return JsonSchema.builder()
.name(rootSchemaNode.get("title").textValue())
.rootElement(parse(rootSchemaNode))
.build();
}
private static JsonSchemaElement parse(JsonNode schemaNode) {
String description = schemaNode.hasNonNull("description") ? schemaNode.get("description").textValue() : null;
if (schemaNode.has("enum")) { // enum schemas can be defined without 'type'
return parseEnum(schemaNode).description(description).build();
}
String type = schemaNode.get("type").textValue();
return switch (type) {
case "string" -> JsonStringSchema.builder().description(description).build();
case "integer" -> JsonIntegerSchema.builder().description(description).build();
case "boolean" -> JsonBooleanSchema.builder().description(description).build();
case "number" -> JsonNumberSchema.builder().description(description).build();
case "null" -> new JsonNullSchema();
case "object" -> parseObject(schemaNode).description(description).build();
case "array" -> parseArray(schemaNode).description(description).build();
default -> throw new IllegalArgumentException("Unsupported JSON Schema type: " + type);
};
}
private static JsonEnumSchema.Builder parseEnum(JsonNode enumSchema) {
var builder = new JsonEnumSchema.Builder();
List<String> enumValues = new ArrayList<>();
for (JsonNode element : enumSchema.get("enum")) {
if (!element.isTextual()) {
throw new IllegalArgumentException("Expected each 'enum' element to be a string, but found: " + element.getNodeType());
}
enumValues.add(element.textValue());
}
builder.enumValues(enumValues);
return builder;
}
private static JsonObjectSchema.Builder parseObject(JsonNode objectSchema) {
var builder = new JsonObjectSchema.Builder();
JsonNode propertiesNode = objectSchema.get("properties");
if (propertiesNode != null) {
propertiesNode.fields().forEachRemaining(entry -> {
String key = entry.getKey();
JsonNode value = entry.getValue();
builder.addProperty(key, parse(value));
});
}
List<String> required = new ArrayList<>();
JsonNode requiredNode = objectSchema.get("required");
if (requiredNode != null) {
for (JsonNode value : requiredNode) {
required.add(value.textValue());
}
}
builder.required(required);
boolean additionalProperties = true; // default value if 'additionalProperties' is not set
JsonNode additionalPropertiesNode = objectSchema.get("additionalProperties");
if (additionalPropertiesNode != null) {
if (!additionalPropertiesNode.isBoolean()) {
throw new IllegalArgumentException("Expected 'additionalProperties' to be a boolean, but found: " + additionalPropertiesNode.getNodeType());
}
additionalProperties = additionalPropertiesNode.booleanValue();
}
builder.additionalProperties(additionalProperties);
return builder;
}
private static JsonArraySchema.Builder parseArray(JsonNode arraySchema) {
var builder = new JsonArraySchema.Builder();
if (arraySchema.hasNonNull("items")) {
builder.items(parse(arraySchema.get("items")));
}
return builder;
}
}

View File

@ -15,6 +15,7 @@
*/
package org.thingsboard.rule.engine.ai;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.ListenableFuture;
import dev.langchain4j.data.message.SystemMessage;
@ -22,6 +23,8 @@ import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.input.PromptTemplate;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.thingsboard.common.util.JacksonUtil;
@ -55,6 +58,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
private SystemMessage systemMessage;
private PromptTemplate userPromptTemplate;
private ResponseFormat responseFormat;
private ChatModel chatModel;
@Override
@ -69,6 +73,11 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
throw new TbNodeException(e, true);
}
responseFormat = ResponseFormat.builder()
.type(config.getResponseFormatType())
.jsonSchema(getJsonSchema(config.getResponseFormatType(), config.getJsonSchema()))
.build();
systemMessage = SystemMessage.from(config.getSystemPrompt());
userPromptTemplate = PromptTemplate.from("""
User-provided task or question: %s
@ -80,6 +89,13 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), config.getAiSettingsId());
}
private static JsonSchema getJsonSchema(ResponseFormatType responseFormatType, JsonNode jsonSchema) {
if (responseFormatType == ResponseFormatType.TEXT) {
return null;
}
return responseFormatType == ResponseFormatType.JSON && jsonSchema != null ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null;
}
@Override
public void onMsg(TbContext ctx, TbMsg msg) {
var ackedMsg = ackIfNeeded(ctx, msg);
@ -93,12 +109,15 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
var chatRequest = ChatRequest.builder()
.messages(List.of(systemMessage, userMessage))
.responseFormat(ResponseFormat.JSON)
.responseFormat(responseFormat)
.build();
addCallback(sendChatRequest(ctx, chatRequest), new FutureCallback<>() {
@Override
public void onSuccess(String response) {
if (!isValidJson(response)) {
response = wrapInJsonObject(response);
}
tellSuccess(ctx, ackedMsg.transform()
.data(response)
.build());
@ -115,11 +134,24 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text());
}
private static boolean isValidJson(String jsonString) {
try {
return JacksonUtil.toJsonNode(jsonString) != null;
} catch (IllegalArgumentException e) {
return false;
}
}
private static String wrapInJsonObject(String response) {
return JacksonUtil.newObjectNode().put("response", response).toString();
}
@Override
public void destroy() {
super.destroy();
systemMessage = null;
userPromptTemplate = null;
responseFormat = null;
chatModel = null;
}

View File

@ -16,6 +16,7 @@
package org.thingsboard.rule.engine.ai;
import com.fasterxml.jackson.databind.JsonNode;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
@ -39,6 +40,9 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
@Length(min = 1, max = 1000)
private String userPrompt;
@NotNull
private ResponseFormatType responseFormatType;
private JsonNode jsonSchema;
@AssertTrue(message = "provided JSON Schema must conform to the Draft 2020-12 meta-schema")
@ -51,6 +55,7 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
var configuration = new TbAiNodeConfiguration();
configuration.setSystemPrompt("You are helpful assistant. Your response must be in JSON format.");
configuration.setUserPrompt("Tell me a joke.");
configuration.setResponseFormatType(ResponseFormatType.JSON);
return configuration;
}