From da5e3fe52f28c855a81c3fe8d8fbec97d734417b Mon Sep 17 00:00:00 2001 From: Jennifer Taylor Date: Thu, 26 Sep 2024 00:04:23 +0000 Subject: [PATCH] Finally upgrade to SQLAlchemy 2.0, fix a bunch of type issues relating to that from Python 3.12. --- bemani/data/data.py | 1 - bemani/data/mysql/api.py | 10 +++---- bemani/data/mysql/base.py | 8 ++++-- bemani/data/mysql/game.py | 16 +++++------ bemani/data/mysql/lobby.py | 8 +++--- bemani/data/mysql/machine.py | 24 ++++++++-------- bemani/data/mysql/music.py | 26 ++++++++--------- bemani/data/mysql/network.py | 8 +++--- bemani/data/mysql/user.py | 54 ++++++++++++++++++------------------ bemani/tests/helpers.py | 3 ++ bemani/utils/read.py | 31 ++++++++++----------- requirements.txt | 2 +- 12 files changed, 97 insertions(+), 94 deletions(-) diff --git a/bemani/data/data.py b/bemani/data/data.py index 9401989..f926b10 100644 --- a/bemani/data/data.py +++ b/bemani/data/data.py @@ -97,7 +97,6 @@ class Data: session_factory = sessionmaker( bind=config.database.engine, autoflush=True, - autocommit=True, ) self.__config = config self.__session = scoped_session(session_factory) diff --git a/bemani/data/mysql/api.py b/bemani/data/mysql/api.py index 263ae81..4f60893 100644 --- a/bemani/data/mysql/api.py +++ b/bemani/data/mysql/api.py @@ -55,7 +55,7 @@ class APIData(APIProviderInterface, BaseData): result["name"], result["token"], ) - for result in cursor + for result in cursor.mappings() ] def validate_client(self, token: str) -> bool: @@ -70,7 +70,7 @@ class APIData(APIProviderInterface, BaseData): """ sql = "SELECT count(*) AS count FROM client WHERE token = :token" cursor = self.execute(sql, {"token": token}) - return cursor.fetchone()["count"] == 1 + return cursor.mappings().fetchone()["count"] == 1 # type: ignore def create_client(self, name: str) -> int: """ @@ -109,7 +109,7 @@ class APIData(APIProviderInterface, BaseData): # Couldn't find an entry with this ID return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return Client( clientid, result["timestamp"], @@ -159,7 +159,7 @@ class APIData(APIProviderInterface, BaseData): sql = "SELECT id, timestamp, uri, token, config FROM server ORDER BY timestamp ASC" cursor = self.execute(sql) - return [format_result(result) for result in cursor] + return [format_result(result) for result in cursor.mappings()] def create_server(self, uri: str, token: str) -> int: """ @@ -199,7 +199,7 @@ class APIData(APIProviderInterface, BaseData): # Couldn't find an entry with this ID return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore allow_stats = (result["config"] & 0x1) == 0 allow_scores = (result["config"] & 0x2) == 0 return Server( diff --git a/bemani/data/mysql/base.py b/bemani/data/mysql/base.py index 015758c..1ca2fff 100644 --- a/bemani/data/mysql/base.py +++ b/bemani/data/mysql/base.py @@ -6,7 +6,7 @@ from typing_extensions import Final from bemani.common import Time from bemani.data.config import Config -from sqlalchemy.engine import CursorResult # type: ignore +from sqlalchemy.engine import CursorResult from sqlalchemy.orm import scoped_session from sqlalchemy.sql import text from sqlalchemy.types import String, Integer @@ -82,10 +82,12 @@ class BaseData: includes = all(s in lowered for s in write_statement_group) if includes and not safe_write_operation: raise Exception("Read-only mode is active!") - return self.__conn.execute( + result = self.__conn.execute( text(sql), params if params is not None else {}, ) + self.__conn.commit() + return result def serialize(self, data: Dict[str, Any]) -> str: """ @@ -141,7 +143,7 @@ class BaseData: # Couldn't find a user with this session return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return result["id"] def _create_session(self, opid: int, optype: str, expiration: int = (30 * 86400)) -> str: diff --git a/bemani/data/mysql/game.py b/bemani/data/mysql/game.py index c9e5cd6..f7bb29b 100644 --- a/bemani/data/mysql/game.py +++ b/bemani/data/mysql/game.py @@ -96,7 +96,7 @@ class GameData(BaseData): # Settings doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return ValidatedDict(self.deserialize(result["data"])) def put_settings(self, game: GameConstants, userid: UserID, settings: Dict[str, Any]) -> None: @@ -158,7 +158,7 @@ class GameData(BaseData): # score doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return ValidatedDict(self.deserialize(result["data"])) def get_achievements(self, game: GameConstants, userid: UserID) -> List[Achievement]: @@ -182,7 +182,7 @@ class GameData(BaseData): None, self.deserialize(result["data"]), ) - for result in cursor + for result in cursor.mappings() ] def put_achievement( @@ -251,7 +251,7 @@ class GameData(BaseData): # setting doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore retval = ValidatedDict(self.deserialize(result["data"])) retval["start_time"] = result["start_time"] retval["end_time"] = result["end_time"] @@ -288,7 +288,7 @@ class GameData(BaseData): "end_time": result["end_time"], } ) - for result in cursor + for result in cursor.mappings() ] def put_time_sensitive_settings( @@ -336,7 +336,7 @@ class GameData(BaseData): "end_time": end_time, }, ) - for result in cursor: + for result in cursor.mappings(): if result["start_time"] == start_time and result["end_time"] == end_time: # This is just this event being updated, that's fine. continue @@ -387,7 +387,7 @@ class GameData(BaseData): # entry doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return ValidatedDict(self.deserialize(result["data"])) def get_items(self, game: GameConstants, version: int) -> List[Item]: @@ -410,5 +410,5 @@ class GameData(BaseData): result["id"], self.deserialize(result["data"]), ) - for result in cursor + for result in cursor.mappings() ] diff --git a/bemani/data/mysql/lobby.py b/bemani/data/mysql/lobby.py index 31698fe..d75434f 100644 --- a/bemani/data/mysql/lobby.py +++ b/bemani/data/mysql/lobby.py @@ -82,7 +82,7 @@ class LobbyData(BaseData): # Settings doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore data = ValidatedDict(self.deserialize(result["data"])) data["id"] = result["id"] data["time"] = result["time"] @@ -119,7 +119,7 @@ class LobbyData(BaseData): data["time"] = result["time"] return data - return [(UserID(result["userid"]), format_result(result)) for result in cursor] + return [(UserID(result["userid"]), format_result(result)) for result in cursor.mappings()] def put_play_session_info(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None: """ @@ -214,7 +214,7 @@ class LobbyData(BaseData): # Settings doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore data = ValidatedDict(self.deserialize(result["data"])) data["id"] = result["id"] data["time"] = result["time"] @@ -252,7 +252,7 @@ class LobbyData(BaseData): data["time"] = result["time"] return data - return [(UserID(result["userid"]), format_result(result)) for result in cursor] + return [(UserID(result["userid"]), format_result(result)) for result in cursor.mappings()] def put_lobby(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None: """ diff --git a/bemani/data/mysql/machine.py b/bemani/data/mysql/machine.py index 95430a5..cf58216 100644 --- a/bemani/data/mysql/machine.py +++ b/bemani/data/mysql/machine.py @@ -102,7 +102,7 @@ class MachineData(BaseData): # Machine doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return result["pcbid"] def from_machine_id(self, machine_id: int) -> Optional[str]: @@ -122,7 +122,7 @@ class MachineData(BaseData): # Machine doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return result["pcbid"] def from_userid(self, userid: UserID) -> List[ArcadeID]: @@ -137,7 +137,7 @@ class MachineData(BaseData): """ sql = "SELECT arcadeid FROM arcade_owner WHERE userid = :userid" cursor = self.execute(sql, {"userid": userid}) - return [ArcadeID(result["arcadeid"]) for result in cursor] + return [ArcadeID(result["arcadeid"]) for result in cursor.mappings()] def from_session(self, session: str) -> Optional[ArcadeID]: """ @@ -173,7 +173,7 @@ class MachineData(BaseData): # Machine doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return Machine( result["id"], pcbid, @@ -212,7 +212,7 @@ class MachineData(BaseData): result["version"], self.deserialize(result["data"]), ) - for result in cursor + for result in cursor.mappings() ] def put_machine(self, machine: Machine) -> None: @@ -279,7 +279,7 @@ class MachineData(BaseData): port = None else: # Grab highest port - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore port = result["port"] if port is not None: port = port + 1 @@ -383,7 +383,7 @@ class MachineData(BaseData): # Arcade doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore sql = "SELECT userid FROM arcade_owner WHERE arcadeid = :id" cursor = self.execute(sql, {"id": arcadeid}) @@ -396,7 +396,7 @@ class MachineData(BaseData): result["pref"], result["area"] or None, self.deserialize(result["data"]), - [owner["userid"] for owner in cursor], + [owner["userid"] for owner in cursor.mappings()], ) def put_arcade(self, arcade: Arcade) -> None: @@ -464,7 +464,7 @@ class MachineData(BaseData): sql = "SELECT userid, arcadeid FROM arcade_owner" cursor = self.execute(sql) arcade_to_owners: Dict[int, List[UserID]] = {} - for row in cursor: + for row in cursor.mappings(): arcade = row["arcadeid"] owner = UserID(row["userid"]) if arcade not in arcade_to_owners: @@ -484,7 +484,7 @@ class MachineData(BaseData): self.deserialize(result["data"]), arcade_to_owners.get(result["id"], []), ) - for result in cursor + for result in cursor.mappings() ] def get_settings( @@ -512,7 +512,7 @@ class MachineData(BaseData): # Settings doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return ValidatedDict(self.deserialize(result["data"])) def put_settings( @@ -566,7 +566,7 @@ class MachineData(BaseData): UserID(entry["userid"]), entry["balance"], ) - for entry in cursor + for entry in cursor.mappings() ] def create_session(self, arcadeid: ArcadeID, expiration: int = (30 * 86400)) -> str: diff --git a/bemani/data/mysql/music.py b/bemani/data/mysql/music.py index 35d7767..80444cc 100644 --- a/bemani/data/mysql/music.py +++ b/bemani/data/mysql/music.py @@ -107,7 +107,7 @@ class MusicData(BaseData): if cursor.rowcount != 1: # music doesn't exist raise Exception(f"Song {songid} chart {songchart} doesn't exist for game {game} version {version}") - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return result["id"] def put_score( @@ -296,7 +296,7 @@ class MusicData(BaseData): # score doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return Score( result["scorekey"], result["songid"], @@ -356,7 +356,7 @@ class MusicData(BaseData): # score doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return ( UserID(result["userid"]), Score( @@ -439,7 +439,7 @@ class MusicData(BaseData): result["plays"], self.deserialize(result["data"]), ) - for result in cursor + for result in cursor.mappings() ] def get_most_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]: @@ -472,7 +472,7 @@ class MusicData(BaseData): {"userid": userid, "game": game.value, "version": version, "count": count}, ) - return [(result["songid"], result["plays"]) for result in cursor] + return [(result["songid"], result["plays"]) for result in cursor.mappings()] def get_last_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]: """ @@ -504,7 +504,7 @@ class MusicData(BaseData): {"userid": userid, "game": game.value, "version": version, "count": count}, ) - return [(result["songid"], result["timestamp"]) for result in cursor] + return [(result["songid"], result["timestamp"]) for result in cursor.mappings()] def get_hit_chart( self, @@ -551,7 +551,7 @@ class MusicData(BaseData): }, ) - return [(result["songid"], result["plays"]) for result in cursor] + return [(result["songid"], result["plays"]) for result in cursor.mappings()] def get_song( self, @@ -597,7 +597,7 @@ class MusicData(BaseData): if cursor.rowcount != 1: # music doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return Song( game, version, @@ -647,7 +647,7 @@ class MusicData(BaseData): result["genre"], self.deserialize(result["data"]), ) - for result in cursor + for result in cursor.mappings() ] def get_all_scores( @@ -752,7 +752,7 @@ class MusicData(BaseData): self.deserialize(result["data"]), ), ) - for result in cursor + for result in cursor.mappings() ] def get_all_records( @@ -863,7 +863,7 @@ class MusicData(BaseData): self.deserialize(result["data"]), ), ) - for result in cursor + for result in cursor.mappings() ] def get_attempt_by_key(self, game: GameConstants, version: int, key: int) -> Optional[Tuple[UserID, Attempt]]: @@ -908,7 +908,7 @@ class MusicData(BaseData): # score doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return ( UserID(result["userid"]), Attempt( @@ -1019,5 +1019,5 @@ class MusicData(BaseData): self.deserialize(result["data"]), ), ) - for result in cursor + for result in cursor.mappings() ] diff --git a/bemani/data/mysql/network.py b/bemani/data/mysql/network.py index 8c4c5aa..6152e5f 100644 --- a/bemani/data/mysql/network.py +++ b/bemani/data/mysql/network.py @@ -73,7 +73,7 @@ class NetworkData(BaseData): result["title"], result["body"], ) - for result in cursor + for result in cursor.mappings() ] def create_news(self, title: str, body: str) -> int: @@ -107,7 +107,7 @@ class NetworkData(BaseData): # Couldn't find an entry with this ID return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return News( newsid, result["timestamp"], @@ -181,7 +181,7 @@ class NetworkData(BaseData): # No scheduled work was registered, so time to get going! return True - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore if schedule == "daily": # Just look at the day and year, make sure it matches @@ -317,7 +317,7 @@ class NetworkData(BaseData): result["type"], self.deserialize(result["data"]), ) - for result in cursor + for result in cursor.mappings() ] def delete_events(self, oldest_event_ts: int) -> None: diff --git a/bemani/data/mysql/user.py b/bemani/data/mysql/user.py index c48089a..b4d9db6 100644 --- a/bemani/data/mysql/user.py +++ b/bemani/data/mysql/user.py @@ -194,7 +194,7 @@ class UserData(BaseData): # Couldn't find a user with this card return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return UserID(result["userid"]) def from_username(self, username: str) -> Optional[UserID]: @@ -213,7 +213,7 @@ class UserData(BaseData): # Couldn't find this username return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return UserID(result["id"]) def from_refid(self, game: GameConstants, version: int, refid: str) -> Optional[UserID]: @@ -238,7 +238,7 @@ class UserData(BaseData): # Couldn't find a user with this refid return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return UserID(result["userid"]) def from_extid(self, game: GameConstants, version: int, extid: int) -> Optional[UserID]: @@ -263,7 +263,7 @@ class UserData(BaseData): # Couldn't find a user with this refid return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return UserID(result["userid"]) def from_session(self, session: str) -> Optional[UserID]: @@ -297,7 +297,7 @@ class UserData(BaseData): # User doesn't exist, but we have a reference? return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return User(userid, result["username"], result["email"], result["admin"] == 1) def get_all_users(self) -> List[User]: @@ -316,7 +316,7 @@ class UserData(BaseData): result["email"], result["admin"] == 1, ) - for result in cursor + for result in cursor.mappings() ] def get_all_usernames(self) -> List[str]: @@ -331,7 +331,7 @@ class UserData(BaseData): """ sql = "SELECT username FROM user WHERE username is not null" cursor = self.execute(sql) - return [res["username"] for res in cursor] + return [res["username"] for res in cursor.mappings()] def get_all_cards(self) -> List[Tuple[str, UserID]]: """ @@ -342,7 +342,7 @@ class UserData(BaseData): """ sql = "SELECT id, userid FROM card" cursor = self.execute(sql) - return [(str(res["id"]).upper(), UserID(res["userid"])) for res in cursor] + return [(str(res["id"]).upper(), UserID(res["userid"])) for res in cursor.mappings()] def get_cards(self, userid: UserID) -> List[str]: """ @@ -356,7 +356,7 @@ class UserData(BaseData): """ sql = "SELECT id FROM card WHERE userid = :userid" cursor = self.execute(sql, {"userid": userid}) - return [str(res["id"]).upper() for res in cursor] + return [str(res["id"]).upper() for res in cursor.mappings()] def add_card(self, userid: UserID, cardid: str) -> None: """ @@ -437,7 +437,7 @@ class UserData(BaseData): # User doesn't exist, but we have a reference? return False - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return pin == result["pin"] def update_pin(self, userid: UserID, pin: str) -> None: @@ -468,7 +468,7 @@ class UserData(BaseData): # User doesn't exist, but we have a reference? return False - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore passhash = result["password"] try: @@ -517,7 +517,7 @@ class UserData(BaseData): # Profile doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return Profile( game, version, @@ -577,7 +577,7 @@ class UserData(BaseData): cursor = self.execute(sql, {"game": game.value, "userids": userids}) profilever: Dict[UserID, int] = {} - for result in cursor: + for result in cursor.mappings(): tuid = UserID(result["userid"]) tver = result["version"] @@ -626,7 +626,7 @@ class UserData(BaseData): vals["game"] = game.value cursor = self.execute(sql, vals) - return [(GameConstants(result["game"]), result["version"]) for result in cursor] + return [(GameConstants(result["game"]), result["version"]) for result in cursor.mappings()] def get_all_profiles(self, game: GameConstants, version: int) -> List[Tuple[UserID, Profile]]: """ @@ -662,7 +662,7 @@ class UserData(BaseData): self.deserialize(result["data"]), ), ) - for result in cursor + for result in cursor.mappings() ] def get_all_players(self, game: GameConstants, version: int) -> List[UserID]: @@ -682,7 +682,7 @@ class UserData(BaseData): """ cursor = self.execute(sql, {"game": game.value, "version": version}) - return [UserID(result["userid"]) for result in cursor] + return [UserID(result["userid"]) for result in cursor.mappings()] def get_all_achievements( self, @@ -732,7 +732,7 @@ class UserData(BaseData): self.deserialize(result["data"]), ), ) - for result in cursor + for result in cursor.mappings() ] def put_profile(self, game: GameConstants, version: int, userid: UserID, profile: Profile) -> None: @@ -808,7 +808,7 @@ class UserData(BaseData): # score doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return ValidatedDict(self.deserialize(result["data"])) def get_achievements(self, game: GameConstants, version: int, userid: UserID) -> List[Achievement]: @@ -834,7 +834,7 @@ class UserData(BaseData): None, self.deserialize(result["data"]), ) - for result in cursor + for result in cursor.mappings() ] def put_achievement( @@ -945,7 +945,7 @@ class UserData(BaseData): result["timestamp"], self.deserialize(result["data"]), ) - for result in cursor + for result in cursor.mappings() ] def put_time_based_achievement( @@ -1023,7 +1023,7 @@ class UserData(BaseData): self.deserialize(result["data"]), ), ) - for result in cursor + for result in cursor.mappings() ] def get_link( @@ -1074,7 +1074,7 @@ class UserData(BaseData): # score doesn't exist return None - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return ValidatedDict(self.deserialize(result["data"])) def get_links(self, game: GameConstants, version: int, userid: UserID) -> List[Link]: @@ -1103,7 +1103,7 @@ class UserData(BaseData): UserID(result["other_userid"]), self.deserialize(result["data"]), ) - for result in cursor + for result in cursor.mappings() ] def put_link( @@ -1196,7 +1196,7 @@ class UserData(BaseData): sql = "SELECT balance FROM balance WHERE userid = :userid AND arcadeid = :arcadeid" cursor = self.execute(sql, {"userid": userid, "arcadeid": arcadeid}) if cursor.rowcount == 1: - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return result["balance"] else: return 0 @@ -1242,7 +1242,7 @@ class UserData(BaseData): sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version" cursor = self.execute(sql, {"userid": userid, "game": game.value, "version": version}) if cursor.rowcount == 1: - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return result["refid"] else: return self.create_refid(game, version, userid) @@ -1265,7 +1265,7 @@ class UserData(BaseData): sql = "SELECT extid FROM extid WHERE userid = :userid AND game = :game" cursor = self.execute(sql, {"userid": userid, "game": game.value}) if cursor.rowcount == 1: - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return result["extid"] else: return None @@ -1371,7 +1371,7 @@ class UserData(BaseData): sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version" cursor = self.execute(sql, {"userid": userid, "game": game.value, "version": version}) if cursor.rowcount == 1: - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return result["refid"] # Shouldn't be possible, but here we are raise AccountCreationException("Failed to recover lost race refid!") diff --git a/bemani/tests/helpers.py b/bemani/tests/helpers.py index 7ca0771..54b6c25 100644 --- a/bemani/tests/helpers.py +++ b/bemani/tests/helpers.py @@ -26,6 +26,9 @@ class FakeCursor: self.rowcount = len(rows) self.pos = -1 + def mappings(self) -> "FakeCursor": + return self + def fetchone(self) -> Dict[str, Any]: if len(self.__rows) != 1: raise Exception(f"Tried to fetch one row and there are {len(self.__rows)} rows!") diff --git a/bemani/utils/read.py b/bemani/utils/read.py index c3bc0b2..b915016 100644 --- a/bemani/utils/read.py +++ b/bemani/utils/read.py @@ -4,14 +4,14 @@ import csv import argparse import copy import io -import jaconv # type: ignore +import jaconv import json import os import struct import xml.etree.ElementTree as ET from pathlib import Path -from sqlalchemy.engine import CursorResult # type: ignore -from sqlalchemy.orm import sessionmaker +from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.sql import text from sqlalchemy.exc import IntegrityError from typing import Any, Callable, Dict, List, Optional, Tuple @@ -70,17 +70,18 @@ class ImportBase: self.update = update self.no_combine = no_combine self.__config = config - self.__engine = self.__config.database.engine - self.__sessionmanager = sessionmaker(self.__engine) - self.__conn = self.__engine.connect() - self.__session = self.__sessionmanager(bind=self.__conn) self.__batch = False + # Set up DB connection stuff. + self.__engine = self.__config.database.engine + session_factory = sessionmaker(self.__engine) + self.__conn = scoped_session(session_factory) + def start_batch(self) -> None: self.__batch = True def finish_batch(self) -> None: - self.__session.commit() + self.__conn.commit() self.__batch = False def execute(self, sql: str, params: Optional[Dict[str, Any]] = None) -> CursorResult: @@ -96,12 +97,12 @@ class ImportBase: ]: if write_statement in sql.lower(): raise Exception("Read-only mode is active!") - return self.__session.execute(text(sql), params if params is not None else {}) + return self.__conn.execute(text(sql), params if params is not None else {}) def remote_music(self, server: str, token: str) -> GlobalMusicData: api = ReadAPI(server, token) - user = UserData(self.__config, self.__session) - music = MusicData(self.__config, self.__session) + user = UserData(self.__config, self.__conn) + music = MusicData(self.__config, self.__conn) return GlobalMusicData(api, user, music) def remote_game(self, server: str, token: str) -> GlobalGameData: @@ -110,7 +111,7 @@ class ImportBase: def get_next_music_id(self) -> int: cursor = self.execute("SELECT MAX(id) AS next_id FROM `music`") - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore try: return result["next_id"] + 1 except TypeError: @@ -138,7 +139,7 @@ class ImportBase: }, ) if cursor.rowcount != 0: - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return result["id"] else: return None @@ -183,7 +184,7 @@ class ImportBase: }, ) if cursor.rowcount != 0: - result = cursor.fetchone() + result = cursor.mappings().fetchone() # type: ignore return result["id"] else: return None @@ -369,8 +370,6 @@ class ImportBase: # Make sure we don't leak connections after finising insertion. if self.__batch: raise Exception("Logic error, opened a batch without closing!") - if self.__session is not None: - self.__session.close() if self.__conn is not None: self.__conn.close() self.__conn = None diff --git a/requirements.txt b/requirements.txt index 270667c..7ed664f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ setuptools Cython -SQLAlchemy<2.0.0 +SQLAlchemy alembic PyYAML Flask