diff --git a/bemani/api/objects/base.py b/bemani/api/objects/base.py index 09c71e2..75e546a 100644 --- a/bemani/api/objects/base.py +++ b/bemani/api/objects/base.py @@ -1,6 +1,7 @@ from typing import List, Any, Dict from bemani.api.exceptions import APIException +from bemani.common import GameConstants from bemani.data import Data @@ -13,7 +14,7 @@ class BaseObject: various fetch versions. """ - def __init__(self, data: Data, game: str, version: int, omnimix: bool) -> None: + def __init__(self, data: Data, game: GameConstants, version: int, omnimix: bool) -> None: self.data = data self.game = game self.version = version diff --git a/bemani/backend/base.py b/bemani/backend/base.py index 05d445a..c27ebd1 100644 --- a/bemani/backend/base.py +++ b/bemani/backend/base.py @@ -1,7 +1,8 @@ +from abc import ABC import traceback from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type -from bemani.common import Model, ValidatedDict, Time +from bemani.common import Model, ValidatedDict, GameConstants, Time from bemani.data import Data, UserID, RemoteUser @@ -65,7 +66,7 @@ class Factory: data.local.network.put_event(event[0], event[1]) @classmethod - def all_games(cls) -> Iterator[Tuple[str, int, str]]: + def all_games(cls) -> Iterator[Tuple[GameConstants, int, str]]: """ Given a particular factory, iterate over all game, version combinations. Useful for loading things from the DB without wanting to hardcode values. @@ -74,7 +75,7 @@ class Factory: yield (game.game, game.version, game.name) @classmethod - def all_settings(cls) -> Iterator[Tuple[str, int, Dict[str, Any]]]: + def all_settings(cls) -> Iterator[Tuple[GameConstants, int, Dict[str, Any]]]: """ Given a particular factory, iterate over all game, version combinations that have settings and return those settings. @@ -104,7 +105,7 @@ class Factory: raise Exception('Override this in subclass!') -class Base: +class Base(ABC): """ The base class every game class inherits from. Incudes handlers for card management, PASELI, most non-game startup packets, and simple code for loading/storing profiles. @@ -116,17 +117,17 @@ class Base: """ Override this in your subclass. """ - game = 'dummy' + game: GameConstants """ Override this in your subclass. """ - version = 0 + version: int """ Override this in your subclass. """ - name = 'dummy' + name: str def __init__(self, data: Data, config: Dict[str, Any], model: Model) -> None: self.data = data @@ -192,7 +193,7 @@ class Base: return {} @classmethod - def all_games(cls) -> Iterator[Tuple[str, int, str]]: + def all_games(cls) -> Iterator[Tuple[GameConstants, int, str]]: """ Given all registered factories, iterate over all game, version combinations. Useful for loading things from the DB without wanting to hardcode values. @@ -202,7 +203,7 @@ class Base: yield (game.game, game.version, game.name) @classmethod - def all_settings(cls) -> Iterator[Tuple[str, int, Dict[str, Any]]]: + def all_settings(cls) -> Iterator[Tuple[GameConstants, int, Dict[str, Any]]]: """ Given all registered factories, iterate over all game, version combinations that have settings and return those settings. diff --git a/bemani/common/constants.py b/bemani/common/constants.py index 2afdad6..1a55f74 100644 --- a/bemani/common/constants.py +++ b/bemani/common/constants.py @@ -2,7 +2,13 @@ from enum import Enum from typing_extensions import Final -class GameConstants: +class GameConstants(Enum): + """ + Constants that identify a game series. These are used in the code as enums + in order to catch type errors and prevent raw strings being introduced for + game series. They re also used verbatum in MySQL, so any column named 'game' + in any of the tables should only contain one of the following strings. + """ BISHI_BASHI: Final[str] = 'bishi' DANCE_EVOLUTION: Final[str] = 'danevo' DDR: Final[str] = 'ddr' @@ -15,6 +21,11 @@ class GameConstants: class VersionConstants: + """ + Constants used to centralize game versions. These are not enumerations + since there are multiple keys with the same value. However, all database + column named 'version' should contain only values found here. + """ BISHI_BASHI_TSBB: Final[int] = 1 DDR_1STMIX: Final[int] = 1 @@ -122,6 +133,11 @@ class VersionConstants: class APIConstants: + """ + The four types of IDs found in a BEMAPI request or response. + + TODO: These should be an enum. + """ ID_TYPE_SERVER: Final[str] = 'server' ID_TYPE_CARD: Final[str] = 'card' ID_TYPE_SONG: Final[str] = 'song' @@ -129,6 +145,10 @@ class APIConstants: class DBConstants: + """ + Constants found in the DB relating to clear lamps, halos, grades, and the like. + """ + # When adding new game series, I try to make sure that constants # go in order, and have a difference of 100 between them. This is # so I can promote lamps/scores/etc by using a simple "max", while @@ -261,6 +281,12 @@ class DBConstants: class BroadcastConstants(Enum): + """ + Enum representing the various sections of a broadcast trigger. These come + into play when a new score is earned and there is a trigger such as a + discord webhook that needs to be notified. + """ + # Sections related to the player/song/etc. DJ_NAME: Final[str] = 'DJ Name' SONG_NAME: Final[str] = 'Song' diff --git a/bemani/data/api/client.py b/bemani/data/api/client.py index 2328875..41bac8b 100644 --- a/bemani/data/api/client.py +++ b/bemani/data/api/client.py @@ -107,7 +107,7 @@ class APIClient: raise UnsupportedVersionAPIException('The server does not support this version of the API!') raise APIException('The server returned an invalid status code {}!', format(r.status_code)) - def __translate(self, game: str, version: int) -> Tuple[str, str]: + def __translate(self, game: GameConstants, version: int) -> Tuple[str, str]: servergame = { GameConstants.DDR: 'ddr', GameConstants.IIDX: 'iidx', @@ -194,7 +194,7 @@ class APIClient: 'versions': resp['versions'], }) - def get_profiles(self, game: str, version: int, idtype: str, ids: List[str]) -> List[Dict[str, Any]]: + def get_profiles(self, game: GameConstants, version: int, idtype: str, ids: List[str]) -> List[Dict[str, Any]]: # Allow remote servers to be disabled if not self.allow_scores: return [] @@ -216,7 +216,7 @@ class APIClient: def get_records( self, - game: str, + game: GameConstants, version: int, idtype: str, ids: List[str], @@ -247,7 +247,7 @@ class APIClient: # Couldn't talk to server, assume empty records return [] - def get_statistics(self, game: str, version: int, idtype: str, ids: List[str]) -> List[Dict[str, Any]]: + def get_statistics(self, game: GameConstants, version: int, idtype: str, ids: List[str]) -> List[Dict[str, Any]]: # Allow remote servers to be disabled if not self.allow_stats: return [] @@ -267,7 +267,7 @@ class APIClient: # Couldn't talk to server, assume empty statistics return [] - def get_catalog(self, game: str, version: int) -> Dict[str, List[Dict[str, Any]]]: + def get_catalog(self, game: GameConstants, version: int) -> Dict[str, List[Dict[str, Any]]]: # No point disallowing this, since its only ever used for bootstrapping. try: diff --git a/bemani/data/api/game.py b/bemani/data/api/game.py index 533e556..b187d18 100644 --- a/bemani/data/api/game.py +++ b/bemani/data/api/game.py @@ -49,12 +49,12 @@ class GlobalGameData(BaseGlobalData): } ) - def get_items(self, game: str, version: int) -> List[Item]: + def get_items(self, game: GameConstants, version: int) -> List[Item]: """ Given a game/userid, find all items in the catalog. Parameters: - game - String identifier of the game looking up the catalog. + game - Enum value identifier of the game looking up the catalog. version - Integer identifier of the version looking up this catalog. Returns: @@ -103,7 +103,7 @@ class GlobalGameData(BaseGlobalData): seen.add(key) return retval - def get_item(self, game: str, version: int, catid: int, cattype: str) -> Optional[ValidatedDict]: + def get_item(self, game: GameConstants, version: int, catid: int, cattype: str) -> Optional[ValidatedDict]: """ Given a game/userid and catalog id/type, find that catalog entry. @@ -111,7 +111,7 @@ class GlobalGameData(BaseGlobalData): as long as each one is a different type. Essentially, cattype namespaces catalog entry. Parameters: - game - String identifier of the game looking up this entry. + game - Enum value identifier of the game looking up this entry. version - Integer identifier of the version looking up this entry. catid - Integer ID, as provided by a game. cattype - The type of catalog entry. diff --git a/bemani/data/api/music.py b/bemani/data/api/music.py index 496fa1c..15e5ca2 100644 --- a/bemani/data/api/music.py +++ b/bemani/data/api/music.py @@ -289,7 +289,7 @@ class GlobalMusicData(BaseGlobalData): }, ) - def __format_score(self, game: str, version: int, songid: int, songchart: int, data: Dict[str, Any]) -> Optional[Score]: + def __format_score(self, game: GameConstants, version: int, songid: int, songchart: int, data: Dict[str, Any]) -> Optional[Score]: if game == GameConstants.DDR: return self.__format_ddr_score(version, songid, songchart, data) if game == GameConstants.IIDX: @@ -432,7 +432,7 @@ class GlobalMusicData(BaseGlobalData): }, ) - def __merge_score(self, game: str, version: int, oldscore: Score, newscore: Score) -> Score: + def __merge_score(self, game: GameConstants, version: int, oldscore: Score, newscore: Score) -> Score: if oldscore.id != newscore.id or oldscore.chart != newscore.chart: raise Exception('Logic error! Tried to merge scores from different song/charts!') @@ -453,7 +453,7 @@ class GlobalMusicData(BaseGlobalData): return oldscore - def get_score(self, game: str, version: int, userid: UserID, songid: int, songchart: int) -> Optional[Score]: + def get_score(self, game: GameConstants, version: int, userid: UserID, songid: int, songchart: int) -> Optional[Score]: # Helper function so we can iterate over all servers for a single card def get_scores_for_card(cardid: str) -> List[Score]: return Parallel.flatten(Parallel.call( @@ -502,7 +502,7 @@ class GlobalMusicData(BaseGlobalData): def get_scores( self, - game: str, + game: GameConstants, version: int, userid: UserID, since: Optional[int]=None, @@ -571,7 +571,7 @@ class GlobalMusicData(BaseGlobalData): def __merge_global_scores( self, - game: str, + game: GameConstants, version: int, localcards: List[Tuple[str, UserID]], localscores: List[Tuple[UserID, Score]], @@ -629,7 +629,7 @@ class GlobalMusicData(BaseGlobalData): def get_all_scores( self, - game: str, + game: GameConstants, version: Optional[int]=None, userid: Optional[UserID]=None, songid: Optional[int]=None, @@ -670,7 +670,7 @@ class GlobalMusicData(BaseGlobalData): def __merge_global_records( self, - game: str, + game: GameConstants, version: int, localcards: List[Tuple[str, UserID]], localscores: List[Tuple[UserID, Score]], @@ -731,7 +731,7 @@ class GlobalMusicData(BaseGlobalData): def get_all_records( self, - game: str, + game: GameConstants, version: Optional[int]=None, userlist: Optional[List[UserID]]=None, locationlist: Optional[List[int]]=None, @@ -761,7 +761,7 @@ class GlobalMusicData(BaseGlobalData): def get_clear_rates( self, - game: str, + game: GameConstants, version: int, songid: Optional[int]=None, songchart: Optional[int]=None, @@ -1029,7 +1029,7 @@ class GlobalMusicData(BaseGlobalData): def __format_song( self, - game: str, + game: GameConstants, version: int, songid: int, songchart: int, @@ -1056,14 +1056,14 @@ class GlobalMusicData(BaseGlobalData): def get_all_songs( self, - game: str, + game: GameConstants, version: Optional[int]=None, ) -> List[Song]: """ Given a game and a version, look up all song/chart combos associated with that game. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. Returns: diff --git a/bemani/data/api/user.py b/bemani/data/api/user.py index a25b2b2..b085c4d 100644 --- a/bemani/data/api/user.py +++ b/bemani/data/api/user.py @@ -88,24 +88,25 @@ class GlobalUserData(BaseGlobalData): 'extid': profile['extid'], } - if profile.get('game') == GameConstants.DDR: + profilegame = GameConstants(profile['game']) + if profilegame == GameConstants.DDR: base.update(self.__format_ddr_profile(profile)) - if profile.get('game') == GameConstants.IIDX: + if profilegame == GameConstants.IIDX: base.update(self.__format_iidx_profile(profile)) - if profile.get('game') == GameConstants.JUBEAT: + if profilegame == GameConstants.JUBEAT: base.update(self.__format_jubeat_profile(profile)) - if profile.get('game') == GameConstants.MUSECA: + if profilegame == GameConstants.MUSECA: base.update(self.__format_museca_profile(profile)) - if profile.get('game') == GameConstants.POPN_MUSIC: + if profilegame == GameConstants.POPN_MUSIC: base.update(self.__format_popn_profile(profile)) - if profile.get('game') == GameConstants.REFLEC_BEAT: + if profilegame == GameConstants.REFLEC_BEAT: base.update(self.__format_reflec_profile(profile)) - if profile.get('game') == GameConstants.SDVX: + if profilegame == GameConstants.SDVX: base.update(self.__format_sdvx_profile(profile)) return ValidatedDict(base) - def __profile_request(self, game: str, version: int, userid: UserID, exact: bool) -> Optional[ValidatedDict]: + def __profile_request(self, game: GameConstants, version: int, userid: UserID, exact: bool) -> Optional[ValidatedDict]: # First, get or create the extid/refid for this virtual user cardid = RemoteUser.userid_to_card(userid) refid = self.user.get_refid(game, version, userid) @@ -134,7 +135,7 @@ class GlobalUserData(BaseGlobalData): del profile['match'] # Add in our defaults we always provide - profile['game'] = game + profile['game'] = game.value profile['version'] = version if exact_match else 0 profile['refid'] = refid profile['extid'] = extid @@ -149,25 +150,25 @@ class GlobalUserData(BaseGlobalData): userid = RemoteUser.card_to_userid(cardid) return userid - def from_refid(self, game: str, version: int, refid: str) -> Optional[UserID]: + def from_refid(self, game: GameConstants, version: int, refid: str) -> Optional[UserID]: return self.user.from_refid(game, version, refid) - def from_extid(self, game: str, version: int, extid: int) -> Optional[UserID]: + def from_extid(self, game: GameConstants, version: int, extid: int) -> Optional[UserID]: return self.user.from_extid(game, version, extid) - def get_profile(self, game: str, version: int, userid: UserID) -> Optional[ValidatedDict]: + def get_profile(self, game: GameConstants, version: int, userid: UserID) -> Optional[ValidatedDict]: if RemoteUser.is_remote(userid): return self.__profile_request(game, version, userid, exact=True) else: return self.user.get_profile(game, version, userid) - def get_any_profile(self, game: str, version: int, userid: UserID) -> Optional[ValidatedDict]: + def get_any_profile(self, game: GameConstants, version: int, userid: UserID) -> Optional[ValidatedDict]: if RemoteUser.is_remote(userid): return self.__profile_request(game, version, userid, exact=False) else: return self.user.get_any_profile(game, version, userid) - def get_any_profiles(self, game: str, version: int, userids: List[UserID]) -> List[Tuple[UserID, Optional[ValidatedDict]]]: + def get_any_profiles(self, game: GameConstants, version: int, userids: List[UserID]) -> List[Tuple[UserID, Optional[ValidatedDict]]]: if len(userids) == 0: return [] @@ -223,7 +224,7 @@ class GlobalUserData(BaseGlobalData): extid = self.user.get_extid(game, version, userid) # Add in our defaults we always provide - profile['game'] = game + profile['game'] = game.value profile['version'] = version if exact_match else 0 profile['refid'] = refid profile['extid'] = extid @@ -241,7 +242,7 @@ class GlobalUserData(BaseGlobalData): return local_profiles - def get_all_profiles(self, game: str, version: int) -> List[Tuple[UserID, ValidatedDict]]: + def get_all_profiles(self, game: GameConstants, version: int) -> List[Tuple[UserID, ValidatedDict]]: # Fetch local and remote profiles, and then merge by adding remote profiles to local # profiles when we don't have a profile for that user ID yet. local_cards, local_profiles, remote_profiles = Parallel.execute([ @@ -282,7 +283,7 @@ class GlobalUserData(BaseGlobalData): extid = self.user.get_extid(game, version, userid) # Add in our defaults we always provide - profile['game'] = game + profile['game'] = game.value profile['version'] = version profile['refid'] = refid profile['extid'] = extid diff --git a/bemani/data/mysql/game.py b/bemani/data/mysql/game.py index ca9ceeb..ce09497 100644 --- a/bemani/data/mysql/game.py +++ b/bemani/data/mysql/game.py @@ -3,7 +3,7 @@ from sqlalchemy.types import String, Integer, JSON # type: ignore from sqlalchemy.dialects.mysql import BIGINT as BigInteger # type: ignore from typing import Any, Dict, List, Optional -from bemani.common import ValidatedDict, Time +from bemani.common import GameConstants, ValidatedDict, Time from bemani.data.mysql.base import BaseData, metadata from bemani.data.types import Achievement, Item, UserID @@ -74,7 +74,7 @@ time_sensitive_settings = Table( class GameData(BaseData): - def get_settings(self, game: str, userid: UserID) -> Optional[ValidatedDict]: + def get_settings(self, game: GameConstants, userid: UserID) -> Optional[ValidatedDict]: """ Given a game and a user ID, look up game-wide settings as a dictionary. @@ -83,7 +83,7 @@ class GameData(BaseData): all versions in a game series. Parameters: - game - String identifying a game series. + game - Enum value identifying a game series. userid - Integer identifying a user, as possibly looked up by UserData. Returns: @@ -91,7 +91,7 @@ class GameData(BaseData): if there are no settings for this game/user. """ sql = "SELECT data FROM game_settings WHERE game = :game AND userid = :userid" - cursor = self.execute(sql, {'game': game, 'userid': userid}) + cursor = self.execute(sql, {'game': game.value, 'userid': userid}) if cursor.rowcount != 1: # Settings doesn't exist @@ -100,12 +100,12 @@ class GameData(BaseData): result = cursor.fetchone() return ValidatedDict(self.deserialize(result['data'])) - def put_settings(self, game: str, userid: UserID, settings: Dict[str, Any]) -> None: + def put_settings(self, game: GameConstants, userid: UserID, settings: Dict[str, Any]) -> None: """ Given a game and a user ID, save game-wide settings to the DB. Parameters: - game - String identifying a game series. + game - Enum value identifying a game series. userid - Integer identifying a user. settings - A dictionary of settings that a game wishes to retrieve later. """ @@ -115,9 +115,9 @@ class GameData(BaseData): "VALUES (:game, :userid, :data) " + "ON DUPLICATE KEY UPDATE data=VALUES(data)" ) - self.execute(sql, {'game': game, 'userid': userid, 'data': self.serialize(settings)}) + self.execute(sql, {'game': game.value, 'userid': userid, 'data': self.serialize(settings)}) - def get_achievement(self, game: str, userid: UserID, achievementid: int, achievementtype: str) -> Optional[ValidatedDict]: + def get_achievement(self, game: GameConstants, userid: UserID, achievementid: int, achievementtype: str) -> Optional[ValidatedDict]: """ Given a game/userid and achievement id/type, find that achievement. @@ -125,7 +125,7 @@ class GameData(BaseData): as long as each one is a different type. Essentially, achievementtype namespaces achievements. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. achievementid - Integer ID, as provided by a game. achievementtype - The type of achievement. @@ -137,7 +137,7 @@ class GameData(BaseData): "SELECT data FROM series_achievement " "WHERE game = :game AND userid = :userid AND id = :id AND type = :type" ) - cursor = self.execute(sql, {'game': game, 'userid': userid, 'id': achievementid, 'type': achievementtype}) + cursor = self.execute(sql, {'game': game.value, 'userid': userid, 'id': achievementid, 'type': achievementtype}) if cursor.rowcount != 1: # score doesn't exist return None @@ -145,19 +145,19 @@ class GameData(BaseData): result = cursor.fetchone() return ValidatedDict(self.deserialize(result['data'])) - def get_achievements(self, game: str, userid: UserID) -> List[Achievement]: + def get_achievements(self, game: GameConstants, userid: UserID) -> List[Achievement]: """ Given a game/userid, find all achievements Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. Returns: A list of Achievement objects. """ sql = "SELECT id, type, data FROM series_achievement WHERE game = :game AND userid = :userid" - cursor = self.execute(sql, {'game': game, 'userid': userid}) + cursor = self.execute(sql, {'game': game.value, 'userid': userid}) achievements = [] for result in cursor.fetchall(): @@ -172,12 +172,12 @@ class GameData(BaseData): return achievements - def put_achievement(self, game: str, userid: UserID, achievementid: int, achievementtype: str, data: Dict[str, Any]) -> None: + def put_achievement(self, game: GameConstants, userid: UserID, achievementid: int, achievementtype: str, data: Dict[str, Any]) -> None: """ Given a game/userid and achievement id/type, save an achievement. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. achievementid - Integer ID, as provided by a game. achievementtype - The type of achievement. @@ -189,14 +189,14 @@ class GameData(BaseData): "VALUES (:game, :userid, :id, :type, :data) " + "ON DUPLICATE KEY UPDATE data=VALUES(data)" ) - self.execute(sql, {'game': game, 'userid': userid, 'id': achievementid, 'type': achievementtype, 'data': self.serialize(data)}) + self.execute(sql, {'game': game.value, 'userid': userid, 'id': achievementid, 'type': achievementtype, 'data': self.serialize(data)}) - def get_time_sensitive_settings(self, game: str, version: int, name: str) -> Optional[ValidatedDict]: + def get_time_sensitive_settings(self, game: GameConstants, version: int, name: str) -> Optional[ValidatedDict]: """ Given a game/version/name, look up the current time-sensitive settings for this game. Parameters: - game - String identifier of the game we want settings for. + game - Enum value identifier of the game we want settings for. version - Integer identifying the game version we want settings for. name - The name of the setting we are concerned with. @@ -209,7 +209,7 @@ class GameData(BaseData): "SELECT data, start_time, end_time FROM time_sensitive_settings WHERE " "game = :game AND version = :version AND name = :name AND start_time <= :time AND end_time > :time" ) - cursor = self.execute(sql, {'game': game, 'version': version, 'name': name, 'time': Time.now()}) + cursor = self.execute(sql, {'game': game.value, 'version': version, 'name': name, 'time': Time.now()}) if cursor.rowcount != 1: # setting doesn't exist return None @@ -220,12 +220,12 @@ class GameData(BaseData): retval['end_time'] = result['end_time'] return retval - def get_all_time_sensitive_settings(self, game: str, version: int, name: str) -> List[ValidatedDict]: + def get_all_time_sensitive_settings(self, game: GameConstants, version: int, name: str) -> List[ValidatedDict]: """ Given a game/version/name, look up all of the time-sensitive settings for this game. Parameters: - game - String identifier of the game we want settings for. + game - Enum value identifier of the game we want settings for. version - Integer identifying the game version we want settings for. name - The name of the setting we are concerned with. @@ -238,7 +238,7 @@ class GameData(BaseData): "SELECT data, start_time, end_time FROM time_sensitive_settings WHERE " "game = :game AND version = :version AND name = :name" ) - cursor = self.execute(sql, {'game': game, 'version': version, 'name': name}) + cursor = self.execute(sql, {'game': game.value, 'version': version, 'name': name}) if cursor.rowcount == 0: # setting doesn't exist return [] @@ -251,7 +251,7 @@ class GameData(BaseData): settings.append(retval) return settings - def put_time_sensitive_settings(self, game: str, version: int, name: str, settings: Dict[str, Any]) -> None: + def put_time_sensitive_settings(self, game: GameConstants, version: int, name: str, settings: Dict[str, Any]) -> None: """ Given a game/version/name and a settings dictionary that contains 'start_time' and 'end_time', as seconds since the unix epoch (UTC), update the DB to store or update this time-sensitive @@ -259,7 +259,7 @@ class GameData(BaseData): the setting duration, and that this setting doesn't overlap any other setting already present. Parameters: - game - String identifier of the game we want settings for. + game - Enum value identifier of the game we want settings for. version - Integer identifying the game version we want settings for. name - The name of the setting we are concerned with. settings - A dictionary containing at least 'start_time' and 'end_time'. @@ -287,7 +287,7 @@ class GameData(BaseData): cursor = self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'name': name, 'start_time': start_time, @@ -309,7 +309,7 @@ class GameData(BaseData): self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'name': name, 'start_time': start_time, @@ -318,7 +318,7 @@ class GameData(BaseData): }, ) - def get_item(self, game: str, version: int, catid: int, cattype: str) -> Optional[ValidatedDict]: + def get_item(self, game: GameConstants, version: int, catid: int, cattype: str) -> Optional[ValidatedDict]: """ Given a game/userid and catalog id/type, find that catalog entry. @@ -326,7 +326,7 @@ class GameData(BaseData): as long as each one is a different type. Essentially, cattype namespaces catalog entry. Parameters: - game - String identifier of the game looking up this entry. + game - Enum value identifier of the game looking up this entry. version - Integer identifier of the version looking up this entry. catid - Integer ID, as provided by a game. cattype - The type of catalog entry. @@ -338,7 +338,7 @@ class GameData(BaseData): "SELECT data FROM catalog " "WHERE game = :game AND version = :version AND id = :id AND type = :type" ) - cursor = self.execute(sql, {'game': game, 'version': version, 'id': catid, 'type': cattype}) + cursor = self.execute(sql, {'game': game.value, 'version': version, 'id': catid, 'type': cattype}) if cursor.rowcount != 1: # entry doesn't exist return None @@ -346,19 +346,19 @@ class GameData(BaseData): result = cursor.fetchone() return ValidatedDict(self.deserialize(result['data'])) - def get_items(self, game: str, version: int) -> List[Item]: + def get_items(self, game: GameConstants, version: int) -> List[Item]: """ Given a game/userid, find all items in the catalog. Parameters: - game - String identifier of the game looking up the catalog. + game - Enum value identifier of the game looking up the catalog. version - Integer identifier of the version looking up this catalog. Returns: A list of Item objects. """ sql = "SELECT id, type, data FROM catalog WHERE game = :game AND version = :version" - cursor = self.execute(sql, {'game': game, 'version': version}) + cursor = self.execute(sql, {'game': game.value, 'version': version}) catalog = [] for result in cursor.fetchall(): diff --git a/bemani/data/mysql/lobby.py b/bemani/data/mysql/lobby.py index 63ea30c..1778c07 100644 --- a/bemani/data/mysql/lobby.py +++ b/bemani/data/mysql/lobby.py @@ -5,7 +5,7 @@ from sqlalchemy.types import String, Integer, JSON # type: ignore from sqlalchemy.dialects.mysql import BIGINT as BigInteger # type: ignore from typing import Optional, Dict, List, Tuple, Any -from bemani.common import ValidatedDict, Time +from bemani.common import GameConstants, ValidatedDict, Time from bemani.data.mysql.base import BaseData, metadata from bemani.data.types import UserID @@ -46,12 +46,12 @@ lobby = Table( class LobbyData(BaseData): - def get_play_session_info(self, game: str, version: int, userid: UserID) -> Optional[ValidatedDict]: + def get_play_session_info(self, game: GameConstants, version: int, userid: UserID) -> Optional[ValidatedDict]: """ Given a game, version and a user ID, look up play session information for that user. Parameters: - game - String identifying a game series. + game - Enum value identifying a game series. version - Integer identifying the version of the game in the series. userid - Integer identifying a user, as possibly looked up by UserData. @@ -69,7 +69,7 @@ class LobbyData(BaseData): cursor = self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'userid': userid, 'time': Time.now() - Time.SECONDS_IN_HOUR, @@ -86,12 +86,12 @@ class LobbyData(BaseData): data['time'] = result['time'] return data - def get_all_play_session_infos(self, game: str, version: int) -> List[Tuple[UserID, ValidatedDict]]: + def get_all_play_session_infos(self, game: GameConstants, version: int) -> List[Tuple[UserID, ValidatedDict]]: """ Given a game and version, look up all play session information. Parameters: - game - String identifying a game series. + game - Enum value identifying a game series. version - Integer identifying the version of the game in the series. Returns: @@ -106,7 +106,7 @@ class LobbyData(BaseData): cursor = self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'time': Time.now() - Time.SECONDS_IN_HOUR, }, @@ -120,12 +120,12 @@ class LobbyData(BaseData): ret.append((UserID(result['userid']), data)) return ret - def put_play_session_info(self, game: str, version: int, userid: UserID, data: Dict[str, Any]) -> None: + def put_play_session_info(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None: """ Given a game, version and a user ID, save play session information for that user. Parameters: - game - String identifying a game series. + game - Enum value identifying a game series. version - Integer identifying the version of the game in the series. userid - Integer identifying a user. data - A dictionary of play session information to store. @@ -143,7 +143,7 @@ class LobbyData(BaseData): self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'userid': userid, 'time': Time.now(), @@ -151,12 +151,12 @@ class LobbyData(BaseData): }, ) - def destroy_play_session_info(self, game: str, version: int, userid: UserID) -> None: + def destroy_play_session_info(self, game: GameConstants, version: int, userid: UserID) -> None: """ Given a game, version and a user ID, throw away session info for that play session. Parameters: - game - String identifying a game series. + game - Enum value identifying a game series. version - Integer identifying the version of the game in the series. userid - Integer identifying a user, as possibly looked up by UserData. """ @@ -167,7 +167,7 @@ class LobbyData(BaseData): self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'userid': userid, }, @@ -176,12 +176,12 @@ class LobbyData(BaseData): sql = "DELETE FROM playsession WHERE time <= :time" self.execute(sql, {'time': Time.now() - Time.SECONDS_IN_HOUR}) - def get_lobby(self, game: str, version: int, userid: UserID) -> Optional[ValidatedDict]: + def get_lobby(self, game: GameConstants, version: int, userid: UserID) -> Optional[ValidatedDict]: """ Given a game, version and a user ID, look up lobby information for that user. Parameters: - game - String identifying a game series. + game - Enum value identifying a game series. version - Integer identifying the version of the game in the series. userid - Integer identifying a user, as possibly looked up by UserData. @@ -199,7 +199,7 @@ class LobbyData(BaseData): cursor = self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'userid': userid, 'time': Time.now() - Time.SECONDS_IN_HOUR, @@ -216,12 +216,12 @@ class LobbyData(BaseData): data['time'] = result['time'] return data - def get_all_lobbies(self, game: str, version: int) -> List[Tuple[UserID, ValidatedDict]]: + def get_all_lobbies(self, game: GameConstants, version: int) -> List[Tuple[UserID, ValidatedDict]]: """ Given a game and version, look up all active lobbies. Parameters: - game - String identifying a game series. + game - Enum value identifying a game series. version - Integer identifying the version of the game in the series. Returns: @@ -234,7 +234,7 @@ class LobbyData(BaseData): cursor = self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'time': Time.now() - Time.SECONDS_IN_HOUR, }, @@ -247,12 +247,12 @@ class LobbyData(BaseData): ret.append((UserID(result['userid']), data)) return ret - def put_lobby(self, game: str, version: int, userid: UserID, data: Dict[str, Any]) -> None: + def put_lobby(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None: """ Given a game, version and a user ID, save lobby information for that user. Parameters: - game - String identifying a game series. + game - Enum value identifying a game series. version - Integer identifying the version of the game in the series. userid - Integer identifying a user. data - A dictionary of lobby information to store. @@ -270,7 +270,7 @@ class LobbyData(BaseData): self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'userid': userid, 'time': Time.now(), diff --git a/bemani/data/mysql/machine.py b/bemani/data/mysql/machine.py index 43a7382..3b9a7b1 100644 --- a/bemani/data/mysql/machine.py +++ b/bemani/data/mysql/machine.py @@ -3,7 +3,7 @@ from sqlalchemy.types import String, Integer, JSON # type: ignore from sqlalchemy.dialects.mysql import BIGINT as BigInteger # type: ignore from typing import Optional, Dict, List, Tuple, Any -from bemani.common import ValidatedDict +from bemani.common import GameConstants, ValidatedDict from bemani.data.mysql.base import BaseData, metadata from bemani.data.types import Machine, Arcade, UserID, ArcadeID @@ -172,7 +172,7 @@ class MachineData(BaseData): result['description'], result['arcadeid'], result['port'], - result['game'], + GameConstants(result['game']) if result['game'] else None, result['version'], self.deserialize(result['data']), ) @@ -199,7 +199,7 @@ class MachineData(BaseData): result['description'], result['arcadeid'], result['port'], - result['game'], + GameConstants(result['game']) if result['game'] else None, result['version'], self.deserialize(result['data']), ) for result in cursor.fetchall() @@ -224,7 +224,7 @@ class MachineData(BaseData): 'description': machine.description, 'arcadeid': machine.arcade, 'port': machine.port, - 'game': machine.game, + 'game': machine.game.value if machine.game else None, 'version': machine.version, 'pcbid': machine.pcbid, 'data': self.serialize(machine.data) @@ -424,13 +424,13 @@ class MachineData(BaseData): ) for result in cursor.fetchall() ] - def get_settings(self, arcadeid: ArcadeID, game: str, version: int, setting: str) -> Optional[ValidatedDict]: + def get_settings(self, arcadeid: ArcadeID, game: GameConstants, version: int, setting: str) -> Optional[ValidatedDict]: """ Given an arcade and a game/version combo, look up this particular setting. Parameters: arcadeid - Integer specifying the arcade to delete. - game - String identifying a game series. + game - Enum value identifying a game series. version - String identifying a game version. setting - String identifying the particular setting we're interestsed in. @@ -438,7 +438,7 @@ class MachineData(BaseData): A dictionary representing game settings, or None if there are no settings for this game/user. """ sql = "SELECT data FROM arcade_settings WHERE arcadeid = :id AND game = :game AND version = :version AND type = :type" - cursor = self.execute(sql, {'id': arcadeid, 'game': game, 'version': version, 'type': setting}) + cursor = self.execute(sql, {'id': arcadeid, 'game': game.value, 'version': version, 'type': setting}) if cursor.rowcount != 1: # Settings doesn't exist @@ -447,13 +447,13 @@ class MachineData(BaseData): result = cursor.fetchone() return ValidatedDict(self.deserialize(result['data'])) - def put_settings(self, arcadeid: ArcadeID, game: str, version: int, setting: str, data: Dict[str, Any]) -> None: + def put_settings(self, arcadeid: ArcadeID, game: GameConstants, version: int, setting: str, data: Dict[str, Any]) -> None: """ Given an arcade and a game/version combo, update the particular setting. Parameters: arcadeid - Integer specifying the arcade to delete. - game - String identifying a game series. + game - Enum value identifying a game series. version - String identifying a game version. setting - String identifying the particular setting we're interestsed in. data - A dictionary that should be saved for this setting. @@ -463,7 +463,7 @@ class MachineData(BaseData): "VALUES (:id, :game, :version, :type, :data) " "ON DUPLICATE KEY UPDATE data=VALUES(data)" ) - self.execute(sql, {'id': arcadeid, 'game': game, 'version': version, 'type': setting, 'data': self.serialize(data)}) + self.execute(sql, {'id': arcadeid, 'game': game.value, 'version': version, 'type': setting, 'data': self.serialize(data)}) def get_balances(self, arcadeid: ArcadeID) -> List[Tuple[UserID, int]]: """ diff --git a/bemani/data/mysql/music.py b/bemani/data/mysql/music.py index 83848d9..c76b541 100644 --- a/bemani/data/mysql/music.py +++ b/bemani/data/mysql/music.py @@ -4,7 +4,7 @@ from sqlalchemy.types import String, Integer, JSON # type: ignore from sqlalchemy.dialects.mysql import BIGINT as BigInteger # type: ignore from typing import Optional, Dict, List, Tuple, Any -from bemani.common import Time +from bemani.common import GameConstants, Time from bemani.data.exceptions import ScoreSaveException from bemani.data.mysql.base import BaseData, metadata from bemani.data.types import Score, Attempt, Song, UserID @@ -82,12 +82,12 @@ music = Table( class MusicData(BaseData): - def __get_musicid(self, game: str, version: int, songid: int, songchart: int) -> int: + def __get_musicid(self, game: GameConstants, version: int, songid: int, songchart: int) -> int: """ Given a game/version/songid/chart, look up the unique music ID for this song. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. songid - ID of the song according to the game. songchart - Chart number according to the game. @@ -98,7 +98,7 @@ class MusicData(BaseData): sql = ( "SELECT id FROM music WHERE songid = :songid AND chart = :chart AND game = :game AND version = :version" ) - cursor = self.execute(sql, {'songid': songid, 'chart': songchart, 'game': game, 'version': version}) + cursor = self.execute(sql, {'songid': songid, 'chart': songchart, 'game': game.value, 'version': version}) if cursor.rowcount != 1: # music doesn't exist raise Exception(f'Song {songid} chart {songchart} doesn\'t exist for game {game} version {version}') @@ -107,7 +107,7 @@ class MusicData(BaseData): def put_score( self, - game: str, + game: GameConstants, version: int, userid: UserID, songid: int, @@ -122,7 +122,7 @@ class MusicData(BaseData): Given a game/version/song/chart and user ID, save a new/updated high score. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. userid - Integer representing a user. Usually looked up with UserData. songid - ID of the song according to the game. @@ -168,7 +168,7 @@ class MusicData(BaseData): def put_attempt( self, - game: str, + game: GameConstants, version: int, userid: Optional[UserID], songid: int, @@ -186,7 +186,7 @@ class MusicData(BaseData): per song/chart in a given game, but they can have as many history entries as times played. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. userid - Integer representing a user. Usually looked up with UserData. songid - ID of the song according to the game. @@ -224,12 +224,12 @@ class MusicData(BaseData): f'There is already an attempt by {userid if userid is not None else 0} for music id {musicid} at {ts}' ) - def get_score(self, game: str, version: int, userid: UserID, songid: int, songchart: int) -> Optional[Score]: + def get_score(self, game: GameConstants, version: int, userid: UserID, songid: int, songchart: int) -> Optional[Score]: """ Look up a user's previous high score. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. userid - Integer representing a user. Usually looked up with UserData. songid - ID of the song according to the game. @@ -248,7 +248,7 @@ class MusicData(BaseData): sql, { 'userid': userid, - 'game': game, + 'game': game.value, 'version': version, 'songid': songid, 'songchart': songchart, @@ -271,12 +271,12 @@ class MusicData(BaseData): self.deserialize(result['data']), ) - def get_score_by_key(self, game: str, version: int, key: int) -> Optional[Tuple[UserID, Score]]: + def get_score_by_key(self, game: GameConstants, version: int, key: int) -> Optional[Tuple[UserID, Score]]: """ Look up previous high score by key. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. key - Integer representing a unique key fetched in a previous Score lookup. @@ -293,7 +293,7 @@ class MusicData(BaseData): cursor = self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'scorekey': key, }, @@ -320,7 +320,7 @@ class MusicData(BaseData): def get_scores( self, - game: str, + game: GameConstants, version: int, userid: UserID, since: Optional[int]=None, @@ -330,7 +330,7 @@ class MusicData(BaseData): Look up all of a user's previous high scores. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. userid - Integer representing a user. Usually looked up with UserData. @@ -347,7 +347,7 @@ class MusicData(BaseData): sql = sql + ' AND score.update >= :since' if until is not None: sql = sql + ' AND score.update < :until' - cursor = self.execute(sql, {'userid': userid, 'game': game, 'version': version, 'since': since, 'until': until}) + cursor = self.execute(sql, {'userid': userid, 'game': game.value, 'version': version, 'since': since, 'until': until}) scores = [] for result in cursor.fetchall(): @@ -367,12 +367,12 @@ class MusicData(BaseData): return scores - def get_most_played(self, game: str, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]: + def get_most_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]: """ Look up a user's most played songs. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. userid - Integer representing a user. Usually looked up with UserData. count - Number of scores to look up. @@ -386,7 +386,7 @@ class MusicData(BaseData): "AND music.game = :game AND music.version = :version " + "GROUP BY songid ORDER BY plays DESC LIMIT :count" ) - cursor = self.execute(sql, {'userid': userid, 'game': game, 'version': version, 'count': count}) + cursor = self.execute(sql, {'userid': userid, 'game': game.value, 'version': version, 'count': count}) most_played = [] for result in cursor.fetchall(): @@ -396,12 +396,12 @@ class MusicData(BaseData): return most_played - def get_last_played(self, game: str, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]: + def get_last_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]: """ Look up a user's last played songs. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. userid - Integer representing a user. Usually looked up with UserData. count - Number of scores to look up. @@ -415,7 +415,7 @@ class MusicData(BaseData): "AND music.game = :game AND music.version = :version " + "ORDER BY timestamp DESC LIMIT :count" ) - cursor = self.execute(sql, {'userid': userid, 'game': game, 'version': version, 'count': count}) + cursor = self.execute(sql, {'userid': userid, 'game': game.value, 'version': version, 'count': count}) last_played = [] for result in cursor.fetchall(): @@ -427,7 +427,7 @@ class MusicData(BaseData): def get_hit_chart( self, - game: str, + game: GameConstants, version: int, count: int, days: Optional[int]=None, @@ -436,7 +436,7 @@ class MusicData(BaseData): Look up a game's most played songs. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. count - Number of scores to look up. @@ -454,7 +454,7 @@ class MusicData(BaseData): timestamp = Time.now() - (Time.SECONDS_IN_DAY * days) sql = sql + "GROUP BY songid ORDER BY plays DESC LIMIT :count" - cursor = self.execute(sql, {'game': game, 'version': version, 'count': count, 'timestamp': timestamp}) + cursor = self.execute(sql, {'game': game.value, 'version': version, 'count': count, 'timestamp': timestamp}) most_played = [] for result in cursor.fetchall(): @@ -466,7 +466,7 @@ class MusicData(BaseData): def get_song( self, - game: str, + game: GameConstants, version: int, songid: int, songchart: int, @@ -475,7 +475,7 @@ class MusicData(BaseData): Given a game/version/songid/chart, look up the name, artist and genre of that song. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. songid - Integer representing the ID (from the game) for this song. songchart - Integer representing the chart for this song. @@ -488,7 +488,7 @@ class MusicData(BaseData): "FROM music WHERE music.game = :game AND music.version = :version AND " + "music.songid = :songid AND music.chart = :songchart" ) - cursor = self.execute(sql, {'game': game, 'version': version, 'songid': songid, 'songchart': songchart}) + cursor = self.execute(sql, {'game': game.value, 'version': version, 'songid': songid, 'songchart': songchart}) if cursor.rowcount != 1: # music doesn't exist return None @@ -506,14 +506,14 @@ class MusicData(BaseData): def get_all_songs( self, - game: str, + game: GameConstants, version: Optional[int]=None, ) -> List[Song]: """ Given a game and a version, look up all song/chart combos associated with that game. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. Returns: @@ -523,7 +523,7 @@ class MusicData(BaseData): "SELECT version, songid, chart, name, artist, genre, data FROM music " "WHERE music.game = :game" ) - params: Dict[str, Any] = {'game': game} + params: Dict[str, Any] = {'game': game.value} if version is not None: sql += " AND music.version = :version" params['version'] = version @@ -550,7 +550,7 @@ class MusicData(BaseData): def get_all_versions_of_song( self, - game: str, + game: GameConstants, version: int, songid: int, songchart: int, @@ -560,7 +560,7 @@ class MusicData(BaseData): Given a game/version/songid/chart, look up all versions of that song across all game versions. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. songid - Integer representing the ID (from the game) for this song. songchart - Integer representing the chart for this song. @@ -594,7 +594,7 @@ class MusicData(BaseData): def get_all_scores( self, - game: str, + game: GameConstants, version: Optional[int]=None, userid: Optional[UserID]=None, songid: Optional[int]=None, @@ -606,7 +606,7 @@ class MusicData(BaseData): Look up all of a game's high scores for all users. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. Returns: @@ -660,7 +660,7 @@ class MusicData(BaseData): # Now, query itself cursor = self.execute(sql, { - 'game': game, + 'game': game.value, 'version': version, 'userid': userid, 'songid': songid, @@ -693,7 +693,7 @@ class MusicData(BaseData): def get_all_records( self, - game: str, + game: GameConstants, version: Optional[int]=None, userlist: Optional[List[UserID]]=None, locationlist: Optional[List[int]]=None, @@ -706,7 +706,7 @@ class MusicData(BaseData): display area-local high scores, etc. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. userlist - List of UserIDs to limit the search to. locationlist - A list of location IDs to limit searches to. @@ -734,7 +734,7 @@ class MusicData(BaseData): musicid_sql = ( "SELECT DISTINCT(score.musicid) FROM score, music WHERE score.musicid = music.id AND music.game = :game" ) - params: Dict[str, Any] = {'game': game} + params: Dict[str, Any] = {'game': game.value} if version is not None: musicid_sql = musicid_sql + ' AND music.version = :version' params['version'] = version @@ -790,12 +790,12 @@ class MusicData(BaseData): return scores - def get_attempt_by_key(self, game: str, version: int, key: int) -> Optional[Tuple[UserID, Attempt]]: + def get_attempt_by_key(self, game: GameConstants, version: int, key: int) -> Optional[Tuple[UserID, Attempt]]: """ Look up a previous attempt by key. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. key - Integer representing a unique key fetched in a previous Attempt lookup. @@ -810,7 +810,7 @@ class MusicData(BaseData): cursor = self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'scorekey': key, }, @@ -836,7 +836,7 @@ class MusicData(BaseData): def get_all_attempts( self, - game: str, + game: GameConstants, version: Optional[int]=None, userid: Optional[UserID]=None, songid: Optional[int]=None, @@ -849,7 +849,7 @@ class MusicData(BaseData): Look up all of the attempts to score for a particular game. Parameters: - game - String representing a game series. + game - Enum value representing a game series. version - Integer representing which version of the game. Returns: @@ -901,7 +901,7 @@ class MusicData(BaseData): # Now, query itself cursor = self.execute(sql, { - 'game': game, + 'game': game.value, 'version': version, 'userid': userid, 'songid': songid, diff --git a/bemani/data/mysql/network.py b/bemani/data/mysql/network.py index debe3a3..b2109a7 100644 --- a/bemani/data/mysql/network.py +++ b/bemani/data/mysql/network.py @@ -3,7 +3,7 @@ from sqlalchemy.types import String, Integer, Text, JSON # type: ignore from sqlalchemy.dialects.mysql import BIGINT as BigInteger # type: ignore from typing import Optional, Dict, List, Tuple, Any -from bemani.common import Time +from bemani.common import GameConstants, Time from bemani.data.mysql.base import BaseData, metadata from bemani.data.types import News, Event, UserID, ArcadeID @@ -152,7 +152,7 @@ class NetworkData(BaseData): # Should never happen return (0, 0) - def should_schedule(self, game: str, version: int, name: str, schedule: str) -> bool: + def should_schedule(self, game: GameConstants, version: int, name: str, schedule: str) -> bool: """ Given a game/version/name pair and a schedule value, return whether this scheduled work is overdue or not. @@ -165,7 +165,7 @@ class NetworkData(BaseData): "WHERE game = :game AND version = :version AND " "name = :name AND schedule = :schedule" ) - cursor = self.execute(sql, {'game': game, 'version': version, 'name': name, 'schedule': schedule}) + cursor = self.execute(sql, {'game': game.value, 'version': version, 'name': name, 'schedule': schedule}) if cursor.rowcount != 1: # No scheduled work was registered, so time to get going! return True @@ -191,7 +191,7 @@ class NetworkData(BaseData): # We have already run this work for this schedule return False - def mark_scheduled(self, game: str, version: int, name: str, schedule: str) -> None: + def mark_scheduled(self, game: GameConstants, version: int, name: str, schedule: str) -> None: if schedule not in ['daily', 'weekly']: raise Exception('Logic error, specify either \'daily\' or \'weekly\' for schedule type!') @@ -205,7 +205,7 @@ class NetworkData(BaseData): self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'name': name, 'schedule': schedule, @@ -224,7 +224,7 @@ class NetworkData(BaseData): self.execute( sql, { - 'game': game, + 'game': game.value, 'version': version, 'name': name, 'schedule': schedule, diff --git a/bemani/data/mysql/user.py b/bemani/data/mysql/user.py index cf64529..5e2db16 100644 --- a/bemani/data/mysql/user.py +++ b/bemani/data/mysql/user.py @@ -7,7 +7,7 @@ from sqlalchemy.exc import IntegrityError # type: ignore from typing import Optional, Dict, List, Tuple, Any from passlib.hash import pbkdf2_sha512 # type: ignore -from bemani.common import ValidatedDict, Time +from bemani.common import ValidatedDict, GameConstants, Time from bemani.data.mysql.base import BaseData, metadata from bemani.data.remoteuser import RemoteUser from bemani.data.types import User, Achievement, Link, UserID, ArcadeID @@ -210,7 +210,7 @@ class UserData(BaseData): result = cursor.fetchone() return UserID(result['id']) - def from_refid(self, game: str, version: int, refid: str) -> Optional[UserID]: + def from_refid(self, game: GameConstants, version: int, refid: str) -> Optional[UserID]: """ Given a generated RefID, look up a user ID. @@ -218,7 +218,7 @@ class UserData(BaseData): to look up a user. When creating a new profile, we generate a unique RefID and ExtID. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. refid - RefID in question, most likely previously generated by this class. @@ -227,7 +227,7 @@ class UserData(BaseData): """ # First, look up the user account sql = "SELECT userid FROM refid WHERE game = :game AND version = :version AND refid = :refid" - cursor = self.execute(sql, {'game': game, 'version': version, 'refid': refid}) + cursor = self.execute(sql, {'game': game.value, 'version': version, 'refid': refid}) if cursor.rowcount != 1: # Couldn't find a user with this refid return None @@ -235,7 +235,7 @@ class UserData(BaseData): result = cursor.fetchone() return UserID(result['userid']) - def from_extid(self, game: str, version: int, extid: int) -> Optional[UserID]: + def from_extid(self, game: GameConstants, version: int, extid: int) -> Optional[UserID]: """ Given a generated ExtID, look up a user ID. @@ -243,7 +243,7 @@ class UserData(BaseData): to look up a user. When creating a new profile, we generate a unique RefID and ExtID. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. extid - ExtID in question, most likely previously generated by this class. @@ -252,7 +252,7 @@ class UserData(BaseData): """ # First, look up the user account sql = "SELECT userid FROM extid WHERE game = :game AND extid = :extid" - cursor = self.execute(sql, {'game': game, 'extid': extid}) + cursor = self.execute(sql, {'game': game.value, 'extid': extid}) if cursor.rowcount != 1: # Couldn't find a user with this refid return None @@ -476,12 +476,12 @@ class UserData(BaseData): sql = "UPDATE user SET password = :hash WHERE id = :userid" self.execute(sql, {'hash': passhash, 'userid': userid}) - def get_profile(self, game: str, version: int, userid: UserID) -> Optional[ValidatedDict]: + def get_profile(self, game: GameConstants, version: int, userid: UserID) -> Optional[ValidatedDict]: """ Given a game/version/userid, look up the associated profile. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. @@ -494,7 +494,7 @@ class UserData(BaseData): "WHERE refid.userid = :userid AND refid.game = :game AND refid.version = :version AND " "extid.userid = refid.userid AND extid.game = refid.game" ) - cursor = self.execute(sql, {'userid': userid, 'game': game, 'version': version}) + cursor = self.execute(sql, {'userid': userid, 'game': game.value, 'version': version}) if cursor.rowcount != 1: # Profile doesn't exist return None @@ -503,7 +503,7 @@ class UserData(BaseData): profile = { 'refid': result['refid'], 'extid': result['extid'], - 'game': game, + 'game': game.value, 'version': version, } @@ -517,14 +517,14 @@ class UserData(BaseData): profile.update(self.deserialize(result['data'])) return ValidatedDict(profile) - def get_any_profile(self, game: str, version: int, userid: UserID) -> Optional[ValidatedDict]: + def get_any_profile(self, game: GameConstants, version: int, userid: UserID) -> Optional[ValidatedDict]: """ Given a game/version/userid, look up the associated profile. If the profile for that version doesn't exist, try another profile, failing only if there is no profile for any version of this game. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. @@ -542,13 +542,13 @@ class UserData(BaseData): else: return None - def get_any_profiles(self, game: str, version: int, userids: List[UserID]) -> List[Tuple[UserID, Optional[ValidatedDict]]]: + def get_any_profiles(self, game: GameConstants, version: int, userids: List[UserID]) -> List[Tuple[UserID, Optional[ValidatedDict]]]: """ Does the exact same thing as get_any_profile but across a list of users instead of one. Provided purely as a convenience function. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userids - List of Integer user IDs, as looked up by one of the above functions. @@ -561,7 +561,7 @@ class UserData(BaseData): for userid in userids ] - def get_games_played(self, userid: UserID) -> List[Tuple[str, int]]: + def get_games_played(self, userid: UserID) -> List[Tuple[GameConstants, int]]: """ Given a user ID, look up all game/version combos this user has played. @@ -575,15 +575,15 @@ class UserData(BaseData): cursor = self.execute(sql, {'userid': userid}) profiles = [] for result in cursor.fetchall(): - profiles.append((result['game'], result['version'])) + profiles.append((GameConstants(result['game']), result['version'])) return profiles - def get_all_profiles(self, game: str, version: int) -> List[Tuple[UserID, ValidatedDict]]: + def get_all_profiles(self, game: GameConstants, version: int) -> List[Tuple[UserID, ValidatedDict]]: """ Given a game/version, look up all user profiles for that game. Parameters: - game - String identifier of the game we want all user profiles for. + game - Enum value identifier of the game we want all user profiles for. version - Integer version of the game we want all user profiles for. Returns: @@ -595,14 +595,14 @@ class UserData(BaseData): "WHERE refid.game = :game AND refid.version = :version " "AND refid.refid = profile.refid AND extid.game = refid.game AND extid.userid = refid.userid" ) - cursor = self.execute(sql, {'game': game, 'version': version}) + cursor = self.execute(sql, {'game': game.value, 'version': version}) profiles = [] for result in cursor.fetchall(): profile = { 'refid': result['refid'], 'extid': result['extid'], - 'game': game, + 'game': game.value, 'version': version, } profile.update(self.deserialize(result['data'])) @@ -615,12 +615,12 @@ class UserData(BaseData): return profiles - def get_all_players(self, game: str, version: int) -> List[UserID]: + def get_all_players(self, game: GameConstants, version: int) -> List[UserID]: """ Given a game/version, look up all user IDs that played this game/version. Parameters: - game - String identifier of the game we want all user profiles for. + game - Enum value identifier of the game we want all user profiles for. version - Integer version of the game we want all user profiles for. Returns: @@ -630,16 +630,16 @@ class UserData(BaseData): "SELECT refid.userid AS userid FROM refid " "WHERE refid.game = :game AND refid.version = :version" ) - cursor = self.execute(sql, {'game': game, 'version': version}) + cursor = self.execute(sql, {'game': game.value, 'version': version}) return [UserID(result['userid']) for result in cursor.fetchall()] - def get_all_achievements(self, game: str, version: int) -> List[Tuple[UserID, Achievement]]: + def get_all_achievements(self, game: GameConstants, version: int) -> List[Tuple[UserID, Achievement]]: """ Given a game/version, find all achievements for al players. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. Returns: @@ -650,7 +650,7 @@ class UserData(BaseData): "refid.userid AS userid FROM achievement, refid WHERE refid.game = :game AND " "refid.version = :version AND refid.refid = achievement.refid" ) - cursor = self.execute(sql, {'game': game, 'version': version}) + cursor = self.execute(sql, {'game': game.value, 'version': version}) achievements = [] for result in cursor.fetchall(): @@ -668,12 +668,12 @@ class UserData(BaseData): return achievements - def put_profile(self, game: str, version: int, userid: UserID, profile: Dict[str, Any]) -> None: + def put_profile(self, game: GameConstants, version: int, userid: UserID, profile: Dict[str, Any]) -> None: """ Given a game/version/userid, save an associated profile. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. profile - A dictionary that a game class will want to retrieve later. @@ -697,7 +697,7 @@ class UserData(BaseData): ) self.execute(sql, {'refid': refid, 'json': self.serialize(profile)}) - def get_achievement(self, game: str, version: int, userid: UserID, achievementid: int, achievementtype: str) -> Optional[ValidatedDict]: + def get_achievement(self, game: GameConstants, version: int, userid: UserID, achievementid: int, achievementtype: str) -> Optional[ValidatedDict]: """ Given a game/version/userid and achievement id/type, find that achievement. @@ -705,7 +705,7 @@ class UserData(BaseData): as long as each one is a different type. Essentially, achievementtype namespaces achievements. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. achievementid - Integer ID, as provided by a game. @@ -726,12 +726,12 @@ class UserData(BaseData): result = cursor.fetchone() return ValidatedDict(self.deserialize(result['data'])) - def get_achievements(self, game: str, version: int, userid: UserID) -> List[Achievement]: + def get_achievements(self, game: GameConstants, version: int, userid: UserID) -> List[Achievement]: """ Given a game/version/userid, find all achievements Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. @@ -755,12 +755,12 @@ class UserData(BaseData): return achievements - def put_achievement(self, game: str, version: int, userid: UserID, achievementid: int, achievementtype: str, data: Dict[str, Any]) -> None: + def put_achievement(self, game: GameConstants, version: int, userid: UserID, achievementid: int, achievementtype: str, data: Dict[str, Any]) -> None: """ Given a game/version/userid and achievement id/type, save an achievement. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. achievementid - Integer ID, as provided by a game. @@ -777,12 +777,12 @@ class UserData(BaseData): ) self.execute(sql, {'refid': refid, 'id': achievementid, 'type': achievementtype, 'data': self.serialize(data)}) - def destroy_achievement(self, game: str, version: int, userid: UserID, achievementid: int, achievementtype: str) -> None: + def destroy_achievement(self, game: GameConstants, version: int, userid: UserID, achievementid: int, achievementtype: str) -> None: """ Given a game/version/userid and achievement id/type, delete an achievement. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. achievementid - Integer ID, as provided by a game. @@ -798,7 +798,7 @@ class UserData(BaseData): def get_time_based_achievements( self, - game: str, + game: GameConstants, version: int, userid: UserID, achievementtype: Optional[str]=None, @@ -809,7 +809,7 @@ class UserData(BaseData): Given a game/version/userid, find all time-based achievements Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. achievementtype - Optional string specifying to constrain to a type of achievement. @@ -844,7 +844,7 @@ class UserData(BaseData): def put_time_based_achievement( self, - game: str, + game: GameConstants, version: int, userid: UserID, achievementid: int, @@ -856,7 +856,7 @@ class UserData(BaseData): time-based achievements are immutable once saved. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. achievementid - Integer ID, as provided by a game. @@ -872,12 +872,12 @@ class UserData(BaseData): ) self.execute(sql, {'refid': refid, 'id': achievementid, 'type': achievementtype, 'ts': Time.now(), 'data': self.serialize(data)}) - def get_all_time_based_achievements(self, game: str, version: int) -> List[Tuple[UserID, Achievement]]: + def get_all_time_based_achievements(self, game: GameConstants, version: int) -> List[Tuple[UserID, Achievement]]: """ Given a game/version, find all time-based achievements for all players. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. Returns: @@ -889,7 +889,7 @@ class UserData(BaseData): "refid.userid AS userid FROM time_based_achievement, refid WHERE refid.game = :game AND " "refid.version = :version AND refid.refid = time_based_achievement.refid" ) - cursor = self.execute(sql, {'game': game, 'version': version}) + cursor = self.execute(sql, {'game': game.value, 'version': version}) achievements = [] for result in cursor.fetchall(): @@ -907,7 +907,7 @@ class UserData(BaseData): return achievements - def get_link(self, game: str, version: int, userid: UserID, linktype: str, other_userid: UserID) -> Optional[ValidatedDict]: + def get_link(self, game: GameConstants, version: int, userid: UserID, linktype: str, other_userid: UserID) -> Optional[ValidatedDict]: """ Given a game/version/userid and link type + other userid, find that link. @@ -915,7 +915,7 @@ class UserData(BaseData): as long as each one is a different type. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. linktype - The type of link. @@ -927,7 +927,7 @@ class UserData(BaseData): sql = ( "SELECT data FROM link WHERE game = :game AND version = :version AND userid = :userid AND type = :type AND other_userid = :other_userid" ) - cursor = self.execute(sql, {'game': game, 'version': version, 'userid': userid, 'type': linktype, 'other_userid': other_userid}) + cursor = self.execute(sql, {'game': game.value, 'version': version, 'userid': userid, 'type': linktype, 'other_userid': other_userid}) if cursor.rowcount != 1: # score doesn't exist return None @@ -935,12 +935,12 @@ class UserData(BaseData): result = cursor.fetchone() return ValidatedDict(self.deserialize(result['data'])) - def get_links(self, game: str, version: int, userid: UserID) -> List[Link]: + def get_links(self, game: GameConstants, version: int, userid: UserID) -> List[Link]: """ Given a game/version/userid, find all links between this user and other users Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. @@ -948,7 +948,7 @@ class UserData(BaseData): A list of Link objects. """ sql = "SELECT type, other_userid, data FROM link WHERE game = :game AND version = :version AND userid = :userid" - cursor = self.execute(sql, {'game': game, 'version': version, 'userid': userid}) + cursor = self.execute(sql, {'game': game.value, 'version': version, 'userid': userid}) links = [] for result in cursor.fetchall(): @@ -963,12 +963,12 @@ class UserData(BaseData): return links - def put_link(self, game: str, version: int, userid: UserID, linktype: str, other_userid: UserID, data: Dict[str, Any]) -> None: + def put_link(self, game: GameConstants, version: int, userid: UserID, linktype: str, other_userid: UserID, data: Dict[str, Any]) -> None: """ Given a game/version/userid and link id + other_userid, save an link. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. linktype - The type of link. @@ -981,14 +981,14 @@ class UserData(BaseData): "VALUES (:game, :version, :userid, :type, :other_userid, :data) " "ON DUPLICATE KEY UPDATE data=VALUES(data)" ) - self.execute(sql, {'game': game, 'version': version, 'userid': userid, 'type': linktype, 'other_userid': other_userid, 'data': self.serialize(data)}) + self.execute(sql, {'game': game.value, 'version': version, 'userid': userid, 'type': linktype, 'other_userid': other_userid, 'data': self.serialize(data)}) - def destroy_link(self, game: str, version: int, userid: UserID, linktype: str, other_userid: UserID) -> None: + def destroy_link(self, game: GameConstants, version: int, userid: UserID, linktype: str, other_userid: UserID) -> None: """ Given a game/version/userid and link id + other_userid, destroy the link. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. linktype - The type of link. @@ -997,7 +997,7 @@ class UserData(BaseData): sql = ( "DELETE FROM link WHERE game = :game AND version = :version AND userid = :userid AND type = :type AND other_userid = :other_userid" ) - self.execute(sql, {'game': game, 'version': version, 'userid': userid, 'type': linktype, 'other_userid': other_userid}) + self.execute(sql, {'game': game.value, 'version': version, 'userid': userid, 'type': linktype, 'other_userid': other_userid}) def get_balance(self, userid: UserID, arcadeid: ArcadeID) -> int: """ @@ -1043,12 +1043,12 @@ class UserData(BaseData): return None return newbalance - def get_refid(self, game: str, version: int, userid: UserID) -> str: + def get_refid(self, game: GameConstants, version: int, userid: UserID) -> str: """ Given a game/version and user ID, look up the RefID for the profile. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. @@ -1057,19 +1057,19 @@ class UserData(BaseData): and returns it, which can be used for creating/looking up a profile in the future. """ sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version" - cursor = self.execute(sql, {'userid': userid, 'game': game, 'version': version}) + cursor = self.execute(sql, {'userid': userid, 'game': game.value, 'version': version}) if cursor.rowcount == 1: result = cursor.fetchone() return result['refid'] else: return self.create_refid(game, version, userid) - def get_extid(self, game: str, version: int, userid: UserID) -> int: + def get_extid(self, game: GameConstants, version: int, userid: UserID) -> int: """ Given a game/version and a user ID, look up the ExtID for the profile. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. @@ -1080,7 +1080,7 @@ class UserData(BaseData): def fetch_extid() -> Optional[int]: sql = "SELECT extid FROM extid WHERE userid = :userid AND game = :game" - cursor = self.execute(sql, {'userid': userid, 'game': game}) + cursor = self.execute(sql, {'userid': userid, 'game': game.value}) if cursor.rowcount == 1: result = cursor.fetchone() return result['extid'] @@ -1120,7 +1120,7 @@ class UserData(BaseData): """ self._destroy_session(session, 'userid') - def create_refid(self, game: str, version: int, userid: UserID) -> str: + def create_refid(self, game: GameConstants, version: int, userid: UserID) -> str: """ Given a game/version/userid, create a RefID and an ExtID if necessary. @@ -1130,7 +1130,7 @@ class UserData(BaseData): is guaranteed unique for each game series/user. Parameters: - game - String identifier of the game looking up the user. + game - Enum value identifier of the game looking up the user. version - Integer version of the game looking up the user. userid - Integer user ID, as looked up by one of the above functions. @@ -1151,7 +1151,7 @@ class UserData(BaseData): "VALUES (:game, :extid, :userid)" ) try: - cursor = self.execute(sql, {'game': game, 'extid': extid, 'userid': userid}) + cursor = self.execute(sql, {'game': game.value, 'extid': extid, 'userid': userid}) except IntegrityError: # User already has an ExtID for this game series pass @@ -1170,7 +1170,7 @@ class UserData(BaseData): "VALUES (:game, :version, :refid, :userid)" ) try: - cursor = self.execute(sql, {'game': game, 'version': version, 'refid': refid, 'userid': userid}) + cursor = self.execute(sql, {'game': game.value, 'version': version, 'refid': refid, 'userid': userid}) if cursor.rowcount != 1: raise AccountCreationException() return refid @@ -1178,7 +1178,7 @@ class UserData(BaseData): # We maybe lost the race? Look up the ID from another creation. Don't call get_refid # because it calls us, so we don't want an infinite loop. sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version" - cursor = self.execute(sql, {'userid': userid, 'game': game, 'version': version}) + cursor = self.execute(sql, {'userid': userid, 'game': game.value, 'version': version}) if cursor.rowcount == 1: result = cursor.fetchone() return result['refid'] diff --git a/bemani/data/triggers.py b/bemani/data/triggers.py index f6c71d4..7434964 100644 --- a/bemani/data/triggers.py +++ b/bemani/data/triggers.py @@ -13,7 +13,7 @@ class Triggers: def __init__(self, config: Dict[str, Any]): self.config = config - def __gameconst_to_series(self, game: str) -> str: + def __gameconst_to_series(self, game: GameConstants) -> str: return { GameConstants.BISHI_BASHI: 'Bishi Bashi', GameConstants.DANCE_EVOLUTION: 'Dance Evolution', @@ -26,12 +26,12 @@ class Triggers: GameConstants.SDVX: 'Sound Voltex', }.get(game, 'Unknown') - def broadcast_score(self, data: Dict[BroadcastConstants, str], game: str, song: Song) -> None: + def broadcast_score(self, data: Dict[BroadcastConstants, str], game: GameConstants, song: Song) -> None: # For now we only support discord if self.config.get('webhooks', {}).get('discord', {}).get(game, None) is not None: self.broadcast_score_discord(data, game, song) - def broadcast_score_discord(self, data: Dict[BroadcastConstants, str], game: str, song: Song) -> None: + def broadcast_score_discord(self, data: Dict[BroadcastConstants, str], game: GameConstants, song: Song) -> None: if game == GameConstants.IIDX: now = datetime.now() diff --git a/bemani/data/types.py b/bemani/data/types.py index 8c7b521..c38950b 100644 --- a/bemani/data/types.py +++ b/bemani/data/types.py @@ -1,6 +1,6 @@ from typing import Optional, List, Dict, Any, NewType -from bemani.common import ValidatedDict +from bemani.common import ValidatedDict, GameConstants UserID = NewType('UserID', int) ArcadeID = NewType('ArcadeID', int) @@ -102,7 +102,7 @@ class Machine: description: str, arcade: Optional[ArcadeID], port: int, - game: Optional[str], + game: Optional[GameConstants], version: Optional[int], data: Dict[str, Any], ) -> None: @@ -174,7 +174,7 @@ class Song: def __init__( self, - game: str, + game: GameConstants, version: int, songid: int, songchart: int, diff --git a/bemani/frontend/admin/admin.py b/bemani/frontend/admin/admin.py index ef5b208..4eafc2b 100644 --- a/bemani/frontend/admin/admin.py +++ b/bemani/frontend/admin/admin.py @@ -48,7 +48,7 @@ def format_machine(machine: Machine) -> Dict[str, Any]: 'description': machine.description, 'arcade': machine.arcade, 'port': machine.port, - 'game': machine.game or 'any', + 'game': machine.game.value if machine.game else 'any', 'version': machine.version, } @@ -231,9 +231,9 @@ def viewarcades() -> Response: def viewmachines() -> Response: games: Dict[str, Dict[int, str]] = {} for (game, version, name) in Base.all_games(): - if game not in games: - games[game] = {} - games[game][version] = name + if game.value not in games: + games[game.value] = {} + games[game.value][version] = name return render_react( 'Machines', @@ -242,14 +242,14 @@ def viewmachines() -> Response: 'machines': [format_machine(machine) for machine in g.data.local.machine.get_all_machines()], 'arcades': {arcade.id: arcade.name for arcade in g.data.local.machine.get_all_arcades()}, 'series': { - GameConstants.BISHI_BASHI: 'BishiBashi', - GameConstants.DDR: 'DDR', - GameConstants.IIDX: 'IIDX', - GameConstants.JUBEAT: 'Jubeat', - GameConstants.MUSECA: 'MÚSECA', - GameConstants.POPN_MUSIC: 'Pop\'n Music', - GameConstants.REFLEC_BEAT: 'Reflec Beat', - GameConstants.SDVX: 'SDVX', + GameConstants.BISHI_BASHI.value: 'BishiBashi', + GameConstants.DDR.value: 'DDR', + GameConstants.IIDX.value: 'IIDX', + GameConstants.JUBEAT.value: 'Jubeat', + GameConstants.MUSECA.value: 'MÚSECA', + GameConstants.POPN_MUSIC.value: 'Pop\'n Music', + GameConstants.REFLEC_BEAT.value: 'Reflec Beat', + GameConstants.SDVX.value: 'SDVX', }, 'games': games, 'enforcing': g.config['server']['enforce_pcbid'], @@ -701,7 +701,7 @@ def updatepcbid() -> Dict[str, Any]: current_machine.description = machine['description'] current_machine.arcade = machine['arcade'] current_machine.port = machine['port'] - current_machine.game = None if machine['game'] == 'any' else machine['game'] + current_machine.game = None if machine['game'] == 'any' else GameConstants(machine['game']) current_machine.version = None if machine['game'] == 'any' else machine['version'] g.data.local.machine.put_machine(current_machine) diff --git a/bemani/frontend/arcade/arcade.py b/bemani/frontend/arcade/arcade.py index afb19b6..b1251ed 100644 --- a/bemani/frontend/arcade/arcade.py +++ b/bemani/frontend/arcade/arcade.py @@ -78,8 +78,8 @@ def format_event(event: Event) -> Dict[str, Any]: def get_game_settings(arcade: Arcade) -> List[Dict[str, Any]]: - game_lut: Dict[str, Dict[int, str]] = {} - settings_lut: Dict[str, Dict[int, Dict[str, Any]]] = {} + game_lut: Dict[GameConstants, Dict[int, str]] = {} + settings_lut: Dict[GameConstants, Dict[int, Dict[str, Any]]] = {} all_settings = [] for (game, version, name) in Base.all_games(): @@ -95,7 +95,7 @@ def get_game_settings(arcade: Arcade) -> List[Dict[str, Any]]: # First, set up the basics game_settings: Dict[str, Any] = { - 'game': game, + 'game': game.value, 'version': version, 'name': game_lut[game][version], 'bools': [], @@ -351,7 +351,7 @@ def updatesettings(arcadeid: int) -> Dict[str, Any]: if g.userID not in arcade.owners: raise Exception('You don\'t own this arcade, refusing to update!') - game = request.get_json()['game'] + game = GameConstants(request.get_json()['game']) version = request.get_json()['version'] for setting_type, update_function in [ @@ -380,6 +380,6 @@ def updatesettings(arcadeid: int) -> Dict[str, Any]: return { 'game_settings': [ gs for gs in get_game_settings(arcade) - if gs['game'] == game and gs['version'] == version + if gs['game'] == game.value and gs['version'] == version ][0], } diff --git a/bemani/frontend/base.py b/bemani/frontend/base.py index fb204c2..9801465 100644 --- a/bemani/frontend/base.py +++ b/bemani/frontend/base.py @@ -1,20 +1,21 @@ # vim: set fileencoding=utf-8 import copy +from abc import ABC from typing import Any, Dict, Iterator, List, Optional, Set, Tuple from flask_caching import Cache # type: ignore -from bemani.common import ValidatedDict, ID +from bemani.common import GameConstants, ValidatedDict, ID from bemani.data import Data, Score, Attempt, Link, Song, UserID, RemoteUser -class FrontendBase: +class FrontendBase(ABC): """ All subclasses should override this attribute with the string the game series uses in the DB. """ - game: str = None + game: GameConstants """ If a subclass wishes to constrain music searches to a particular @@ -98,14 +99,14 @@ class FrontendBase: else: return elems[:-(num % 10)] - def all_games(self) -> Iterator[Tuple[str, int, str]]: + def all_games(self) -> Iterator[Tuple[GameConstants, int, str]]: """ Override this to return an interator based on a game series factory. """ def get_all_songs(self, force_db_load: bool=False) -> Dict[int, Dict[str, Any]]: if not force_db_load: - cached_songs = self.cache.get(f'{self.game}.sorted_songs') + cached_songs = self.cache.get(f'{self.game.value}.sorted_songs') if cached_songs is not None: return cached_songs @@ -120,7 +121,7 @@ class FrontendBase: else: songs[song.id] = self.merge_song(songs[song.id], song) - self.cache.set(f'{self.game}.sorted_songs', songs, timeout=600) + self.cache.set(f'{self.game.value}.sorted_songs', songs, timeout=600) return songs def get_all_player_info(self, userids: List[UserID], limit: Optional[int]=None, allow_remote: bool=False) -> Dict[UserID, Dict[int, Dict[str, Any]]]: diff --git a/bemani/frontend/bishi/bishi.py b/bemani/frontend/bishi/bishi.py index fe196b3..b25dbaa 100644 --- a/bemani/frontend/bishi/bishi.py +++ b/bemani/frontend/bishi/bishi.py @@ -18,7 +18,7 @@ class BishiBashiFrontend(FrontendBase): super().__init__(data, config, cache) self.machines: Dict[int, str] = {} - def all_games(self) -> Iterator[Tuple[str, int, str]]: + def all_games(self) -> Iterator[Tuple[GameConstants, int, str]]: yield from BishiBashiFactory.all_games() def __update_value(self, oldvalue: str, newvalue: bytes) -> str: diff --git a/bemani/frontend/ddr/ddr.py b/bemani/frontend/ddr/ddr.py index c5615ad..261206f 100644 --- a/bemani/frontend/ddr/ddr.py +++ b/bemani/frontend/ddr/ddr.py @@ -37,7 +37,7 @@ class DDRFrontend(FrontendBase): VersionConstants.DDR_A20: 3, } - def all_games(self) -> Iterator[Tuple[str, int, str]]: + def all_games(self) -> Iterator[Tuple[GameConstants, int, str]]: yield from DDRFactory.all_games() def update_name(self, profile: ValidatedDict, name: str) -> ValidatedDict: diff --git a/bemani/frontend/iidx/iidx.py b/bemani/frontend/iidx/iidx.py index c4dc2f5..000bd19 100644 --- a/bemani/frontend/iidx/iidx.py +++ b/bemani/frontend/iidx/iidx.py @@ -31,7 +31,7 @@ class IIDXFrontend(FrontendBase): super().__init__(data, config, cache) self.machines: Dict[int, str] = {} - def all_games(self) -> Iterator[Tuple[str, int, str]]: + def all_games(self) -> Iterator[Tuple[GameConstants, int, str]]: yield from IIDXFactory.all_games() def get_duplicate_id(self, musicid: int, chart: int) -> Optional[Tuple[int, int]]: diff --git a/bemani/frontend/jubeat/jubeat.py b/bemani/frontend/jubeat/jubeat.py index 0b905d7..2a138bb 100644 --- a/bemani/frontend/jubeat/jubeat.py +++ b/bemani/frontend/jubeat/jubeat.py @@ -19,10 +19,10 @@ class JubeatFrontend(FrontendBase): valid_rival_types = ['rival'] - def all_games(self) -> Iterator[Tuple[str, int, str]]: + def all_games(self) -> Iterator[Tuple[GameConstants, int, str]]: yield from JubeatFactory.all_games() - def sanitized_games(self) -> Iterator[Tuple[str, int, str]]: + def sanitized_games(self) -> Iterator[Tuple[GameConstants, int, str]]: mapping = { VersionConstants.JUBEAT: 1, VersionConstants.JUBEAT_RIPPLES: 2, diff --git a/bemani/frontend/museca/museca.py b/bemani/frontend/museca/museca.py index 2346b69..5e12305 100644 --- a/bemani/frontend/museca/museca.py +++ b/bemani/frontend/museca/museca.py @@ -22,7 +22,7 @@ class MusecaFrontend(FrontendBase): def __init__(self, data: Data, config: Dict[str, Any], cache: Cache) -> None: super().__init__(data, config, cache) - def all_games(self) -> Iterator[Tuple[str, int, str]]: + def all_games(self) -> Iterator[Tuple[GameConstants, int, str]]: yield from MusecaFactory.all_games() yield ( GameConstants.MUSECA, diff --git a/bemani/frontend/popn/popn.py b/bemani/frontend/popn/popn.py index 80c599f..347aa8b 100644 --- a/bemani/frontend/popn/popn.py +++ b/bemani/frontend/popn/popn.py @@ -31,7 +31,7 @@ class PopnMusicFrontend(FrontendBase): VersionConstants.POPN_MUSIC_USANEKO: 4, } - def all_games(self) -> Iterator[Tuple[str, int, str]]: + def all_games(self) -> Iterator[Tuple[GameConstants, int, str]]: yield from PopnMusicFactory.all_games() yield (GameConstants.POPN_MUSIC, 0, 'CS and Licenses') # Folder that doesn't belong to any specific game diff --git a/bemani/frontend/reflec/reflec.py b/bemani/frontend/reflec/reflec.py index c06ebe6..b63660a 100644 --- a/bemani/frontend/reflec/reflec.py +++ b/bemani/frontend/reflec/reflec.py @@ -29,7 +29,7 @@ class ReflecBeatFrontend(FrontendBase): def __init__(self, data: Data, config: Dict[str, Any], cache: Cache) -> None: super().__init__(data, config, cache) - def all_games(self) -> Iterator[Tuple[str, int, str]]: + def all_games(self) -> Iterator[Tuple[GameConstants, int, str]]: yield from ReflecBeatFactory.all_games() def format_score(self, userid: UserID, score: Score) -> Dict[str, Any]: diff --git a/bemani/frontend/sdvx/sdvx.py b/bemani/frontend/sdvx/sdvx.py index 3da6002..8a65f96 100644 --- a/bemani/frontend/sdvx/sdvx.py +++ b/bemani/frontend/sdvx/sdvx.py @@ -28,7 +28,7 @@ class SoundVoltexFrontend(FrontendBase): def __init__(self, data: Data, config: Dict[str, Any], cache: Cache) -> None: super().__init__(data, config, cache) - def all_games(self) -> Iterator[Tuple[str, int, str]]: + def all_games(self) -> Iterator[Tuple[GameConstants, int, str]]: yield from SoundVoltexFactory.all_games() def format_score(self, userid: UserID, score: Score) -> Dict[str, Any]: diff --git a/bemani/utils/read.py b/bemani/utils/read.py index d8624dc..17322ac 100644 --- a/bemani/utils/read.py +++ b/bemani/utils/read.py @@ -50,7 +50,7 @@ class ImportBase: def __init__( self, config: Dict[str, Any], - game: str, + game: GameConstants, version: Optional[int], no_combine: bool, update: bool, @@ -123,7 +123,7 @@ class ImportBase: "SELECT id FROM `music` WHERE songid = :songid AND chart = :chart AND game = :game AND version = :version" ) - cursor = self.execute(sql, {'songid': songid, 'chart': chart, 'game': self.game, 'version': version}) + cursor = self.execute(sql, {'songid': songid, 'chart': chart, 'game': self.game.value, 'version': version}) if cursor.rowcount != 0: result = cursor.fetchone() return result['id'] @@ -158,7 +158,7 @@ class ImportBase: frags.append("version = :version") sql = "SELECT id FROM `music` WHERE " + " AND ".join(frags) - cursor = self.execute(sql, {'title': title, 'artist': artist, 'genre': genre, 'chart': chart, 'game': self.game, 'version': version}) + cursor = self.execute(sql, {'title': title, 'artist': artist, 'genre': genre, 'chart': chart, 'game': self.game.value, 'version': version}) if cursor.rowcount != 0: result = cursor.fetchone() return result['id'] @@ -194,7 +194,7 @@ class ImportBase: 'id': musicid, 'songid': songid, 'chart': chart, - 'game': self.game, + 'game': self.game.value, 'version': version, 'name': name, 'artist': artist, @@ -244,7 +244,7 @@ class ImportBase: { 'songid': songid, 'chart': chart, - 'game': self.game, + 'game': self.game.value, 'version': version, 'name': name, 'artist': artist, @@ -286,7 +286,7 @@ class ImportBase: sql, { 'musicid': musicid, - 'game': self.game, + 'game': self.game.value, 'version': version, 'name': name, 'artist': artist, @@ -315,7 +315,7 @@ class ImportBase: { 'id': catid, 'type': cattype, - 'game': self.game, + 'game': self.game.value, 'version': self.version, 'data': jsondata }, @@ -332,7 +332,7 @@ class ImportBase: { 'id': catid, 'type': cattype, - 'game': self.game, + 'game': self.game.value, 'version': self.version, 'data': jsondata },