Skip to content

Commit 9bcefb3

Browse files
homanpValentaTomas
andauthored
Add support for querying code interpreter (#66)
* Add support for queryig code interpreter * Fix formatting * Ensure the sandbox close is called on exceptions * Update service/code_interpreter.py Co-authored-by: Tomas Valenta <[email protected]> * Update service/code_interpreter.py Co-authored-by: Tomas Valenta <[email protected]> * Update service/router.py Co-authored-by: Tomas Valenta <[email protected]> * Update service/code_interpreter.py Co-authored-by: Tomas Valenta <[email protected]> * Add system prompt * Format code * Bump dependencies * Minor tweaks --------- Co-authored-by: Tomas Valenta <[email protected]>
1 parent 8264094 commit 9bcefb3

File tree

10 files changed

+351
-289
lines changed

10 files changed

+351
-289
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ Input example:
4949
"encoder": {
5050
"type": "openai",
5151
"name": "text-embedding-3-small",
52-
}
52+
},
53+
"interpreter_mode": False, # Set to True if you wish to run computation Q&A with a code interpreter
5354
"session_id": "my_session_id" # keeps micro-vm sessions and enables caching
5455
}
5556
```

api/ingest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from models.ingest import RequestPayload
88
from service.embedding import EmbeddingService, get_encoder
9-
from service.ingest import handle_urls, handle_google_drive
9+
from service.ingest import handle_google_drive, handle_urls
1010
from utils.summarise import SUMMARY_SUFFIX
1111

1212
router = APIRouter()

models/ingest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from pydantic import BaseModel
55

66
from models.file import File
7-
from models.vector_database import VectorDatabase
87
from models.google_drive import GoogleDrive
8+
from models.vector_database import VectorDatabase
99

1010

1111
class EncoderEnum(str, Enum):

models/query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class RequestPayload(BaseModel):
1212
index_name: str
1313
encoder: Encoder
1414
session_id: Optional[str] = None
15+
interpreter_mode: Optional[bool] = False
1516

1617

1718
class ResponseData(BaseModel):

poetry.lock

Lines changed: 242 additions & 241 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ cmake = "^3.28.1"
2828
pypdf = "^4.0.1"
2929
docx2txt = "^0.8"
3030
python-dotenv = "^1.0.1"
31-
e2b = "^0.14.4"
31+
e2b = "^0.14.7"
3232
gunicorn = "^21.2.0"
3333
unstructured-client = "^0.18.0"
3434
unstructured = {extras = ["google-drive"], version = "^0.12.4"}

service/code_interpreter.py

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
import asyncio
22
import logging
3+
import re
4+
import textwrap
35
import time
46
from typing import List
57

8+
import pandas as pd
9+
from decouple import config
610
from e2b import Sandbox
11+
from openai import AsyncOpenAI
712

813
logging.getLogger("e2b").setLevel(logging.INFO)
914

15+
client = AsyncOpenAI(
16+
api_key=config("OPENAI_API_KEY"),
17+
)
18+
19+
SYSTEM_PROMPT = "You are a world-class python programmer that can complete any data analysis tasks by coding."
20+
1021

1122
class CodeInterpreterService:
1223
timeout = 3 * 60 # 3 minutes
@@ -61,48 +72,87 @@ def __init__(
6172
self.sandbox = self._ensure_sandbox(session_id)
6273

6374
async def __aenter__(self):
64-
if not self._is_initialized:
65-
self._is_initialized = True
66-
for file_url in self.file_urls:
67-
await self._upload_file(file_url)
75+
try:
76+
if not self._is_initialized:
77+
self._is_initialized = True
78+
for file_url in self.file_urls:
79+
await self._upload_file(file_url)
80+
except:
81+
self.self.sandbox.close()
82+
raise
6883

6984
return self
7085

7186
async def __aexit__(self, _exc_type, _exc_value, _traceback):
72-
if self.session_id:
73-
self.sandbox.keep_alive(self.timeout)
74-
self.sandbox.close()
87+
try:
88+
if self.session_id:
89+
self.sandbox.keep_alive(self.timeout)
90+
finally:
91+
self.sandbox.close()
7592

76-
def get_files_code(self):
93+
def get_dataframe(self):
7794
"""
7895
Get the code to read the files in the sandbox.
7996
This can be used for instructing the LLM how to access the loaded files.
8097
"""
81-
82-
# TODO: Add support for xslx, json
83-
files_code = "\n".join(
84-
f'df{i} = pd.read_csv("{self._get_file_path(url)}") # {url}'
85-
for i, url in enumerate(self.file_urls)
98+
# TODO: Add support for multiple dataframes
99+
df = pd.read_csv(self.file_urls[0])
100+
return df, self.file_urls[0]
101+
102+
def generate_prompt(self, query: str) -> str:
103+
df, url = self.get_dataframe()
104+
return textwrap.dedent(
105+
f"""
106+
You are provided with a following pandas dataframe (`df`):
107+
{df.info()}
108+
109+
Using the provided dataframe (`df`), update the following python code using pandas that returns the answer to question: \"{query}\"
110+
111+
This is the initial python code to be updated:
112+
113+
```python
114+
import pandas as pd
115+
116+
df = pd.read_csv("{url}")
117+
1. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
118+
2. Analyze: Conducting the actual analysis
119+
3. Output: Returning the answer as a string
120+
```
121+
"""
86122
)
87123

88-
return f"""
89-
import pandas as pd
124+
def extract_code(self, code: str) -> str:
125+
pattern = r"```(?:python)?(.*?)```"
126+
matches = re.findall(pattern, code, re.DOTALL)
127+
if matches:
128+
return matches[0].strip()
129+
return ""
90130

91-
{files_code}
92-
93-
"""
131+
async def generate_code(
132+
self,
133+
query: str,
134+
) -> str:
135+
content = self.generate_prompt(query=query)
136+
completion = await client.chat.completions.create(
137+
messages=[
138+
{
139+
"role": "system",
140+
"content": SYSTEM_PROMPT,
141+
},
142+
{
143+
"role": "user",
144+
"content": content,
145+
},
146+
],
147+
model="gpt-3.5-turbo-0125",
148+
)
149+
output = completion.choices[0].message.content
150+
return self.extract_code(code=output)
94151

95152
async def run_python(self, code: str):
96-
files_code = self.get_files_code()
97-
98-
templated_code = f"""
99-
{files_code}
100-
{code}
101-
"""
102-
103153
epoch_time = time.time()
104154
codefile_path = f"/tmp/main-{epoch_time}.py"
105-
self.sandbox.filesystem.write(codefile_path, templated_code)
155+
self.sandbox.filesystem.write(codefile_path, code)
106156
process = await asyncio.to_thread(
107157
self.sandbox.process.start_and_wait,
108158
f"python {codefile_path}",

service/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from models.file import File
2222
from models.google_drive import GoogleDrive
2323
from models.ingest import Encoder, EncoderEnum
24+
from utils.file import get_file_extension_from_url
2425
from utils.logger import logger
2526
from utils.summarise import completion
26-
from utils.file import get_file_extension_from_url
2727
from vectordbs import get_vector_service
2828

2929

service/router.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1+
from uuid import uuid4
2+
13
from decouple import config
24
from semantic_router.encoders import CohereEncoder
35
from semantic_router.layer import RouteLayer
46
from semantic_router.route import Route
57

68
from models.document import BaseDocumentChunk
79
from models.query import RequestPayload
8-
9-
# from service.code_interpreter import CodeInterpreterService
10+
from service.code_interpreter import CodeInterpreterService
1011
from service.embedding import get_encoder
1112
from utils.logger import logger
1213
from utils.summarise import SUMMARY_SUFFIX
1314
from vectordbs import BaseVectorDatabase, get_vector_service
1415

16+
STRUTURED_DATA = [".xlsx", ".csv", ".json"]
17+
1518

1619
def create_route_layer() -> RouteLayer:
1720
routes = [
@@ -35,12 +38,29 @@ async def get_documents(
3538
*, vector_service: BaseVectorDatabase, payload: RequestPayload
3639
) -> list[BaseDocumentChunk]:
3740
chunks = await vector_service.query(input=payload.input, top_k=5)
38-
3941
if not len(chunks):
4042
logger.error(f"No documents found for query: {payload.input}")
4143
return []
42-
43-
reranked_chunks = await vector_service.rerank(query=payload.input, documents=chunks)
44+
is_structured = chunks[0].metadata.get("document_type") in STRUTURED_DATA
45+
reranked_chunks = []
46+
if is_structured and payload.interpreter_mode:
47+
async with CodeInterpreterService(
48+
session_id=payload.session_id, file_urls=[chunks[0].metadata.get("doc_url")]
49+
) as service:
50+
code = await service.generate_code(query=payload.input)
51+
response = await service.run_python(code=code)
52+
output = response.stdout
53+
reranked_chunks.append(
54+
BaseDocumentChunk(
55+
id=str(uuid4()),
56+
document_id=str(uuid4()),
57+
content=output,
58+
doc_url=chunks[0].metadata.get("doc_url"),
59+
)
60+
)
61+
reranked_chunks.extend(
62+
await vector_service.rerank(query=payload.input, documents=chunks)
63+
)
4464
return reranked_chunks
4565

4666

@@ -63,15 +83,4 @@ async def query(payload: RequestPayload) -> list[BaseDocumentChunk]:
6383
encoder=encoder,
6484
)
6585

66-
# async with CodeInterpreterService(
67-
# session_id=payload.session_id,
68-
# file_urls=[
69-
# "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
70-
# ],
71-
# ) as service:
72-
# code = "df0.info()"
73-
# output = await service.run_python(code=code)
74-
# print(output.stderr)
75-
# print(output.stdout)
76-
7786
return await get_documents(vector_service=vector_service, payload=payload)

utils/file.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from urllib.parse import urlparse
21
import os
2+
from urllib.parse import urlparse
33

44

55
def get_file_extension_from_url(url: str) -> str:

0 commit comments

Comments
 (0)