AI rule node: support text and JSON Schema response formats
This commit is contained in:
		
							parent
							
								
									b64b5795a3
								
							
						
					
					
						commit
						d44bbe4dd8
					
				@ -0,0 +1,135 @@
 | 
			
		||||
/**
 | 
			
		||||
 * Copyright © 2016-2025 The Thingsboard Authors
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
 * You may obtain a copy of the License at
 | 
			
		||||
 *
 | 
			
		||||
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
package org.thingsboard.rule.engine.ai;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.databind.JsonNode;
 | 
			
		||||
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
 | 
			
		||||
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
 | 
			
		||||
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
 | 
			
		||||
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
 | 
			
		||||
import dev.langchain4j.model.chat.request.json.JsonNullSchema;
 | 
			
		||||
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
 | 
			
		||||
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
 | 
			
		||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
 | 
			
		||||
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
 | 
			
		||||
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Converts a Jackson {@link JsonNode} JSON Schema into a Langchain4j {@link JsonSchema} model.
 | 
			
		||||
 */
 | 
			
		||||
final class Langchain4jJsonSchemaAdapter {
 | 
			
		||||
 | 
			
		||||
    private Langchain4jJsonSchemaAdapter() {
 | 
			
		||||
        throw new AssertionError("Can't instantiate utility class");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Creates a Langchain4j {@link JsonSchema} from the given root JSON Schema node.
 | 
			
		||||
     *
 | 
			
		||||
     * @param rootSchemaNode a valid JSON Schema as a Jackson {@link JsonNode}
 | 
			
		||||
     * @return the corresponding Langchain4j {@link JsonSchema}
 | 
			
		||||
     */
 | 
			
		||||
    public static JsonSchema fromJsonNode(JsonNode rootSchemaNode) {
 | 
			
		||||
        return JsonSchema.builder()
 | 
			
		||||
                .name(rootSchemaNode.get("title").textValue())
 | 
			
		||||
                .rootElement(parse(rootSchemaNode))
 | 
			
		||||
                .build();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static JsonSchemaElement parse(JsonNode schemaNode) {
 | 
			
		||||
        String description = schemaNode.hasNonNull("description") ? schemaNode.get("description").textValue() : null;
 | 
			
		||||
 | 
			
		||||
        if (schemaNode.has("enum")) { // enum schemas can be defined without 'type'
 | 
			
		||||
            return parseEnum(schemaNode).description(description).build();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        String type = schemaNode.get("type").textValue();
 | 
			
		||||
 | 
			
		||||
        return switch (type) {
 | 
			
		||||
            case "string" -> JsonStringSchema.builder().description(description).build();
 | 
			
		||||
            case "integer" -> JsonIntegerSchema.builder().description(description).build();
 | 
			
		||||
            case "boolean" -> JsonBooleanSchema.builder().description(description).build();
 | 
			
		||||
            case "number" -> JsonNumberSchema.builder().description(description).build();
 | 
			
		||||
            case "null" -> new JsonNullSchema();
 | 
			
		||||
            case "object" -> parseObject(schemaNode).description(description).build();
 | 
			
		||||
            case "array" -> parseArray(schemaNode).description(description).build();
 | 
			
		||||
            default -> throw new IllegalArgumentException("Unsupported JSON Schema type: " + type);
 | 
			
		||||
        };
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static JsonEnumSchema.Builder parseEnum(JsonNode enumSchema) {
 | 
			
		||||
        var builder = new JsonEnumSchema.Builder();
 | 
			
		||||
 | 
			
		||||
        List<String> enumValues = new ArrayList<>();
 | 
			
		||||
        for (JsonNode element : enumSchema.get("enum")) {
 | 
			
		||||
            if (!element.isTextual()) {
 | 
			
		||||
                throw new IllegalArgumentException("Expected each 'enum' element to be a string, but found: " + element.getNodeType());
 | 
			
		||||
            }
 | 
			
		||||
            enumValues.add(element.textValue());
 | 
			
		||||
        }
 | 
			
		||||
        builder.enumValues(enumValues);
 | 
			
		||||
 | 
			
		||||
        return builder;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static JsonObjectSchema.Builder parseObject(JsonNode objectSchema) {
 | 
			
		||||
        var builder = new JsonObjectSchema.Builder();
 | 
			
		||||
 | 
			
		||||
        JsonNode propertiesNode = objectSchema.get("properties");
 | 
			
		||||
        if (propertiesNode != null) {
 | 
			
		||||
            propertiesNode.fields().forEachRemaining(entry -> {
 | 
			
		||||
                String key = entry.getKey();
 | 
			
		||||
                JsonNode value = entry.getValue();
 | 
			
		||||
                builder.addProperty(key, parse(value));
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        List<String> required = new ArrayList<>();
 | 
			
		||||
        JsonNode requiredNode = objectSchema.get("required");
 | 
			
		||||
        if (requiredNode != null) {
 | 
			
		||||
            for (JsonNode value : requiredNode) {
 | 
			
		||||
                required.add(value.textValue());
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        builder.required(required);
 | 
			
		||||
 | 
			
		||||
        boolean additionalProperties = true; // default value if 'additionalProperties' is not set
 | 
			
		||||
        JsonNode additionalPropertiesNode = objectSchema.get("additionalProperties");
 | 
			
		||||
        if (additionalPropertiesNode != null) {
 | 
			
		||||
            if (!additionalPropertiesNode.isBoolean()) {
 | 
			
		||||
                throw new IllegalArgumentException("Expected 'additionalProperties' to be a boolean, but found: " + additionalPropertiesNode.getNodeType());
 | 
			
		||||
            }
 | 
			
		||||
            additionalProperties = additionalPropertiesNode.booleanValue();
 | 
			
		||||
        }
 | 
			
		||||
        builder.additionalProperties(additionalProperties);
 | 
			
		||||
 | 
			
		||||
        return builder;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static JsonArraySchema.Builder parseArray(JsonNode arraySchema) {
 | 
			
		||||
        var builder = new JsonArraySchema.Builder();
 | 
			
		||||
 | 
			
		||||
        if (arraySchema.hasNonNull("items")) {
 | 
			
		||||
            builder.items(parse(arraySchema.get("items")));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return builder;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -15,6 +15,7 @@
 | 
			
		||||
 */
 | 
			
		||||
package org.thingsboard.rule.engine.ai;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.databind.JsonNode;
 | 
			
		||||
import com.google.common.util.concurrent.FutureCallback;
 | 
			
		||||
import com.google.common.util.concurrent.ListenableFuture;
 | 
			
		||||
import dev.langchain4j.data.message.SystemMessage;
 | 
			
		||||
@ -22,6 +23,8 @@ import dev.langchain4j.data.message.UserMessage;
 | 
			
		||||
import dev.langchain4j.model.chat.ChatModel;
 | 
			
		||||
import dev.langchain4j.model.chat.request.ChatRequest;
 | 
			
		||||
import dev.langchain4j.model.chat.request.ResponseFormat;
 | 
			
		||||
import dev.langchain4j.model.chat.request.ResponseFormatType;
 | 
			
		||||
import dev.langchain4j.model.chat.request.json.JsonSchema;
 | 
			
		||||
import dev.langchain4j.model.input.PromptTemplate;
 | 
			
		||||
import org.checkerframework.checker.nullness.qual.NonNull;
 | 
			
		||||
import org.thingsboard.common.util.JacksonUtil;
 | 
			
		||||
@ -55,6 +58,7 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
 | 
			
		||||
 | 
			
		||||
    private SystemMessage systemMessage;
 | 
			
		||||
    private PromptTemplate userPromptTemplate;
 | 
			
		||||
    private ResponseFormat responseFormat;
 | 
			
		||||
    private ChatModel chatModel;
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
@ -69,6 +73,11 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
 | 
			
		||||
            throw new TbNodeException(e, true);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        responseFormat = ResponseFormat.builder()
 | 
			
		||||
                .type(config.getResponseFormatType())
 | 
			
		||||
                .jsonSchema(getJsonSchema(config.getResponseFormatType(), config.getJsonSchema()))
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        systemMessage = SystemMessage.from(config.getSystemPrompt());
 | 
			
		||||
        userPromptTemplate = PromptTemplate.from("""
 | 
			
		||||
                User-provided task or question: %s
 | 
			
		||||
@ -80,6 +89,13 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
 | 
			
		||||
        chatModel = ctx.getAiService().configureChatModel(ctx.getTenantId(), config.getAiSettingsId());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static JsonSchema getJsonSchema(ResponseFormatType responseFormatType, JsonNode jsonSchema) {
 | 
			
		||||
        if (responseFormatType == ResponseFormatType.TEXT) {
 | 
			
		||||
            return null;
 | 
			
		||||
        }
 | 
			
		||||
        return responseFormatType == ResponseFormatType.JSON && jsonSchema != null ? Langchain4jJsonSchemaAdapter.fromJsonNode(jsonSchema) : null;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public void onMsg(TbContext ctx, TbMsg msg) {
 | 
			
		||||
        var ackedMsg = ackIfNeeded(ctx, msg);
 | 
			
		||||
@ -93,12 +109,15 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
 | 
			
		||||
 | 
			
		||||
        var chatRequest = ChatRequest.builder()
 | 
			
		||||
                .messages(List.of(systemMessage, userMessage))
 | 
			
		||||
                .responseFormat(ResponseFormat.JSON)
 | 
			
		||||
                .responseFormat(responseFormat)
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        addCallback(sendChatRequest(ctx, chatRequest), new FutureCallback<>() {
 | 
			
		||||
            @Override
 | 
			
		||||
            public void onSuccess(String response) {
 | 
			
		||||
                if (!isValidJson(response)) {
 | 
			
		||||
                    response = wrapInJsonObject(response);
 | 
			
		||||
                }
 | 
			
		||||
                tellSuccess(ctx, ackedMsg.transform()
 | 
			
		||||
                        .data(response)
 | 
			
		||||
                        .build());
 | 
			
		||||
@ -115,11 +134,24 @@ public final class TbAiNode extends TbAbstractExternalNode implements TbNode {
 | 
			
		||||
        return ctx.getExternalCallExecutor().submit(() -> chatModel.chat(chatRequest).aiMessage().text());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static boolean isValidJson(String jsonString) {
 | 
			
		||||
        try {
 | 
			
		||||
            return JacksonUtil.toJsonNode(jsonString) != null;
 | 
			
		||||
        } catch (IllegalArgumentException e) {
 | 
			
		||||
            return false;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static String wrapInJsonObject(String response) {
 | 
			
		||||
        return JacksonUtil.newObjectNode().put("response", response).toString();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public void destroy() {
 | 
			
		||||
        super.destroy();
 | 
			
		||||
        systemMessage = null;
 | 
			
		||||
        userPromptTemplate = null;
 | 
			
		||||
        responseFormat = null;
 | 
			
		||||
        chatModel = null;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
package org.thingsboard.rule.engine.ai;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.databind.JsonNode;
 | 
			
		||||
import dev.langchain4j.model.chat.request.ResponseFormatType;
 | 
			
		||||
import jakarta.validation.constraints.AssertTrue;
 | 
			
		||||
import jakarta.validation.constraints.NotBlank;
 | 
			
		||||
import jakarta.validation.constraints.NotNull;
 | 
			
		||||
@ -39,6 +40,9 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
 | 
			
		||||
    @Length(min = 1, max = 1000)
 | 
			
		||||
    private String userPrompt;
 | 
			
		||||
 | 
			
		||||
    @NotNull
 | 
			
		||||
    private ResponseFormatType responseFormatType;
 | 
			
		||||
 | 
			
		||||
    private JsonNode jsonSchema;
 | 
			
		||||
 | 
			
		||||
    @AssertTrue(message = "provided JSON Schema must conform to the Draft 2020-12 meta-schema")
 | 
			
		||||
@ -51,6 +55,7 @@ public class TbAiNodeConfiguration implements NodeConfiguration<TbAiNodeConfigur
 | 
			
		||||
        var configuration = new TbAiNodeConfiguration();
 | 
			
		||||
        configuration.setSystemPrompt("You are helpful assistant. Your response must be in JSON format.");
 | 
			
		||||
        configuration.setUserPrompt("Tell me a joke.");
 | 
			
		||||
        configuration.setResponseFormatType(ResponseFormatType.JSON);
 | 
			
		||||
        return configuration;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user