Skip to content

Commit 7a41b57

Browse files
authored
Merge pull request #34 from oracle-devrel/backendSupportforDedicatedServingMode
updating to support dedicated serving mode with fine-tuned models
2 parents aea7905 + 0343695 commit 7a41b57

File tree

8 files changed

+170
-67
lines changed

8 files changed

+170
-67
lines changed

app/src/components/content/index.tsx

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ type Chat = {
2525
answer?: string;
2626
loading?: string;
2727
};
28+
type Model = {
29+
id: string;
30+
name: string;
31+
vendor: string;
32+
version: string;
33+
capabilities: Array<string>;
34+
timeCreated: string;
35+
};
2836

2937
const defaultServiceType: string = localStorage.getItem("service") || "text";
3038
const defaultBackendType: string = localStorage.getItem("backend") || "java";
@@ -46,6 +54,7 @@ const Content = () => {
4654
const question = useRef<string>();
4755
const chatData = useRef<Array<object>>([]);
4856
const socket = useRef<WebSocket>();
57+
const finetune = useRef<boolean>(false);
4958
const [client, setClient] = useState<Client | null>(null);
5059

5160
const messagesDP = useRef(
@@ -167,7 +176,13 @@ const Content = () => {
167176
JSON.stringify({ msgType: "question", data: question.current })
168177
);
169178
} else {
170-
sendPrompt(client, question.current!, modelId!, conversationId!);
179+
sendPrompt(
180+
client,
181+
question.current!,
182+
modelId!,
183+
conversationId!,
184+
finetune.current
185+
);
171186
}
172187
}
173188
};
@@ -199,9 +214,9 @@ const Content = () => {
199214
localStorage.setItem("backend", backend);
200215
location.reload();
201216
};
202-
const modelIdChangeHandler = (event: CustomEvent) => {
203-
console.log("model Id: ", event.detail.value);
204-
if (event.detail.value != null) setModelId(event.detail.value);
217+
const modelIdChangeHandler = (value: string, modelType: boolean) => {
218+
if (value != null) setModelId(value);
219+
finetune.current = modelType;
205220
};
206221
const clearSummary = () => {
207222
setSummaryResults("");

app/src/components/content/settings.tsx

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import "oj-c/select-single";
55
import "ojs/ojlistitemlayout";
66
import "ojs/ojhighlighttext";
77
import MutableArrayDataProvider = require("ojs/ojmutablearraydataprovider");
8+
import { ojSelectSingle } from "@oracle/oraclejet/ojselectsingle";
89

910
type ServiceTypeVal = "text" | "summary" | "sim";
1011
type BackendTypeVal = "java" | "python";
@@ -17,7 +18,7 @@ type Props = {
1718
backendType: BackendTypeVal;
1819
aiServiceChange: (service: ServiceTypeVal) => void;
1920
backendChange: (backend: BackendTypeVal) => void;
20-
modelIdChange: (modelName: any) => void;
21+
modelIdChange: (modelId: any, modelData: any) => void;
2122
};
2223

2324
const serviceTypes = [
@@ -30,6 +31,21 @@ const backendTypes = [
3031
{ value: "java", label: "Java" },
3132
{ value: "python", label: "Python" },
3233
];
34+
type Model = {
35+
id: string;
36+
name: string;
37+
vendor: string;
38+
version: string;
39+
capabilities: Array<string>;
40+
timeCreated: string;
41+
};
42+
type Endpoint = {
43+
id: string;
44+
name: string;
45+
state: string;
46+
model: string;
47+
timeCreated: string;
48+
};
3349
const serviceOptionsDP = new MutableArrayDataProvider<
3450
Services["value"],
3551
Services
@@ -50,8 +66,11 @@ export const Settings = (props: Props) => {
5066
};
5167

5268
const modelDP = useRef(
53-
new MutableArrayDataProvider<string, {}>([], { keyAttributes: "id" })
69+
new MutableArrayDataProvider<string, {}>([], {
70+
keyAttributes: "id",
71+
})
5472
);
73+
const endpoints = useRef<Array<Endpoint>>();
5574

5675
const fetchModels = async () => {
5776
try {
@@ -60,9 +79,8 @@ export const Settings = (props: Props) => {
6079
throw new Error(`Response status: ${response.status}`);
6180
}
6281
const json = await response.json();
63-
const result = json.filter((model: any) => {
82+
const result = json.filter((model: Model) => {
6483
if (
65-
// model.capabilities.includes("FINE_TUNE") &&
6684
model.capabilities.includes("TEXT_GENERATION") &&
6785
(model.vendor == "cohere" || model.vendor == "") &&
6886
model.version != "14.2"
@@ -77,11 +95,55 @@ export const Settings = (props: Props) => {
7795
);
7896
}
7997
};
98+
const fetchEndpoints = async () => {
99+
try {
100+
const response = await fetch("/api/genai/endpoints");
101+
if (!response.ok) {
102+
throw new Error(`Response status: ${response.status}`);
103+
}
104+
const json = await response.json();
105+
const result = json.filter((endpoint: Endpoint) => {
106+
// add filtering code here
107+
return endpoint;
108+
});
109+
endpoints.current = result;
110+
} catch (error: any) {
111+
console.log(
112+
"Java service not available for fetching list of Endpoints: ",
113+
error.message
114+
);
115+
}
116+
};
80117

81118
useEffect(() => {
119+
fetchEndpoints();
82120
fetchModels();
83121
}, []);
84122

123+
const modelChangeHandler = async (
124+
event: ojSelectSingle.valueChanged<string, {}>
125+
) => {
126+
let selected = event.detail.value;
127+
let finetune = false;
128+
const asyncIterator = modelDP.current.fetchFirst()[Symbol.asyncIterator]();
129+
let result = await asyncIterator.next();
130+
let value = result.value;
131+
let data = value.data as Array<Model>;
132+
let idx = data.find((e: Model) => {
133+
if (e.id === selected) return e;
134+
});
135+
if (idx?.capabilities.includes("FINE_TUNE")) {
136+
finetune = true;
137+
let endpointId = endpoints.current?.find((e: Endpoint) => {
138+
if (e.model === event.detail.value) {
139+
return e.id;
140+
}
141+
});
142+
selected = endpointId ? endpointId.id : event.detail.value;
143+
}
144+
props.modelIdChange(selected, finetune);
145+
};
146+
85147
const modelTemplate = (item: any) => {
86148
return (
87149
<oj-list-item-layout class="oj-listitemlayout-padding-off">
@@ -134,7 +196,7 @@ export const Settings = (props: Props) => {
134196
data={modelDP.current}
135197
labelHint={"Model"}
136198
itemText={"name"}
137-
onvalueChanged={props.modelIdChange}
199+
onvalueChanged={modelChangeHandler}
138200
>
139201
<template slot="itemTemplate" render={modelTemplate}></template>
140202
</oj-c-select-single>

app/src/components/content/stomp-interface.tsx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ export const sendPrompt = (
124124
client: Client | null,
125125
prompt: string,
126126
modelId: string,
127-
convoId: string
127+
convoId: string,
128+
finetune: boolean
128129
) => {
129130
if (client?.connected) {
130131
console.log("Sending prompt: ", prompt);
@@ -134,6 +135,7 @@ export const sendPrompt = (
134135
conversationId: convoId,
135136
content: prompt,
136137
modelId: modelId,
138+
finetune: finetune,
137139
}),
138140
});
139141
} else {

backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/GenAIController.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
import com.oracle.bmc.generativeai.GenerativeAiClient;
44
import com.oracle.bmc.generativeai.model.ModelCapability;
55
import com.oracle.bmc.generativeai.requests.ListModelsRequest;
6+
import com.oracle.bmc.generativeai.requests.ListEndpointsRequest;
67
import com.oracle.bmc.generativeai.responses.ListModelsResponse;
8+
import com.oracle.bmc.generativeai.responses.ListEndpointsResponse;
9+
import com.oracle.bmc.generativeai.model.EndpointSummary;
710
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
11+
import dev.victormartin.oci.genai.backend.backend.dao.GenAiEndpoint;
812
import dev.victormartin.oci.genai.backend.backend.service.GenerativeAiClientService;
913
import org.slf4j.Logger;
1014
import org.slf4j.LoggerFactory;
@@ -33,11 +37,25 @@ public List<GenAiModel> getModels() {
3337
GenerativeAiClient client = generativeAiClientService.getClient();
3438
ListModelsResponse response = client.listModels(listModelsRequest);
3539
return response.getModelCollection().getItems().stream().map(m -> {
36-
List<String> capabilities = m.getCapabilities().stream().map(ModelCapability::getValue).collect(Collectors.toList());
37-
GenAiModel model = new GenAiModel(m.getId(),m.getDisplayName(), m.getVendor(), m.getVersion(),
38-
capabilities,
39-
m.getTimeCreated());
40+
List<String> capabilities = m.getCapabilities().stream().map(ModelCapability::getValue)
41+
.collect(Collectors.toList());
42+
GenAiModel model = new GenAiModel(m.getId(), m.getDisplayName(), m.getVendor(), m.getVersion(),
43+
capabilities, m.getTimeCreated());
4044
return model;
4145
}).collect(Collectors.toList());
4246
}
47+
48+
@GetMapping("/api/genai/endpoints")
49+
public List<GenAiEndpoint> getEndpoints() {
50+
logger.info("getEndpoints()");
51+
ListEndpointsRequest listEndpointsRequest = ListEndpointsRequest.builder().compartmentId(COMPARTMENT_ID)
52+
.build();
53+
GenerativeAiClient client = generativeAiClientService.getClient();
54+
ListEndpointsResponse response = client.listEndpoints(listEndpointsRequest);
55+
return response.getEndpointCollection().getItems().stream().map(e -> {
56+
GenAiEndpoint endpoint = new GenAiEndpoint(e.getId(), e.getDisplayName(), e.getLifecycleState(),
57+
e.getModelId(), e.getTimeCreated());
58+
return endpoint;
59+
}).collect(Collectors.toList());
60+
}
4361
}

backend/src/main/java/dev/victormartin/oci/genai/backend/backend/controller/PromptController.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public PromptController(InteractionRepository interactionRepository, OCIGenAISer
4545
@SendToUser("/queue/answer")
4646
public Answer handlePrompt(Prompt prompt) {
4747
String promptEscaped = HtmlUtils.htmlEscape(prompt.content());
48+
boolean finetune = prompt.finetune();
4849
String activeModel = (prompt.modelId() == null) ? hardcodedChatModelId : prompt.modelId();
4950
logger.info("Prompt " + promptEscaped + " received, on model " + activeModel);
5051

@@ -59,11 +60,8 @@ public Answer handlePrompt(Prompt prompt) {
5960
if (prompt.content().isEmpty()) {
6061
throw new InvalidPromptRequest();
6162
}
62-
// if (prompt.modelId() == null ||
63-
// !prompt.modelId().startsWith("ocid1.generativeaimodel.")) { throw new
64-
// InvalidPromptRequest(); }
6563
saved.setDatetimeResponse(new Date());
66-
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel);
64+
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune);
6765
saved.setResponse(responseFromGenAI);
6866
interactionRepository.save(saved);
6967
return new Answer(responseFromGenAI, "");
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package dev.victormartin.oci.genai.backend.backend.dao;
2+
3+
import java.util.Date;
4+
import com.oracle.bmc.generativeai.model.Endpoint;
5+
6+
public record GenAiEndpoint(String id, String name, Endpoint.LifecycleState state, String model, Date timeCreated) {
7+
}
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
package dev.victormartin.oci.genai.backend.backend.dao;
22

3-
public record Prompt(String content, String conversationId, String modelId) {};
3+
public record Prompt(String content, String conversationId, String modelId, boolean finetune) {
4+
};

backend/src/main/java/dev/victormartin/oci/genai/backend/backend/service/OCIGenAIService.java

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,56 +14,56 @@
1414

1515
@Service
1616
public class OCIGenAIService {
17-
@Value("${genai.compartment_id}")
18-
private String COMPARTMENT_ID;
17+
@Value("${genai.compartment_id}")
18+
private String COMPARTMENT_ID;
1919

20-
@Autowired
21-
private GenerativeAiInferenceClientService generativeAiInferenceClientService;
20+
@Autowired
21+
private GenerativeAiInferenceClientService generativeAiInferenceClientService;
2222

23-
public String resolvePrompt(String input, String modelId) {
24-
// Build generate text request, send, and get response
25-
CohereLlmInferenceRequest llmInferenceRequest =
26-
CohereLlmInferenceRequest.builder()
27-
.prompt(input)
28-
.maxTokens(600)
29-
.temperature((double)1)
30-
.frequencyPenalty((double)0)
31-
.topP((double)0.75)
32-
.isStream(false)
33-
.isEcho(false)
34-
.build();
23+
public String resolvePrompt(String input, String modelId, boolean finetune) {
24+
// Build generate text request, send, and get response
25+
CohereLlmInferenceRequest llmInferenceRequest = CohereLlmInferenceRequest.builder()
26+
.prompt(input)
27+
.maxTokens(600)
28+
.temperature((double) 1)
29+
.frequencyPenalty((double) 0)
30+
.topP((double) 0.75)
31+
.isStream(false)
32+
.isEcho(false)
33+
.build();
3534

36-
GenerateTextDetails generateTextDetails = GenerateTextDetails.builder()
37-
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
38-
.compartmentId(COMPARTMENT_ID)
39-
.inferenceRequest(llmInferenceRequest)
40-
.build();
41-
GenerateTextRequest generateTextRequest = GenerateTextRequest.builder()
42-
.generateTextDetails(generateTextDetails)
43-
.build();
44-
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
45-
GenerateTextResponse generateTextResponse = client.generateText(generateTextRequest);
46-
CohereLlmInferenceResponse response =
47-
(CohereLlmInferenceResponse) generateTextResponse.getGenerateTextResult().getInferenceResponse();
48-
String responseTexts = response.getGeneratedTexts()
49-
.stream()
50-
.map(t -> t.getText())
51-
.collect(Collectors.joining(","));
52-
return responseTexts;
53-
}
35+
GenerateTextDetails generateTextDetails = GenerateTextDetails.builder()
36+
.servingMode(finetune ? DedicatedServingMode.builder().endpointId(modelId).build()
37+
: OnDemandServingMode.builder().modelId(modelId).build())
38+
.compartmentId(COMPARTMENT_ID)
39+
.inferenceRequest(llmInferenceRequest)
40+
.build();
41+
GenerateTextRequest generateTextRequest = GenerateTextRequest.builder()
42+
.generateTextDetails(generateTextDetails)
43+
.build();
44+
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
45+
GenerateTextResponse generateTextResponse = client.generateText(generateTextRequest);
46+
CohereLlmInferenceResponse response = (CohereLlmInferenceResponse) generateTextResponse
47+
.getGenerateTextResult().getInferenceResponse();
48+
String responseTexts = response.getGeneratedTexts()
49+
.stream()
50+
.map(t -> t.getText())
51+
.collect(Collectors.joining(","));
52+
return responseTexts;
53+
}
5454

55-
public String summaryText(String input, String modelId) {
56-
SummarizeTextDetails summarizeTextDetails = SummarizeTextDetails.builder()
57-
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
58-
.compartmentId(COMPARTMENT_ID)
59-
.input(input)
60-
.build();
61-
SummarizeTextRequest request = SummarizeTextRequest.builder()
62-
.summarizeTextDetails(summarizeTextDetails)
63-
.build();
64-
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
65-
SummarizeTextResponse summarizeTextResponse = client.summarizeText(request);
66-
String summaryText = summarizeTextResponse.getSummarizeTextResult().getSummary();
67-
return summaryText;
68-
}
55+
public String summaryText(String input, String modelId) {
56+
SummarizeTextDetails summarizeTextDetails = SummarizeTextDetails.builder()
57+
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
58+
.compartmentId(COMPARTMENT_ID)
59+
.input(input)
60+
.build();
61+
SummarizeTextRequest request = SummarizeTextRequest.builder()
62+
.summarizeTextDetails(summarizeTextDetails)
63+
.build();
64+
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
65+
SummarizeTextResponse summarizeTextResponse = client.summarizeText(request);
66+
String summaryText = summarizeTextResponse.getSummarizeTextResult().getSummary();
67+
return summaryText;
68+
}
6969
}

0 commit comments

Comments
 (0)