diff --git a/bemani/data/mysql/api.py b/bemani/data/mysql/api.py index 07c140a..7444b0c 100644 --- a/bemani/data/mysql/api.py +++ b/bemani/data/mysql/api.py @@ -55,7 +55,7 @@ class APIData(APIProviderInterface, BaseData): result["name"], result["token"], ) - for result in cursor.fetchall() + for result in cursor ] def validate_client(self, token: str) -> bool: @@ -159,7 +159,7 @@ class APIData(APIProviderInterface, BaseData): sql = "SELECT id, timestamp, uri, token, config FROM server ORDER BY timestamp ASC" cursor = self.execute(sql) - return [format_result(result) for result in cursor.fetchall()] + return [format_result(result) for result in cursor] def create_server(self, uri: str, token: str) -> int: """ diff --git a/bemani/data/mysql/game.py b/bemani/data/mysql/game.py index 316c11c..2e6b76e 100644 --- a/bemani/data/mysql/game.py +++ b/bemani/data/mysql/game.py @@ -183,18 +183,15 @@ class GameData(BaseData): sql = "SELECT id, type, data FROM series_achievement WHERE game = :game AND userid = :userid" cursor = self.execute(sql, {"game": game.value, "userid": userid}) - achievements = [] - for result in cursor.fetchall(): - achievements.append( - Achievement( - result["id"], - result["type"], - None, - self.deserialize(result["data"]), - ) + return [ + Achievement( + result["id"], + result["type"], + None, + self.deserialize(result["data"]), ) - - return achievements + for result in cursor + ] def put_achievement( self, @@ -292,13 +289,16 @@ class GameData(BaseData): # setting doesn't exist return [] - settings = [] - for result in cursor.fetchall(): - retval = ValidatedDict(self.deserialize(result["data"])) - retval["start_time"] = result["start_time"] - retval["end_time"] = result["end_time"] - settings.append(retval) - return settings + return [ + ValidatedDict( + { + **self.deserialize(result["data"]), + "start_time": result["start_time"], + "end_time": result["end_time"], + } + ) + for result in cursor + ] def put_time_sensitive_settings( self, game: GameConstants, version: int, name: str, settings: Dict[str, Any] @@ -345,7 +345,7 @@ class GameData(BaseData): "end_time": end_time, }, ) - for result in cursor.fetchall(): + for result in cursor: if result["start_time"] == start_time and result["end_time"] == end_time: # This is just this event being updated, that's fine. continue @@ -417,14 +417,11 @@ class GameData(BaseData): sql = "SELECT id, type, data FROM catalog WHERE game = :game AND version = :version" cursor = self.execute(sql, {"game": game.value, "version": version}) - catalog = [] - for result in cursor.fetchall(): - catalog.append( - Item( - result["type"], - result["id"], - self.deserialize(result["data"]), - ) + return [ + Item( + result["type"], + result["id"], + self.deserialize(result["data"]), ) - - return catalog + for result in cursor + ] diff --git a/bemani/data/mysql/lobby.py b/bemani/data/mysql/lobby.py index cd74c39..6b49a21 100644 --- a/bemani/data/mysql/lobby.py +++ b/bemani/data/mysql/lobby.py @@ -115,13 +115,13 @@ class LobbyData(BaseData): }, ) - ret = [] - for result in cursor.fetchall(): + def format_result(result: Dict[str, Any]) -> ValidatedDict: data = ValidatedDict(self.deserialize(result["data"])) data["id"] = result["id"] data["time"] = result["time"] - ret.append((UserID(result["userid"]), data)) - return ret + return data + + return [(UserID(result["userid"]), format_result(result)) for result in cursor] def put_play_session_info( self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any] @@ -251,13 +251,13 @@ class LobbyData(BaseData): }, ) - ret = [] - for result in cursor.fetchall(): + def format_result(result: Dict[str, Any]) -> ValidatedDict: data = ValidatedDict(self.deserialize(result["data"])) data["id"] = result["id"] data["time"] = result["time"] - ret.append((UserID(result["userid"]), data)) - return ret + return data + + return [(UserID(result["userid"]), format_result(result)) for result in cursor] def put_lobby( self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any] diff --git a/bemani/data/mysql/machine.py b/bemani/data/mysql/machine.py index b9c0d0c..88300df 100644 --- a/bemani/data/mysql/machine.py +++ b/bemani/data/mysql/machine.py @@ -139,7 +139,7 @@ class MachineData(BaseData): """ sql = "SELECT arcadeid FROM arcade_owner WHERE userid = :userid" cursor = self.execute(sql, {"userid": userid}) - return [ArcadeID(result["arcadeid"]) for result in cursor.fetchall()] + return [ArcadeID(result["arcadeid"]) for result in cursor] def from_session(self, session: str) -> Optional[ArcadeID]: """ @@ -211,7 +211,7 @@ class MachineData(BaseData): result["version"], self.deserialize(result["data"]), ) - for result in cursor.fetchall() + for result in cursor ] def put_machine(self, machine: Machine) -> None: @@ -385,7 +385,7 @@ class MachineData(BaseData): result["pref"], result["area"] or None, self.deserialize(result["data"]), - [owner["userid"] for owner in cursor.fetchall()], + [owner["userid"] for owner in cursor], ) def put_arcade(self, arcade: Arcade) -> None: @@ -446,7 +446,7 @@ class MachineData(BaseData): sql = "SELECT userid, arcadeid FROM arcade_owner" cursor = self.execute(sql) arcade_to_owners: Dict[int, List[UserID]] = {} - for row in cursor.fetchall(): + for row in cursor: arcade = row["arcadeid"] owner = UserID(row["userid"]) if arcade not in arcade_to_owners: @@ -466,7 +466,7 @@ class MachineData(BaseData): self.deserialize(result["data"]), arcade_to_owners.get(result["id"], []), ) - for result in cursor.fetchall() + for result in cursor ] def get_settings( @@ -543,15 +543,13 @@ class MachineData(BaseData): """ sql = "SELECT userid, balance FROM balance WHERE arcadeid = :arcadeid" cursor = self.execute(sql, {"arcadeid": arcadeid}) - balances = [] - for entry in cursor.fetchall(): - balances.append( - ( - UserID(entry["userid"]), - entry["balance"], - ) + return [ + ( + UserID(entry["userid"]), + entry["balance"], ) - return balances + for entry in cursor + ] def create_session(self, arcadeid: ArcadeID, expiration: int = (30 * 86400)) -> str: """ diff --git a/bemani/data/mysql/music.py b/bemani/data/mysql/music.py index c3850a7..0d4dd82 100644 --- a/bemani/data/mysql/music.py +++ b/bemani/data/mysql/music.py @@ -378,23 +378,20 @@ class MusicData(BaseData): }, ) - scores = [] - for result in cursor.fetchall(): - scores.append( - Score( - result["scorekey"], - result["songid"], - result["chart"], - result["points"], - result["timestamp"], - result["update"], - result["lid"], - result["plays"], - self.deserialize(result["data"]), - ) + return [ + Score( + result["scorekey"], + result["songid"], + result["chart"], + result["points"], + result["timestamp"], + result["update"], + result["lid"], + result["plays"], + self.deserialize(result["data"]), ) - - return scores + for result in cursor + ] def get_most_played( self, game: GameConstants, version: int, userid: UserID, count: int @@ -422,11 +419,7 @@ class MusicData(BaseData): {"userid": userid, "game": game.value, "version": version, "count": count}, ) - most_played = [] - for result in cursor.fetchall(): - most_played.append((result["songid"], result["plays"])) - - return most_played + return [(result["songid"], result["plays"]) for result in cursor] def get_last_played( self, game: GameConstants, version: int, userid: UserID, count: int @@ -454,11 +447,7 @@ class MusicData(BaseData): {"userid": userid, "game": game.value, "version": version, "count": count}, ) - last_played = [] - for result in cursor.fetchall(): - last_played.append((result["songid"], result["timestamp"])) - - return last_played + return [(result["songid"], result["timestamp"]) for result in cursor] def get_hit_chart( self, @@ -499,11 +488,7 @@ class MusicData(BaseData): }, ) - most_played = [] - for result in cursor.fetchall(): - most_played.append((result["songid"], result["plays"])) - - return most_played + return [(result["songid"], result["plays"]) for result in cursor] def get_song( self, @@ -580,22 +565,19 @@ class MusicData(BaseData): sql += " ORDER BY music.version DESC" cursor = self.execute(sql, params) - all_songs = [] - for result in cursor.fetchall(): - all_songs.append( - Song( - game, - result["version"], - result["songid"], - result["chart"], - result["name"], - result["artist"], - result["genre"], - self.deserialize(result["data"]), - ) + return [ + Song( + game, + result["version"], + result["songid"], + result["chart"], + result["name"], + result["artist"], + result["genre"], + self.deserialize(result["data"]), ) - - return all_songs + for result in cursor + ] def get_all_versions_of_song( self, @@ -625,21 +607,20 @@ class MusicData(BaseData): if interested_versions is not None: sql += f" AND music.version in ({','.join(str(int(v)) for v in interested_versions)})" cursor = self.execute(sql, {"musicid": musicid}) - all_songs = [] - for result in cursor.fetchall(): - all_songs.append( - Song( - game, - result["version"], - result["songid"], - result["chart"], - result["name"], - result["artist"], - result["genre"], - self.deserialize(result["data"]), - ) + + return [ + Song( + game, + result["version"], + result["songid"], + result["chart"], + result["name"], + result["artist"], + result["genre"], + self.deserialize(result["data"]), ) - return all_songs + for result in cursor + ] def get_all_scores( self, @@ -710,26 +691,23 @@ class MusicData(BaseData): ) # Objectify result - scores = [] - for result in cursor.fetchall(): - scores.append( - ( - UserID(result["userid"]), - Score( - result["scorekey"], - result["songid"], - result["chart"], - result["points"], - result["timestamp"], - result["update"], - result["lid"], - result["plays"], - self.deserialize(result["data"]), - ), - ) + return [ + ( + UserID(result["userid"]), + Score( + result["scorekey"], + result["songid"], + result["chart"], + result["points"], + result["timestamp"], + result["update"], + result["lid"], + result["plays"], + self.deserialize(result["data"]), + ), ) - - return scores + for result in cursor + ] def get_all_records( self, @@ -801,26 +779,23 @@ class MusicData(BaseData): ).format(songidquery, chartquery, records_sql) cursor = self.execute(sql, params) - scores = [] - for result in cursor.fetchall(): - scores.append( - ( - UserID(result["userid"]), - Score( - result["scorekey"], - result["songid"], - result["chart"], - result["points"], - result["timestamp"], - result["update"], - result["lid"], - result["plays"], - self.deserialize(result["data"]), - ), - ) + return [ + ( + UserID(result["userid"]), + Score( + result["scorekey"], + result["songid"], + result["chart"], + result["points"], + result["timestamp"], + result["update"], + result["lid"], + result["plays"], + self.deserialize(result["data"]), + ), ) - - return scores + for result in cursor + ] def get_attempt_by_key( self, game: GameConstants, version: int, key: int @@ -939,22 +914,19 @@ class MusicData(BaseData): ) # Now objectify the attempts - attempts = [] - for result in cursor.fetchall(): - attempts.append( - ( - UserID(result["userid"]) if result["userid"] > 0 else None, - Attempt( - result["scorekey"], - result["songid"], - result["chart"], - result["points"], - result["timestamp"], - result["lid"], - True if result["new_record"] == 1 else False, - self.deserialize(result["data"]), - ), - ) + return [ + ( + UserID(result["userid"]) if result["userid"] > 0 else None, + Attempt( + result["scorekey"], + result["songid"], + result["chart"], + result["points"], + result["timestamp"], + result["lid"], + True if result["new_record"] == 1 else False, + self.deserialize(result["data"]), + ), ) - - return attempts + for result in cursor + ] diff --git a/bemani/data/mysql/network.py b/bemani/data/mysql/network.py index 5c037ac..79b88d1 100644 --- a/bemani/data/mysql/network.py +++ b/bemani/data/mysql/network.py @@ -75,7 +75,7 @@ class NetworkData(BaseData): result["title"], result["body"], ) - for result in cursor.fetchall() + for result in cursor ] def create_news(self, title: str, body: str) -> int: @@ -318,27 +318,20 @@ class NetworkData(BaseData): "until_id": until_id, }, ) - events = [] - for result in cursor.fetchall(): - if result["userid"] is not None: - userid = UserID(result["userid"]) - else: - userid = None - if result["arcadeid"] is not None: - arcadeid = ArcadeID(result["arcadeid"]) - else: - arcadeid = None - events.append( - Event( - result["id"], - result["timestamp"], - userid, - arcadeid, - result["type"], - self.deserialize(result["data"]), - ), + + return [ + Event( + result["id"], + result["timestamp"], + UserID(result["userid"]) if result["userid"] is not None else None, + ArcadeID(result["arcadeid"]) + if result["arcadeid"] is not None + else None, + result["type"], + self.deserialize(result["data"]), ) - return events + for result in cursor + ] def delete_events(self, oldest_event_ts: int) -> None: """ diff --git a/bemani/data/mysql/user.py b/bemani/data/mysql/user.py index 2c1373a..87e420c 100644 --- a/bemani/data/mysql/user.py +++ b/bemani/data/mysql/user.py @@ -324,7 +324,7 @@ class UserData(BaseData): result["email"], result["admin"] == 1, ) - for result in cursor.fetchall() + for result in cursor ] def get_all_usernames(self) -> List[str]: @@ -339,7 +339,7 @@ class UserData(BaseData): """ sql = "SELECT username FROM user WHERE username is not null" cursor = self.execute(sql) - return [res["username"] for res in cursor.fetchall()] + return [res["username"] for res in cursor] def get_all_cards(self) -> List[Tuple[str, UserID]]: """ @@ -350,9 +350,7 @@ class UserData(BaseData): """ sql = "SELECT id, userid FROM card" cursor = self.execute(sql) - return [ - (str(res["id"]).upper(), UserID(res["userid"])) for res in cursor.fetchall() - ] + return [(str(res["id"]).upper(), UserID(res["userid"])) for res in cursor] def get_cards(self, userid: UserID) -> List[str]: """ @@ -366,7 +364,7 @@ class UserData(BaseData): """ sql = "SELECT id FROM card WHERE userid = :userid" cursor = self.execute(sql, {"userid": userid}) - return [str(res["id"]).upper() for res in cursor.fetchall()] + return [str(res["id"]).upper() for res in cursor] def add_card(self, userid: UserID, cardid: str) -> None: """ @@ -581,7 +579,7 @@ class UserData(BaseData): cursor = self.execute(sql, {"game": game.value, "userids": userids}) profilever: Dict[UserID, int] = {} - for result in cursor.fetchall(): + for result in cursor: tuid = UserID(result["userid"]) tver = result["version"] @@ -627,10 +625,7 @@ class UserData(BaseData): vals["game"] = game.value cursor = self.execute(sql, vals) - profiles = [] - for result in cursor.fetchall(): - profiles.append((GameConstants(result["game"]), result["version"])) - return profiles + return [(GameConstants(result["game"]), result["version"]) for result in cursor] def get_all_profiles( self, game: GameConstants, version: int @@ -653,22 +648,19 @@ class UserData(BaseData): ) cursor = self.execute(sql, {"game": game.value, "version": version}) - profiles = [] - for result in cursor.fetchall(): - profiles.append( - ( - UserID(result["userid"]), - Profile( - game, - version, - result["refid"], - result["extid"], - self.deserialize(result["data"]), - ), - ) + return [ + ( + UserID(result["userid"]), + Profile( + game, + version, + result["refid"], + result["extid"], + self.deserialize(result["data"]), + ), ) - - return profiles + for result in cursor + ] def get_all_players(self, game: GameConstants, version: int) -> List[UserID]: """ @@ -687,7 +679,7 @@ class UserData(BaseData): ) cursor = self.execute(sql, {"game": game.value, "version": version}) - return [UserID(result["userid"]) for result in cursor.fetchall()] + return [UserID(result["userid"]) for result in cursor] def get_all_achievements( self, @@ -720,21 +712,18 @@ class UserData(BaseData): params["id"] = achievementid cursor = self.execute(sql, params) - achievements = [] - for result in cursor.fetchall(): - achievements.append( - ( - UserID(result["userid"]), - Achievement( - result["id"], - result["type"], - None, - self.deserialize(result["data"]), - ), - ) + return [ + ( + UserID(result["userid"]), + Achievement( + result["id"], + result["type"], + None, + self.deserialize(result["data"]), + ), ) - - return achievements + for result in cursor + ] def put_profile( self, game: GameConstants, version: int, userid: UserID, profile: Profile @@ -834,18 +823,15 @@ class UserData(BaseData): sql = "SELECT id, type, data FROM achievement WHERE refid = :refid" cursor = self.execute(sql, {"refid": refid}) - achievements = [] - for result in cursor.fetchall(): - achievements.append( - Achievement( - result["id"], - result["type"], - None, - self.deserialize(result["data"]), - ) + return [ + Achievement( + result["id"], + result["type"], + None, + self.deserialize(result["data"]), ) - - return achievements + for result in cursor + ] def put_achievement( self, @@ -949,18 +935,15 @@ class UserData(BaseData): {"refid": refid, "type": achievementtype, "since": since, "until": until}, ) - achievements = [] - for result in cursor.fetchall(): - achievements.append( - Achievement( - result["id"], - result["type"], - result["timestamp"], - self.deserialize(result["data"]), - ) + return [ + Achievement( + result["id"], + result["type"], + result["timestamp"], + self.deserialize(result["data"]), ) - - return achievements + for result in cursor + ] def put_time_based_achievement( self, @@ -1022,21 +1005,18 @@ class UserData(BaseData): ) cursor = self.execute(sql, {"game": game.value, "version": version}) - achievements = [] - for result in cursor.fetchall(): - achievements.append( - ( - UserID(result["userid"]), - Achievement( - result["id"], - result["type"], - result["timestamp"], - self.deserialize(result["data"]), - ), - ) + return [ + ( + UserID(result["userid"]), + Achievement( + result["id"], + result["type"], + result["timestamp"], + self.deserialize(result["data"]), + ), ) - - return achievements + for result in cursor + ] def get_link( self, @@ -1099,18 +1079,15 @@ class UserData(BaseData): sql, {"game": game.value, "version": version, "userid": userid} ) - links = [] - for result in cursor.fetchall(): - links.append( - Link( - userid, - result["type"], - UserID(result["other_userid"]), - self.deserialize(result["data"]), - ) + return [ + Link( + userid, + result["type"], + UserID(result["other_userid"]), + self.deserialize(result["data"]), ) - - return links + for result in cursor + ] def put_link( self, diff --git a/bemani/tests/helpers.py b/bemani/tests/helpers.py index 9f3f313..f168eac 100644 --- a/bemani/tests/helpers.py +++ b/bemani/tests/helpers.py @@ -24,6 +24,7 @@ class FakeCursor: def __init__(self, rows: List[Dict[str, Any]]) -> None: self.__rows = rows self.rowcount = len(rows) + self.pos = -1 def fetchone(self) -> Dict[str, Any]: if len(self.__rows) != 1: @@ -32,8 +33,16 @@ class FakeCursor: ) return self.__rows[0] - def fetchall(self) -> List[Dict[str, Any]]: - return self.__rows + def __iter__(self) -> "FakeCursor": + self.pos = -1 + return self + + def __next__(self) -> Dict[str, Any]: + self.pos += 1 + if self.pos < self.rowcount: + return self.__rows[self.pos] + else: + raise StopIteration def get_fixture(name: str) -> bytes: