1
1
package oracleai .services ;
2
2
3
- import com .oracle .bmc .Region ;
3
+ import java .util .ArrayList ;
4
+ import java .util .List ;
5
+ import java .util .Objects ;
6
+ import java .util .stream .Collectors ;
7
+
8
+ import com .oracle .bmc .auth .BasicAuthenticationDetailsProvider ;
9
+ import com .oracle .bmc .generativeaiinference .GenerativeAiInference ;
4
10
import com .oracle .bmc .generativeaiinference .GenerativeAiInferenceClient ;
5
- import com .oracle .bmc .generativeaiinference .model .CohereLlmInferenceRequest ;
6
- import com .oracle .bmc .generativeaiinference .model .GenerateTextDetails ;
7
- import com .oracle .bmc .generativeaiinference .model .OnDemandServingMode ;
8
- import com .oracle .bmc .generativeaiinference .requests .GenerateTextRequest ;
9
- import com .oracle .bmc .generativeaiinference .responses .GenerateTextResponse ;
10
- import com .oracle .bmc .generativeaiinference .responses .GenerateTextResponse ;
11
+ import com .oracle .bmc .generativeaiinference .model .*;
12
+ import com .oracle .bmc .generativeaiinference .requests .ChatRequest ;
13
+ import com .oracle .bmc .generativeaiinference .responses .ChatResponse ;
14
+ import lombok .Builder ;
15
+ import lombok .Getter ;
16
+ import oracleai .AIApplication ;
17
+
18
+ /**
19
+ * OCI GenAI Chat
20
+ */
21
+ public class OracleGenAI {
22
+ private final GenerativeAiInference client ;
23
+ private final ServingMode servingMode ;
24
+ private final String compartment ;
25
+ private final String preambleOverride ;
26
+ private final Double temperature ;
27
+ private final Double frequencyPenalty ;
28
+ private final Integer maxTokens ;
29
+ private final Double presencePenalty ;
30
+ private final Double topP ;
31
+ private final Integer topK ;
32
+ private final InferenceRequestType inferenceRequestType ;
33
+ private List <CohereMessage > cohereChatMessages ;
34
+ private List <ChatChoice > genericChatMessages ;
11
35
36
+ @ Builder
37
+ public OracleGenAI (BasicAuthenticationDetailsProvider authProvider ,
38
+ ServingMode servingMode ,
39
+ String compartment ,
40
+ String preambleOverride ,
41
+ Double temperature ,
42
+ Double frequencyPenalty ,
43
+ Integer maxTokens ,
44
+ Double presencePenalty ,
45
+ Double topP ,
46
+ Integer topK ,
47
+ InferenceRequestType inferenceRequestType ) throws Exception {
12
48
13
- import oracleai .AIApplication ;
14
49
15
50
16
- public class OracleGenAI {
51
+ this .client = GenerativeAiInferenceClient .builder ()
52
+ .build (AuthProvider .getAuthenticationDetailsProvider ());
53
+ this .servingMode = servingMode ;
54
+ this .compartment = compartment ;
55
+ this .preambleOverride = preambleOverride ;
17
56
18
- public static String chat (String textcontent ) throws Exception {
19
- return new OracleGenAI ().doChat (textcontent );
57
+ this .temperature = Objects .requireNonNullElse (temperature , 1.0 );
58
+ this .frequencyPenalty = Objects .requireNonNullElse (
59
+ frequencyPenalty ,
60
+ 0.0
61
+ );
62
+ this .maxTokens = Objects .requireNonNullElse (maxTokens , 600 );
63
+ this .presencePenalty = Objects .requireNonNullElse (
64
+ presencePenalty ,
65
+ 0.0
66
+ );
67
+ this .topP = Objects .requireNonNullElse (topP , 0.75 );
68
+ this .inferenceRequestType = Objects .requireNonNullElse (
69
+ inferenceRequestType ,
70
+ InferenceRequestType .COHERE
71
+ );
72
+ this .topK = Objects .requireNonNullElseGet (topK , () -> {
73
+ if (this .inferenceRequestType == InferenceRequestType .COHERE ) {
74
+ return 0 ;
75
+ }
76
+ return -1 ;
77
+ });
20
78
}
21
-
22
- public String doChat (String textcontent ) throws Exception {
23
- final GenerativeAiInferenceClient generativeAiInferenceClient =
24
- new GenerativeAiInferenceClient (AuthProvider .getAuthenticationDetailsProvider ());
25
- // generativeAiInferenceClient.setEndpoint(ENDPOINT);
26
- generativeAiInferenceClient .setRegion (Region .US_CHICAGO_1 );
27
- CohereLlmInferenceRequest cohereLlmInferenceRequest =
28
- CohereLlmInferenceRequest .builder ()
29
- .prompt (textcontent )
30
- .maxTokens (600 )
31
- .temperature (0.75 )
32
- .frequencyPenalty (1.0 )
33
- .topP (0.7 )
34
- .isStream (false ) // SDK doesn't support streaming responses, feature is under development
35
- .isEcho (true )
36
- .build ();
37
- GenerateTextDetails generateTextDetails = GenerateTextDetails .builder ()
38
- .servingMode (OnDemandServingMode .builder ().modelId ("cohere.command" ).build ()) // "cohere.command-light" is also available to use
39
- // .servingMode(DedicatedServingMode.builder().endpointId("custom-model-endpoint").build()) // for custom model from Dedicated AI Cluster
40
- .compartmentId (AIApplication .COMPARTMENT_ID )
41
- .inferenceRequest (cohereLlmInferenceRequest )
79
+
80
+ @ Getter
81
+ public enum InferenceRequestType {
82
+ COHERE ("COHERE" ),
83
+ LLAMA ("LLAMA" );
84
+
85
+ private final String type ;
86
+
87
+ InferenceRequestType (String type ) {
88
+ this .type = type ;
89
+ }
90
+ }
91
+
92
+ /**
93
+ * Chat using OCI GenAI.
94
+ * @param prompt Prompt text sent to OCI GenAI chat model.
95
+ * @return OCI GenAI ChatResponse
96
+ */
97
+ public String chat1 (String prompt ) {
98
+ return "whateve" ;
99
+ }
100
+
101
+ public String chat (String prompt ) {
102
+ ChatDetails chatDetails = ChatDetails .builder ()
103
+ // .compartmentId(AIApplication.COMPARTMENT_ID)
104
+ .compartmentId (compartment )
105
+ // .servingMode(OnDemandServingMode.builder().build())
106
+ .servingMode (servingMode )
107
+ .chatRequest (createChatRequest (prompt ))
42
108
.build ();
43
- GenerateTextRequest generateTextRequest = GenerateTextRequest .builder ()
44
- .generateTextDetails ( generateTextDetails )
109
+ ChatRequest chatRequest = ChatRequest .builder ()
110
+ .body$ ( chatDetails )
45
111
.build ();
46
- GenerateTextResponse generateTextResponse = generativeAiInferenceClient .generateText (generateTextRequest );
47
- System .out .println (generateTextResponse .toString ());
48
- return generateTextResponse .toString ();
112
+ ChatResponse response = client .chat (chatRequest );
113
+ saveChatHistory (response );
114
+ return extractText (response );
115
+ }
116
+
117
+ /**
118
+ * Create a ChatRequest from a text prompt. Supports COHERE or LLAMA inference.
119
+ * @param prompt To create a ChatRequest from.
120
+ * @return A COHERE or LLAMA ChatRequest.
121
+ */
122
+ private BaseChatRequest createChatRequest (String prompt ) {
123
+ switch (inferenceRequestType ) {
124
+ case COHERE :
125
+ return CohereChatRequest .builder ()
126
+ .frequencyPenalty (frequencyPenalty )
127
+ .maxTokens (maxTokens )
128
+ .presencePenalty (presencePenalty )
129
+ .message (prompt )
130
+ .temperature (temperature )
131
+ .topP (topP )
132
+ .topK (topK )
133
+ .chatHistory (cohereChatMessages )
134
+ .preambleOverride (preambleOverride )
135
+ .build ();
136
+ case LLAMA :
137
+ List <Message > messages = genericChatMessages == null ?
138
+ new ArrayList <>() :
139
+ genericChatMessages .stream ()
140
+ .map (ChatChoice ::getMessage )
141
+ .collect (Collectors .toList ());
142
+ ChatContent content = TextContent .builder ()
143
+ .text (prompt )
144
+ .build ();
145
+ List <ChatContent > contents = new ArrayList <>();
146
+ contents .add (content );
147
+ UserMessage message = UserMessage .builder ()
148
+ .name ("USER" )
149
+ .content (contents )
150
+ .build ();
151
+ messages .add (message );
152
+ return GenericChatRequest .builder ()
153
+ .messages (messages )
154
+ .frequencyPenalty (frequencyPenalty )
155
+ .temperature (temperature )
156
+ .maxTokens (maxTokens )
157
+ .presencePenalty (presencePenalty )
158
+ .topP (topP )
159
+ .topK (topK )
160
+ .build ();
161
+ }
162
+
163
+ throw new IllegalArgumentException (String .format (
164
+ "Unknown request type %s" ,
165
+ inferenceRequestType
166
+ ));
49
167
}
50
168
51
- }
169
+ /**
170
+ * Save the current chat history to memory.
171
+ * @param chatResponse The latest chat response.
172
+ */
173
+ private void saveChatHistory (ChatResponse chatResponse ) {
174
+ BaseChatResponse bcr = chatResponse .getChatResult ()
175
+ .getChatResponse ();
176
+ if (bcr instanceof CohereChatResponse resp ) {
177
+ cohereChatMessages = resp .getChatHistory ();
178
+ } else if (bcr instanceof GenericChatResponse resp ) {
179
+ genericChatMessages = resp .getChoices ();
180
+ } else {
181
+ throw new IllegalStateException (String .format (
182
+ "Unexpected chat response type: %s" ,
183
+ bcr .getClass ().getName ()
184
+ ));
185
+ }
186
+ }
187
+
188
+ /**
189
+ * Extract text from an OCI GenAI ChatResponse.
190
+ * @param chatResponse The response to extract text from.
191
+ * @return The chat response text.
192
+ */
193
+ private String extractText (ChatResponse chatResponse ) {
194
+ BaseChatResponse bcr = chatResponse
195
+ .getChatResult ()
196
+ .getChatResponse ();
197
+ if (bcr instanceof CohereChatResponse resp ) {
198
+ return resp .getText ();
199
+ } else if (bcr instanceof GenericChatResponse resp ) {
200
+ List <ChatChoice > choices = resp .getChoices ();
201
+ List <ChatContent > contents = choices .get (choices .size () - 1 )
202
+ .getMessage ()
203
+ .getContent ();
204
+ ChatContent content = contents .get (contents .size () - 1 );
205
+ if (content instanceof TextContent ) {
206
+ return ((TextContent ) content ).getText ();
207
+ }
208
+ }
209
+ throw new IllegalStateException (String .format (
210
+ "Unexpected chat response type: %s" ,
211
+ bcr .getClass ().getName ()
212
+ ));
213
+ }
214
+ }
0 commit comments