Skip to content

support meta models #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions app/src/components/content/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ export const Settings = (props: Props) => {
const json = await response.json();
const result = json.filter((model: Model) => {
if (
model.capabilities.includes("TEXT_GENERATION") &&
(model.vendor == "cohere" || model.vendor == "") &&
model.version != "14.2"
model.capabilities.includes("CHAT") &&
(model.vendor == "cohere" || model.vendor == "meta")
)
return model;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.oracle.bmc.generativeai.responses.ListEndpointsResponse;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiEndpoint;
import dev.victormartin.oci.genai.backend.backend.service.GenAIModelsService;
import dev.victormartin.oci.genai.backend.backend.service.GenAiClientService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -29,19 +30,16 @@ public class GenAIController {
@Autowired
private GenAiClientService generativeAiClientService;

@Autowired
private GenAIModelsService genAIModelsService;

@GetMapping("/api/genai/models")
public List<GenAiModel> getModels() {
logger.info("getModels()");
ListModelsRequest listModelsRequest = ListModelsRequest.builder().compartmentId(COMPARTMENT_ID).build();
GenerativeAiClient client = generativeAiClientService.getClient();
ListModelsResponse response = client.listModels(listModelsRequest);
return response.getModelCollection().getItems().stream().map(m -> {
List<String> capabilities = m.getCapabilities().stream().map(ModelCapability::getValue)
.collect(Collectors.toList());
GenAiModel model = new GenAiModel(m.getId(), m.getDisplayName(), m.getVendor(), m.getVersion(),
capabilities, m.getTimeCreated());
return model;
}).collect(Collectors.toList());
List<GenAiModel> models = genAIModelsService.getModels();
return models.stream()
.filter(m -> m.capabilities().contains("CHAT"))
.collect(Collectors.toList());
}

@GetMapping("/api/genai/endpoints")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public Answer handlePrompt(Prompt prompt) {
throw new InvalidPromptRequest();
}
saved.setDatetimeResponse(new Date());
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune);
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune, false);
saved.setResponse(responseFromGenAI);
interactionRepository.save(saved);
return new Answer(responseFromGenAI, "");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package dev.victormartin.oci.genai.backend.backend.service;

import com.oracle.bmc.generativeai.GenerativeAiClient;
import com.oracle.bmc.generativeai.model.ModelCapability;
import com.oracle.bmc.generativeai.requests.ListModelsRequest;
import com.oracle.bmc.generativeai.responses.ListModelsResponse;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.stream.Collectors;

@Service
public class GenAIModelsService {
Logger log = LoggerFactory.getLogger(GenAIModelsService.class);

@Value("${genai.compartment_id}")
private String COMPARTMENT_ID;

@Autowired
private GenAiClientService generativeAiClientService;

public List<GenAiModel> getModels() {
log.info("getModels()");
ListModelsRequest listModelsRequest = ListModelsRequest.builder()
.compartmentId(COMPARTMENT_ID)
.build();
GenerativeAiClient client = generativeAiClientService.getClient();
ListModelsResponse response = client.listModels(listModelsRequest);
return response.getModelCollection().getItems().stream()
.map(m -> {
List<String> capabilities = m.getCapabilities().stream()
.map(ModelCapability::getValue).collect(Collectors.toList());
GenAiModel model = new GenAiModel(
m.getId(), m.getDisplayName(), m.getVendor(),
m.getVersion(), capabilities, m.getTimeCreated());
return model;
}).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
@@ -1,48 +1,94 @@
package dev.victormartin.oci.genai.backend.backend.service;

import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient;
import com.oracle.bmc.generativeaiinference.model.*;
import com.oracle.bmc.generativeaiinference.model.Message;
import com.oracle.bmc.generativeaiinference.requests.ChatRequest;
import com.oracle.bmc.generativeaiinference.requests.GenerateTextRequest;
import com.oracle.bmc.generativeaiinference.requests.SummarizeTextRequest;
import com.oracle.bmc.generativeaiinference.responses.ChatResponse;
import com.oracle.bmc.generativeaiinference.responses.GenerateTextResponse;
import com.oracle.bmc.generativeaiinference.responses.SummarizeTextResponse;
import com.oracle.bmc.http.client.jersey.WrappedResponseInputStream;
import org.hibernate.boot.archive.scan.internal.StandardScanner;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

@Service
public class OCIGenAIService {

Logger log = LoggerFactory.getLogger(OCIGenAIService.class);

@Value("${genai.compartment_id}")
private String COMPARTMENT_ID;

@Autowired
private GenAiInferenceClientService generativeAiInferenceClientService;

public String resolvePrompt(String input, String modelId, boolean finetune) {
CohereChatRequest cohereChatRequest = CohereChatRequest.builder()
.message(input)
.maxTokens(600)
.temperature((double) 1)
.frequencyPenalty((double) 0)
.topP((double) 0.75)
.topK(0)
.isStream(false) // TODO websockets and streams
.build();
@Autowired
private GenAIModelsService genAIModelsService;

ChatDetails chatDetails = ChatDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.chatRequest(cohereChatRequest)
.build();
public String resolvePrompt(String input, String modelId, boolean finetune, boolean summarization) {

List<GenAiModel> models = genAIModelsService.getModels();
GenAiModel currentModel = models.stream()
.filter(m-> modelId.equals(m.id()))
.findFirst()
.orElseThrow();

log.info("Model {} with finetune {}", currentModel.name(), finetune? "yes" : "no");

double temperature = summarization?0.0:0.5;

String inputText = summarization?"Summarize this text:\n" + input: input;

ChatDetails chatDetails;
switch (currentModel.vendor()) {
case "cohere":
CohereChatRequest cohereChatRequest = CohereChatRequest.builder()
.message(inputText)
.maxTokens(600)
.temperature(temperature)
.frequencyPenalty((double) 0)
.topP(0.75)
.topK(0)
.isStream(false) // TODO websockets and streams
.build();

chatDetails = ChatDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(currentModel.id()).build())
.compartmentId(COMPARTMENT_ID)
.chatRequest(cohereChatRequest)
.build();
break;
case "meta":
ChatContent content = TextContent.builder()
.text(inputText)
.build();
List<ChatContent> contents = new ArrayList<>();
contents.add(content);
List<Message> messages = new ArrayList<>();
Message message = new UserMessage(contents, "user");
messages.add(message);
GenericChatRequest genericChatRequest = GenericChatRequest.builder()
.messages(messages)
.maxTokens(600)
.temperature((double)1)
.frequencyPenalty((double)0)
.presencePenalty((double)0)
.topP(0.75)
.topK(-1)
.isStream(false)
.build();
chatDetails = ChatDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(currentModel.id()).build())
.compartmentId(COMPARTMENT_ID)
.chatRequest(genericChatRequest)
.build();
break;
default:
throw new IllegalStateException("Unexpected value: " + currentModel.vendor());
}

ChatRequest request = ChatRequest.builder()
.chatDetails(chatDetails)
Expand All @@ -65,7 +111,7 @@ public String resolvePrompt(String input, String modelId, boolean finetune) {
}

public String summaryText(String input, String modelId, boolean finetuned) {
String response = resolvePrompt("Summarize this:\n" + input, modelId, finetuned);
String response = resolvePrompt(input, modelId, finetuned, true);
return response;
}
}