Skip to content

Commit 2ffe374

Browse files
authored
feat: add index batch size setting for lightrag (#720) #none
1 parent 79a5f06 commit 2ffe374

File tree

5 files changed

+61
-25
lines changed

5 files changed

+61
-25
lines changed

Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ RUN --mount=type=ssh \
8686
ENV USE_LIGHTRAG=true
8787
RUN --mount=type=ssh \
8888
--mount=type=cache,target=/root/.cache/pip \
89-
pip install aioboto3 nano-vectordb ollama xxhash "lightrag-hku<=0.0.8"
89+
pip install aioboto3 nano-vectordb ollama xxhash "lightrag-hku<=1.3.0"
9090

9191
RUN --mount=type=ssh \
9292
--mount=type=cache,target=/root/.cache/pip \

libs/ktem/ktem/index/file/graph/light_graph_index.py

+4
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
5252
pipeline.prompts = striped_settings
5353
# set collection graph id
5454
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
55+
# set index batch size
56+
pipeline.index_batch_size = striped_settings.get(
57+
"batch_size", pipeline.index_batch_size
58+
)
5559
return pipeline
5660

5761
def get_retriever_pipelines(

libs/ktem/ktem/index/file/graph/lightrag_pipelines.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
243243

244244
prompts: dict[str, str] = {}
245245
collection_graph_id: str
246+
index_batch_size: int = INDEX_BATCHSIZE
246247

247248
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
248249
if not settings.USE_GLOBAL_GRAPHRAG:
@@ -283,18 +284,31 @@ def get_user_settings(cls) -> dict:
283284
from lightrag.prompt import PROMPTS
284285

285286
blacklist_keywords = ["default", "response", "process"]
286-
return {
287-
prompt_name: {
288-
"name": f"Prompt for '{prompt_name}'",
289-
"value": content,
290-
"component": "text",
287+
settings_dict = {
288+
"batch_size": {
289+
"name": (
290+
"Index batch size " "(reduce if you have rate limit issues)"
291+
),
292+
"value": INDEX_BATCHSIZE,
293+
"component": "number",
291294
}
292-
for prompt_name, content in PROMPTS.items()
293-
if all(
294-
keyword not in prompt_name.lower() for keyword in blacklist_keywords
295-
)
296-
and isinstance(content, str)
297295
}
296+
settings_dict.update(
297+
{
298+
prompt_name: {
299+
"name": f"Prompt for '{prompt_name}'",
300+
"value": content,
301+
"component": "text",
302+
}
303+
for prompt_name, content in PROMPTS.items()
304+
if all(
305+
keyword not in prompt_name.lower()
306+
for keyword in blacklist_keywords
307+
)
308+
and isinstance(content, str)
309+
}
310+
)
311+
return settings_dict
298312
except ImportError as e:
299313
print(e)
300314
return {}
@@ -359,8 +373,8 @@ def call_graphrag_index(self, graph_id: str, docs: list[Document]):
359373
),
360374
)
361375

362-
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
363-
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
376+
for doc_id in range(0, len(all_docs), self.index_batch_size):
377+
cur_docs = all_docs[doc_id : doc_id + self.index_batch_size]
364378
combined_doc = "\n".join(cur_docs)
365379

366380
# Use insert for incremental updates

libs/ktem/ktem/index/file/graph/nano_graph_index.py

+4
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
5252
pipeline.prompts = striped_settings
5353
# set collection graph id
5454
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
55+
# set index batch size
56+
pipeline.index_batch_size = striped_settings.get(
57+
"batch_size", pipeline.index_batch_size
58+
)
5559
return pipeline
5660

5761
def get_retriever_pipelines(

libs/ktem/ktem/index/file/graph/nano_pipelines.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
239239

240240
prompts: dict[str, str] = {}
241241
collection_graph_id: str
242+
index_batch_size: int = INDEX_BATCHSIZE
242243

243244
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
244245
if not settings.USE_GLOBAL_GRAPHRAG:
@@ -279,18 +280,31 @@ def get_user_settings(cls) -> dict:
279280
from nano_graphrag.prompt import PROMPTS
280281

281282
blacklist_keywords = ["default", "response", "process"]
282-
return {
283-
prompt_name: {
284-
"name": f"Prompt for '{prompt_name}'",
285-
"value": content,
286-
"component": "text",
283+
settings_dict = {
284+
"batch_size": {
285+
"name": (
286+
"Index batch size " "(reduce if you have rate limit issues)"
287+
),
288+
"value": INDEX_BATCHSIZE,
289+
"component": "number",
287290
}
288-
for prompt_name, content in PROMPTS.items()
289-
if all(
290-
keyword not in prompt_name.lower() for keyword in blacklist_keywords
291-
)
292-
and isinstance(content, str)
293291
}
292+
settings_dict.update(
293+
{
294+
prompt_name: {
295+
"name": f"Prompt for '{prompt_name}'",
296+
"value": content,
297+
"component": "text",
298+
}
299+
for prompt_name, content in PROMPTS.items()
300+
if all(
301+
keyword not in prompt_name.lower()
302+
for keyword in blacklist_keywords
303+
)
304+
and isinstance(content, str)
305+
}
306+
)
307+
return settings_dict
294308
except ImportError as e:
295309
print(e)
296310
return {}
@@ -355,8 +369,8 @@ def call_graphrag_index(self, graph_id: str, docs: list[Document]):
355369
),
356370
)
357371

358-
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
359-
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
372+
for doc_id in range(0, len(all_docs), self.index_batch_size):
373+
cur_docs = all_docs[doc_id : doc_id + self.index_batch_size]
360374
combined_doc = "\n".join(cur_docs)
361375

362376
# Use insert for incremental updates

0 commit comments

Comments
 (0)