1
0
mirror of synced 2025-02-15 18:02:39 +01:00

Merge pull request '[database] fix invalid transaction being left open' (#187) from beerpsi/artemis:fix/invalid-transaction into develop

Reviewed-on: https://gitea.tendokyu.moe/Hay1tsme/artemis/pulls/187
This commit is contained in:
Hay1tsme 2024-12-19 06:14:45 +00:00
commit 6a305d2514
2 changed files with 23 additions and 21 deletions

View File

@ -1,12 +1,11 @@
import logging import logging
import os import os
import secrets import secrets
import ssl
import string import string
import warnings import warnings
from hashlib import sha256 from hashlib import sha256
from logging.handlers import TimedRotatingFileHandler from logging.handlers import TimedRotatingFileHandler
from typing import Any, ClassVar, Optional from typing import ClassVar, Optional
import alembic.config import alembic.config
import bcrypt import bcrypt
@ -17,6 +16,7 @@ from sqlalchemy.ext.asyncio import (
AsyncSession, AsyncSession,
create_async_engine, create_async_engine,
) )
from sqlalchemy.orm import sessionmaker
from core.config import CoreConfig from core.config import CoreConfig
from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata
@ -25,7 +25,7 @@ from core.utils import MISSING, Utils
class Data: class Data:
engine: ClassVar[AsyncEngine] = MISSING engine: ClassVar[AsyncEngine] = MISSING
session: ClassVar[AsyncSession] = MISSING session: ClassVar["sessionmaker[AsyncSession]"] = MISSING
user: ClassVar[UserData] = MISSING user: ClassVar[UserData] = MISSING
arcade: ClassVar[ArcadeData] = MISSING arcade: ClassVar[ArcadeData] = MISSING
card: ClassVar[CardData] = MISSING card: ClassVar[CardData] = MISSING
@ -53,7 +53,7 @@ class Data:
self.__engine = Data.engine self.__engine = Data.engine
if Data.session is MISSING: if Data.session is MISSING:
Data.session = AsyncSession(Data.engine, expire_on_commit=False) Data.session = sessionmaker(Data.engine, expire_on_commit=False, class_=AsyncSession)
if Data.user is MISSING: if Data.user is MISSING:
Data.user = UserData(self.config, self.session) Data.user = UserData(self.config, self.session)

View File

@ -9,6 +9,7 @@ from sqlalchemy.engine import Row
from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import ForeignKey from sqlalchemy.schema import ForeignKey
from sqlalchemy.sql import func, text from sqlalchemy.sql import func, text
from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String
@ -38,7 +39,7 @@ event_log: Table = Table(
class BaseData: class BaseData:
def __init__(self, cfg: CoreConfig, conn: AsyncSession) -> None: def __init__(self, cfg: CoreConfig, conn: "sessionmaker[AsyncSession]") -> None:
self.config = cfg self.config = cfg
self.conn = conn self.conn = conn
self.logger = logging.getLogger("database") self.logger = logging.getLogger("database")
@ -46,21 +47,10 @@ class BaseData:
async def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]: async def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]:
res = None res = None
try: async with self.conn() as session:
self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}")
res = await self.conn.execute(text(sql), opts)
except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}")
return None
except UnicodeEncodeError as e:
self.logger.error(f"UnicodeEncodeError error {e}")
return None
except Exception:
try: try:
res = await self.conn.execute(sql, opts) self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}")
res = await session.execute(text(sql), opts)
except SQLAlchemyError as e: except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}") self.logger.error(f"SQLAlchemy error {e}")
@ -71,8 +61,20 @@ class BaseData:
return None return None
except Exception: except Exception:
self.logger.error(f"Unknown error") try:
raise res = await session.execute(sql, opts)
except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}")
return None
except UnicodeEncodeError as e:
self.logger.error(f"UnicodeEncodeError error {e}")
return None
except Exception:
self.logger.error(f"Unknown error")
raise
return res return res