Skip to content

Commit e2313cf

Browse files
committed
Adjusted the code according to dependency bump
1 parent 0dfbf55 commit e2313cf

File tree

10 files changed

+49
-51
lines changed

10 files changed

+49
-51
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ asyncpg==0.28.0
33
Authlib==1.2.1
44
fastapi==0.101.0
55
fastapi-pagination==0.12.7
6+
fastapi-users-db-sqlalchemy==6.0.1
67
fastapi-users[sqlalchemy]==12.1.1
78
numpy==1.25.2
89
orjson==3.9.4

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ web =
5959
Authlib==1.2.1
6060
fastapi==0.101.0
6161
fastapi-pagination==0.12.7
62+
fastapi-users-db-sqlalchemy==6.0.1
6263
fastapi-users[sqlalchemy]==12.1.1
6364
numpy==1.25.2
6465
orjson==3.9.4

src/dataset_image_annotator/api/http.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
11
import logging
2+
from contextlib import asynccontextmanager
23

34
from fastapi import FastAPI
45
from fastapi.middleware.cors import CORSMiddleware
56
from fastapi_pagination import add_pagination
7+
from starlette.applications import Starlette
68

79
from dataset_image_annotator.api.v1.endpoints import router, auth_router, users_router
810

11+
12+
logger = logging.getLogger(__name__)
913
origins = [
1014
'http://localhost',
1115
'http://localhost:3000',
1216
'http://localhost:8080',
1317
'http://localhost:5000',
1418
]
1519

16-
logger = logging.getLogger(__name__)
17-
app = FastAPI(docs_url='/api/docs', openapi_url='/api/v1/openapi.json')
20+
21+
@asynccontextmanager
22+
async def lifespan(app: Starlette):
23+
yield
24+
25+
26+
app = FastAPI(docs_url='/api/docs', openapi_url='/api/v1/openapi.json', lifespan=lifespan)
1827
app.add_middleware(
1928
CORSMiddleware,
2029
allow_origins=origins,

src/dataset_image_annotator/api/v1/endpoints.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from fastapi_users.authentication import AuthenticationBackend, JWTStrategy, CookieTransport
1212
from pydantic import Json
1313
from python3_commons.db import connect_to_db
14+
from sqlalchemy.ext.asyncio import AsyncSession
1415

1516
from dataset_image_annotator import core
1617
from dataset_image_annotator.api import users
@@ -20,8 +21,8 @@
2021
)
2122
from dataset_image_annotator.conf import settings
2223
from dataset_image_annotator.core import upload_handler
23-
from dataset_image_annotator.db import database
2424
from dataset_image_annotator.db.models import User
25+
from dataset_image_annotator.db.user_db_helpers import get_async_session
2526

2627
logger = logging.getLogger(__name__)
2728
router = APIRouter()
@@ -44,16 +45,6 @@ def get_jwt_strategy() -> JWTStrategy:
4445
users_router = fastapi_users.get_users_router(UserRead, UserUpdate, requires_verification=True)
4546

4647

47-
@router.on_event('startup')
48-
async def startup():
49-
await connect_to_db(database, settings.db_dsn)
50-
51-
52-
@router.on_event('shutdown')
53-
async def shutdown():
54-
await database.disconnect()
55-
56-
5748
def _handle_exceptions_helper(status_code, *args):
5849
if args:
5950
raise HTTPException(status_code=status_code, detail=args[0])
@@ -81,8 +72,9 @@ async def wrapper(*args, **kwargs):
8172
@router.get('/users', response_class=ORJSONResponse, tags=['Admin'])
8273
@handle_exceptions
8374
async def get_users(search: Json | None = None, order_by: str | None = None,
84-
user=Depends(get_current_superuser)) -> Page[UserItem]:
85-
return await core.get_users(database, search, order_by)
75+
user=Depends(get_current_superuser),
76+
session: AsyncSession = Depends(get_async_session)) -> Page[UserItem]:
77+
return await core.get_users(session, search, order_by)
8678

8779

8880
@router.post('/users', response_class=ORJSONResponse, tags=['Admin'])
@@ -94,18 +86,20 @@ async def create_user(new_user: UserCreate, user=Depends(get_current_superuser))
9486
@router.get('/user-groups', response_class=ORJSONResponse, tags=['Admin'])
9587
@handle_exceptions
9688
async def get_user_groups(search: Json | None = None, order_by: str | None = None,
97-
user=Depends(get_current_user)) -> Sequence[UserGroup]:
98-
return await core.get_user_groups(database, search, order_by)
89+
user=Depends(get_current_user),
90+
session: AsyncSession = Depends(get_async_session)) -> Sequence[UserGroup]:
91+
return await core.get_user_groups(session, search, order_by)
9992

10093

10194
@router.post('/raw-file', response_class=ORJSONResponse, tags=['Admin'])
10295
@handle_exceptions
103-
async def upload_raw_file(image_file: UploadFile = File(...), user=Depends(get_current_superuser)) -> bool:
96+
async def upload_raw_file(image_file: UploadFile = File(...), user=Depends(get_current_superuser),
97+
session: AsyncSession = Depends(get_async_session)) -> bool:
10498
if not image_file.filename:
10599
raise HTTPException(status_code=400, detail='Missing file')
106100

107101
try:
108-
response = await upload_handler.handle_raw_file(database, image_file)
102+
response = await upload_handler.handle_raw_file(session, image_file)
109103
except TimeoutError as e:
110104
raise HTTPException(status_code=504, detail=str(e))
111105

@@ -114,5 +108,6 @@ async def upload_raw_file(image_file: UploadFile = File(...), user=Depends(get_c
114108

115109
@router.get('/image-samples', response_class=ORJSONResponse, tags=['Images'])
116110
@handle_exceptions
117-
async def get_image_samples(search: Json | None = None, order_by: str | None = None) -> Page[ImageSampleItem]:
118-
return await core.get_image_samples(search, order_by)
111+
async def get_image_samples(session: AsyncSession = Depends(get_async_session), search: Json | None = None,
112+
order_by: str | None = None) -> Page[ImageSampleItem]:
113+
return await core.get_image_samples(session, search, order_by)

src/dataset_image_annotator/core.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from typing import Mapping, Sequence
55

66
import sqlalchemy as sa
7-
from databases.backends.postgres import Record
87
from fastapi_pagination import Page
9-
from fastapi_pagination.ext.databases import paginate
8+
from fastapi_pagination.ext.sqlalchemy import paginate
9+
from sqlalchemy.ext.asyncio import AsyncSession
1010

1111
from dataset_image_annotator.api.v1.schemas import UserItem, ImageSampleItem
1212
from dataset_image_annotator.db.helpers import get_query
@@ -17,28 +17,32 @@
1717
_ = t.gettext
1818

1919

20-
async def get_users(database, search: Mapping[str, str] | None = None, order_by: str | None = None) -> Page[UserItem]:
20+
async def get_users(session: AsyncSession, search: Mapping[str, str] | None = None,
21+
order_by: str | None = None) -> Page[UserItem]:
2122
query = sa.select([User])
22-
result = await paginate(database, query)
23+
result = await paginate(session, query)
2324

2425
return result
2526

2627

27-
async def get_user(database, user_id: str) -> Record:
28-
query = sa.select([User]).where(User.id == user_id)
28+
async def get_user(session: AsyncSession, user_id: str) -> UserItem:
29+
query = sa.select(User).where(User.id == user_id)
30+
cursor = await session.execute(query)
31+
result = cursor.scalar_one()
2932

30-
return await database.fetch_row(query)
33+
return result
3134

3235

33-
async def get_user_groups(database, search: Mapping[str, str] | None = None,
36+
async def get_user_groups(session: AsyncSession, search: Mapping[str, str] | None = None,
3437
order_by: str | None = None) -> Sequence[UserGroup]:
35-
query = sa.select([UserGroup]).order_by(UserGroup.name)
36-
result = await database.fetch_all(query)
38+
query = sa.select(UserGroup).order_by(UserGroup.name)
39+
cursor = await session.execute(query)
40+
result = cursor.scalars()
3741

3842
return result
3943

4044

41-
async def get_image_samples(database, search: Mapping[str, str] | None = None,
45+
async def get_image_samples(session: AsyncSession, search: Mapping[str, str] | None = None,
4246
order_by: str | None = None) -> Page[ImageSampleItem]:
4347
columns = {
4448
'id': (ImageSample.id, False, int, True),
@@ -47,11 +51,9 @@ async def get_image_samples(database, search: Mapping[str, str] | None = None,
4751
'location': (ImageSample.location, True, str, False),
4852
}
4953
where_clause, order_by_clause = get_query(search, order_by, columns)
50-
query = sa.select([ImageSample.id, ImageSample.location]).order_by(order_by_clause)
54+
query = sa.select(ImageSample.id, ImageSample.location).order_by(order_by_clause)
5155

5256
if where_clause is not None:
5357
query = query.where(where_clause)
5458

55-
result = await database.fetch_all(query)
56-
57-
return result
59+
return await paginate(session, query)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import logging
22
from zoneinfo import ZoneInfo
33

4-
from databases import Database
4+
from sqlalchemy.ext.asyncio import AsyncSession
55

66
from dataset_image_annotator.conf import settings
77

88
logger = logging.getLogger(__name__)
99
timezone = ZoneInfo(settings.timezone)
1010

1111

12-
async def handle_raw_file(database: Database, image_file):
12+
async def handle_raw_file(session: AsyncSession, image_file):
1313
image_file_body = await image_file.read()
1414

1515
return True

src/dataset_image_annotator/db/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
import databases
21
from sqlalchemy import MetaData
32
from sqlalchemy.ext.declarative import declarative_base
43

5-
from dataset_image_annotator.conf import settings
6-
74
metadata = MetaData()
85
Base = declarative_base(metadata=metadata)
9-
database = databases.Database(settings.db_dsn)
106

117

128
async def is_healthy(pg) -> bool:

src/dataset_image_annotator/db/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ class ImageSampleAnnotation(BaseDBModel, Base):
6161
color = Column(String(32))
6262
votes = Column(Integer, nullable=False, default=0)
6363

64-
6564
__table_args__ = (
6665
UniqueConstraint('user_id', 'image_sample_id', name='uq_image_sample_annotation_item'),
6766
)

src/dataset_image_annotator/db/user_db_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
from fastapi import Depends
55
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
66
from sqlalchemy import MetaData
7-
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
8-
from sqlalchemy.orm import declarative_base, sessionmaker
7+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
8+
from sqlalchemy.orm import declarative_base
99

1010
from dataset_image_annotator.conf import settings
1111
from dataset_image_annotator.db.models import User
1212

1313
metadata = MetaData()
1414
Base = declarative_base(metadata=metadata)
1515
engine = create_async_engine(settings.db_dsn)
16-
async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
16+
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
1717

1818

1919
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:

src/dataset_image_annotator/jobs.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from dataset_image_annotator.api.users import get_user_manager_context
88
from dataset_image_annotator.api.v1.schemas import UserCreate
99
from dataset_image_annotator.conf import settings
10-
from dataset_image_annotator.db import database
1110
from dataset_image_annotator.db.user_db_helpers import get_async_session_context, get_user_db_context
1211

1312
logging.config.dictConfig({
@@ -55,8 +54,6 @@
5554

5655

5756
async def create_superuser():
58-
await database.connect()
59-
6057
try:
6158
async with get_async_session_context() as session:
6259
async with get_user_db_context(session) as user_db:
@@ -73,8 +70,6 @@ async def create_superuser():
7370
logger.info(f'User created: {settings.bootstrap_user_email}')
7471
except UserAlreadyExists:
7572
logger.warning(f'User already exists: {settings.bootstrap_user_email}')
76-
finally:
77-
await database.disconnect()
7873

7974

8075
def get_parsed_args():

0 commit comments

Comments
 (0)