Finally upgrade to SQLAlchemy 2.0, fix a bunch of type issues relating to that from Python 3.12.
This commit is contained in:
parent
eb8264c553
commit
da5e3fe52f
@ -97,7 +97,6 @@ class Data:
|
||||
session_factory = sessionmaker(
|
||||
bind=config.database.engine,
|
||||
autoflush=True,
|
||||
autocommit=True,
|
||||
)
|
||||
self.__config = config
|
||||
self.__session = scoped_session(session_factory)
|
||||
|
@ -55,7 +55,7 @@ class APIData(APIProviderInterface, BaseData):
|
||||
result["name"],
|
||||
result["token"],
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def validate_client(self, token: str) -> bool:
|
||||
@ -70,7 +70,7 @@ class APIData(APIProviderInterface, BaseData):
|
||||
"""
|
||||
sql = "SELECT count(*) AS count FROM client WHERE token = :token"
|
||||
cursor = self.execute(sql, {"token": token})
|
||||
return cursor.fetchone()["count"] == 1
|
||||
return cursor.mappings().fetchone()["count"] == 1 # type: ignore
|
||||
|
||||
def create_client(self, name: str) -> int:
|
||||
"""
|
||||
@ -109,7 +109,7 @@ class APIData(APIProviderInterface, BaseData):
|
||||
# Couldn't find an entry with this ID
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return Client(
|
||||
clientid,
|
||||
result["timestamp"],
|
||||
@ -159,7 +159,7 @@ class APIData(APIProviderInterface, BaseData):
|
||||
|
||||
sql = "SELECT id, timestamp, uri, token, config FROM server ORDER BY timestamp ASC"
|
||||
cursor = self.execute(sql)
|
||||
return [format_result(result) for result in cursor]
|
||||
return [format_result(result) for result in cursor.mappings()]
|
||||
|
||||
def create_server(self, uri: str, token: str) -> int:
|
||||
"""
|
||||
@ -199,7 +199,7 @@ class APIData(APIProviderInterface, BaseData):
|
||||
# Couldn't find an entry with this ID
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
allow_stats = (result["config"] & 0x1) == 0
|
||||
allow_scores = (result["config"] & 0x2) == 0
|
||||
return Server(
|
||||
|
@ -6,7 +6,7 @@ from typing_extensions import Final
|
||||
from bemani.common import Time
|
||||
from bemani.data.config import Config
|
||||
|
||||
from sqlalchemy.engine import CursorResult # type: ignore
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import scoped_session
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.types import String, Integer
|
||||
@ -82,10 +82,12 @@ class BaseData:
|
||||
includes = all(s in lowered for s in write_statement_group)
|
||||
if includes and not safe_write_operation:
|
||||
raise Exception("Read-only mode is active!")
|
||||
return self.__conn.execute(
|
||||
result = self.__conn.execute(
|
||||
text(sql),
|
||||
params if params is not None else {},
|
||||
)
|
||||
self.__conn.commit()
|
||||
return result
|
||||
|
||||
def serialize(self, data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
@ -141,7 +143,7 @@ class BaseData:
|
||||
# Couldn't find a user with this session
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return result["id"]
|
||||
|
||||
def _create_session(self, opid: int, optype: str, expiration: int = (30 * 86400)) -> str:
|
||||
|
@ -96,7 +96,7 @@ class GameData(BaseData):
|
||||
# Settings doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return ValidatedDict(self.deserialize(result["data"]))
|
||||
|
||||
def put_settings(self, game: GameConstants, userid: UserID, settings: Dict[str, Any]) -> None:
|
||||
@ -158,7 +158,7 @@ class GameData(BaseData):
|
||||
# score doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return ValidatedDict(self.deserialize(result["data"]))
|
||||
|
||||
def get_achievements(self, game: GameConstants, userid: UserID) -> List[Achievement]:
|
||||
@ -182,7 +182,7 @@ class GameData(BaseData):
|
||||
None,
|
||||
self.deserialize(result["data"]),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def put_achievement(
|
||||
@ -251,7 +251,7 @@ class GameData(BaseData):
|
||||
# setting doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
retval = ValidatedDict(self.deserialize(result["data"]))
|
||||
retval["start_time"] = result["start_time"]
|
||||
retval["end_time"] = result["end_time"]
|
||||
@ -288,7 +288,7 @@ class GameData(BaseData):
|
||||
"end_time": result["end_time"],
|
||||
}
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def put_time_sensitive_settings(
|
||||
@ -336,7 +336,7 @@ class GameData(BaseData):
|
||||
"end_time": end_time,
|
||||
},
|
||||
)
|
||||
for result in cursor:
|
||||
for result in cursor.mappings():
|
||||
if result["start_time"] == start_time and result["end_time"] == end_time:
|
||||
# This is just this event being updated, that's fine.
|
||||
continue
|
||||
@ -387,7 +387,7 @@ class GameData(BaseData):
|
||||
# entry doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return ValidatedDict(self.deserialize(result["data"]))
|
||||
|
||||
def get_items(self, game: GameConstants, version: int) -> List[Item]:
|
||||
@ -410,5 +410,5 @@ class GameData(BaseData):
|
||||
result["id"],
|
||||
self.deserialize(result["data"]),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
@ -82,7 +82,7 @@ class LobbyData(BaseData):
|
||||
# Settings doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
data = ValidatedDict(self.deserialize(result["data"]))
|
||||
data["id"] = result["id"]
|
||||
data["time"] = result["time"]
|
||||
@ -119,7 +119,7 @@ class LobbyData(BaseData):
|
||||
data["time"] = result["time"]
|
||||
return data
|
||||
|
||||
return [(UserID(result["userid"]), format_result(result)) for result in cursor]
|
||||
return [(UserID(result["userid"]), format_result(result)) for result in cursor.mappings()]
|
||||
|
||||
def put_play_session_info(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
@ -214,7 +214,7 @@ class LobbyData(BaseData):
|
||||
# Settings doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
data = ValidatedDict(self.deserialize(result["data"]))
|
||||
data["id"] = result["id"]
|
||||
data["time"] = result["time"]
|
||||
@ -252,7 +252,7 @@ class LobbyData(BaseData):
|
||||
data["time"] = result["time"]
|
||||
return data
|
||||
|
||||
return [(UserID(result["userid"]), format_result(result)) for result in cursor]
|
||||
return [(UserID(result["userid"]), format_result(result)) for result in cursor.mappings()]
|
||||
|
||||
def put_lobby(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
|
@ -102,7 +102,7 @@ class MachineData(BaseData):
|
||||
# Machine doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return result["pcbid"]
|
||||
|
||||
def from_machine_id(self, machine_id: int) -> Optional[str]:
|
||||
@ -122,7 +122,7 @@ class MachineData(BaseData):
|
||||
# Machine doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return result["pcbid"]
|
||||
|
||||
def from_userid(self, userid: UserID) -> List[ArcadeID]:
|
||||
@ -137,7 +137,7 @@ class MachineData(BaseData):
|
||||
"""
|
||||
sql = "SELECT arcadeid FROM arcade_owner WHERE userid = :userid"
|
||||
cursor = self.execute(sql, {"userid": userid})
|
||||
return [ArcadeID(result["arcadeid"]) for result in cursor]
|
||||
return [ArcadeID(result["arcadeid"]) for result in cursor.mappings()]
|
||||
|
||||
def from_session(self, session: str) -> Optional[ArcadeID]:
|
||||
"""
|
||||
@ -173,7 +173,7 @@ class MachineData(BaseData):
|
||||
# Machine doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return Machine(
|
||||
result["id"],
|
||||
pcbid,
|
||||
@ -212,7 +212,7 @@ class MachineData(BaseData):
|
||||
result["version"],
|
||||
self.deserialize(result["data"]),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def put_machine(self, machine: Machine) -> None:
|
||||
@ -279,7 +279,7 @@ class MachineData(BaseData):
|
||||
port = None
|
||||
else:
|
||||
# Grab highest port
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
port = result["port"]
|
||||
if port is not None:
|
||||
port = port + 1
|
||||
@ -383,7 +383,7 @@ class MachineData(BaseData):
|
||||
# Arcade doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
|
||||
sql = "SELECT userid FROM arcade_owner WHERE arcadeid = :id"
|
||||
cursor = self.execute(sql, {"id": arcadeid})
|
||||
@ -396,7 +396,7 @@ class MachineData(BaseData):
|
||||
result["pref"],
|
||||
result["area"] or None,
|
||||
self.deserialize(result["data"]),
|
||||
[owner["userid"] for owner in cursor],
|
||||
[owner["userid"] for owner in cursor.mappings()],
|
||||
)
|
||||
|
||||
def put_arcade(self, arcade: Arcade) -> None:
|
||||
@ -464,7 +464,7 @@ class MachineData(BaseData):
|
||||
sql = "SELECT userid, arcadeid FROM arcade_owner"
|
||||
cursor = self.execute(sql)
|
||||
arcade_to_owners: Dict[int, List[UserID]] = {}
|
||||
for row in cursor:
|
||||
for row in cursor.mappings():
|
||||
arcade = row["arcadeid"]
|
||||
owner = UserID(row["userid"])
|
||||
if arcade not in arcade_to_owners:
|
||||
@ -484,7 +484,7 @@ class MachineData(BaseData):
|
||||
self.deserialize(result["data"]),
|
||||
arcade_to_owners.get(result["id"], []),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def get_settings(
|
||||
@ -512,7 +512,7 @@ class MachineData(BaseData):
|
||||
# Settings doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return ValidatedDict(self.deserialize(result["data"]))
|
||||
|
||||
def put_settings(
|
||||
@ -566,7 +566,7 @@ class MachineData(BaseData):
|
||||
UserID(entry["userid"]),
|
||||
entry["balance"],
|
||||
)
|
||||
for entry in cursor
|
||||
for entry in cursor.mappings()
|
||||
]
|
||||
|
||||
def create_session(self, arcadeid: ArcadeID, expiration: int = (30 * 86400)) -> str:
|
||||
|
@ -107,7 +107,7 @@ class MusicData(BaseData):
|
||||
if cursor.rowcount != 1:
|
||||
# music doesn't exist
|
||||
raise Exception(f"Song {songid} chart {songchart} doesn't exist for game {game} version {version}")
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return result["id"]
|
||||
|
||||
def put_score(
|
||||
@ -296,7 +296,7 @@ class MusicData(BaseData):
|
||||
# score doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return Score(
|
||||
result["scorekey"],
|
||||
result["songid"],
|
||||
@ -356,7 +356,7 @@ class MusicData(BaseData):
|
||||
# score doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return (
|
||||
UserID(result["userid"]),
|
||||
Score(
|
||||
@ -439,7 +439,7 @@ class MusicData(BaseData):
|
||||
result["plays"],
|
||||
self.deserialize(result["data"]),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def get_most_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]:
|
||||
@ -472,7 +472,7 @@ class MusicData(BaseData):
|
||||
{"userid": userid, "game": game.value, "version": version, "count": count},
|
||||
)
|
||||
|
||||
return [(result["songid"], result["plays"]) for result in cursor]
|
||||
return [(result["songid"], result["plays"]) for result in cursor.mappings()]
|
||||
|
||||
def get_last_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
@ -504,7 +504,7 @@ class MusicData(BaseData):
|
||||
{"userid": userid, "game": game.value, "version": version, "count": count},
|
||||
)
|
||||
|
||||
return [(result["songid"], result["timestamp"]) for result in cursor]
|
||||
return [(result["songid"], result["timestamp"]) for result in cursor.mappings()]
|
||||
|
||||
def get_hit_chart(
|
||||
self,
|
||||
@ -551,7 +551,7 @@ class MusicData(BaseData):
|
||||
},
|
||||
)
|
||||
|
||||
return [(result["songid"], result["plays"]) for result in cursor]
|
||||
return [(result["songid"], result["plays"]) for result in cursor.mappings()]
|
||||
|
||||
def get_song(
|
||||
self,
|
||||
@ -597,7 +597,7 @@ class MusicData(BaseData):
|
||||
if cursor.rowcount != 1:
|
||||
# music doesn't exist
|
||||
return None
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return Song(
|
||||
game,
|
||||
version,
|
||||
@ -647,7 +647,7 @@ class MusicData(BaseData):
|
||||
result["genre"],
|
||||
self.deserialize(result["data"]),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def get_all_scores(
|
||||
@ -752,7 +752,7 @@ class MusicData(BaseData):
|
||||
self.deserialize(result["data"]),
|
||||
),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def get_all_records(
|
||||
@ -863,7 +863,7 @@ class MusicData(BaseData):
|
||||
self.deserialize(result["data"]),
|
||||
),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def get_attempt_by_key(self, game: GameConstants, version: int, key: int) -> Optional[Tuple[UserID, Attempt]]:
|
||||
@ -908,7 +908,7 @@ class MusicData(BaseData):
|
||||
# score doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return (
|
||||
UserID(result["userid"]),
|
||||
Attempt(
|
||||
@ -1019,5 +1019,5 @@ class MusicData(BaseData):
|
||||
self.deserialize(result["data"]),
|
||||
),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
@ -73,7 +73,7 @@ class NetworkData(BaseData):
|
||||
result["title"],
|
||||
result["body"],
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def create_news(self, title: str, body: str) -> int:
|
||||
@ -107,7 +107,7 @@ class NetworkData(BaseData):
|
||||
# Couldn't find an entry with this ID
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return News(
|
||||
newsid,
|
||||
result["timestamp"],
|
||||
@ -181,7 +181,7 @@ class NetworkData(BaseData):
|
||||
# No scheduled work was registered, so time to get going!
|
||||
return True
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
|
||||
if schedule == "daily":
|
||||
# Just look at the day and year, make sure it matches
|
||||
@ -317,7 +317,7 @@ class NetworkData(BaseData):
|
||||
result["type"],
|
||||
self.deserialize(result["data"]),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def delete_events(self, oldest_event_ts: int) -> None:
|
||||
|
@ -194,7 +194,7 @@ class UserData(BaseData):
|
||||
# Couldn't find a user with this card
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return UserID(result["userid"])
|
||||
|
||||
def from_username(self, username: str) -> Optional[UserID]:
|
||||
@ -213,7 +213,7 @@ class UserData(BaseData):
|
||||
# Couldn't find this username
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return UserID(result["id"])
|
||||
|
||||
def from_refid(self, game: GameConstants, version: int, refid: str) -> Optional[UserID]:
|
||||
@ -238,7 +238,7 @@ class UserData(BaseData):
|
||||
# Couldn't find a user with this refid
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return UserID(result["userid"])
|
||||
|
||||
def from_extid(self, game: GameConstants, version: int, extid: int) -> Optional[UserID]:
|
||||
@ -263,7 +263,7 @@ class UserData(BaseData):
|
||||
# Couldn't find a user with this refid
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return UserID(result["userid"])
|
||||
|
||||
def from_session(self, session: str) -> Optional[UserID]:
|
||||
@ -297,7 +297,7 @@ class UserData(BaseData):
|
||||
# User doesn't exist, but we have a reference?
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return User(userid, result["username"], result["email"], result["admin"] == 1)
|
||||
|
||||
def get_all_users(self) -> List[User]:
|
||||
@ -316,7 +316,7 @@ class UserData(BaseData):
|
||||
result["email"],
|
||||
result["admin"] == 1,
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def get_all_usernames(self) -> List[str]:
|
||||
@ -331,7 +331,7 @@ class UserData(BaseData):
|
||||
"""
|
||||
sql = "SELECT username FROM user WHERE username is not null"
|
||||
cursor = self.execute(sql)
|
||||
return [res["username"] for res in cursor]
|
||||
return [res["username"] for res in cursor.mappings()]
|
||||
|
||||
def get_all_cards(self) -> List[Tuple[str, UserID]]:
|
||||
"""
|
||||
@ -342,7 +342,7 @@ class UserData(BaseData):
|
||||
"""
|
||||
sql = "SELECT id, userid FROM card"
|
||||
cursor = self.execute(sql)
|
||||
return [(str(res["id"]).upper(), UserID(res["userid"])) for res in cursor]
|
||||
return [(str(res["id"]).upper(), UserID(res["userid"])) for res in cursor.mappings()]
|
||||
|
||||
def get_cards(self, userid: UserID) -> List[str]:
|
||||
"""
|
||||
@ -356,7 +356,7 @@ class UserData(BaseData):
|
||||
"""
|
||||
sql = "SELECT id FROM card WHERE userid = :userid"
|
||||
cursor = self.execute(sql, {"userid": userid})
|
||||
return [str(res["id"]).upper() for res in cursor]
|
||||
return [str(res["id"]).upper() for res in cursor.mappings()]
|
||||
|
||||
def add_card(self, userid: UserID, cardid: str) -> None:
|
||||
"""
|
||||
@ -437,7 +437,7 @@ class UserData(BaseData):
|
||||
# User doesn't exist, but we have a reference?
|
||||
return False
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return pin == result["pin"]
|
||||
|
||||
def update_pin(self, userid: UserID, pin: str) -> None:
|
||||
@ -468,7 +468,7 @@ class UserData(BaseData):
|
||||
# User doesn't exist, but we have a reference?
|
||||
return False
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
passhash = result["password"]
|
||||
|
||||
try:
|
||||
@ -517,7 +517,7 @@ class UserData(BaseData):
|
||||
# Profile doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return Profile(
|
||||
game,
|
||||
version,
|
||||
@ -577,7 +577,7 @@ class UserData(BaseData):
|
||||
cursor = self.execute(sql, {"game": game.value, "userids": userids})
|
||||
profilever: Dict[UserID, int] = {}
|
||||
|
||||
for result in cursor:
|
||||
for result in cursor.mappings():
|
||||
tuid = UserID(result["userid"])
|
||||
tver = result["version"]
|
||||
|
||||
@ -626,7 +626,7 @@ class UserData(BaseData):
|
||||
vals["game"] = game.value
|
||||
|
||||
cursor = self.execute(sql, vals)
|
||||
return [(GameConstants(result["game"]), result["version"]) for result in cursor]
|
||||
return [(GameConstants(result["game"]), result["version"]) for result in cursor.mappings()]
|
||||
|
||||
def get_all_profiles(self, game: GameConstants, version: int) -> List[Tuple[UserID, Profile]]:
|
||||
"""
|
||||
@ -662,7 +662,7 @@ class UserData(BaseData):
|
||||
self.deserialize(result["data"]),
|
||||
),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def get_all_players(self, game: GameConstants, version: int) -> List[UserID]:
|
||||
@ -682,7 +682,7 @@ class UserData(BaseData):
|
||||
"""
|
||||
cursor = self.execute(sql, {"game": game.value, "version": version})
|
||||
|
||||
return [UserID(result["userid"]) for result in cursor]
|
||||
return [UserID(result["userid"]) for result in cursor.mappings()]
|
||||
|
||||
def get_all_achievements(
|
||||
self,
|
||||
@ -732,7 +732,7 @@ class UserData(BaseData):
|
||||
self.deserialize(result["data"]),
|
||||
),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def put_profile(self, game: GameConstants, version: int, userid: UserID, profile: Profile) -> None:
|
||||
@ -808,7 +808,7 @@ class UserData(BaseData):
|
||||
# score doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return ValidatedDict(self.deserialize(result["data"]))
|
||||
|
||||
def get_achievements(self, game: GameConstants, version: int, userid: UserID) -> List[Achievement]:
|
||||
@ -834,7 +834,7 @@ class UserData(BaseData):
|
||||
None,
|
||||
self.deserialize(result["data"]),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def put_achievement(
|
||||
@ -945,7 +945,7 @@ class UserData(BaseData):
|
||||
result["timestamp"],
|
||||
self.deserialize(result["data"]),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def put_time_based_achievement(
|
||||
@ -1023,7 +1023,7 @@ class UserData(BaseData):
|
||||
self.deserialize(result["data"]),
|
||||
),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def get_link(
|
||||
@ -1074,7 +1074,7 @@ class UserData(BaseData):
|
||||
# score doesn't exist
|
||||
return None
|
||||
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return ValidatedDict(self.deserialize(result["data"]))
|
||||
|
||||
def get_links(self, game: GameConstants, version: int, userid: UserID) -> List[Link]:
|
||||
@ -1103,7 +1103,7 @@ class UserData(BaseData):
|
||||
UserID(result["other_userid"]),
|
||||
self.deserialize(result["data"]),
|
||||
)
|
||||
for result in cursor
|
||||
for result in cursor.mappings()
|
||||
]
|
||||
|
||||
def put_link(
|
||||
@ -1196,7 +1196,7 @@ class UserData(BaseData):
|
||||
sql = "SELECT balance FROM balance WHERE userid = :userid AND arcadeid = :arcadeid"
|
||||
cursor = self.execute(sql, {"userid": userid, "arcadeid": arcadeid})
|
||||
if cursor.rowcount == 1:
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return result["balance"]
|
||||
else:
|
||||
return 0
|
||||
@ -1242,7 +1242,7 @@ class UserData(BaseData):
|
||||
sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version"
|
||||
cursor = self.execute(sql, {"userid": userid, "game": game.value, "version": version})
|
||||
if cursor.rowcount == 1:
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return result["refid"]
|
||||
else:
|
||||
return self.create_refid(game, version, userid)
|
||||
@ -1265,7 +1265,7 @@ class UserData(BaseData):
|
||||
sql = "SELECT extid FROM extid WHERE userid = :userid AND game = :game"
|
||||
cursor = self.execute(sql, {"userid": userid, "game": game.value})
|
||||
if cursor.rowcount == 1:
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return result["extid"]
|
||||
else:
|
||||
return None
|
||||
@ -1371,7 +1371,7 @@ class UserData(BaseData):
|
||||
sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version"
|
||||
cursor = self.execute(sql, {"userid": userid, "game": game.value, "version": version})
|
||||
if cursor.rowcount == 1:
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return result["refid"]
|
||||
# Shouldn't be possible, but here we are
|
||||
raise AccountCreationException("Failed to recover lost race refid!")
|
||||
|
@ -26,6 +26,9 @@ class FakeCursor:
|
||||
self.rowcount = len(rows)
|
||||
self.pos = -1
|
||||
|
||||
def mappings(self) -> "FakeCursor":
|
||||
return self
|
||||
|
||||
def fetchone(self) -> Dict[str, Any]:
|
||||
if len(self.__rows) != 1:
|
||||
raise Exception(f"Tried to fetch one row and there are {len(self.__rows)} rows!")
|
||||
|
@ -4,14 +4,14 @@ import csv
|
||||
import argparse
|
||||
import copy
|
||||
import io
|
||||
import jaconv # type: ignore
|
||||
import jaconv
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from sqlalchemy.engine import CursorResult # type: ignore
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
@ -70,17 +70,18 @@ class ImportBase:
|
||||
self.update = update
|
||||
self.no_combine = no_combine
|
||||
self.__config = config
|
||||
self.__engine = self.__config.database.engine
|
||||
self.__sessionmanager = sessionmaker(self.__engine)
|
||||
self.__conn = self.__engine.connect()
|
||||
self.__session = self.__sessionmanager(bind=self.__conn)
|
||||
self.__batch = False
|
||||
|
||||
# Set up DB connection stuff.
|
||||
self.__engine = self.__config.database.engine
|
||||
session_factory = sessionmaker(self.__engine)
|
||||
self.__conn = scoped_session(session_factory)
|
||||
|
||||
def start_batch(self) -> None:
|
||||
self.__batch = True
|
||||
|
||||
def finish_batch(self) -> None:
|
||||
self.__session.commit()
|
||||
self.__conn.commit()
|
||||
self.__batch = False
|
||||
|
||||
def execute(self, sql: str, params: Optional[Dict[str, Any]] = None) -> CursorResult:
|
||||
@ -96,12 +97,12 @@ class ImportBase:
|
||||
]:
|
||||
if write_statement in sql.lower():
|
||||
raise Exception("Read-only mode is active!")
|
||||
return self.__session.execute(text(sql), params if params is not None else {})
|
||||
return self.__conn.execute(text(sql), params if params is not None else {})
|
||||
|
||||
def remote_music(self, server: str, token: str) -> GlobalMusicData:
|
||||
api = ReadAPI(server, token)
|
||||
user = UserData(self.__config, self.__session)
|
||||
music = MusicData(self.__config, self.__session)
|
||||
user = UserData(self.__config, self.__conn)
|
||||
music = MusicData(self.__config, self.__conn)
|
||||
return GlobalMusicData(api, user, music)
|
||||
|
||||
def remote_game(self, server: str, token: str) -> GlobalGameData:
|
||||
@ -110,7 +111,7 @@ class ImportBase:
|
||||
|
||||
def get_next_music_id(self) -> int:
|
||||
cursor = self.execute("SELECT MAX(id) AS next_id FROM `music`")
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
try:
|
||||
return result["next_id"] + 1
|
||||
except TypeError:
|
||||
@ -138,7 +139,7 @@ class ImportBase:
|
||||
},
|
||||
)
|
||||
if cursor.rowcount != 0:
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return result["id"]
|
||||
else:
|
||||
return None
|
||||
@ -183,7 +184,7 @@ class ImportBase:
|
||||
},
|
||||
)
|
||||
if cursor.rowcount != 0:
|
||||
result = cursor.fetchone()
|
||||
result = cursor.mappings().fetchone() # type: ignore
|
||||
return result["id"]
|
||||
else:
|
||||
return None
|
||||
@ -369,8 +370,6 @@ class ImportBase:
|
||||
# Make sure we don't leak connections after finising insertion.
|
||||
if self.__batch:
|
||||
raise Exception("Logic error, opened a batch without closing!")
|
||||
if self.__session is not None:
|
||||
self.__session.close()
|
||||
if self.__conn is not None:
|
||||
self.__conn.close()
|
||||
self.__conn = None
|
||||
|
@ -1,6 +1,6 @@
|
||||
setuptools
|
||||
Cython
|
||||
SQLAlchemy<2.0.0
|
||||
SQLAlchemy
|
||||
alembic
|
||||
PyYAML
|
||||
Flask
|
||||
|
Loading…
Reference in New Issue
Block a user