diff --git a/libs/community/langchain_community/indexes/_sql_record_manager.py b/libs/community/langchain_community/indexes/_sql_record_manager.py index 544e828df2a82..f70a1dc4f4864 100644 --- a/libs/community/langchain_community/indexes/_sql_record_manager.py +++ b/libs/community/langchain_community/indexes/_sql_record_manager.py @@ -16,7 +16,17 @@ import contextlib import decimal import uuid -from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence, Union +from typing import ( + Any, + AsyncGenerator, + Dict, + Generator, + List, + Optional, + Sequence, + Union, + cast, +) from sqlalchemy import ( URL, @@ -175,10 +185,10 @@ def _make_session(self) -> Generator[Session, None, None]: async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]: """Create a session and close it after use.""" - if not isinstance(self.session_factory, async_sessionmaker): + if not isinstance(self.engine, AsyncEngine): raise AssertionError("This method is not supported for sync engines.") - async with self.session_factory() as session: + async with cast(AsyncSession, self.session_factory()) as session: yield session def get_time(self) -> float: