Skip to content

Commit 46088d0

Browse files
committed
Move to milvus from chroma
1 parent 542a7e6 commit 46088d0

21 files changed

+660
-362
lines changed

backend/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636
"@types/ioredis": "^4.28.8",
3737
"@types/jsonwebtoken": "^8.5.8",
3838
"@types/node": "^17.0.10",
39+
"@zilliz/milvus2-sdk-node": "^2.2.7",
3940
"add": "^2.0.6",
4041
"bcrypt": "^5.0.1",
4142
"body-parser": "^1.19.1",
42-
"chromadb": "^1.3.1",
4343
"commonjs": "^0.0.1",
4444
"cors": "^2.8.5",
4545
"dotenv": "^14.2.0",

backend/src/helpers/clients/ChromaClient.ts

-17
This file was deleted.

backend/src/helpers/clients/ElasticSearchClient.ts

+13-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { Client } from '@elastic/elasticsearch';
22

3-
if (!process.env.ELASTICSEARCH_URL) {
4-
throw new Error('ELASTICSEARCH_URL is not defined');
3+
if (!process.env.ELASTICSEARCH_URL && !process.env.ELASTICSEARCH_CLOUD_ID) {
4+
throw new Error('ELASTICSEARCH_URL is not defined or ELASTICSEARCH_CLOUD_ID is not defined');
55
}
66

77
if (!process.env.ELASTICSEARCH_PASSWORD) {
@@ -12,8 +12,18 @@ if (!process.env.ELASTICSEARCH_USERNAME) {
1212
throw new Error('ELASTICSEARCH_USERNAME is not defined');
1313
}
1414

15+
const URLConnection = process.env.ELASTICSEARCH_URL !== undefined
16+
? {
17+
node: process.env.ELASTICSEARCH_URL as string,
18+
}
19+
: {
20+
cloud: {
21+
id: process.env.ELASTICSEARCH_CLOUD_ID as string,
22+
}
23+
}
24+
1525
const client = new Client({
16-
node: process.env.ELASTICSEARCH_URL,
26+
...URLConnection,
1727
auth: {
1828
password: process.env.ELASTICSEARCH_PASSWORD,
1929
username: process.env.ELASTICSEARCH_USERNAME,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import { MilvusClient, DataType } from "@zilliz/milvus2-sdk-node";
2+
3+
if (!process.env.MILVUS_URL) {
4+
throw new Error('MILVUS_URL is not defined');
5+
}
6+
7+
const EMBEDDING_DIM = 1536;
8+
9+
const BLOCK_FIELDS = [
10+
{
11+
name: 'block_id',
12+
description: 'The ID of the block (Same as the ID of the block in mongoDB)',
13+
data_type: DataType.VarChar,
14+
max_length: 24,
15+
is_primary_key: true,
16+
},
17+
{
18+
name: 'page_id',
19+
description: 'The ID of the page that the block belongs to',
20+
max_length: 24,
21+
data_type: DataType.VarChar,
22+
},
23+
{
24+
name: 'embedding',
25+
description: 'The embedding of the block (Generated by OpenAI)',
26+
data_type: DataType.FloatVector,
27+
dim: EMBEDDING_DIM,
28+
},
29+
{
30+
name: 'content',
31+
description: 'The content of the block',
32+
data_type: DataType.VarChar,
33+
max_length: 1000,
34+
},
35+
{
36+
name: 'context',
37+
description: 'The context of the block',
38+
data_type: DataType.VarChar,
39+
max_length: (24 + 2) * 11 + 2,
40+
}
41+
]
42+
43+
const milvusClient = new MilvusClient(
44+
process.env.MILVUS_URL,
45+
true,
46+
process.env.MILVUS_USERNAME,
47+
process.env.MILVUS_PASSWORD
48+
);
49+
50+
(async () => {
51+
const doesCollectionExist = await milvusClient.hasCollection({
52+
collection_name: 'blocks',
53+
});
54+
55+
if (doesCollectionExist.status.error_code !== 'Success') {
56+
throw new Error(doesCollectionExist.status.reason);
57+
}
58+
59+
if (doesCollectionExist.value) return;
60+
61+
await milvusClient.createCollection({
62+
collection_name: 'blocks',
63+
description: 'Collection for storing block embeddings',
64+
fields: BLOCK_FIELDS,
65+
});
66+
67+
await milvusClient.createIndex({
68+
collection_name: 'blocks',
69+
field_name: 'embedding',
70+
index_type: 'IVF_FLAT',
71+
params: {
72+
M: 8,
73+
efConstruction: 64
74+
}
75+
})
76+
})();
77+
78+
export default milvusClient;

backend/src/helpers/getChatResponse.ts

+40-44
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import { DataType } from '@zilliz/milvus2-sdk-node/dist/milvus';
2+
import type { ChatCompletionRequestMessage } from 'openai';
3+
14
import OpenAIClient from './clients/OpenAIClient';
2-
import ChromaClient from './clients/ChromaClient';
5+
import MilvusClient from './clients/MilvusClient';
36

4-
import type { ChatCompletionRequestMessage } from 'openai';
57

68
const RELATIVE_TEXT_COUNT = 3;
79

@@ -13,30 +15,15 @@ const PRE_PROMPT = `You are a helpful AI assistant. Use the following pieces of
1315
If you don't know the answer, just say you don't know. DO NOT try to make up an answer.
1416
Try to keep your answers helpful, short and to the point using markdown formatting.`;
1517

16-
interface QueryResponse {
17-
ids: string[][];
18-
documents: string[][];
19-
metadatas: {
20-
userID: string;
21-
context: string[];
22-
}[][];
23-
}
24-
25-
interface GetQueryResponse {
26-
ids: string[];
27-
documents: string[];
28-
metadatas: {
29-
userID: string;
30-
context: string[];
31-
}[];
32-
}
33-
3418
const getChatResponse = async (
3519
messages: ChatCompletionRequestMessage[],
3620
question: string,
37-
user: string
21+
page: string
3822
): Promise<string> => {
39-
const blockCollection = await ChromaClient.getCollection('blocks');
23+
// ~ Load the block collection
24+
await MilvusClient.loadCollection({
25+
collection_name: 'blocks',
26+
});
4027

4128
// ~ Create an embedding for the latest message
4229
const embeddings = await OpenAIClient.createEmbedding({
@@ -45,45 +32,54 @@ const getChatResponse = async (
4532
});
4633

4734
try {
48-
// ~ Query the database for the most similar message
49-
const similarMessages = await blockCollection.query(
50-
embeddings.data.data[0].embedding,
51-
RELATIVE_TEXT_COUNT,
52-
{
53-
userID: user,
35+
const similarMessages = await MilvusClient.search({
36+
collection_name: 'blocks',
37+
limit: RELATIVE_TEXT_COUNT,
38+
vector_type: DataType.FloatVector,
39+
params: {
40+
anns_field: 'block_id',
41+
topk: `${RELATIVE_TEXT_COUNT}`,
42+
metric_type: "L2",
43+
params: JSON.stringify({ nprobe: 10 }),
5444
},
55-
) as QueryResponse;
45+
vector: embeddings.data.data[0].embedding,
46+
expr: `page_id == ${page}`,
47+
output_fields: ['content', 'context']
48+
})
5649

5750
// ~ If there are no similar messages, return a default message
58-
if (!similarMessages.documents) {
51+
if (!similarMessages.results) {
5952
return 'I don\'t know what to say.';
6053
}
6154

55+
console.log(JSON.stringify(
56+
similarMessages
57+
))
58+
6259
// ~ Get the context messages
6360
const contextIDs = Array.from(
6461
new Set(
65-
similarMessages.metadatas
62+
similarMessages.results
6663
.flat()
67-
.flatMap((metadata) => metadata.context)
64+
.flatMap((metadata) => JSON.parse(metadata.context) as string[])
6865
)
69-
)
66+
);
7067

71-
const contextMessages = await blockCollection.get(
72-
contextIDs,
73-
{
74-
userID: user,
75-
}
76-
) as GetQueryResponse;
68+
const contextMessages = await MilvusClient.query({
69+
collection_name: 'blocks',
70+
expr: `block_id in [${contextIDs.join(', ')}]`,
71+
output_fields: ['content', 'block_id'],
72+
})
7773

7874
const contextMessagesMap: Record<string, string> = {};
7975

80-
contextMessages.documents.forEach((document, index) => {
81-
contextMessagesMap[contextMessages.ids[index]] = document;
76+
contextMessages.data.forEach((result) => {
77+
contextMessagesMap[result.block_id] = result.content;
8278
});
8379

84-
const context = similarMessages.metadatas
85-
.flat(2)
86-
.flatMap((metadata) => metadata.context)
80+
const context = similarMessages.results
81+
.map((result) => result.context)
82+
.flat()
8783
.map((id) => contextMessagesMap[id])
8884
.filter((message) => message)
8985
.map((message) => message.trim())

backend/src/helpers/refreshEmbeds.ts

+27-19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import ChromaClient from './clients/ChromaClient';
1+
import MilvusClient from './clients/MilvusClient';
22
import OpenAIClient from './clients/OpenAIClient';
33

44
export interface EmbedOperation {
@@ -14,15 +14,10 @@ export interface EmbedOperation {
1414
* @param page The page to update the embeds for.
1515
* @param pageData The page data of the page that is being updated.
1616
*/
17-
const refreshEmbeds = async (updates: EmbedOperation[], page: string, pageOwner: string) => {
18-
const blockCollection = await ChromaClient.getCollection('blocks');
19-
20-
await blockCollection.delete(
21-
updates.map((operation) => `${page}.${operation.id}`),
22-
{
23-
userID: pageOwner,
24-
}
25-
);
17+
const refreshEmbeds = async (updates: EmbedOperation[], page: string) => {
18+
await MilvusClient.loadCollection({
19+
collection_name: 'blocks'
20+
});
2621

2722
const updateOperations = updates.filter((update) => update.type === 'update');
2823

@@ -31,19 +26,32 @@ const refreshEmbeds = async (updates: EmbedOperation[], page: string, pageOwner:
3126
model: 'text-embedding-ada-002',
3227
});
3328

34-
const metaData = new Array(updateOperations.length)
29+
const fieldsData = new Array(updateOperations.length)
3530
.fill(undefined)
3631
.map((_, index) => ({
37-
userID: pageOwner,
38-
context: updateOperations[index].context.map((context) => `${page}-${context}`),
32+
block_id: updateOperations[index].id,
33+
page_id: page,
34+
embedding: embeddings.data.data[index].embedding,
35+
content: updateOperations[index].value,
36+
context: JSON.stringify(updateOperations[index].context),
3937
}));
4038

41-
await blockCollection.add(
42-
updateOperations.map((operation) => operation.id),
43-
embeddings.data.data.map((embedding) => embedding.embedding),
44-
metaData,
45-
updateOperations.map((operation) => operation?.value || ''),
46-
);
39+
40+
await MilvusClient.insert({
41+
collection_name: 'blocks',
42+
fields_data: fieldsData
43+
});
44+
45+
const deleteOperations = updates.filter((update) => update.type === 'delete');
46+
47+
await MilvusClient.deleteEntities({
48+
collection_name: 'blocks',
49+
expr: `block_id in [${deleteOperations.map((operation) => operation.id).join(', ')}]`,
50+
});
51+
52+
await MilvusClient.flush({
53+
collection_names: ['blocks'],
54+
});
4755
};
4856

4957
export default refreshEmbeds;

backend/src/index.ts

+6-2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@ import setupAuth from './setupAuth';
1010

1111
// -=- Connect to MongoDB with dotenv file -=-
1212
dotenv.config();
13-
mongoose.connect(process.env.MONGO_URL ?? '');
13+
mongoose.connect(
14+
process.env.MONGO_URL ?? '',
15+
).catch((err) => {
16+
console.log(err);
17+
});
1418

1519
// -=- Setup express -=-
1620
const app = express();
17-
const port = 8000;
21+
const port = 8000;
1822

1923
// -=- Setup Super Tokens -=-
2024
setupAuth();

backend/src/routes/account/chat.ts

+11-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ router.get(
1212
'/chat',
1313
verifySession(),
1414
async (req: SessionRequest, res) => {
15-
const { message, previousMessages } = req.query;
15+
const { message, previousMessages, pageID } = req.query;
1616

1717
if (typeof message !== 'string') {
1818
res.statusCode = 401;
@@ -23,6 +23,15 @@ router.get(
2323
return;
2424
}
2525

26+
if (typeof pageID !== 'string') {
27+
res.statusCode = 401;
28+
res.json({
29+
status: 'error',
30+
message: 'Please enter a pageID!',
31+
});
32+
return;
33+
}
34+
2635
if (previousMessages && typeof previousMessages !== 'string') {
2736
res.statusCode = 401;
2837
res.json({
@@ -55,7 +64,7 @@ router.get(
5564
messages = messages.slice(-10);
5665
}
5766

58-
const response = await getChatResponse(messages, message, req.session!.getUserId());
67+
const response = await getChatResponse(messages, message, pageID);
5968

6069
messages.push(
6170
{

0 commit comments

Comments
 (0)