1
1
package dev .victormartin .oci .genai .backend .backend .service ;
2
2
3
- import com .oracle .bmc .generativeaiinference .GenerativeAiInferenceClient ;
4
3
import com .oracle .bmc .generativeaiinference .model .*;
4
+ import com .oracle .bmc .generativeaiinference .model .Message ;
5
5
import com .oracle .bmc .generativeaiinference .requests .ChatRequest ;
6
- import com .oracle .bmc .generativeaiinference .requests .GenerateTextRequest ;
7
- import com .oracle .bmc .generativeaiinference .requests .SummarizeTextRequest ;
8
6
import com .oracle .bmc .generativeaiinference .responses .ChatResponse ;
9
- import com .oracle .bmc .generativeaiinference .responses .GenerateTextResponse ;
10
- import com .oracle .bmc .generativeaiinference .responses .SummarizeTextResponse ;
11
- import com .oracle .bmc .http .client .jersey .WrappedResponseInputStream ;
12
- import org .hibernate .boot .archive .scan .internal .StandardScanner ;
7
+ import dev .victormartin .oci .genai .backend .backend .dao .GenAiModel ;
8
+ import org .slf4j .Logger ;
9
+ import org .slf4j .LoggerFactory ;
13
10
import org .springframework .beans .factory .annotation .Autowired ;
14
11
import org .springframework .beans .factory .annotation .Value ;
15
12
import org .springframework .stereotype .Service ;
16
13
17
- import java .io .*;
18
- import java .nio .charset .StandardCharsets ;
14
+ import java .util .ArrayList ;
19
15
import java .util .List ;
20
- import java .util .stream .Collectors ;
21
16
22
17
@ Service
23
18
public class OCIGenAIService {
19
+
20
+ Logger log = LoggerFactory .getLogger (OCIGenAIService .class );
21
+
24
22
@ Value ("${genai.compartment_id}" )
25
23
private String COMPARTMENT_ID ;
26
24
27
25
@ Autowired
28
26
private GenAiInferenceClientService generativeAiInferenceClientService ;
29
27
30
- public String resolvePrompt (String input , String modelId , boolean finetune ) {
31
- CohereChatRequest cohereChatRequest = CohereChatRequest .builder ()
32
- .message (input )
33
- .maxTokens (600 )
34
- .temperature ((double ) 1 )
35
- .frequencyPenalty ((double ) 0 )
36
- .topP ((double ) 0.75 )
37
- .topK (0 )
38
- .isStream (false ) // TODO websockets and streams
39
- .build ();
28
+ @ Autowired
29
+ private GenAIModelsService genAIModelsService ;
40
30
41
- ChatDetails chatDetails = ChatDetails .builder ()
42
- .servingMode (OnDemandServingMode .builder ().modelId (modelId ).build ())
43
- .compartmentId (COMPARTMENT_ID )
44
- .chatRequest (cohereChatRequest )
45
- .build ();
31
+ public String resolvePrompt (String input , String modelId , boolean finetune , boolean summarization ) {
32
+
33
+ List <GenAiModel > models = genAIModelsService .getModels ();
34
+ GenAiModel currentModel = models .stream ()
35
+ .filter (m -> modelId .equals (m .id ()))
36
+ .findFirst ()
37
+ .orElseThrow ();
38
+
39
+ log .info ("Model {} with finetune {}" , currentModel .name (), finetune ? "yes" : "no" );
40
+
41
+ double temperature = summarization ?0.0 :0.5 ;
42
+
43
+ String inputText = summarization ?"Summarize this text:\n " + input : input ;
44
+
45
+ ChatDetails chatDetails ;
46
+ switch (currentModel .vendor ()) {
47
+ case "cohere" :
48
+ CohereChatRequest cohereChatRequest = CohereChatRequest .builder ()
49
+ .message (inputText )
50
+ .maxTokens (600 )
51
+ .temperature (temperature )
52
+ .frequencyPenalty ((double ) 0 )
53
+ .topP (0.75 )
54
+ .topK (0 )
55
+ .isStream (false ) // TODO websockets and streams
56
+ .build ();
57
+
58
+ chatDetails = ChatDetails .builder ()
59
+ .servingMode (OnDemandServingMode .builder ().modelId (currentModel .id ()).build ())
60
+ .compartmentId (COMPARTMENT_ID )
61
+ .chatRequest (cohereChatRequest )
62
+ .build ();
63
+ break ;
64
+ case "meta" :
65
+ ChatContent content = TextContent .builder ()
66
+ .text (inputText )
67
+ .build ();
68
+ List <ChatContent > contents = new ArrayList <>();
69
+ contents .add (content );
70
+ List <Message > messages = new ArrayList <>();
71
+ Message message = new UserMessage (contents , "user" );
72
+ messages .add (message );
73
+ GenericChatRequest genericChatRequest = GenericChatRequest .builder ()
74
+ .messages (messages )
75
+ .maxTokens (600 )
76
+ .temperature ((double )1 )
77
+ .frequencyPenalty ((double )0 )
78
+ .presencePenalty ((double )0 )
79
+ .topP (0.75 )
80
+ .topK (-1 )
81
+ .isStream (false )
82
+ .build ();
83
+ chatDetails = ChatDetails .builder ()
84
+ .servingMode (OnDemandServingMode .builder ().modelId (currentModel .id ()).build ())
85
+ .compartmentId (COMPARTMENT_ID )
86
+ .chatRequest (genericChatRequest )
87
+ .build ();
88
+ break ;
89
+ default :
90
+ throw new IllegalStateException ("Unexpected value: " + currentModel .vendor ());
91
+ }
46
92
47
93
ChatRequest request = ChatRequest .builder ()
48
94
.chatDetails (chatDetails )
@@ -65,7 +111,7 @@ public String resolvePrompt(String input, String modelId, boolean finetune) {
65
111
}
66
112
67
113
public String summaryText (String input , String modelId , boolean finetuned ) {
68
- String response = resolvePrompt ("Summarize this: \n " + input , modelId , finetuned );
114
+ String response = resolvePrompt (input , modelId , finetuned , true );
69
115
return response ;
70
116
}
71
117
}
0 commit comments