1
0
mirror of synced 2024-11-30 16:54:30 +01:00

Finally upgrade to SQLAlchemy 2.0, fix a bunch of type issues relating to that from Python 3.12.

This commit is contained in:
Jennifer Taylor 2024-09-26 00:04:23 +00:00
parent eb8264c553
commit da5e3fe52f
12 changed files with 97 additions and 94 deletions

View File

@ -97,7 +97,6 @@ class Data:
session_factory = sessionmaker( session_factory = sessionmaker(
bind=config.database.engine, bind=config.database.engine,
autoflush=True, autoflush=True,
autocommit=True,
) )
self.__config = config self.__config = config
self.__session = scoped_session(session_factory) self.__session = scoped_session(session_factory)

View File

@ -55,7 +55,7 @@ class APIData(APIProviderInterface, BaseData):
result["name"], result["name"],
result["token"], result["token"],
) )
for result in cursor for result in cursor.mappings()
] ]
def validate_client(self, token: str) -> bool: def validate_client(self, token: str) -> bool:
@ -70,7 +70,7 @@ class APIData(APIProviderInterface, BaseData):
""" """
sql = "SELECT count(*) AS count FROM client WHERE token = :token" sql = "SELECT count(*) AS count FROM client WHERE token = :token"
cursor = self.execute(sql, {"token": token}) cursor = self.execute(sql, {"token": token})
return cursor.fetchone()["count"] == 1 return cursor.mappings().fetchone()["count"] == 1 # type: ignore
def create_client(self, name: str) -> int: def create_client(self, name: str) -> int:
""" """
@ -109,7 +109,7 @@ class APIData(APIProviderInterface, BaseData):
# Couldn't find an entry with this ID # Couldn't find an entry with this ID
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return Client( return Client(
clientid, clientid,
result["timestamp"], result["timestamp"],
@ -159,7 +159,7 @@ class APIData(APIProviderInterface, BaseData):
sql = "SELECT id, timestamp, uri, token, config FROM server ORDER BY timestamp ASC" sql = "SELECT id, timestamp, uri, token, config FROM server ORDER BY timestamp ASC"
cursor = self.execute(sql) cursor = self.execute(sql)
return [format_result(result) for result in cursor] return [format_result(result) for result in cursor.mappings()]
def create_server(self, uri: str, token: str) -> int: def create_server(self, uri: str, token: str) -> int:
""" """
@ -199,7 +199,7 @@ class APIData(APIProviderInterface, BaseData):
# Couldn't find an entry with this ID # Couldn't find an entry with this ID
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
allow_stats = (result["config"] & 0x1) == 0 allow_stats = (result["config"] & 0x1) == 0
allow_scores = (result["config"] & 0x2) == 0 allow_scores = (result["config"] & 0x2) == 0
return Server( return Server(

View File

@ -6,7 +6,7 @@ from typing_extensions import Final
from bemani.common import Time from bemani.common import Time
from bemani.data.config import Config from bemani.data.config import Config
from sqlalchemy.engine import CursorResult # type: ignore from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session
from sqlalchemy.sql import text from sqlalchemy.sql import text
from sqlalchemy.types import String, Integer from sqlalchemy.types import String, Integer
@ -82,10 +82,12 @@ class BaseData:
includes = all(s in lowered for s in write_statement_group) includes = all(s in lowered for s in write_statement_group)
if includes and not safe_write_operation: if includes and not safe_write_operation:
raise Exception("Read-only mode is active!") raise Exception("Read-only mode is active!")
return self.__conn.execute( result = self.__conn.execute(
text(sql), text(sql),
params if params is not None else {}, params if params is not None else {},
) )
self.__conn.commit()
return result
def serialize(self, data: Dict[str, Any]) -> str: def serialize(self, data: Dict[str, Any]) -> str:
""" """
@ -141,7 +143,7 @@ class BaseData:
# Couldn't find a user with this session # Couldn't find a user with this session
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return result["id"] return result["id"]
def _create_session(self, opid: int, optype: str, expiration: int = (30 * 86400)) -> str: def _create_session(self, opid: int, optype: str, expiration: int = (30 * 86400)) -> str:

View File

@ -96,7 +96,7 @@ class GameData(BaseData):
# Settings doesn't exist # Settings doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"])) return ValidatedDict(self.deserialize(result["data"]))
def put_settings(self, game: GameConstants, userid: UserID, settings: Dict[str, Any]) -> None: def put_settings(self, game: GameConstants, userid: UserID, settings: Dict[str, Any]) -> None:
@ -158,7 +158,7 @@ class GameData(BaseData):
# score doesn't exist # score doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"])) return ValidatedDict(self.deserialize(result["data"]))
def get_achievements(self, game: GameConstants, userid: UserID) -> List[Achievement]: def get_achievements(self, game: GameConstants, userid: UserID) -> List[Achievement]:
@ -182,7 +182,7 @@ class GameData(BaseData):
None, None,
self.deserialize(result["data"]), self.deserialize(result["data"]),
) )
for result in cursor for result in cursor.mappings()
] ]
def put_achievement( def put_achievement(
@ -251,7 +251,7 @@ class GameData(BaseData):
# setting doesn't exist # setting doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
retval = ValidatedDict(self.deserialize(result["data"])) retval = ValidatedDict(self.deserialize(result["data"]))
retval["start_time"] = result["start_time"] retval["start_time"] = result["start_time"]
retval["end_time"] = result["end_time"] retval["end_time"] = result["end_time"]
@ -288,7 +288,7 @@ class GameData(BaseData):
"end_time": result["end_time"], "end_time": result["end_time"],
} }
) )
for result in cursor for result in cursor.mappings()
] ]
def put_time_sensitive_settings( def put_time_sensitive_settings(
@ -336,7 +336,7 @@ class GameData(BaseData):
"end_time": end_time, "end_time": end_time,
}, },
) )
for result in cursor: for result in cursor.mappings():
if result["start_time"] == start_time and result["end_time"] == end_time: if result["start_time"] == start_time and result["end_time"] == end_time:
# This is just this event being updated, that's fine. # This is just this event being updated, that's fine.
continue continue
@ -387,7 +387,7 @@ class GameData(BaseData):
# entry doesn't exist # entry doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"])) return ValidatedDict(self.deserialize(result["data"]))
def get_items(self, game: GameConstants, version: int) -> List[Item]: def get_items(self, game: GameConstants, version: int) -> List[Item]:
@ -410,5 +410,5 @@ class GameData(BaseData):
result["id"], result["id"],
self.deserialize(result["data"]), self.deserialize(result["data"]),
) )
for result in cursor for result in cursor.mappings()
] ]

View File

@ -82,7 +82,7 @@ class LobbyData(BaseData):
# Settings doesn't exist # Settings doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
data = ValidatedDict(self.deserialize(result["data"])) data = ValidatedDict(self.deserialize(result["data"]))
data["id"] = result["id"] data["id"] = result["id"]
data["time"] = result["time"] data["time"] = result["time"]
@ -119,7 +119,7 @@ class LobbyData(BaseData):
data["time"] = result["time"] data["time"] = result["time"]
return data return data
return [(UserID(result["userid"]), format_result(result)) for result in cursor] return [(UserID(result["userid"]), format_result(result)) for result in cursor.mappings()]
def put_play_session_info(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None: def put_play_session_info(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None:
""" """
@ -214,7 +214,7 @@ class LobbyData(BaseData):
# Settings doesn't exist # Settings doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
data = ValidatedDict(self.deserialize(result["data"])) data = ValidatedDict(self.deserialize(result["data"]))
data["id"] = result["id"] data["id"] = result["id"]
data["time"] = result["time"] data["time"] = result["time"]
@ -252,7 +252,7 @@ class LobbyData(BaseData):
data["time"] = result["time"] data["time"] = result["time"]
return data return data
return [(UserID(result["userid"]), format_result(result)) for result in cursor] return [(UserID(result["userid"]), format_result(result)) for result in cursor.mappings()]
def put_lobby(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None: def put_lobby(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None:
""" """

View File

@ -102,7 +102,7 @@ class MachineData(BaseData):
# Machine doesn't exist # Machine doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return result["pcbid"] return result["pcbid"]
def from_machine_id(self, machine_id: int) -> Optional[str]: def from_machine_id(self, machine_id: int) -> Optional[str]:
@ -122,7 +122,7 @@ class MachineData(BaseData):
# Machine doesn't exist # Machine doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return result["pcbid"] return result["pcbid"]
def from_userid(self, userid: UserID) -> List[ArcadeID]: def from_userid(self, userid: UserID) -> List[ArcadeID]:
@ -137,7 +137,7 @@ class MachineData(BaseData):
""" """
sql = "SELECT arcadeid FROM arcade_owner WHERE userid = :userid" sql = "SELECT arcadeid FROM arcade_owner WHERE userid = :userid"
cursor = self.execute(sql, {"userid": userid}) cursor = self.execute(sql, {"userid": userid})
return [ArcadeID(result["arcadeid"]) for result in cursor] return [ArcadeID(result["arcadeid"]) for result in cursor.mappings()]
def from_session(self, session: str) -> Optional[ArcadeID]: def from_session(self, session: str) -> Optional[ArcadeID]:
""" """
@ -173,7 +173,7 @@ class MachineData(BaseData):
# Machine doesn't exist # Machine doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return Machine( return Machine(
result["id"], result["id"],
pcbid, pcbid,
@ -212,7 +212,7 @@ class MachineData(BaseData):
result["version"], result["version"],
self.deserialize(result["data"]), self.deserialize(result["data"]),
) )
for result in cursor for result in cursor.mappings()
] ]
def put_machine(self, machine: Machine) -> None: def put_machine(self, machine: Machine) -> None:
@ -279,7 +279,7 @@ class MachineData(BaseData):
port = None port = None
else: else:
# Grab highest port # Grab highest port
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
port = result["port"] port = result["port"]
if port is not None: if port is not None:
port = port + 1 port = port + 1
@ -383,7 +383,7 @@ class MachineData(BaseData):
# Arcade doesn't exist # Arcade doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
sql = "SELECT userid FROM arcade_owner WHERE arcadeid = :id" sql = "SELECT userid FROM arcade_owner WHERE arcadeid = :id"
cursor = self.execute(sql, {"id": arcadeid}) cursor = self.execute(sql, {"id": arcadeid})
@ -396,7 +396,7 @@ class MachineData(BaseData):
result["pref"], result["pref"],
result["area"] or None, result["area"] or None,
self.deserialize(result["data"]), self.deserialize(result["data"]),
[owner["userid"] for owner in cursor], [owner["userid"] for owner in cursor.mappings()],
) )
def put_arcade(self, arcade: Arcade) -> None: def put_arcade(self, arcade: Arcade) -> None:
@ -464,7 +464,7 @@ class MachineData(BaseData):
sql = "SELECT userid, arcadeid FROM arcade_owner" sql = "SELECT userid, arcadeid FROM arcade_owner"
cursor = self.execute(sql) cursor = self.execute(sql)
arcade_to_owners: Dict[int, List[UserID]] = {} arcade_to_owners: Dict[int, List[UserID]] = {}
for row in cursor: for row in cursor.mappings():
arcade = row["arcadeid"] arcade = row["arcadeid"]
owner = UserID(row["userid"]) owner = UserID(row["userid"])
if arcade not in arcade_to_owners: if arcade not in arcade_to_owners:
@ -484,7 +484,7 @@ class MachineData(BaseData):
self.deserialize(result["data"]), self.deserialize(result["data"]),
arcade_to_owners.get(result["id"], []), arcade_to_owners.get(result["id"], []),
) )
for result in cursor for result in cursor.mappings()
] ]
def get_settings( def get_settings(
@ -512,7 +512,7 @@ class MachineData(BaseData):
# Settings doesn't exist # Settings doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"])) return ValidatedDict(self.deserialize(result["data"]))
def put_settings( def put_settings(
@ -566,7 +566,7 @@ class MachineData(BaseData):
UserID(entry["userid"]), UserID(entry["userid"]),
entry["balance"], entry["balance"],
) )
for entry in cursor for entry in cursor.mappings()
] ]
def create_session(self, arcadeid: ArcadeID, expiration: int = (30 * 86400)) -> str: def create_session(self, arcadeid: ArcadeID, expiration: int = (30 * 86400)) -> str:

View File

@ -107,7 +107,7 @@ class MusicData(BaseData):
if cursor.rowcount != 1: if cursor.rowcount != 1:
# music doesn't exist # music doesn't exist
raise Exception(f"Song {songid} chart {songchart} doesn't exist for game {game} version {version}") raise Exception(f"Song {songid} chart {songchart} doesn't exist for game {game} version {version}")
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return result["id"] return result["id"]
def put_score( def put_score(
@ -296,7 +296,7 @@ class MusicData(BaseData):
# score doesn't exist # score doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return Score( return Score(
result["scorekey"], result["scorekey"],
result["songid"], result["songid"],
@ -356,7 +356,7 @@ class MusicData(BaseData):
# score doesn't exist # score doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return ( return (
UserID(result["userid"]), UserID(result["userid"]),
Score( Score(
@ -439,7 +439,7 @@ class MusicData(BaseData):
result["plays"], result["plays"],
self.deserialize(result["data"]), self.deserialize(result["data"]),
) )
for result in cursor for result in cursor.mappings()
] ]
def get_most_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]: def get_most_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]:
@ -472,7 +472,7 @@ class MusicData(BaseData):
{"userid": userid, "game": game.value, "version": version, "count": count}, {"userid": userid, "game": game.value, "version": version, "count": count},
) )
return [(result["songid"], result["plays"]) for result in cursor] return [(result["songid"], result["plays"]) for result in cursor.mappings()]
def get_last_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]: def get_last_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]:
""" """
@ -504,7 +504,7 @@ class MusicData(BaseData):
{"userid": userid, "game": game.value, "version": version, "count": count}, {"userid": userid, "game": game.value, "version": version, "count": count},
) )
return [(result["songid"], result["timestamp"]) for result in cursor] return [(result["songid"], result["timestamp"]) for result in cursor.mappings()]
def get_hit_chart( def get_hit_chart(
self, self,
@ -551,7 +551,7 @@ class MusicData(BaseData):
}, },
) )
return [(result["songid"], result["plays"]) for result in cursor] return [(result["songid"], result["plays"]) for result in cursor.mappings()]
def get_song( def get_song(
self, self,
@ -597,7 +597,7 @@ class MusicData(BaseData):
if cursor.rowcount != 1: if cursor.rowcount != 1:
# music doesn't exist # music doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return Song( return Song(
game, game,
version, version,
@ -647,7 +647,7 @@ class MusicData(BaseData):
result["genre"], result["genre"],
self.deserialize(result["data"]), self.deserialize(result["data"]),
) )
for result in cursor for result in cursor.mappings()
] ]
def get_all_scores( def get_all_scores(
@ -752,7 +752,7 @@ class MusicData(BaseData):
self.deserialize(result["data"]), self.deserialize(result["data"]),
), ),
) )
for result in cursor for result in cursor.mappings()
] ]
def get_all_records( def get_all_records(
@ -863,7 +863,7 @@ class MusicData(BaseData):
self.deserialize(result["data"]), self.deserialize(result["data"]),
), ),
) )
for result in cursor for result in cursor.mappings()
] ]
def get_attempt_by_key(self, game: GameConstants, version: int, key: int) -> Optional[Tuple[UserID, Attempt]]: def get_attempt_by_key(self, game: GameConstants, version: int, key: int) -> Optional[Tuple[UserID, Attempt]]:
@ -908,7 +908,7 @@ class MusicData(BaseData):
# score doesn't exist # score doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return ( return (
UserID(result["userid"]), UserID(result["userid"]),
Attempt( Attempt(
@ -1019,5 +1019,5 @@ class MusicData(BaseData):
self.deserialize(result["data"]), self.deserialize(result["data"]),
), ),
) )
for result in cursor for result in cursor.mappings()
] ]

View File

@ -73,7 +73,7 @@ class NetworkData(BaseData):
result["title"], result["title"],
result["body"], result["body"],
) )
for result in cursor for result in cursor.mappings()
] ]
def create_news(self, title: str, body: str) -> int: def create_news(self, title: str, body: str) -> int:
@ -107,7 +107,7 @@ class NetworkData(BaseData):
# Couldn't find an entry with this ID # Couldn't find an entry with this ID
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return News( return News(
newsid, newsid,
result["timestamp"], result["timestamp"],
@ -181,7 +181,7 @@ class NetworkData(BaseData):
# No scheduled work was registered, so time to get going! # No scheduled work was registered, so time to get going!
return True return True
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
if schedule == "daily": if schedule == "daily":
# Just look at the day and year, make sure it matches # Just look at the day and year, make sure it matches
@ -317,7 +317,7 @@ class NetworkData(BaseData):
result["type"], result["type"],
self.deserialize(result["data"]), self.deserialize(result["data"]),
) )
for result in cursor for result in cursor.mappings()
] ]
def delete_events(self, oldest_event_ts: int) -> None: def delete_events(self, oldest_event_ts: int) -> None:

View File

@ -194,7 +194,7 @@ class UserData(BaseData):
# Couldn't find a user with this card # Couldn't find a user with this card
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return UserID(result["userid"]) return UserID(result["userid"])
def from_username(self, username: str) -> Optional[UserID]: def from_username(self, username: str) -> Optional[UserID]:
@ -213,7 +213,7 @@ class UserData(BaseData):
# Couldn't find this username # Couldn't find this username
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return UserID(result["id"]) return UserID(result["id"])
def from_refid(self, game: GameConstants, version: int, refid: str) -> Optional[UserID]: def from_refid(self, game: GameConstants, version: int, refid: str) -> Optional[UserID]:
@ -238,7 +238,7 @@ class UserData(BaseData):
# Couldn't find a user with this refid # Couldn't find a user with this refid
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return UserID(result["userid"]) return UserID(result["userid"])
def from_extid(self, game: GameConstants, version: int, extid: int) -> Optional[UserID]: def from_extid(self, game: GameConstants, version: int, extid: int) -> Optional[UserID]:
@ -263,7 +263,7 @@ class UserData(BaseData):
# Couldn't find a user with this refid # Couldn't find a user with this refid
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return UserID(result["userid"]) return UserID(result["userid"])
def from_session(self, session: str) -> Optional[UserID]: def from_session(self, session: str) -> Optional[UserID]:
@ -297,7 +297,7 @@ class UserData(BaseData):
# User doesn't exist, but we have a reference? # User doesn't exist, but we have a reference?
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return User(userid, result["username"], result["email"], result["admin"] == 1) return User(userid, result["username"], result["email"], result["admin"] == 1)
def get_all_users(self) -> List[User]: def get_all_users(self) -> List[User]:
@ -316,7 +316,7 @@ class UserData(BaseData):
result["email"], result["email"],
result["admin"] == 1, result["admin"] == 1,
) )
for result in cursor for result in cursor.mappings()
] ]
def get_all_usernames(self) -> List[str]: def get_all_usernames(self) -> List[str]:
@ -331,7 +331,7 @@ class UserData(BaseData):
""" """
sql = "SELECT username FROM user WHERE username is not null" sql = "SELECT username FROM user WHERE username is not null"
cursor = self.execute(sql) cursor = self.execute(sql)
return [res["username"] for res in cursor] return [res["username"] for res in cursor.mappings()]
def get_all_cards(self) -> List[Tuple[str, UserID]]: def get_all_cards(self) -> List[Tuple[str, UserID]]:
""" """
@ -342,7 +342,7 @@ class UserData(BaseData):
""" """
sql = "SELECT id, userid FROM card" sql = "SELECT id, userid FROM card"
cursor = self.execute(sql) cursor = self.execute(sql)
return [(str(res["id"]).upper(), UserID(res["userid"])) for res in cursor] return [(str(res["id"]).upper(), UserID(res["userid"])) for res in cursor.mappings()]
def get_cards(self, userid: UserID) -> List[str]: def get_cards(self, userid: UserID) -> List[str]:
""" """
@ -356,7 +356,7 @@ class UserData(BaseData):
""" """
sql = "SELECT id FROM card WHERE userid = :userid" sql = "SELECT id FROM card WHERE userid = :userid"
cursor = self.execute(sql, {"userid": userid}) cursor = self.execute(sql, {"userid": userid})
return [str(res["id"]).upper() for res in cursor] return [str(res["id"]).upper() for res in cursor.mappings()]
def add_card(self, userid: UserID, cardid: str) -> None: def add_card(self, userid: UserID, cardid: str) -> None:
""" """
@ -437,7 +437,7 @@ class UserData(BaseData):
# User doesn't exist, but we have a reference? # User doesn't exist, but we have a reference?
return False return False
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return pin == result["pin"] return pin == result["pin"]
def update_pin(self, userid: UserID, pin: str) -> None: def update_pin(self, userid: UserID, pin: str) -> None:
@ -468,7 +468,7 @@ class UserData(BaseData):
# User doesn't exist, but we have a reference? # User doesn't exist, but we have a reference?
return False return False
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
passhash = result["password"] passhash = result["password"]
try: try:
@ -517,7 +517,7 @@ class UserData(BaseData):
# Profile doesn't exist # Profile doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return Profile( return Profile(
game, game,
version, version,
@ -577,7 +577,7 @@ class UserData(BaseData):
cursor = self.execute(sql, {"game": game.value, "userids": userids}) cursor = self.execute(sql, {"game": game.value, "userids": userids})
profilever: Dict[UserID, int] = {} profilever: Dict[UserID, int] = {}
for result in cursor: for result in cursor.mappings():
tuid = UserID(result["userid"]) tuid = UserID(result["userid"])
tver = result["version"] tver = result["version"]
@ -626,7 +626,7 @@ class UserData(BaseData):
vals["game"] = game.value vals["game"] = game.value
cursor = self.execute(sql, vals) cursor = self.execute(sql, vals)
return [(GameConstants(result["game"]), result["version"]) for result in cursor] return [(GameConstants(result["game"]), result["version"]) for result in cursor.mappings()]
def get_all_profiles(self, game: GameConstants, version: int) -> List[Tuple[UserID, Profile]]: def get_all_profiles(self, game: GameConstants, version: int) -> List[Tuple[UserID, Profile]]:
""" """
@ -662,7 +662,7 @@ class UserData(BaseData):
self.deserialize(result["data"]), self.deserialize(result["data"]),
), ),
) )
for result in cursor for result in cursor.mappings()
] ]
def get_all_players(self, game: GameConstants, version: int) -> List[UserID]: def get_all_players(self, game: GameConstants, version: int) -> List[UserID]:
@ -682,7 +682,7 @@ class UserData(BaseData):
""" """
cursor = self.execute(sql, {"game": game.value, "version": version}) cursor = self.execute(sql, {"game": game.value, "version": version})
return [UserID(result["userid"]) for result in cursor] return [UserID(result["userid"]) for result in cursor.mappings()]
def get_all_achievements( def get_all_achievements(
self, self,
@ -732,7 +732,7 @@ class UserData(BaseData):
self.deserialize(result["data"]), self.deserialize(result["data"]),
), ),
) )
for result in cursor for result in cursor.mappings()
] ]
def put_profile(self, game: GameConstants, version: int, userid: UserID, profile: Profile) -> None: def put_profile(self, game: GameConstants, version: int, userid: UserID, profile: Profile) -> None:
@ -808,7 +808,7 @@ class UserData(BaseData):
# score doesn't exist # score doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"])) return ValidatedDict(self.deserialize(result["data"]))
def get_achievements(self, game: GameConstants, version: int, userid: UserID) -> List[Achievement]: def get_achievements(self, game: GameConstants, version: int, userid: UserID) -> List[Achievement]:
@ -834,7 +834,7 @@ class UserData(BaseData):
None, None,
self.deserialize(result["data"]), self.deserialize(result["data"]),
) )
for result in cursor for result in cursor.mappings()
] ]
def put_achievement( def put_achievement(
@ -945,7 +945,7 @@ class UserData(BaseData):
result["timestamp"], result["timestamp"],
self.deserialize(result["data"]), self.deserialize(result["data"]),
) )
for result in cursor for result in cursor.mappings()
] ]
def put_time_based_achievement( def put_time_based_achievement(
@ -1023,7 +1023,7 @@ class UserData(BaseData):
self.deserialize(result["data"]), self.deserialize(result["data"]),
), ),
) )
for result in cursor for result in cursor.mappings()
] ]
def get_link( def get_link(
@ -1074,7 +1074,7 @@ class UserData(BaseData):
# score doesn't exist # score doesn't exist
return None return None
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"])) return ValidatedDict(self.deserialize(result["data"]))
def get_links(self, game: GameConstants, version: int, userid: UserID) -> List[Link]: def get_links(self, game: GameConstants, version: int, userid: UserID) -> List[Link]:
@ -1103,7 +1103,7 @@ class UserData(BaseData):
UserID(result["other_userid"]), UserID(result["other_userid"]),
self.deserialize(result["data"]), self.deserialize(result["data"]),
) )
for result in cursor for result in cursor.mappings()
] ]
def put_link( def put_link(
@ -1196,7 +1196,7 @@ class UserData(BaseData):
sql = "SELECT balance FROM balance WHERE userid = :userid AND arcadeid = :arcadeid" sql = "SELECT balance FROM balance WHERE userid = :userid AND arcadeid = :arcadeid"
cursor = self.execute(sql, {"userid": userid, "arcadeid": arcadeid}) cursor = self.execute(sql, {"userid": userid, "arcadeid": arcadeid})
if cursor.rowcount == 1: if cursor.rowcount == 1:
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return result["balance"] return result["balance"]
else: else:
return 0 return 0
@ -1242,7 +1242,7 @@ class UserData(BaseData):
sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version" sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version"
cursor = self.execute(sql, {"userid": userid, "game": game.value, "version": version}) cursor = self.execute(sql, {"userid": userid, "game": game.value, "version": version})
if cursor.rowcount == 1: if cursor.rowcount == 1:
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return result["refid"] return result["refid"]
else: else:
return self.create_refid(game, version, userid) return self.create_refid(game, version, userid)
@ -1265,7 +1265,7 @@ class UserData(BaseData):
sql = "SELECT extid FROM extid WHERE userid = :userid AND game = :game" sql = "SELECT extid FROM extid WHERE userid = :userid AND game = :game"
cursor = self.execute(sql, {"userid": userid, "game": game.value}) cursor = self.execute(sql, {"userid": userid, "game": game.value})
if cursor.rowcount == 1: if cursor.rowcount == 1:
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return result["extid"] return result["extid"]
else: else:
return None return None
@ -1371,7 +1371,7 @@ class UserData(BaseData):
sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version" sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version"
cursor = self.execute(sql, {"userid": userid, "game": game.value, "version": version}) cursor = self.execute(sql, {"userid": userid, "game": game.value, "version": version})
if cursor.rowcount == 1: if cursor.rowcount == 1:
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return result["refid"] return result["refid"]
# Shouldn't be possible, but here we are # Shouldn't be possible, but here we are
raise AccountCreationException("Failed to recover lost race refid!") raise AccountCreationException("Failed to recover lost race refid!")

View File

@ -26,6 +26,9 @@ class FakeCursor:
self.rowcount = len(rows) self.rowcount = len(rows)
self.pos = -1 self.pos = -1
def mappings(self) -> "FakeCursor":
return self
def fetchone(self) -> Dict[str, Any]: def fetchone(self) -> Dict[str, Any]:
if len(self.__rows) != 1: if len(self.__rows) != 1:
raise Exception(f"Tried to fetch one row and there are {len(self.__rows)} rows!") raise Exception(f"Tried to fetch one row and there are {len(self.__rows)} rows!")

View File

@ -4,14 +4,14 @@ import csv
import argparse import argparse
import copy import copy
import io import io
import jaconv # type: ignore import jaconv
import json import json
import os import os
import struct import struct
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from pathlib import Path from pathlib import Path
from sqlalchemy.engine import CursorResult # type: ignore from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.sql import text from sqlalchemy.sql import text
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
@ -70,17 +70,18 @@ class ImportBase:
self.update = update self.update = update
self.no_combine = no_combine self.no_combine = no_combine
self.__config = config self.__config = config
self.__engine = self.__config.database.engine
self.__sessionmanager = sessionmaker(self.__engine)
self.__conn = self.__engine.connect()
self.__session = self.__sessionmanager(bind=self.__conn)
self.__batch = False self.__batch = False
# Set up DB connection stuff.
self.__engine = self.__config.database.engine
session_factory = sessionmaker(self.__engine)
self.__conn = scoped_session(session_factory)
def start_batch(self) -> None: def start_batch(self) -> None:
self.__batch = True self.__batch = True
def finish_batch(self) -> None: def finish_batch(self) -> None:
self.__session.commit() self.__conn.commit()
self.__batch = False self.__batch = False
def execute(self, sql: str, params: Optional[Dict[str, Any]] = None) -> CursorResult: def execute(self, sql: str, params: Optional[Dict[str, Any]] = None) -> CursorResult:
@ -96,12 +97,12 @@ class ImportBase:
]: ]:
if write_statement in sql.lower(): if write_statement in sql.lower():
raise Exception("Read-only mode is active!") raise Exception("Read-only mode is active!")
return self.__session.execute(text(sql), params if params is not None else {}) return self.__conn.execute(text(sql), params if params is not None else {})
def remote_music(self, server: str, token: str) -> GlobalMusicData: def remote_music(self, server: str, token: str) -> GlobalMusicData:
api = ReadAPI(server, token) api = ReadAPI(server, token)
user = UserData(self.__config, self.__session) user = UserData(self.__config, self.__conn)
music = MusicData(self.__config, self.__session) music = MusicData(self.__config, self.__conn)
return GlobalMusicData(api, user, music) return GlobalMusicData(api, user, music)
def remote_game(self, server: str, token: str) -> GlobalGameData: def remote_game(self, server: str, token: str) -> GlobalGameData:
@ -110,7 +111,7 @@ class ImportBase:
def get_next_music_id(self) -> int: def get_next_music_id(self) -> int:
cursor = self.execute("SELECT MAX(id) AS next_id FROM `music`") cursor = self.execute("SELECT MAX(id) AS next_id FROM `music`")
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
try: try:
return result["next_id"] + 1 return result["next_id"] + 1
except TypeError: except TypeError:
@ -138,7 +139,7 @@ class ImportBase:
}, },
) )
if cursor.rowcount != 0: if cursor.rowcount != 0:
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return result["id"] return result["id"]
else: else:
return None return None
@ -183,7 +184,7 @@ class ImportBase:
}, },
) )
if cursor.rowcount != 0: if cursor.rowcount != 0:
result = cursor.fetchone() result = cursor.mappings().fetchone() # type: ignore
return result["id"] return result["id"]
else: else:
return None return None
@ -369,8 +370,6 @@ class ImportBase:
# Make sure we don't leak connections after finising insertion. # Make sure we don't leak connections after finising insertion.
if self.__batch: if self.__batch:
raise Exception("Logic error, opened a batch without closing!") raise Exception("Logic error, opened a batch without closing!")
if self.__session is not None:
self.__session.close()
if self.__conn is not None: if self.__conn is not None:
self.__conn.close() self.__conn.close()
self.__conn = None self.__conn = None

View File

@ -1,6 +1,6 @@
setuptools setuptools
Cython Cython
SQLAlchemy<2.0.0 SQLAlchemy
alembic alembic
PyYAML PyYAML
Flask Flask