spring ai의 openAiChatModel로 Perplexity같은 다른 회사와 통신하기
일단 주의점부터 말하고가면
- 그냥 rest같은거로 통신하는게 더 깔끔할수있음
- 서브클래싱으로 상속을 사용하고,슈퍼클래스보다 서브클래스가 기능이 적어지는등 리스코프치환원칙 위반인 코드임
- 제대로 하려면 ChatModel부터 시작해서 다 만드는게 정석적인 방법같음
같은 문제점이 있는 코드임
일단 가능은 하다는데 의의를 두겠음
완성코드는 맨밑에 있으니 과정 안궁금하면 맨밑으로 내려가면됨
일단 배경스토리부터 말해보자면
나는 이미 SpringAi로 OpenAi의 gpt로 통신하는 코드가 다 짜여져있었음
근데 스토리상 웹검색을 하고 그 결과를 바탕으로 출력을 해야해서 gpt로 웹검색을 해야했는데,내가 웹 검색해서 던져주거나 하는거는 토큰이 얼만큼 튈지도 가늠이 안되고,크롤링을 하면 막힐테니 api로 해야하는데 api비용도 만만찮고(SERP API같은거 알아봤음) ,그냥 Perplexity에다가 던지는게 가장 날로먹을수있는 방법같았음
pro 이용중이라 api크레딧 5달러씩 매달 주기도했고
근데 springAi에는 Perplexity에 대해 만들어진 구현체가 없고,커스텀으로 만들기위한 가이드같은거도 없었음
그래서 대충 OpenAiChatModel 까보니까 ChatModel구현해서 어케하면 될거같아서
처음에는 ChatModel부터 다 만들려고 했는데,OpenAiChatModel에서 코드를 보니까 상당히 복잡한거임
굳이 하려면 못할거같진않은데,스트리밍이 들어가는 기능의 경우 WebFlex가 들어가기도 하고 그래서 날로 함 먹어보자하고 OpenAiChatModel에서 url과 메시지바디만 변경해서 던져보자 하고 나온 결과물임
일단 원형인 OpenAi에 통신하는거부터 보자면
private val gptOption: OpenAiChatOptions =
OpenAiChatOptions
.builder()
.withModel("gpt-4o-mini")
.withTemperature(0.8F)
.build()
이렇게 옵션을 빌더로 만든다음
private val chatClient: ChatClient =
ChatClient.builder(OpenAiChatModel(OpenAiApi(ApiConfig.getGptKey()), gptOption)).build()
이렇게 OpenAiApi를 Option과 api키를 넣어서 만들고, 그걸 넣어서 OpenAiChatModel을 만들고,그걸넣어서 ChatClient를 만드는 구조임
여기서 ChatClient는 최종 퍼사드클래스라서(정확히는 인터페이스고 DefaultChatClient가 build시 나오는 구체클래스임) 중요하지않고,
OpenAiChatModel은 나중에 좀 중요함
그리고 우리가 중점적으로 봐야할건 OpenAiChatOptions와, OpenAiApi 임
OpenAiChatOptions는 api통신할때 보낼 각종 파라미터를 관리하는곳임
보면 json으로 필드들이 쫙 나열되어있고,세터와 게터가 열려있는 형태로 되어있는걸 볼수있음
또한 빌더가 있고,빌더패턴으로 조립하는 형태로 보임
그래서 이게 어떻게 통신할때 조립될까를 보면,OpenAiChatOptions에는 이 객체를 메시지바디로 변경하는 코드가 없음
그러니 한칸 위로 올라가서 OpenAiChatModel에서 Option으로 검색해보면,createRequest라는 메서드에서 ModelOptionsUtils.merge()를 사용해서 뭔가 하는게 보임
//24번째줄
OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
Set<String> functionsForThisRequest = new HashSet();
List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map((m) -> {
List<OpenAiApi.ChatCompletionMessage.MediaContent> contents = new ArrayList(List.of(new OpenAiApi.ChatCompletionMessage.MediaContent(m.getContent())));
if (!CollectionUtils.isEmpty(m.getMedia())) {
contents.addAll(m.getMedia().stream().map((media) -> {
return new OpenAiApi.ChatCompletionMessage.MediaContent(new OpenAiApi.ChatCompletionMessage.MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData())));
}).toList());
}
return new OpenAiApi.ChatCompletionMessage(contents, Role.valueOf(m.getMessageType().name()));
}).toList();
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
if (prompt.getOptions() != null) {
ModelOptions var7 = prompt.getOptions();
if (!(var7 instanceof ChatOptions)) {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + prompt.getOptions().getClass().getSimpleName());
}
ChatOptions runtimeOptions = (ChatOptions)var7;
OpenAiChatOptions updatedRuntimeOptions = (OpenAiChatOptions)ModelOptionsUtils.copyToTarget(runtimeOptions, ChatOptions.class, OpenAiChatOptions.class);
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, true);
functionsForThisRequest.addAll(promptEnabledFunctions);
//여기
request = (OpenAiApi.ChatCompletionRequest)ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.ChatCompletionRequest.class);
}
if (this.defaultOptions != null) {
Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, false);
functionsForThisRequest.addAll(defaultEnabledFunctions);
request = (OpenAiApi.ChatCompletionRequest)ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.ChatCompletionRequest.class);
}
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
request = (OpenAiApi.ChatCompletionRequest)ModelOptionsUtils.merge(OpenAiChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(), request, OpenAiApi.ChatCompletionRequest.class);
}
return request;
그래서 해당 유틸클래스로 들어가보면
public static <T> T merge(Object source, Object target, Class<T> clazz, List<String> acceptedFieldNames) {
if (source == null) {
source = Map.of();
}
List<String> requestFieldNames = CollectionUtils.isEmpty(acceptedFieldNames) ? (List)REQUEST_FIELD_NAMES_PER_CLASS.computeIfAbsent(clazz, ModelOptionsUtils::getJsonPropertyValues) : acceptedFieldNames;
if (CollectionUtils.isEmpty(requestFieldNames)) {
throw new IllegalArgumentException("No @JsonProperty fields found in the " + clazz.getName());
} else {
Map<String, Object> sourceMap = objectToMap(source);
Map<String, Object> targetMap = objectToMap(target);
targetMap.putAll((Map)sourceMap.entrySet().stream().filter((e) -> {
return e.getValue() != null;
}).collect(Collectors.toMap((e) -> {
return (String)e.getKey();
}, (e) -> {
return e.getValue();
})));
targetMap = (Map)targetMap.entrySet().stream().filter((e) -> {
return requestFieldNames.contains(e.getKey());
}).collect(Collectors.toMap((e) -> {
return (String)e.getKey();
}, (e) -> {
return e.getValue();
}));
return mapToClass(targetMap, clazz);
}
}
이렇게 소스와 타겟을 objectToMap라는걸로 녹여서 합치는 로직으로 보이고,objectToMap메서드는
public static Map<String, Object> objectToMap(Object source) {
if (source == null) {
return new HashMap();
} else {
try {
String json = OBJECT_MAPPER.writeValueAsString(source);
return (Map)((Map)OBJECT_MAPPER.readValue(json, new TypeReference<Map<String, Object>>() {
})).entrySet().stream().filter((ex) -> {
return ex.getValue() != null;
}).collect(Collectors.toMap((ex) -> {
return (String)ex.getKey();
}, (ex) -> {
return ex.getValue();
}));
} catch (JsonProcessingException var2) {
JsonProcessingException e = var2;
throw new RuntimeException(e);
}
}
}
OBJECT_MAPPER라는 변수를 사용해서 오브젝트를 맵으로 변경시킨다는걸 알았음
이름부터 오브젝트 매퍼니까 그거겠지만,혹시모르니 확인해보면
public static final ObjectMapper OBJECT_MAPPER;
맨날쓰던 그거 맞음
그러니 우리는 Option클래스에 퍼블릭으로 필드를 추가하면,사용시에 자동으로 필드명이나 어노테이션과 함께 메시지바디에 실린다는걸 알게됐음
근데 문제는,OpenAiChatModel에 있는 Option필드는 OpenAiChatOptions로 선언돼있고,이건 구체클래스라는거임
public class OpenAiChatModel extends AbstractFunctionCallSupport<OpenAiApi.ChatCompletionMessage, OpenAiApi.ChatCompletionRequest, ResponseEntity<OpenAiApi.ChatCompletion>> implements ChatModel, StreamingChatModel {
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel.class);
private OpenAiChatOptions defaultOptions;
private final RetryTemplate retryTemplate;
private final OpenAiApi openAiApi;
그래서 어쩔수없이 우리는 OpenAiChatOptions클래스를 상속받아서 서브클래싱을 목적으로 클래스를 만들어야함
문제는,각 밴더마다 요구하는 파라미터들이 다를수있고,이것때문에 슈퍼클래스의 사용할수없는 파라미터들이 다 노출되는걸 막을수없다는거임
이러면 사용자입장에서는 값을 넣어도 바뀌는게 없으니 엄청 혼란스러울거처럼 보임,그래서 이걸쓰려면 문서화를 잘해두든가 해야할거같음
나는 Perplexity를 사용할거니,PerplexityChatOptions라는 이름으로 클래스를 만들거임
여기에 해당 밴더가 사용할 필드를 퍼블릭으로 추가하고,@JsonProperty로 Json명을 표시해주면 됨
class PerplexityChatOptions : OpenAiChatOptions() {
@JsonProperty("search_domain_filter")
var searchDomainFilter = ""
@JsonProperty("search_recency_filter")
var searchRecencyFilter = ""
그리고 빌더를 만들어야함,원본이 빌더를 썼으니 패턴을 맞춰주는게 좋음
companion object {
@JvmStatic
fun builder(): PerplexityBuilder = PerplexityBuilder()
}
스태틱메서드를 추가하고
내부 빌더클래스도 OpenAiChatOptions의 Builder클래스를 상속하면됨
그리고 우리가 리턴해야할건 PerplexityChatOptions니 해당 필드를 하나 만들고,내가 필요로하는 필드들을 추가하면됨
class PerplexityBuilder : Builder() {
private val perplexityChatOptions = PerplexityChatOptions()
fun searchDomainFilter(searchDomainFilter: String): PerplexityBuilder {
perplexityChatOptions.searchDomainFilter = searchDomainFilter
return this
}
fun searchRecencyFilter(searchRecencyFilter: String): PerplexityBuilder {
perplexityChatOptions.searchRecencyFilter = searchRecencyFilter
return this
}
그리고 원본에서 있던 사용할 필드들을 오버라이드해주고
override fun withTemperature(temperature: Float): PerplexityBuilder {
perplexityChatOptions.temperature = temperature
return this
}
override fun withModel(model: String): PerplexityBuilder {
perplexityChatOptions.model = model
return this
}
그리고 build메서드를 오버라이드해서,원본을 빌드한다음에 거기서 값을 빼와서 새로 빌드해서 리턴해주면됨
override fun build(): PerplexityChatOptions {
val baseOptions = super.build()
perplexityChatOptions.apply {
frequencyPenalty = baseOptions.frequencyPenalty
maxTokens = baseOptions.maxTokens
topP = baseOptions.topP
presencePenalty = baseOptions.presencePenalty
}
return perplexityChatOptions
}
굳이 builder를 상속한 이유는,옵션들의 기본값위임때문에 한거라서 상황에따라 없애는게 나을수도있음
이게 문제가,빌더패턴의 경우에 이렇게되면 중간에 오버라이딩하지않은 메서드가 섞여버리면 바로 상위가 나오게되고,거기서부터 기존메서드를 사용하게되서 꼬이기쉬움
일단 동작하는거부터 만들고 수정하려고 난 이렇게했지만,그냥 builder은 상속받지않고 하는게 나아보이긴함(0925 바로 상속제거하게 수정함,마지막 최종코드에만 적용해둠)
그리고 추가적으로,별로 필요없나 싶긴한데 원본의 Equals와 hashCode가 오버라이딩되어있고,그 구현이 모든 필드가 같으면 같은객체로 취급하게 되어있었음
그걸 맞춰주기위해
override fun equals(other: Any?): Boolean {
val baseEquals = super.equals(other)
if (!baseEquals) {
return false
}
val castOther = other as PerplexityChatOptions
if (this.searchDomainFilter != castOther.searchDomainFilter) {
return false
}
if (this.searchRecencyFilter != castOther.searchRecencyFilter) {
return false
}
return true
}
override fun hashCode(): Int {
var result = super.hashCode()
result = 31 * result + (searchDomainFilter.hashCode())
result = 31 * result + (searchRecencyFilter.hashCode())
return result
}
이렇게 슈퍼클래스를 계산해서 불러오고,거기에 내가만든거만 똑같은방식으로 추가로 붙여주는식으로 해결했음
이러면 일단 옵션은 끝났고,다음은 OpenAiApi임
이건 처음에 슬쩍보면 그냥 baseUrl있으니까 생성자로 저거 바꿔서 하면되겠다 싶지만
public OpenAiApi(String openAiToken) {
this("https://api.openai.com", openAiToken);
}
public OpenAiApi(String baseUrl, String openAiToken) {
this(baseUrl, openAiToken, RestClient.builder(), WebClient.builder());
문제는 그 뒤부분이 하드코딩돼있다는게 문제임
public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest chatRequest) {
Assert.notNull(chatRequest, "The request body can not be null.");
Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false.");
// 여기에서 uri("/v1/chat/completions" 이부분
return ((RestClient.RequestBodySpec)this.restClient.post().uri("/v1/chat/completions", new Object[0])).body(chatRequest).retrieve().toEntity(ChatCompletion.class);
}
저거때문에 어쩔수없이 상속해서 저 메서드를 만들어야함
여기서 우리는 이론적으로는(리스코프치환원칙적으로는),3가지 메서드를 구현해야함
일반통신,스트림통신,임베딩
public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest chatRequest)
public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chatRequest)
public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<T> embeddingRequest)
근데 나는 일반통신밖에 쓰지않을거라서 저거만 구현하고,나머지는 호출하면 예외가 발생하게 해버렸음
OpenAiApi를 보면
public class OpenAiApi {
public static final String DEFAULT_CHAT_MODEL;
public static final String DEFAULT_EMBEDDING_MODEL;
private static final Predicate<String> SSE_DONE_PREDICATE;
private final RestClient restClient;
private final WebClient webClient;
private OpenAiStreamFunctionCallingHelper chunkMerger;
restClient가 프라이빗 파이널이니 저거부터 새로 만들어줘야함
밑에 OpenAiApi는 저걸 어떻게 만드는가 보면
this.restClient = restClientBuilder
.baseUrl(baseUrl)
.defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken))
.defaultStatusHandler(responseErrorHandler)
.build();
이렇게 만들고있음,여기서 토큰은 apiKey니까 바꾸면될거고,responseErrorHandler은 위에있는 생성자를 보면
public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) {
this(baseUrl, openAiToken, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}
RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER를 가져다가 쓰는걸 볼수있음
그래서 일단 클래스를 만들고 restClient를 만들어보면
class PerplexityApi(
key: String,
) : OpenAiApi("https://api.perplexity.ai", key) {
private val restClient =
RestClient
.builder()
.baseUrl(
"https://api.perplexity.ai",
).defaultHeaders(ApiUtils.getJsonContentHeaders(key))
.defaultStatusHandler(RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER)
.build()
이렇게 나오게됨
그리고
chatCompletionEntity를 구현하면됨
이것도 그냥 원본과 똑같이 만든다음,url만 바꿔줄거임
원본을 보면
public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest chatRequest) {
Assert.notNull(chatRequest, "The request body can not be null.");
Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false.");
return ((RestClient.RequestBodySpec)this.restClient.post().uri("/v1/chat/completions", new Object[0])).body(chatRequest).retrieve().toEntity(ChatCompletion.class);
}
이렇게 돼있으니
override fun chatCompletionEntity(chatRequest: ChatCompletionRequest?): ResponseEntity<ChatCompletion> {
Assert.notNull(chatRequest, "The request body can not be null.")
Assert.isTrue(!chatRequest!!.stream(), "Request must set the steam property to false.")
return restClient
.post()
.uri("/chat/completions", *arrayOfNulls(0))
.body(
chatRequest,
).retrieve()
.toEntity(
ChatCompletion::class.java,
)
}
override fun chatCompletionStream(chatRequest: ChatCompletionRequest?): Flux<ChatCompletionChunk> = throw RuntimeException("지원되지않음")
override fun <T : Any?> embeddings(embeddingRequest: EmbeddingRequest<T>?): ResponseEntity<EmbeddingList<Embedding>> =
throw RuntimeException("지원되지않음")
이렇게 url만 바꿔주고,나머지 안쓰는애들은 예외를 던지게했음(일단 런타임으로 했음 뭘로해야할지 생각이 안나서,나중에 바꿀예정)
전체코드는 맨밑에서 다시 올릴거임
그리고 사용은
private val perplexityOptions =
PerplexityChatOptions
.builder()
.searchRecencyFilter("week")
.withModel("llama-3.1-sonar-small-128k-online")
.withTemperature(0.8F)
.build()
private val chatClient: ChatClient =
ChatClient
.builder(
OpenAiChatModel(PerplexityApi(ApiConfig.getPerplexityKey()), perplexityOptions),
).build()
private fun fetch(prompt: Prompt): ChatResponse {
val response =
chatClient
.prompt(prompt)
.call()
.chatResponse()
return response
}
이렇게 똑같이 사용하면됨(이때 옵션만드는거 순서는 빌더 상속받았으면 조심하셈)
일단 이렇게 동작은 하긴하는데,솔직히 별로 맘에들진않는듯
어떻게 마개조해서 쓰는느낌이라,깨지기도 진짜 쉬워보이고,정석대로하려면 직접 구현해야할거같음
스프링ai팀에서 이런식으로 di해서 쓸수있게 하나 열어줬으면좋겠음
구현부분 제외하면 글케 어려울거같진않은데
일단 인터넷에 아무정보도 없던거 어케 클래스뒤져가면서 동작은 되게 만들었다는거에 의의를 두자
풀버전
PerplexityChatOptions
package rkrk.whyprice.share.adapter
import com.fasterxml.jackson.annotation.JsonInclude
import com.fasterxml.jackson.annotation.JsonProperty
import org.springframework.ai.openai.OpenAiChatOptions
@JsonInclude(JsonInclude.Include.NON_NULL)
class PerplexityChatOptions : OpenAiChatOptions() {
@JsonProperty("search_domain_filter")
var searchDomainFilter = ""
@JsonProperty("search_recency_filter")
var searchRecencyFilter = ""
companion object {
@JvmStatic
fun builder(): PerplexityBuilder = PerplexityBuilder()
}
override fun equals(other: Any?): Boolean {
val baseEquals = super.equals(other)
if (!baseEquals) {
return false
}
val castOther = other as PerplexityChatOptions
if (this.searchDomainFilter != castOther.searchDomainFilter) {
return false
}
if (this.searchRecencyFilter != castOther.searchRecencyFilter) {
return false
}
return true
}
override fun hashCode(): Int {
var result = super.hashCode()
result = 31 * result + (searchDomainFilter.hashCode())
result = 31 * result + (searchRecencyFilter.hashCode())
return result
}
class PerplexityBuilder {
private val perplexityChatOptions = PerplexityChatOptions()
fun searchDomainFilter(searchDomainFilter: String): PerplexityBuilder {
perplexityChatOptions.searchDomainFilter = searchDomainFilter
return this
}
fun searchRecencyFilter(searchRecencyFilter: String): PerplexityBuilder {
perplexityChatOptions.searchRecencyFilter = searchRecencyFilter
return this
}
fun withTemperature(temperature: Float): PerplexityBuilder {
perplexityChatOptions.temperature = temperature
return this
}
fun withModel(model: String): PerplexityBuilder {
perplexityChatOptions.model = model
return this
}
fun withFrequencyPenalty(frequencyPenalty: Float): PerplexityBuilder {
perplexityChatOptions.frequencyPenalty = frequencyPenalty
return this
}
fun withMaxTokens(maxTokens: Int): PerplexityBuilder {
perplexityChatOptions.maxTokens = maxTokens
return this
}
fun withTopP(topP: Float): PerplexityBuilder {
perplexityChatOptions.topP = topP
return this
}
fun withPresencePenalty(presencePenalty: Float): PerplexityBuilder {
perplexityChatOptions.presencePenalty = presencePenalty
return this
}
fun build(): PerplexityChatOptions = perplexityChatOptions
}
}
PerplexityApi
package rkrk.whyprice.share.adapter
import org.springframework.ai.openai.api.ApiUtils
import org.springframework.ai.openai.api.OpenAiApi
import org.springframework.ai.retry.RetryUtils
import org.springframework.http.ResponseEntity
import org.springframework.util.Assert
import org.springframework.web.client.RestClient
import reactor.core.publisher.Flux
class PerplexityApi(
key: String,
) : OpenAiApi("https://api.perplexity.ai", key) {
private val restClient =
RestClient
.builder()
.baseUrl(
"https://api.perplexity.ai",
).defaultHeaders(ApiUtils.getJsonContentHeaders(key))
.defaultStatusHandler(RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER)
.build()
override fun chatCompletionEntity(chatRequest: ChatCompletionRequest?): ResponseEntity<ChatCompletion> {
Assert.notNull(chatRequest, "The request body can not be null.")
Assert.isTrue(!chatRequest!!.stream(), "Request must set the steam property to false.")
return restClient
.post()
.uri("/chat/completions", *arrayOfNulls(0))
.body(
chatRequest,
).retrieve()
.toEntity(
ChatCompletion::class.java,
)
}
override fun chatCompletionStream(chatRequest: ChatCompletionRequest?): Flux<ChatCompletionChunk> = throw RuntimeException("지원되지않음")
override fun <T : Any?> embeddings(embeddingRequest: EmbeddingRequest<T>?): ResponseEntity<EmbeddingList<Embedding>> =
throw RuntimeException("지원되지않음")
}
이렇게됨
사용은 다른코드도 있어서 해당부분만 잘랐음
private val perplexityOptions =
PerplexityChatOptions
.builder()
.searchRecencyFilter("week")
.withModel("llama-3.1-sonar-small-128k-online")
.withTemperature(0.8F)
.build()
private val chatClient: ChatClient =
ChatClient
.builder(
OpenAiChatModel(PerplexityApi(ApiConfig.getPerplexityKey()), perplexityOptions),
).build()
private fun fetch(prompt: Prompt): ChatResponse {
val response =
chatClient
.prompt(prompt)
.call()
.chatResponse()
return response
}