Skip to content

Commit fbb1c41

Browse files
add async to memory and postprocessor base classes (#18438)
1 parent 64f6048 commit fbb1c41

File tree

8 files changed

+85
-27
lines changed

8 files changed

+85
-27
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
1111
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
1212

1313
test: ## Run tests via pants
14-
pants --level=info --no-local-cache --changed-since=origin/main --changed-dependents=transitive --no-test-use-coverage test
14+
pants --level=error --no-local-cache --changed-since=origin/main --changed-dependents=transitive --no-test-use-coverage test
1515

1616
test-core: ## Run tests via pants
1717
pants --no-local-cache test llama-index-core/::

llama-index-core/llama_index/core/agent/workflow/codeact_agent.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,7 @@ async def finalize(
344344
Adds all in-progress messages to memory.
345345
"""
346346
scratchpad: List[ChatMessage] = await ctx.get(self.scratchpad_key, default=[])
347-
for msg in scratchpad:
348-
await memory.aput(msg)
347+
await memory.aput_messages(scratchpad)
349348

350349
# reset scratchpad
351350
await ctx.set(self.scratchpad_key, [])

llama-index-core/llama_index/core/agent/workflow/function_agent.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ async def finalize(
122122
Adds all in-progress messages to memory.
123123
"""
124124
scratchpad: List[ChatMessage] = await ctx.get(self.scratchpad_key, default=[])
125-
for msg in scratchpad:
126-
await memory.aput(msg)
125+
await memory.aput_messages(scratchpad)
127126

128127
# reset scratchpad
129128
await ctx.set(self.scratchpad_key, [])

llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ async def init_run(self, ctx: Context, ev: StartEvent) -> AgentInput:
308308

309309
# First set chat history if it exists
310310
if chat_history:
311-
memory.set(chat_history)
311+
await memory.aset(chat_history)
312312

313313
# Then add user message if it exists
314314
if user_msg:
@@ -335,7 +335,7 @@ async def init_run(self, ctx: Context, ev: StartEvent) -> AgentInput:
335335
raise ValueError("Must provide either user_msg or chat_history")
336336

337337
# Get all messages from memory
338-
input_messages = memory.get()
338+
input_messages = await memory.aget()
339339

340340
# send to the current agent
341341
current_agent_name: str = await ctx.get("current_agent_name")
@@ -526,7 +526,7 @@ async def aggregate_tool_results(
526526
return StopEvent(result=result)
527527

528528
user_msg_str = await ctx.get("user_msg_str")
529-
input_messages = memory.get(input=user_msg_str)
529+
input_messages = await memory.aget(input=user_msg_str)
530530

531531
# get this again, in case it changed
532532
agent_name = await ctx.get("current_agent_name")

llama-index-core/llama_index/core/memory/types.py

+46-11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from abc import abstractmethod
23
from typing import Any, List, Optional
34

@@ -11,10 +12,7 @@
1112

1213

1314
class BaseMemory(BaseComponent):
14-
"""Base class for all memory types.
15-
16-
NOTE: The interface for memory is not yet finalized and is subject to change.
17-
"""
15+
"""Base class for all memory types."""
1816

1917
@classmethod
2018
def class_name(cls) -> str:
@@ -33,17 +31,27 @@ def from_defaults(
3331
def get(self, input: Optional[str] = None, **kwargs: Any) -> List[ChatMessage]:
3432
"""Get chat history."""
3533

34+
async def aget(
35+
self, input: Optional[str] = None, **kwargs: Any
36+
) -> List[ChatMessage]:
37+
"""Get chat history."""
38+
return await asyncio.to_thread(self.get, input=input, **kwargs)
39+
3640
@abstractmethod
3741
def get_all(self) -> List[ChatMessage]:
3842
"""Get all chat history."""
3943

44+
async def aget_all(self) -> List[ChatMessage]:
45+
"""Get all chat history."""
46+
return await asyncio.to_thread(self.get_all)
47+
4048
@abstractmethod
4149
def put(self, message: ChatMessage) -> None:
4250
"""Put chat history."""
4351

4452
async def aput(self, message: ChatMessage) -> None:
4553
"""Put chat history."""
46-
self.put(message)
54+
await asyncio.to_thread(self.put, message)
4755

4856
def put_messages(self, messages: List[ChatMessage]) -> None:
4957
"""Put chat history."""
@@ -52,23 +60,27 @@ def put_messages(self, messages: List[ChatMessage]) -> None:
5260

5361
async def aput_messages(self, messages: List[ChatMessage]) -> None:
5462
"""Put chat history."""
55-
for message in messages:
56-
await self.aput(message)
63+
await asyncio.to_thread(self.put_messages, messages)
5764

5865
@abstractmethod
5966
def set(self, messages: List[ChatMessage]) -> None:
6067
"""Set chat history."""
6168

69+
async def aset(self, messages: List[ChatMessage]) -> None:
70+
"""Set chat history."""
71+
await asyncio.to_thread(self.set, messages)
72+
6273
@abstractmethod
6374
def reset(self) -> None:
6475
"""Reset chat history."""
6576

77+
async def areset(self) -> None:
78+
"""Reset chat history."""
79+
await asyncio.to_thread(self.reset)
80+
6681

6782
class BaseChatStoreMemory(BaseMemory):
68-
"""Base class for any .
69-
70-
NOTE: The interface for memory is not yet finalized and is subject to change.
71-
"""
83+
"""Base class for storing multi-tenant chat history."""
7284

7385
chat_store: SerializeAsAny[BaseChatStore] = Field(default_factory=SimpleChatStore)
7486
chat_store_key: str = Field(default=DEFAULT_CHAT_STORE_KEY)
@@ -98,6 +110,20 @@ def get_all(self) -> List[ChatMessage]:
98110
"""Get all chat history."""
99111
return self.chat_store.get_messages(self.chat_store_key)
100112

113+
async def aget_all(self) -> List[ChatMessage]:
114+
"""Get all chat history."""
115+
return await self.chat_store.aget_messages(self.chat_store_key)
116+
117+
def get(self, input: Optional[str] = None, **kwargs: Any) -> List[ChatMessage]:
118+
"""Get chat history."""
119+
return self.chat_store.get_messages(self.chat_store_key, **kwargs)
120+
121+
async def aget(
122+
self, input: Optional[str] = None, **kwargs: Any
123+
) -> List[ChatMessage]:
124+
"""Get chat history."""
125+
return await self.chat_store.aget_messages(self.chat_store_key, **kwargs)
126+
101127
def put(self, message: ChatMessage) -> None:
102128
"""Put chat history."""
103129
# ensure everything is serialized
@@ -112,6 +138,15 @@ def set(self, messages: List[ChatMessage]) -> None:
112138
"""Set chat history."""
113139
self.chat_store.set_messages(self.chat_store_key, messages)
114140

141+
async def aset(self, messages: List[ChatMessage]) -> None:
142+
"""Set chat history."""
143+
# ensure everything is serialized
144+
await self.chat_store.aset_messages(self.chat_store_key, messages)
145+
115146
def reset(self) -> None:
116147
"""Reset chat history."""
117148
self.chat_store.delete_messages(self.chat_store_key)
149+
150+
async def areset(self) -> None:
151+
"""Reset chat history."""
152+
await self.chat_store.adelete_messages(self.chat_store_key)

llama-index-core/llama_index/core/postprocessor/types.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from abc import ABC, abstractmethod
23
from typing import Any, Dict, List, Optional
34

@@ -61,6 +62,29 @@ def _postprocess_nodes(
6162
) -> List[NodeWithScore]:
6263
"""Postprocess nodes."""
6364

65+
async def apostprocess_nodes(
66+
self,
67+
nodes: List[NodeWithScore],
68+
query_bundle: Optional[QueryBundle] = None,
69+
query_str: Optional[str] = None,
70+
) -> List[NodeWithScore]:
71+
"""Postprocess nodes (async)."""
72+
if query_str is not None and query_bundle is not None:
73+
raise ValueError("Cannot specify both query_str and query_bundle")
74+
elif query_str is not None:
75+
query_bundle = QueryBundle(query_str)
76+
else:
77+
pass
78+
return await self._apostprocess_nodes(nodes, query_bundle)
79+
80+
async def _apostprocess_nodes(
81+
self,
82+
nodes: List[NodeWithScore],
83+
query_bundle: Optional[QueryBundle] = None,
84+
) -> List[NodeWithScore]:
85+
"""Postprocess nodes (async)."""
86+
return await asyncio.to_thread(self._postprocess_nodes, nodes, query_bundle)
87+
6488
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
6589
"""As query component."""
6690
return PostprocessorComponent(postprocessor=self)

llama-index-core/llama_index/core/storage/chat_store/base.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Base interface class for storing chat history per user."""
2+
import asyncio
23
from abc import abstractmethod
34
from typing import List, Optional
45

@@ -49,28 +50,28 @@ def get_keys(self) -> List[str]:
4950

5051
async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None:
5152
"""Async version of Get messages for a key."""
52-
self.set_messages(key, messages)
53+
await asyncio.to_thread(self.set_messages, key, messages)
5354

5455
async def aget_messages(self, key: str) -> List[ChatMessage]:
5556
"""Async version of Get messages for a key."""
56-
return self.get_messages(key)
57+
return await asyncio.to_thread(self.get_messages, key)
5758

5859
async def async_add_message(self, key: str, message: ChatMessage) -> None:
5960
"""Async version of Add a message for a key."""
60-
self.add_message(key, message)
61+
await asyncio.to_thread(self.add_message, key, message)
6162

6263
async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]:
6364
"""Async version of Delete messages for a key."""
64-
return self.delete_messages(key)
65+
return await asyncio.to_thread(self.delete_messages, key)
6566

6667
async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
6768
"""Async version of Delete specific message for a key."""
68-
return self.delete_message(key, idx)
69+
return await asyncio.to_thread(self.delete_message, key, idx)
6970

7071
async def adelete_last_message(self, key: str) -> Optional[ChatMessage]:
7172
"""Async version of Delete last message for a key."""
72-
return self.delete_last_message(key)
73+
return await asyncio.to_thread(self.delete_last_message, key)
7374

7475
async def aget_keys(self) -> List[str]:
7576
"""Async version of Get all keys."""
76-
return self.get_keys()
77+
return await asyncio.to_thread(self.get_keys)

llama-index-core/tests/agent/workflow/test_code_act_agent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,4 @@ async def test_code_act_agent_tool_handling(
166166
# Finalize
167167
final_output = await agent.finalize(ctx, output, mock_memory)
168168
assert isinstance(final_output, AgentOutput)
169-
assert mock_memory.aput.called # Verify memory was updated
169+
assert mock_memory.aput_messages.called # Verify memory was updated

0 commit comments

Comments
 (0)