AI rule node: support text and JSON Schema response formats
This commit is contained in:
parent
b64b5795a3
commit
d44bbe4dd8
@ -0,0 +1,135 @@
|
|||||||
|
/**
|
||||||
|
* Copyright © 2016-2025 The Thingsboard Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.thingsboard.rule.engine.ai;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonNullSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a Jackson {@link JsonNode} JSON Schema into a Langchain4j {@link JsonSchema} model.
|
||||||
|
*/
|
||||||
|
final class Langchain4jJsonSchemaAdapter {
|
||||||
|
|
||||||
|
private Langchain4jJsonSchemaAdapter() {
|
||||||
|
throw new AssertionError("Can't instantiate utility class");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a Langchain4j {@link JsonSchema} from the given root JSON Schema node.
|
||||||
|
*
|
||||||
|
* @param rootSchemaNode a valid JSON Schema as a Jackson {@link JsonNode}
|
||||||
|
* @return the corresponding Langchain4j {@link JsonSchema}
|
||||||
|
*/
|
||||||
|
public static JsonSchema fromJsonNode(JsonNode rootSchemaNode) {
|
||||||
|
return JsonSchema.builder()
|
||||||
|
.name(rootSchemaNode.get("title").textValue())
|
||||||
|
.rootElement(parse(rootSchemaNode))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JsonSchemaElement parse(JsonNode schemaNode) {
|
||||||
|
String description = schemaNode.hasNonNull("description") ? schemaNode.get("description").textValue() : null;
|
||||||
|
|
||||||
|
if (schemaNode.has("enum")) { // enum schemas can be defined without 'type'
|
||||||
|
return parseEnum(schemaNode).description(description).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
String type = schemaNode.get("type").textValue();
|
||||||
|
|
||||||
|
return switch (type) {
|
||||||
|
case "string" -> JsonStringSchema.builder().description(description).build();
|
||||||
|
case "integer" -> JsonIntegerSchema.builder().description(description).build();
|
||||||
|
case "boolean" -> JsonBooleanSchema.builder().description(description).build();
|
||||||
|
case "number" -> JsonNumberSchema.builder().description(description).build();
|
||||||
|
case "null" -> new JsonNullSchema();
|
||||||
|
case "object" -> parseObject(schemaNode).description(description).build();
|
||||||
|
case "array" -> parseArray(schemaNode).description(description).build();
|
||||||
|
default -> throw new IllegalArgumentException("Unsupported JSON Schema type: " + type);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JsonEnumSchema.Builder parseEnum(JsonNode enumSchema) {
|
||||||
|
var builder = new JsonEnumSchema.Builder();
|
||||||
|
|
||||||
|
List<String> enumValues = new ArrayList<>();
|
||||||
|
for (JsonNode element : enumSchema.get("enum")) {
|
||||||
|
if (!element.isTextual()) {
|
||||||
|
throw new IllegalArgumentException("Expected each 'enum' element to be a string, but found: " + element.getNodeType());
|
||||||
|
}
|
||||||
|
enumValues.add(element.textValue());
|
||||||
|
}
|
||||||
|
builder.enumValues(enumValues);
|
||||||
|
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JsonObjectSchema.Builder parseObject(JsonNode objectSchema) {
|
||||||
|
var builder = new JsonObjectSchema.Builder();
|
||||||
|
|
||||||
|
JsonNode propertiesNode = objectSchema.get("properties");
|
||||||
|
if (propertiesNode != null) {
|
||||||
|
propertiesNode.fields().forEachRemaining(entry -> {
|
||||||
|
String key = entry.getKey();
|
||||||
|
JsonNode value = entry.getValue();
|
||||||
|
builder.addProperty(key, parse(value));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> required = new ArrayList<>();
|
||||||
|
JsonNode requiredNode = objectSchema.get("required");
|
||||||
|
if (requiredNode != null) {
|
||||||
|
for (JsonNode value : requiredNode) {
|
||||||
|
required.add(value.textValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder.required(required);
|
||||||
|
|
||||||
|
boolean additionalProperties = true; // default value if 'additionalProperties' is not set
|
||||||
|
JsonNode additionalPropertiesNode = objectSchema.get("additionalProperties");
|
||||||
|
if (additionalPropertiesNode != null) {
|
||||||
|
if (!additionalPropertiesNode.isBoolean()) {
|
||||||
|
throw new IllegalArgumentException("Expected 'additionalProperties' to be a boolean, but found: " + additionalPropertiesNode.getNodeType());
|
||||||
|
}
|
||||||
|
additionalProperties = additionalPropertiesNode.booleanValue();
|
||||||
|
}
|
||||||
|
builder.additionalProperties(additionalProperties);
|
||||||
|
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static JsonArraySchema.Builder parseArray(JsonNode arraySchema) {
|
||||||
|
var builder = new JsonArraySchema.Builder();
|
||||||
|
|
||||||
|
if (arraySchema.hasNonNull("items")) {
|
||||||
|
builder.items(parse(arraySchema.get("items")));
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -15,6 +15,7 @@
|
|||||||
*/
|
*/
|
||||||
package org.thingsboard.rule.engine.ai;
|
package org.thingsboard.rule.engine.ai;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
import com.google.common.util.concurrent.FutureCallback;
|
import com.google.common.util.concurrent.FutureCallback;
|
||||||
import com.google.common.util.concurrent.ListenableFuture;
|
import com.google.common.util.concurrent.ListenableFuture;
|
||||||
import dev.langchain4j.data.message.SystemMessage;
|
import dev.langchain4j.data.message.SystemMessage;
|
||||||
@ -22,6 +23,8 @@ import dev.langchain4j.data.message.UserMessage;
|
|||||||
import dev.langchain4j.model.chat.ChatModel;
|
import dev.langchain4j.model.chat.ChatModel;
|
||||||
import dev.langchain4j.model.chat.request.ChatRequest;
|
import dev.langchain4j.model.chat.request.ChatRequest;
|
||||||
import dev.langchain4j.model.chat.request.ResponseFormat;
|
import dev.langchain4j.model.chat.request.ResponseFormat;
|
||||||
|
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
||||||
|
import dev.langchain4j.model.chat.request.json.JsonSchema;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||||
import org.thingsboard.common.util.JacksonUtil;
|
import org.thingsboard.common.util.JacksonUtil;
|
||||||
@ -55,6 +58,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
|||||||
|
|
||||||
private SystemMessage systemMessage;
|
private SystemMessage systemMessage;
|
||||||
private PromptTemplate userPromptTemplate;
|
private PromptTemplate userPromptTemplate;
|
||||||
|
private ResponseFormat responseFormat;
|
||||||
private ChatModel chatModel;
|
private ChatModel chatModel;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -69,6 +73,11 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
|||||||
throw new TbNodeException(e, true);
|
throw new TbNodeException(e, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
responseFormat = ResponseFormat.builder()
|
||||||
|
.type(config.getResponseFormatType())
|
||||||
|
.jsonSchema(getJsonSchema(config.getResponseFormatType(), config.getJsonSchema()))
|
||||||
|
.build();
|
||||||
|
|
||||||
systemMessage = SystemMessage.from(config.getSystemPrompt());
|
systemMessage = SystemMessage.from(config.getSystemPrompt());
|
||||||
userPromptTemplate = PromptTemplate.from("""
|
userPromptTemplate = PromptTemplate.from("""
|
||||||
User-provided task or question: %s
|
User-provided task or question: %s
|
||||||
@ -80,6 +89,13 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
|||||||
chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), config.getAiSettingsId());
|
chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), config.getAiSettingsId());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static JsonSchema getJsonSchema(ResponseFormatType responseFormatType, JsonNode jsonSchema) {
|
||||||
|
if (responseFormatType == ResponseFormatType.TEXT) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return responseFormatType == ResponseFormatType.JSON && jsonSchema != null ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onMsg(TbContext ctx, TbMsg msg) {
|
public void onMsg(TbContext ctx, TbMsg msg) {
|
||||||
var ackedMsg = ackIfNeeded(ctx, msg);
|
var ackedMsg = ackIfNeeded(ctx, msg);
|
||||||
@ -93,12 +109,15 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
|||||||
|
|
||||||
var chatRequest = ChatRequest.builder()
|
var chatRequest = ChatRequest.builder()
|
||||||
.messages(List.of(systemMessage, userMessage))
|
.messages(List.of(systemMessage, userMessage))
|
||||||
.responseFormat(ResponseFormat.JSON)
|
.responseFormat(responseFormat)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
addCallback(sendChatRequest(ctx, chatRequest), new FutureCallback<>() {
|
addCallback(sendChatRequest(ctx, chatRequest), new FutureCallback<>() {
|
||||||
@Override
|
@Override
|
||||||
public void onSuccess(String response) {
|
public void onSuccess(String response) {
|
||||||
|
if (!isValidJson(response)) {
|
||||||
|
response = wrapInJsonObject(response);
|
||||||
|
}
|
||||||
tellSuccess(ctx, ackedMsg.transform()
|
tellSuccess(ctx, ackedMsg.transform()
|
||||||
.data(response)
|
.data(response)
|
||||||
.build());
|
.build());
|
||||||
@ -115,11 +134,24 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
|
|||||||
return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text());
|
return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static boolean isValidJson(String jsonString) {
|
||||||
|
try {
|
||||||
|
return JacksonUtil.toJsonNode(jsonString) != null;
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String wrapInJsonObject(String response) {
|
||||||
|
return JacksonUtil.newObjectNode().put("response", response).toString();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void destroy() {
|
public void destroy() {
|
||||||
super.destroy();
|
super.destroy();
|
||||||
systemMessage = null;
|
systemMessage = null;
|
||||||
userPromptTemplate = null;
|
userPromptTemplate = null;
|
||||||
|
responseFormat = null;
|
||||||
chatModel = null;
|
chatModel = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@
|
|||||||
package org.thingsboard.rule.engine.ai;
|
package org.thingsboard.rule.engine.ai;
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import dev.langchain4j.model.chat.request.ResponseFormatType;
|
||||||
import jakarta.validation.constraints.AssertTrue;
|
import jakarta.validation.constraints.AssertTrue;
|
||||||
import jakarta.validation.constraints.NotBlank;
|
import jakarta.validation.constraints.NotBlank;
|
||||||
import jakarta.validation.constraints.NotNull;
|
import jakarta.validation.constraints.NotNull;
|
||||||
@ -39,6 +40,9 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
|
|||||||
@Length(min = 1, max = 1000)
|
@Length(min = 1, max = 1000)
|
||||||
private String userPrompt;
|
private String userPrompt;
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
private ResponseFormatType responseFormatType;
|
||||||
|
|
||||||
private JsonNode jsonSchema;
|
private JsonNode jsonSchema;
|
||||||
|
|
||||||
@AssertTrue(message = "provided JSON Schema must conform to the Draft 2020-12 meta-schema")
|
@AssertTrue(message = "provided JSON Schema must conform to the Draft 2020-12 meta-schema")
|
||||||
@ -51,6 +55,7 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
|
|||||||
var configuration = new TbAiNodeConfiguration();
|
var configuration = new TbAiNodeConfiguration();
|
||||||
configuration.setSystemPrompt("You are helpful assistant. Your response must be in JSON format.");
|
configuration.setSystemPrompt("You are helpful assistant. Your response must be in JSON format.");
|
||||||
configuration.setUserPrompt("Tell me a joke.");
|
configuration.setUserPrompt("Tell me a joke.");
|
||||||
|
configuration.setResponseFormatType(ResponseFormatType.JSON);
|
||||||
return configuration;
|
return configuration;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user