Skip to content

New chat api #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions app/src/components/content/summary.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const hostname =
? "localhost:8080"
: window.location.hostname;
const serviceRootURL = `${protocol}${hostname}`;
const acceptArr: string[] = ["application/pdf", "*.pdf"];
const acceptArr: string[] = ["application/pdf", "*.pdf", "text/plain", "*.txt"];
const messages: { id: number; severity: string; summary: string }[] = [];
const FILE_SIZE = 120000;

Expand Down Expand Up @@ -271,7 +271,7 @@ export const Summary = ({
<div class="oj-flex-item oj-sm-margin-4x">
<h1>Document Summarization</h1>
<div class="oj-typography-body-md oj-sm-padding-1x-bottom">
Upload a PDF file
Upload a PDF/TXT file
</div>
<oj-validation-group ref={valGroupRef}>
<oj-c-file-picker
Expand All @@ -283,7 +283,7 @@ export const Summary = ({
onojBeforeSelect={beforeSelectListener}
secondaryText={`Maximum file size is ${
FILE_SIZE / 1000
}KB per PDF file.`}
}KB per PDF or TXT file.`}
></oj-c-file-picker>
{backendType === "python" && (
<oj-c-input-text
Expand Down
13 changes: 7 additions & 6 deletions backend/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@ dependencies {
implementation 'org.springframework.boot:spring-boot-starter-data-rest'
implementation 'org.springframework.boot:spring-boot-starter-websocket'
implementation 'org.springframework.boot:spring-boot-starter-actuator'
implementation 'com.oracle.oci.sdk:oci-java-sdk-shaded-full:3.33.0'
implementation 'com.oracle.oci.sdk:oci-java-sdk-core:3.35.0'
implementation 'com.oracle.oci.sdk:oci-java-sdk-common:3.35.0'
implementation 'com.oracle.oci.sdk:oci-java-sdk-addons-oke-workload-identity:3.35.0'
implementation 'com.oracle.oci.sdk:oci-java-sdk-generativeai:3.35.0'
implementation 'org.netbeans.external:org-apache-commons-io:RELEASE113'
implementation 'com.oracle.oci.sdk:oci-java-sdk-shaded-full:3.52.1'
implementation 'com.oracle.oci.sdk:oci-java-sdk-core:3.52.1'
implementation 'com.oracle.oci.sdk:oci-java-sdk-common:3.52.1'
implementation 'com.oracle.oci.sdk:oci-java-sdk-addons-oke-workload-identity:3.52.1'
implementation 'com.oracle.oci.sdk:oci-java-sdk-generativeai:3.52.1'
implementation 'com.oracle.database.jdbc:ojdbc11-production:21.8.0.0'
implementation 'com.oracle.database.jdbc:ucp:21.8.0.0'
implementation 'com.oracle.database.security:oraclepki:21.8.0.0'
implementation 'com.oracle.database.security:osdt_cert:21.8.0.0'
implementation 'com.oracle.database.security:osdt_core:21.8.0.0'
implementation 'org.apache.pdfbox:pdfbox:3.0.2' exclude(group: 'commons-logging', module: 'commons-logging')
implementation 'org.apache.pdfbox:pdfbox:3.0.3' exclude(group: 'commons-logging', module: 'commons-logging')
testImplementation 'org.springframework.boot:spring-boot-starter-test'
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import com.oracle.bmc.generativeai.requests.ListEndpointsRequest;
import com.oracle.bmc.generativeai.responses.ListModelsResponse;
import com.oracle.bmc.generativeai.responses.ListEndpointsResponse;
import com.oracle.bmc.generativeai.model.EndpointSummary;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiEndpoint;
import dev.victormartin.oci.genai.backend.backend.service.GenerativeAiClientService;
import dev.victormartin.oci.genai.backend.backend.service.GenAiClientService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -28,7 +27,7 @@ public class GenAIController {
private String COMPARTMENT_ID;

@Autowired
private GenerativeAiClientService generativeAiClientService;
private GenAiClientService generativeAiClientService;

@GetMapping("/api/genai/models")
public List<GenAiModel> getModels() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import dev.victormartin.oci.genai.backend.backend.data.InteractionType;
import dev.victormartin.oci.genai.backend.backend.service.OCIGenAIService;
import dev.victormartin.oci.genai.backend.backend.service.PDFConvertorService;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -22,6 +23,8 @@
import org.springframework.web.util.HtmlUtils;

import java.io.File;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Date;

@RestController
Expand Down Expand Up @@ -49,6 +52,7 @@ public Answer fileUploading(@RequestParam("file") MultipartFile multipartFile,
@RequestHeader("modelId") String modelId) {
String filename = StringUtils.cleanPath(multipartFile.getOriginalFilename());
log.info("File uploaded {} {} bytes ({})", filename, multipartFile.getSize(), multipartFile.getContentType());
String contentType = multipartFile.getContentType();// application/pdf
try {
if (filename.contains("..")) {
throw new Exception("Filename contains invalid path sequence");
Expand All @@ -60,7 +64,18 @@ public Answer fileUploading(@RequestParam("file") MultipartFile multipartFile,
File file = new File(fileDestinationPath + File.separator + filename);
multipartFile.transferTo(file);
log.info("File destination path: {}", file.getAbsolutePath());
String convertedText = pdfConvertorService.convert(file.getAbsolutePath());
String convertedText;
switch (contentType) {
case "text/plain":
convertedText = FileUtils.readFileToString(file, StandardCharsets.UTF_8);
break;
case "application/pdf":
convertedText = pdfConvertorService.convert(file.getAbsolutePath());
break;
default:
convertedText= "";
break;
}
String textEscaped = HtmlUtils.htmlEscape(convertedText);
Interaction interaction = new Interaction();
interaction.setType(InteractionType.SUMMARY_FILE);
Expand All @@ -69,7 +84,7 @@ public Answer fileUploading(@RequestParam("file") MultipartFile multipartFile,
interaction.setModelId(summarizationModelId);
interaction.setRequest(textEscaped);
Interaction saved = interactionRepository.save(interaction);
String summaryText = ociGenAIService.summaryText(textEscaped, summarizationModelId);
String summaryText = ociGenAIService.summaryText(textEscaped, summarizationModelId, false);
saved.setDatetimeResponse(new Date());
saved.setResponse(summaryText);
interactionRepository.save(saved);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public Answer postSummaryText(@RequestBody SummaryRequest summaryRequest,
interaction.setRequest(contentEscaped);
Interaction saved = interactionRepository.save(interaction);
try {
String summaryText = ociGenAIService.summaryText(contentEscaped, summarizationModelId);
String summaryText = ociGenAIService.summaryText(contentEscaped, summarizationModelId, false);
saved.setDatetimeResponse(new Date());
saved.setResponse(summaryText);
interactionRepository.save(saved);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import java.io.IOException;

@Service
public class GenerativeAiClientService {
public class GenAiClientService {

Logger log = LoggerFactory.getLogger(GenerativeAiClientService.class);
Logger log = LoggerFactory.getLogger(GenAiClientService.class);

@Autowired
private Environment environment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import com.oracle.bmc.auth.ConfigFileAuthenticationDetailsProvider;
import com.oracle.bmc.auth.InstancePrincipalsAuthenticationDetailsProvider;
import com.oracle.bmc.auth.okeworkloadidentity.OkeWorkloadIdentityAuthenticationDetailsProvider;
import com.oracle.bmc.generativeai.GenerativeAiClient;
import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient;
import jakarta.annotation.PostConstruct;
import org.slf4j.Logger;
Expand All @@ -19,9 +18,9 @@
import java.io.IOException;

@Service
public class GenerativeAiInferenceClientService {
public class GenAiInferenceClientService {

Logger log = LoggerFactory.getLogger(GenerativeAiInferenceClientService.class);
Logger log = LoggerFactory.getLogger(GenAiInferenceClientService.class);

private GenerativeAiInferenceClient client;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@

import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient;
import com.oracle.bmc.generativeaiinference.model.*;
import com.oracle.bmc.generativeaiinference.requests.ChatRequest;
import com.oracle.bmc.generativeaiinference.requests.GenerateTextRequest;
import com.oracle.bmc.generativeaiinference.requests.SummarizeTextRequest;
import com.oracle.bmc.generativeaiinference.responses.ChatResponse;
import com.oracle.bmc.generativeaiinference.responses.GenerateTextResponse;
import com.oracle.bmc.generativeaiinference.responses.SummarizeTextResponse;
import com.oracle.bmc.http.client.jersey.WrappedResponseInputStream;
import org.hibernate.boot.archive.scan.internal.StandardScanner;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.stream.Collectors;

@Service
Expand All @@ -18,52 +25,47 @@ public class OCIGenAIService {
private String COMPARTMENT_ID;

@Autowired
private GenerativeAiInferenceClientService generativeAiInferenceClientService;
private GenAiInferenceClientService generativeAiInferenceClientService;

public String resolvePrompt(String input, String modelId, boolean finetune) {
// Build generate text request, send, and get response
CohereLlmInferenceRequest llmInferenceRequest = CohereLlmInferenceRequest.builder()
.prompt(input)
.maxTokens(600)
.temperature((double) 1)
.frequencyPenalty((double) 0)
.topP((double) 0.75)
.isStream(false)
.isEcho(false)
.build();
CohereChatRequest cohereChatRequest = CohereChatRequest.builder()
.message(input)
.maxTokens(600)
.temperature((double) 1)
.frequencyPenalty((double) 0)
.topP((double) 0.75)
.topK(0)
.isStream(false) // TODO websockets and streams
.build();

GenerateTextDetails generateTextDetails = GenerateTextDetails.builder()
.servingMode(finetune ? DedicatedServingMode.builder().endpointId(modelId).build()
: OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.inferenceRequest(llmInferenceRequest)
.build();
GenerateTextRequest generateTextRequest = GenerateTextRequest.builder()
.generateTextDetails(generateTextDetails)
.build();
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
GenerateTextResponse generateTextResponse = client.generateText(generateTextRequest);
CohereLlmInferenceResponse response = (CohereLlmInferenceResponse) generateTextResponse
.getGenerateTextResult().getInferenceResponse();
String responseTexts = response.getGeneratedTexts()
.stream()
.map(t -> t.getText())
.collect(Collectors.joining(","));
return responseTexts;
ChatDetails chatDetails = ChatDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.chatRequest(cohereChatRequest)
.build();

ChatRequest request = ChatRequest.builder()
.chatDetails(chatDetails)
.build();
ChatResponse response = generativeAiInferenceClientService.getClient().chat(request);
ChatResult chatResult = response.getChatResult();

BaseChatResponse baseChatResponse = chatResult.getChatResponse();
if (baseChatResponse instanceof CohereChatResponse) {
return ((CohereChatResponse)baseChatResponse).getText();
} else if (baseChatResponse instanceof GenericChatResponse) {
List<ChatChoice> choices = ((GenericChatResponse) baseChatResponse).getChoices();
List<ChatContent> contents = choices.get(choices.size() - 1).getMessage().getContent();
ChatContent content = contents.get(contents.size() - 1);
if (content instanceof TextContent) {
return ((TextContent) content).getText();
}
}
throw new IllegalStateException("Unexpected chat response type: " + baseChatResponse.getClass().getName());
}

public String summaryText(String input, String modelId) {
SummarizeTextDetails summarizeTextDetails = SummarizeTextDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.input(input)
.build();
SummarizeTextRequest request = SummarizeTextRequest.builder()
.summarizeTextDetails(summarizeTextDetails)
.build();
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
SummarizeTextResponse summarizeTextResponse = client.summarizeText(request);
String summaryText = summarizeTextResponse.getSummarizeTextResult().getSummary();
return summaryText;
public String summaryText(String input, String modelId, boolean finetuned) {
String response = resolvePrompt("Summarize this:\n" + input, modelId, finetuned);
return response;
}
}
Loading