AI models: add auth support for Ollama
This commit is contained in:
parent
dd6bdcf614
commit
478b26b223
@ -35,6 +35,7 @@ import dev.langchain4j.model.mistralai.MistralAiChatModel;
|
|||||||
import dev.langchain4j.model.ollama.OllamaChatModel;
|
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||||
import dev.langchain4j.model.vertexai.gemini.VertexAiGeminiChatModel;
|
import dev.langchain4j.model.vertexai.gemini.VertexAiGeminiChatModel;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import org.thingsboard.server.common.data.ai.model.chat.AmazonBedrockChatModelConfig;
|
import org.thingsboard.server.common.data.ai.model.chat.AmazonBedrockChatModelConfig;
|
||||||
import org.thingsboard.server.common.data.ai.model.chat.AnthropicChatModelConfig;
|
import org.thingsboard.server.common.data.ai.model.chat.AnthropicChatModelConfig;
|
||||||
@ -49,6 +50,7 @@ import org.thingsboard.server.common.data.ai.model.chat.OpenAiChatModelConfig;
|
|||||||
import org.thingsboard.server.common.data.ai.provider.AmazonBedrockProviderConfig;
|
import org.thingsboard.server.common.data.ai.provider.AmazonBedrockProviderConfig;
|
||||||
import org.thingsboard.server.common.data.ai.provider.AzureOpenAiProviderConfig;
|
import org.thingsboard.server.common.data.ai.provider.AzureOpenAiProviderConfig;
|
||||||
import org.thingsboard.server.common.data.ai.provider.GoogleVertexAiGeminiProviderConfig;
|
import org.thingsboard.server.common.data.ai.provider.GoogleVertexAiGeminiProviderConfig;
|
||||||
|
import org.thingsboard.server.common.data.ai.provider.OllamaProviderConfig;
|
||||||
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
||||||
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
||||||
import software.amazon.awssdk.regions.Region;
|
import software.amazon.awssdk.regions.Region;
|
||||||
@ -56,7 +58,11 @@ import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
|
|||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
import java.util.Base64;
|
||||||
|
|
||||||
|
import static java.util.Collections.singletonMap;
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigurer {
|
class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigurer {
|
||||||
@ -136,7 +142,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
|
|||||||
|
|
||||||
// set request timeout from model config
|
// set request timeout from model config
|
||||||
if (chatModelConfig.timeoutSeconds() != null) {
|
if (chatModelConfig.timeoutSeconds() != null) {
|
||||||
retrySettings.setTotalTimeout(org.threeten.bp.Duration.ofSeconds(chatModelConfig.timeoutSeconds()));
|
retrySettings.setTotalTimeoutDuration(Duration.ofSeconds(chatModelConfig.timeoutSeconds()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// set updated retry settings
|
// set updated retry settings
|
||||||
@ -266,7 +272,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatModel configureChatModel(OllamaChatModelConfig chatModelConfig) {
|
public ChatModel configureChatModel(OllamaChatModelConfig chatModelConfig) {
|
||||||
return OllamaChatModel.builder()
|
var builder = OllamaChatModel.builder()
|
||||||
.baseUrl(chatModelConfig.providerConfig().baseUrl())
|
.baseUrl(chatModelConfig.providerConfig().baseUrl())
|
||||||
.modelName(chatModelConfig.modelId())
|
.modelName(chatModelConfig.modelId())
|
||||||
.temperature(chatModelConfig.temperature())
|
.temperature(chatModelConfig.temperature())
|
||||||
@ -275,8 +281,22 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
|
|||||||
.numCtx(chatModelConfig.contextLength())
|
.numCtx(chatModelConfig.contextLength())
|
||||||
.numPredict(chatModelConfig.maxOutputTokens())
|
.numPredict(chatModelConfig.maxOutputTokens())
|
||||||
.timeout(toDuration(chatModelConfig.timeoutSeconds()))
|
.timeout(toDuration(chatModelConfig.timeoutSeconds()))
|
||||||
.maxRetries(chatModelConfig.maxRetries())
|
.maxRetries(chatModelConfig.maxRetries());
|
||||||
.build();
|
|
||||||
|
var auth = chatModelConfig.providerConfig().auth();
|
||||||
|
if (auth instanceof OllamaProviderConfig.OllamaAuth.Basic basicAuth) {
|
||||||
|
String credentials = basicAuth.username() + ":" + basicAuth.password();
|
||||||
|
String encodedCredentials = Base64.getEncoder().encodeToString(credentials.getBytes(StandardCharsets.UTF_8));
|
||||||
|
builder.customHeaders(singletonMap(HttpHeaders.AUTHORIZATION, "Basic " + encodedCredentials));
|
||||||
|
} else if (auth instanceof OllamaProviderConfig.OllamaAuth.Token tokenAuth) {
|
||||||
|
builder.customHeaders(singletonMap(HttpHeaders.AUTHORIZATION, "Bearer " + tokenAuth.token()));
|
||||||
|
} else if (auth instanceof OllamaProviderConfig.OllamaAuth.None) {
|
||||||
|
// do nothing
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException("Unknown authentication type: " + auth.getClass().getSimpleName());
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Duration toDuration(Integer timeoutSeconds) {
|
private static Duration toDuration(Integer timeoutSeconds) {
|
||||||
|
|||||||
@ -15,8 +15,34 @@
|
|||||||
*/
|
*/
|
||||||
package org.thingsboard.server.common.data.ai.provider;
|
package org.thingsboard.server.common.data.ai.provider;
|
||||||
|
|
||||||
import jakarta.validation.constraints.NotBlank;
|
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import jakarta.validation.Valid;
|
||||||
|
import jakarta.validation.constraints.NotNull;
|
||||||
|
|
||||||
public record OllamaProviderConfig(
|
public record OllamaProviderConfig(
|
||||||
@NotBlank String baseUrl
|
@NotNull String baseUrl,
|
||||||
) implements AiProviderConfig {}
|
@NotNull @Valid OllamaAuth auth
|
||||||
|
) implements AiProviderConfig {
|
||||||
|
|
||||||
|
@JsonTypeInfo(
|
||||||
|
use = JsonTypeInfo.Id.NAME,
|
||||||
|
include = JsonTypeInfo.As.PROPERTY,
|
||||||
|
property = "type"
|
||||||
|
)
|
||||||
|
@JsonSubTypes({
|
||||||
|
@JsonSubTypes.Type(value = OllamaAuth.None.class, name = "NONE"),
|
||||||
|
@JsonSubTypes.Type(value = OllamaAuth.Basic.class, name = "BASIC"),
|
||||||
|
@JsonSubTypes.Type(value = OllamaAuth.Token.class, name = "TOKEN")
|
||||||
|
})
|
||||||
|
public sealed interface OllamaAuth {
|
||||||
|
|
||||||
|
record None() implements OllamaAuth {}
|
||||||
|
|
||||||
|
record Basic(@NotNull String username, @NotNull String password) implements OllamaAuth {}
|
||||||
|
|
||||||
|
record Token(@NotNull String token) implements OllamaAuth {}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user