Skip to content

Commit 057516f

Browse files
authored
Merge pull request #37 from oracle-devrel/support-meta-models
support meta models
2 parents ef53870 + aeedf91 commit 057516f

File tree

5 files changed

+127
-40
lines changed

5 files changed

+127
-40
lines changed

app/src/components/content/settings.tsx

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,8 @@ export const Settings = (props: Props) => {
8181
const json = await response.json();
8282
const result = json.filter((model: Model) => {
8383
if (
84-
model.capabilities.includes("TEXT_GENERATION") &&
85-
(model.vendor == "cohere" || model.vendor == "") &&
86-
model.version != "14.2"
84+
model.capabilities.includes("CHAT") &&
85+
(model.vendor == "cohere" || model.vendor == "meta")
8786
)
8887
return model;
8988
});

backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/GenAIController.java

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import com.oracle.bmc.generativeai.responses.ListEndpointsResponse;
99
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
1010
import dev.victormartin.oci.genai.backend.backend.dao.GenAiEndpoint;
11+
import dev.victormartin.oci.genai.backend.backend.service.GenAIModelsService;
1112
import dev.victormartin.oci.genai.backend.backend.service.GenAiClientService;
1213
import org.slf4j.Logger;
1314
import org.slf4j.LoggerFactory;
@@ -29,19 +30,16 @@ public class GenAIController {
2930
@Autowired
3031
private GenAiClientService generativeAiClientService;
3132

33+
@Autowired
34+
private GenAIModelsService genAIModelsService;
35+
3236
@GetMapping("/api/genai/models")
3337
public List<GenAiModel> getModels() {
3438
logger.info("getModels()");
35-
ListModelsRequest listModelsRequest = ListModelsRequest.builder().compartmentId(COMPARTMENT_ID).build();
36-
GenerativeAiClient client = generativeAiClientService.getClient();
37-
ListModelsResponse response = client.listModels(listModelsRequest);
38-
return response.getModelCollection().getItems().stream().map(m -> {
39-
List<String> capabilities = m.getCapabilities().stream().map(ModelCapability::getValue)
40-
.collect(Collectors.toList());
41-
GenAiModel model = new GenAiModel(m.getId(), m.getDisplayName(), m.getVendor(), m.getVersion(),
42-
capabilities, m.getTimeCreated());
43-
return model;
44-
}).collect(Collectors.toList());
39+
List<GenAiModel> models = genAIModelsService.getModels();
40+
return models.stream()
41+
.filter(m -> m.capabilities().contains("CHAT"))
42+
.collect(Collectors.toList());
4543
}
4644

4745
@GetMapping("/api/genai/endpoints")

backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/PromptController.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public Answer handlePrompt(Prompt prompt) {
6161
throw new InvalidPromptRequest();
6262
}
6363
saved.setDatetimeResponse(new Date());
64-
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune);
64+
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune, false);
6565
saved.setResponse(responseFromGenAI);
6666
interactionRepository.save(saved);
6767
return new Answer(responseFromGenAI, "");
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package dev.victormartin.oci.genai.backend.backend.service;
2+
3+
import com.oracle.bmc.generativeai.GenerativeAiClient;
4+
import com.oracle.bmc.generativeai.model.ModelCapability;
5+
import com.oracle.bmc.generativeai.requests.ListModelsRequest;
6+
import com.oracle.bmc.generativeai.responses.ListModelsResponse;
7+
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
8+
import org.slf4j.Logger;
9+
import org.slf4j.LoggerFactory;
10+
import org.springframework.beans.factory.annotation.Autowired;
11+
import org.springframework.beans.factory.annotation.Value;
12+
import org.springframework.stereotype.Service;
13+
14+
import java.util.List;
15+
import java.util.stream.Collectors;
16+
17+
@Service
18+
public class GenAIModelsService {
19+
Logger log = LoggerFactory.getLogger(GenAIModelsService.class);
20+
21+
@Value("${genai.compartment_id}")
22+
private String COMPARTMENT_ID;
23+
24+
@Autowired
25+
private GenAiClientService generativeAiClientService;
26+
27+
public List<GenAiModel> getModels() {
28+
log.info("getModels()");
29+
ListModelsRequest listModelsRequest = ListModelsRequest.builder()
30+
.compartmentId(COMPARTMENT_ID)
31+
.build();
32+
GenerativeAiClient client = generativeAiClientService.getClient();
33+
ListModelsResponse response = client.listModels(listModelsRequest);
34+
return response.getModelCollection().getItems().stream()
35+
.map(m -> {
36+
List<String> capabilities = m.getCapabilities().stream()
37+
.map(ModelCapability::getValue).collect(Collectors.toList());
38+
GenAiModel model = new GenAiModel(
39+
m.getId(), m.getDisplayName(), m.getVendor(),
40+
m.getVersion(), capabilities, m.getTimeCreated());
41+
return model;
42+
}).collect(Collectors.toList());
43+
}
44+
}
Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,94 @@
11
package dev.victormartin.oci.genai.backend.backend.service;
22

3-
import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient;
43
import com.oracle.bmc.generativeaiinference.model.*;
4+
import com.oracle.bmc.generativeaiinference.model.Message;
55
import com.oracle.bmc.generativeaiinference.requests.ChatRequest;
6-
import com.oracle.bmc.generativeaiinference.requests.GenerateTextRequest;
7-
import com.oracle.bmc.generativeaiinference.requests.SummarizeTextRequest;
86
import com.oracle.bmc.generativeaiinference.responses.ChatResponse;
9-
import com.oracle.bmc.generativeaiinference.responses.GenerateTextResponse;
10-
import com.oracle.bmc.generativeaiinference.responses.SummarizeTextResponse;
11-
import com.oracle.bmc.http.client.jersey.WrappedResponseInputStream;
12-
import org.hibernate.boot.archive.scan.internal.StandardScanner;
7+
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
8+
import org.slf4j.Logger;
9+
import org.slf4j.LoggerFactory;
1310
import org.springframework.beans.factory.annotation.Autowired;
1411
import org.springframework.beans.factory.annotation.Value;
1512
import org.springframework.stereotype.Service;
1613

17-
import java.io.*;
18-
import java.nio.charset.StandardCharsets;
14+
import java.util.ArrayList;
1915
import java.util.List;
20-
import java.util.stream.Collectors;
2116

2217
@Service
2318
public class OCIGenAIService {
19+
20+
Logger log = LoggerFactory.getLogger(OCIGenAIService.class);
21+
2422
@Value("${genai.compartment_id}")
2523
private String COMPARTMENT_ID;
2624

2725
@Autowired
2826
private GenAiInferenceClientService generativeAiInferenceClientService;
2927

30-
public String resolvePrompt(String input, String modelId, boolean finetune) {
31-
CohereChatRequest cohereChatRequest = CohereChatRequest.builder()
32-
.message(input)
33-
.maxTokens(600)
34-
.temperature((double) 1)
35-
.frequencyPenalty((double) 0)
36-
.topP((double) 0.75)
37-
.topK(0)
38-
.isStream(false) // TODO websockets and streams
39-
.build();
28+
@Autowired
29+
private GenAIModelsService genAIModelsService;
4030

41-
ChatDetails chatDetails = ChatDetails.builder()
42-
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
43-
.compartmentId(COMPARTMENT_ID)
44-
.chatRequest(cohereChatRequest)
45-
.build();
31+
public String resolvePrompt(String input, String modelId, boolean finetune, boolean summarization) {
32+
33+
List<GenAiModel> models = genAIModelsService.getModels();
34+
GenAiModel currentModel = models.stream()
35+
.filter(m-> modelId.equals(m.id()))
36+
.findFirst()
37+
.orElseThrow();
38+
39+
log.info("Model {} with finetune {}", currentModel.name(), finetune? "yes" : "no");
40+
41+
double temperature = summarization?0.0:0.5;
42+
43+
String inputText = summarization?"Summarize this text:\n" + input: input;
44+
45+
ChatDetails chatDetails;
46+
switch (currentModel.vendor()) {
47+
case "cohere":
48+
CohereChatRequest cohereChatRequest = CohereChatRequest.builder()
49+
.message(inputText)
50+
.maxTokens(600)
51+
.temperature(temperature)
52+
.frequencyPenalty((double) 0)
53+
.topP(0.75)
54+
.topK(0)
55+
.isStream(false) // TODO websockets and streams
56+
.build();
57+
58+
chatDetails = ChatDetails.builder()
59+
.servingMode(OnDemandServingMode.builder().modelId(currentModel.id()).build())
60+
.compartmentId(COMPARTMENT_ID)
61+
.chatRequest(cohereChatRequest)
62+
.build();
63+
break;
64+
case "meta":
65+
ChatContent content = TextContent.builder()
66+
.text(inputText)
67+
.build();
68+
List<ChatContent> contents = new ArrayList<>();
69+
contents.add(content);
70+
List<Message> messages = new ArrayList<>();
71+
Message message = new UserMessage(contents, "user");
72+
messages.add(message);
73+
GenericChatRequest genericChatRequest = GenericChatRequest.builder()
74+
.messages(messages)
75+
.maxTokens(600)
76+
.temperature((double)1)
77+
.frequencyPenalty((double)0)
78+
.presencePenalty((double)0)
79+
.topP(0.75)
80+
.topK(-1)
81+
.isStream(false)
82+
.build();
83+
chatDetails = ChatDetails.builder()
84+
.servingMode(OnDemandServingMode.builder().modelId(currentModel.id()).build())
85+
.compartmentId(COMPARTMENT_ID)
86+
.chatRequest(genericChatRequest)
87+
.build();
88+
break;
89+
default:
90+
throw new IllegalStateException("Unexpected value: " + currentModel.vendor());
91+
}
4692

4793
ChatRequest request = ChatRequest.builder()
4894
.chatDetails(chatDetails)
@@ -65,7 +111,7 @@ public String resolvePrompt(String input, String modelId, boolean finetune) {
65111
}
66112

67113
public String summaryText(String input, String modelId, boolean finetuned) {
68-
String response = resolvePrompt("Summarize this:\n" + input, modelId, finetuned);
114+
String response = resolvePrompt(input, modelId, finetuned, true);
69115
return response;
70116
}
71117
}

0 commit comments

Comments
 (0)