Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace synchronous database driver with async #1131

Draft
wants to merge 1 commit into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def assemble_db_connection(cls, v: Optional[str], info: ValidationInfo) -> Any:
if isinstance(v, str):
return v
pg_url = PostgresDsn.build(
scheme="postgresql",
scheme="postgresql+psycopg",
username=info.data.get("FMTM_DB_USER"),
password=info.data.get("FMTM_DB_PASSWORD"),
host=info.data.get("FMTM_DB_HOST"),
Expand Down
73 changes: 62 additions & 11 deletions src/backend/app/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,73 @@

"""Config for the FMTM database connection."""

from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
import contextlib
from typing import AsyncIterator

from app.config import settings
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import declarative_base

engine = create_engine(settings.FMTM_DB_URL.unicode_string())
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
from app.config import settings

Base = declarative_base()
FmtmMetadata = Base.metadata


def get_db():
class DatabaseSessionManager:
"""Manage the database sessions for FastAPI."""

def __init__(self, host: str):
"""Init DatabaseSessionManager object."""
self._engine = create_async_engine(host)
self._sessionmaker = async_sessionmaker(autocommit=False, bind=self._engine)

async def close(self):
"""Close database engine and session."""
if self._engine is None:
raise Exception("DatabaseSessionManager is not initialized")
await self._engine.dispose()

self._engine = None
self._sessionmaker = None

@contextlib.asynccontextmanager
async def connect(self) -> AsyncIterator[AsyncConnection]:
"""Connect to database engine."""
if self._engine is None:
raise Exception("DatabaseSessionManager is not initialized")

async with self._engine.begin() as connection:
try:
yield connection
except Exception:
await connection.rollback()
raise

@contextlib.asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]:
"""Create a session, handling error rollback and closing."""
if self._sessionmaker is None:
raise Exception("DatabaseSessionManager is not initialized")

session = self._sessionmaker()
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()


sessionmanager = DatabaseSessionManager(settings.FMTM_DB_URL.unicode_string())


async def get_db():
"""Create SQLAlchemy DB session."""
db = SessionLocal()
try:
yield db
finally:
db.close()
async with sessionmanager.session() as session:
yield session
12 changes: 9 additions & 3 deletions src/backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
from app.auth import auth_routes
from app.central import central_routes
from app.config import settings
from app.db.database import get_db
from app.db.database import get_db, sessionmanager
from app.organisations import organisation_routes
from app.projects import project_routes
from app.projects.project_crud import read_xlsforms
from app.projects.project_crud import insert_xlsforms_in_db
from app.submissions import submission_routes
from app.tasks import tasks_routes
from app.users import user_routes
Expand All @@ -53,13 +53,19 @@
async def lifespan(app: FastAPI):
"""FastAPI startup/shutdown event."""
log.debug("Starting up FastAPI server.")

log.debug("Reading XLSForms from DB.")
await read_xlsforms(next(get_db()), xlsforms_path)
async for db_session in get_db():
async with db_session as session:
await insert_xlsforms_in_db(session, xlsforms_path)

yield

# Shutdown events
log.debug("Shutting down FastAPI server.")
if sessionmanager._engine is not None:
# Close the DB connection
await sessionmanager.close()


def get_application() -> FastAPI:
Expand Down
8 changes: 4 additions & 4 deletions src/backend/app/projects/project_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,11 +960,11 @@ async def update_project_with_zip(
# ---------------------------


async def read_xlsforms(
async def insert_xlsforms_in_db(
db: Session,
directory: str,
):
"""Read the list of XLSForms from the disk."""
"""Read the list of XLSForms from the disk and insert in database."""
xlsforms = list()
package_name = "osm_fieldwork"
package_files = pkg_files(package_name)
Expand Down Expand Up @@ -998,8 +998,8 @@ async def read_xlsforms(
sql = ins.on_conflict_do_update(
constraint="xlsforms_title_key", set_=dict(title=name, xls=data)
)
db.execute(sql)
db.commit()
await db.execute(sql)
await db.commit()

return xlsforms

Expand Down
16 changes: 15 additions & 1 deletion src/backend/pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dependencies = [
"osm-fieldwork==0.4.1",
"osm-rawdata==0.1.7",
"fmtm-splitter==1.0.0rc0",
"psycopg>=3.1.17",
]
requires-python = ">=3.10"
readme = "../../README.md"
Expand Down
70 changes: 33 additions & 37 deletions src/backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,22 @@

import logging
import os
from contextlib import ExitStack
from typing import Any, Generator

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from loguru import logger as log
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import create_database, database_exists

from app.central import central_crud
from app.config import settings
from app.db.database import Base, get_db
from app.db.database import get_db, sessionmanager
from app.db.db_models import DbOrganisation, DbUser
from app.main import get_application
from app.projects import project_crud
from app.projects.project_schemas import ODKCentral, ProjectInfo, ProjectUpload
from app.users.user_schemas import User

engine = create_engine(settings.FMTM_DB_URL.unicode_string())
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)


def pytest_configure(config):
"""Configure pytest runs."""
Expand All @@ -53,35 +46,47 @@ def pytest_configure(config):
@pytest.fixture(autouse=True)
def app() -> Generator[FastAPI, Any, None]:
"""Get the FastAPI test server."""
yield get_application()
with ExitStack():
# Use ExitStack to correctly close and cleanup resources
yield get_application()


@pytest.fixture(scope="function")
def client(app, db):
"""The FastAPI test server."""
app.dependency_overrides[get_db] = lambda: db

with TestClient(app) as c:
yield c

@pytest.fixture(scope="session")
def db_engine():
"""The SQLAlchemy database engine to init."""
engine = create_engine(settings.FMTM_DB_URL.unicode_string())
if not database_exists:
create_database(engine.url)

Base.metadata.create_all(bind=engine)
yield engine
@pytest.fixture(scope="function", autouse=True)
async def transactional_session():
"""Each test function is a clean slate."""
async with sessionmanager.session() as session:
try:
await session.begin()
yield session
finally:
# Roll back the outer transaction
await session.rollback()


@pytest.fixture(scope="function")
def db(db_engine):
"""Database session using db_engine."""
connection = db_engine.connect()
async def db(transactional_session):
"""SQLAlchemy session init."""
yield transactional_session

# begin a non-ORM transaction
connection.begin()

# bind an individual Session to the connection
db = TestingSessionLocal(bind=connection)
@pytest.fixture(scope="function", autouse=True)
async def session_override(app, db_session):
"""Replace get_db with session override."""

yield db
async def get_db_session_override():
"""Yield the session override."""
yield db[0]

db.rollback()
connection.close()
app.dependency_overrides[get_db] = get_db_session_override


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -178,12 +183,3 @@ async def project(db, user, organisation):
# }
# log.debug(f"get_ids return: {data}")
# return data


@pytest.fixture(scope="function")
def client(app, db):
"""The FastAPI test server."""
app.dependency_overrides[get_db] = lambda: db

with TestClient(app) as c:
yield c