diff --git a/core/data/database.py b/core/data/database.py index 170665e..fb36ebd 100644 --- a/core/data/database.py +++ b/core/data/database.py @@ -1,12 +1,11 @@ import logging import os import secrets -import ssl import string import warnings from hashlib import sha256 from logging.handlers import TimedRotatingFileHandler -from typing import Any, ClassVar, Optional +from typing import ClassVar, Optional import alembic.config import bcrypt @@ -17,6 +16,7 @@ from sqlalchemy.ext.asyncio import ( AsyncSession, create_async_engine, ) +from sqlalchemy.orm import sessionmaker from core.config import CoreConfig from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata @@ -25,7 +25,7 @@ from core.utils import MISSING, Utils class Data: engine: ClassVar[AsyncEngine] = MISSING - session: ClassVar[AsyncSession] = MISSING + session: ClassVar["sessionmaker[AsyncSession]"] = MISSING user: ClassVar[UserData] = MISSING arcade: ClassVar[ArcadeData] = MISSING card: ClassVar[CardData] = MISSING @@ -53,7 +53,7 @@ class Data: self.__engine = Data.engine if Data.session is MISSING: - Data.session = AsyncSession(Data.engine, expire_on_commit=False) + Data.session = sessionmaker(Data.engine, expire_on_commit=False, class_=AsyncSession) if Data.user is MISSING: Data.user = UserData(self.config, self.session) diff --git a/core/data/schema/base.py b/core/data/schema/base.py index cb44272..80ab74b 100644 --- a/core/data/schema/base.py +++ b/core/data/schema/base.py @@ -9,6 +9,7 @@ from sqlalchemy.engine import Row from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import sessionmaker from sqlalchemy.schema import ForeignKey from sqlalchemy.sql import func, text from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String @@ -38,7 +39,7 @@ event_log: Table = Table( class BaseData: - def __init__(self, cfg: CoreConfig, conn: AsyncSession) -> None: + def __init__(self, cfg: CoreConfig, conn: "sessionmaker[AsyncSession]") -> None: self.config = cfg self.conn = conn self.logger = logging.getLogger("database") @@ -46,21 +47,10 @@ class BaseData: async def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]: res = None - try: - self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}") - res = await self.conn.execute(text(sql), opts) - - except SQLAlchemyError as e: - self.logger.error(f"SQLAlchemy error {e}") - return None - - except UnicodeEncodeError as e: - self.logger.error(f"UnicodeEncodeError error {e}") - return None - - except Exception: + async with self.conn() as session: try: - res = await self.conn.execute(sql, opts) + self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}") + res = await session.execute(text(sql), opts) except SQLAlchemyError as e: self.logger.error(f"SQLAlchemy error {e}") @@ -71,8 +61,20 @@ class BaseData: return None except Exception: - self.logger.error(f"Unknown error") - raise + try: + res = await session.execute(sql, opts) + + except SQLAlchemyError as e: + self.logger.error(f"SQLAlchemy error {e}") + return None + + except UnicodeEncodeError as e: + self.logger.error(f"UnicodeEncodeError error {e}") + return None + + except Exception: + self.logger.error(f"Unknown error") + raise return res