AI models: add auth support for Ollama
This commit is contained in:
		
							parent
							
								
									dd6bdcf614
								
							
						
					
					
						commit
						478b26b223
					
				@ -35,6 +35,7 @@ import dev.langchain4j.model.mistralai.MistralAiChatModel;
 | 
			
		||||
import dev.langchain4j.model.ollama.OllamaChatModel;
 | 
			
		||||
import dev.langchain4j.model.openai.OpenAiChatModel;
 | 
			
		||||
import dev.langchain4j.model.vertexai.gemini.VertexAiGeminiChatModel;
 | 
			
		||||
import org.springframework.http.HttpHeaders;
 | 
			
		||||
import org.springframework.stereotype.Component;
 | 
			
		||||
import org.thingsboard.server.common.data.ai.model.chat.AmazonBedrockChatModelConfig;
 | 
			
		||||
import org.thingsboard.server.common.data.ai.model.chat.AnthropicChatModelConfig;
 | 
			
		||||
@ -49,6 +50,7 @@ import org.thingsboard.server.common.data.ai.model.chat.OpenAiChatModelConfig;
 | 
			
		||||
import org.thingsboard.server.common.data.ai.provider.AmazonBedrockProviderConfig;
 | 
			
		||||
import org.thingsboard.server.common.data.ai.provider.AzureOpenAiProviderConfig;
 | 
			
		||||
import org.thingsboard.server.common.data.ai.provider.GoogleVertexAiGeminiProviderConfig;
 | 
			
		||||
import org.thingsboard.server.common.data.ai.provider.OllamaProviderConfig;
 | 
			
		||||
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
 | 
			
		||||
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
 | 
			
		||||
import software.amazon.awssdk.regions.Region;
 | 
			
		||||
@ -56,7 +58,11 @@ import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
 | 
			
		||||
 | 
			
		||||
import java.io.ByteArrayInputStream;
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import java.nio.charset.StandardCharsets;
 | 
			
		||||
import java.time.Duration;
 | 
			
		||||
import java.util.Base64;
 | 
			
		||||
 | 
			
		||||
import static java.util.Collections.singletonMap;
 | 
			
		||||
 | 
			
		||||
@Component
 | 
			
		||||
class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigurer {
 | 
			
		||||
@ -136,7 +142,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
 | 
			
		||||
 | 
			
		||||
            // set request timeout from model config
 | 
			
		||||
            if (chatModelConfig.timeoutSeconds() != null) {
 | 
			
		||||
                retrySettings.setTotalTimeout(org.threeten.bp.Duration.ofSeconds(chatModelConfig.timeoutSeconds()));
 | 
			
		||||
                retrySettings.setTotalTimeoutDuration(Duration.ofSeconds(chatModelConfig.timeoutSeconds()));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // set updated retry settings
 | 
			
		||||
@ -266,7 +272,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public ChatModel configureChatModel(OllamaChatModelConfig chatModelConfig) {
 | 
			
		||||
        return OllamaChatModel.builder()
 | 
			
		||||
        var builder = OllamaChatModel.builder()
 | 
			
		||||
                .baseUrl(chatModelConfig.providerConfig().baseUrl())
 | 
			
		||||
                .modelName(chatModelConfig.modelId())
 | 
			
		||||
                .temperature(chatModelConfig.temperature())
 | 
			
		||||
@ -275,8 +281,22 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
 | 
			
		||||
                .numCtx(chatModelConfig.contextLength())
 | 
			
		||||
                .numPredict(chatModelConfig.maxOutputTokens())
 | 
			
		||||
                .timeout(toDuration(chatModelConfig.timeoutSeconds()))
 | 
			
		||||
                .maxRetries(chatModelConfig.maxRetries())
 | 
			
		||||
                .build();
 | 
			
		||||
                .maxRetries(chatModelConfig.maxRetries());
 | 
			
		||||
 | 
			
		||||
        var auth = chatModelConfig.providerConfig().auth();
 | 
			
		||||
        if (auth instanceof OllamaProviderConfig.OllamaAuth.Basic basicAuth) {
 | 
			
		||||
            String credentials = basicAuth.username() + ":" + basicAuth.password();
 | 
			
		||||
            String encodedCredentials = Base64.getEncoder().encodeToString(credentials.getBytes(StandardCharsets.UTF_8));
 | 
			
		||||
            builder.customHeaders(singletonMap(HttpHeaders.AUTHORIZATION, "Basic " + encodedCredentials));
 | 
			
		||||
        } else if (auth instanceof OllamaProviderConfig.OllamaAuth.Token tokenAuth) {
 | 
			
		||||
            builder.customHeaders(singletonMap(HttpHeaders.AUTHORIZATION, "Bearer " + tokenAuth.token()));
 | 
			
		||||
        } else if (auth instanceof OllamaProviderConfig.OllamaAuth.None) {
 | 
			
		||||
            // do nothing
 | 
			
		||||
        } else {
 | 
			
		||||
            throw new UnsupportedOperationException("Unknown authentication type: " + auth.getClass().getSimpleName());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return builder.build();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static Duration toDuration(Integer timeoutSeconds) {
 | 
			
		||||
 | 
			
		||||
@ -15,8 +15,34 @@
 | 
			
		||||
 */
 | 
			
		||||
package org.thingsboard.server.common.data.ai.provider;
 | 
			
		||||
 | 
			
		||||
import jakarta.validation.constraints.NotBlank;
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
 | 
			
		||||
import jakarta.validation.Valid;
 | 
			
		||||
import jakarta.validation.constraints.NotNull;
 | 
			
		||||
 | 
			
		||||
public record OllamaProviderConfig(
 | 
			
		||||
        @NotBlank String baseUrl
 | 
			
		||||
) implements AiProviderConfig {}
 | 
			
		||||
        @NotNull String baseUrl,
 | 
			
		||||
        @NotNull @Valid OllamaAuth auth
 | 
			
		||||
) implements AiProviderConfig {
 | 
			
		||||
 | 
			
		||||
    @JsonTypeInfo(
 | 
			
		||||
            use = JsonTypeInfo.Id.NAME,
 | 
			
		||||
            include = JsonTypeInfo.As.PROPERTY,
 | 
			
		||||
            property = "type"
 | 
			
		||||
    )
 | 
			
		||||
    @JsonSubTypes({
 | 
			
		||||
            @JsonSubTypes.Type(value = OllamaAuth.None.class, name = "NONE"),
 | 
			
		||||
            @JsonSubTypes.Type(value = OllamaAuth.Basic.class, name = "BASIC"),
 | 
			
		||||
            @JsonSubTypes.Type(value = OllamaAuth.Token.class, name = "TOKEN")
 | 
			
		||||
    })
 | 
			
		||||
    public sealed interface OllamaAuth {
 | 
			
		||||
 | 
			
		||||
        record None() implements OllamaAuth {}
 | 
			
		||||
 | 
			
		||||
        record Basic(@NotNull String username, @NotNull String password) implements OllamaAuth {}
 | 
			
		||||
 | 
			
		||||
        record Token(@NotNull String token) implements OllamaAuth {}
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user