215 lines
6.9 KiB
Python
215 lines
6.9 KiB
Python
import os
|
|
from typing import Dict, Any
|
|
|
|
import alembic.config # type: ignore
|
|
from alembic.migration import MigrationContext # type: ignore
|
|
from alembic.autogenerate import compare_metadata # type: ignore
|
|
from sqlalchemy import create_engine # type: ignore
|
|
from sqlalchemy.orm import scoped_session # type: ignore
|
|
from sqlalchemy.orm import sessionmaker # type: ignore
|
|
from sqlalchemy.engine import Engine # type: ignore
|
|
from sqlalchemy.sql import text # type: ignore
|
|
from sqlalchemy.exc import ProgrammingError # type: ignore
|
|
|
|
from bemani.data.api.user import GlobalUserData
|
|
from bemani.data.api.game import GlobalGameData
|
|
from bemani.data.api.music import GlobalMusicData
|
|
from bemani.data.mysql.base import metadata
|
|
from bemani.data.mysql.user import UserData
|
|
from bemani.data.mysql.music import MusicData
|
|
from bemani.data.mysql.machine import MachineData
|
|
from bemani.data.mysql.game import GameData
|
|
from bemani.data.mysql.network import NetworkData
|
|
from bemani.data.mysql.lobby import LobbyData
|
|
from bemani.data.mysql.api import APIData
|
|
|
|
|
|
class DBCreateException(Exception):
|
|
pass
|
|
|
|
|
|
class LocalProvider:
|
|
"""
|
|
A wrapper object for implementing local data operations only. Right
|
|
now this goes to the MySQL classes and talks to the backend DB.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
user: UserData,
|
|
music: MusicData,
|
|
machine: MachineData,
|
|
game: GameData,
|
|
network: NetworkData,
|
|
lobby: LobbyData,
|
|
api: APIData,
|
|
) -> None:
|
|
self.user = user
|
|
self.music = music
|
|
self.machine = machine
|
|
self.game = game
|
|
self.network = network
|
|
self.lobby = lobby
|
|
self.api = api
|
|
|
|
|
|
class GlobalProvider:
|
|
"""
|
|
A class that handles fetching data locally and from remote data APIs.
|
|
This means combining data fetched from local MySQL with data fetched
|
|
from remote servers that support BEMAPI.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
local: LocalProvider,
|
|
) -> None:
|
|
self.user = GlobalUserData(
|
|
local.api,
|
|
local.user,
|
|
)
|
|
self.music = GlobalMusicData(
|
|
local.api,
|
|
local.user,
|
|
local.music,
|
|
)
|
|
self.game = GlobalGameData(
|
|
local.api,
|
|
)
|
|
|
|
|
|
class Data:
|
|
"""
|
|
An object that is meant to be used as a singleton, in order to hold
|
|
DB configuration info and provide a set of functions for querying
|
|
and storing data.
|
|
"""
|
|
|
|
def __init__(self, config: Dict[str, Any]) -> None:
|
|
"""
|
|
Initializes the data object.
|
|
|
|
Parameters:
|
|
config - A config structure with a 'database' section which is used
|
|
to initialize an internal DB connection.
|
|
"""
|
|
session_factory = sessionmaker(
|
|
bind=config['database']['engine'],
|
|
autoflush=True,
|
|
autocommit=True,
|
|
)
|
|
self.__config = config
|
|
self.__session = scoped_session(session_factory)
|
|
self.__url = Data.sqlalchemy_url(config)
|
|
self.__user = UserData(config, self.__session)
|
|
self.__music = MusicData(config, self.__session)
|
|
self.__machine = MachineData(config, self.__session)
|
|
self.__game = GameData(config, self.__session)
|
|
self.__network = NetworkData(config, self.__session)
|
|
self.__lobby = LobbyData(config, self.__session)
|
|
self.__api = APIData(config, self.__session)
|
|
self.local = LocalProvider(
|
|
self.__user,
|
|
self.__music,
|
|
self.__machine,
|
|
self.__game,
|
|
self.__network,
|
|
self.__lobby,
|
|
self.__api,
|
|
)
|
|
self.remote = GlobalProvider(self.local)
|
|
|
|
@classmethod
|
|
def sqlalchemy_url(cls, config: Dict[str, Any]) -> str:
|
|
return f"mysql://{config['database']['user']}:{config['database']['password']}@{config['database']['address']}/{config['database']['database']}?charset=utf8mb4"
|
|
|
|
@classmethod
|
|
def create_engine(cls, config: Dict[str, Any]) -> Engine:
|
|
return create_engine( # type: ignore
|
|
Data.sqlalchemy_url(config),
|
|
pool_recycle=3600,
|
|
)
|
|
|
|
def __exists(self) -> bool:
|
|
# See if the DB was already created
|
|
try:
|
|
cursor = self.__session.execute(text('SELECT COUNT(version_num) AS count FROM alembic_version'))
|
|
return (cursor.fetchone()['count'] == 1)
|
|
except ProgrammingError:
|
|
return False
|
|
|
|
def __alembic_cmd(self, command: str, *args: str) -> None:
|
|
base_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'migrations')
|
|
alembicArgs = [
|
|
'-c',
|
|
os.path.join(base_dir, 'alembic.ini'),
|
|
'-x',
|
|
f'script_location={base_dir}',
|
|
'-x',
|
|
f'sqlalchemy.url={self.__url}',
|
|
command,
|
|
]
|
|
alembicArgs.extend(args)
|
|
os.chdir(base_dir)
|
|
alembic.config.main(argv=alembicArgs)
|
|
|
|
def create(self) -> None:
|
|
"""
|
|
Create any tables that need to be created.
|
|
"""
|
|
if self.__exists():
|
|
# Cowardly refused to do anything, we should be using the upgrade path instead.
|
|
raise DBCreateException('Tables already created, use upgrade to upgrade schema!')
|
|
|
|
metadata.create_all( # type: ignore
|
|
self.__config['database']['engine'].connect(),
|
|
checkfirst=True,
|
|
)
|
|
|
|
# Stamp the end revision as if alembic had created it, so it can take off after this.
|
|
self.__alembic_cmd(
|
|
'stamp',
|
|
'head',
|
|
)
|
|
|
|
def generate(self, message: str, allow_empty: bool) -> None:
|
|
"""
|
|
Generate upgrade scripts using alembic.
|
|
"""
|
|
if not self.__exists():
|
|
raise DBCreateException('Tables have not been created yet, use create to create them!')
|
|
|
|
# Verify that there are actual changes, and refuse to create empty migration scripts
|
|
context = MigrationContext.configure(self.__config['database']['engine'].connect(), opts={'compare_type': True})
|
|
diff = compare_metadata(context, metadata)
|
|
if (not allow_empty) and (len(diff) == 0):
|
|
raise DBCreateException('There is nothing different between code and the DB, refusing to create migration!')
|
|
|
|
self.__alembic_cmd(
|
|
'revision',
|
|
'--autogenerate',
|
|
'-m',
|
|
message,
|
|
)
|
|
|
|
def upgrade(self) -> None:
|
|
"""
|
|
Upgrade an existing DB to the current model.
|
|
"""
|
|
if not self.__exists():
|
|
raise DBCreateException('Tables have not been created yet, use create to create them!')
|
|
|
|
self.__alembic_cmd(
|
|
'upgrade',
|
|
'head',
|
|
)
|
|
|
|
def close(self) -> None:
|
|
"""
|
|
Close any open data connection.
|
|
"""
|
|
# Make sure we don't leak connections between web requests
|
|
if self.__session is not None:
|
|
self.__session.close()
|
|
self.__session = None
|