1
0
mirror of synced 2025-02-07 14:31:20 +01:00

fix: make database async

This commit is contained in:
beerpsi 2024-11-14 12:36:22 +07:00
parent 1331d473c9
commit bc7524c8fc
9 changed files with 297 additions and 154 deletions

View File

@ -1,8 +1,14 @@
from __future__ import with_statement from __future__ import with_statement
from alembic import context
from sqlalchemy import engine_from_config, pool import asyncio
import threading
from logging.config import fileConfig from logging.config import fileConfig
from alembic import context
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from core.data.schema.base import metadata from core.data.schema.base import metadata
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
@ -37,34 +43,16 @@ def run_migrations_offline():
script output. script output.
""" """
raise Exception('Not implemented or configured!') raise Exception("Not implemented or configured!")
url = config.get_main_option("sqlalchemy.url") url = config.get_main_option("sqlalchemy.url")
context.configure( context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
url=url, target_metadata=target_metadata, literal_binds=True)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()
def run_migrations_online(): def do_run_migrations(connection: Connection) -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
ini_section = config.get_section(config.config_ini_section)
overrides = context.get_x_argument(as_dictionary=True)
for override in overrides:
ini_section[override] = overrides[override]
connectable = engine_from_config(
ini_section,
prefix='sqlalchemy.',
poolclass=pool.NullPool)
with connectable.connect() as connection:
context.configure( context.configure(
connection=connection, connection=connection,
target_metadata=target_metadata, target_metadata=target_metadata,
@ -75,6 +63,44 @@ def run_migrations_online():
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
ini_section = config.get_section(config.config_ini_section)
overrides = context.get_x_argument(as_dictionary=True)
for override in overrides:
ini_section[override] = overrides[override]
connectable = async_engine_from_config(
ini_section, prefix="sqlalchemy.", poolclass=pool.NullPool
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# there's no event loop
asyncio.run(run_async_migrations())
else:
# there's currently an event loop and trying to wait for a coroutine
# to finish without using `await` is pretty wormy. nested event loops
# are explicitly forbidden by asyncio.
#
# take the easy way out, spawn it in another thread.
thread = threading.Thread(target=asyncio.run, args=(run_async_migrations(),))
thread.start()
thread.join()
if context.is_offline_mode(): if context.is_offline_mode():
run_migrations_offline() run_migrations_offline()
else: else:

View File

@ -1,54 +1,65 @@
import logging, coloredlogs import asyncio
from typing import Optional import logging
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy import create_engine
from logging.handlers import TimedRotatingFileHandler
import os import os
import secrets, string import secrets
import bcrypt import string
import warnings
from hashlib import sha256 from hashlib import sha256
from logging.handlers import TimedRotatingFileHandler
from typing import ClassVar, Optional
import alembic.config import alembic.config
import glob import bcrypt
import coloredlogs
import pymysql.err
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_scoped_session,
create_async_engine,
)
from sqlalchemy.orm import sessionmaker
from core.config import CoreConfig from core.config import CoreConfig
from core.data.schema import * from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata
from core.utils import Utils from core.utils import MISSING, Utils
class Data: class Data:
engine = None engine: ClassVar[AsyncEngine] = MISSING
session = None session: ClassVar[AsyncSession] = MISSING
user = None user: ClassVar[UserData] = MISSING
arcade = None arcade: ClassVar[ArcadeData] = MISSING
card = None card: ClassVar[CardData] = MISSING
base = None base: ClassVar[BaseData] = MISSING
def __init__(self, cfg: CoreConfig) -> None: def __init__(self, cfg: CoreConfig) -> None:
self.config = cfg self.config = cfg
if self.config.database.sha2_password: if self.config.database.sha2_password:
passwd = sha256(self.config.database.password.encode()).digest() passwd = sha256(self.config.database.password.encode()).digest()
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}" self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
else: else:
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}" self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
if Data.engine is None: if Data.engine is MISSING:
Data.engine = create_engine(self.__url, pool_recycle=3600) Data.engine = create_async_engine(self.__url, pool_recycle=3600, isolation_level="AUTOCOMMIT")
self.__engine = Data.engine self.__engine = Data.engine
if Data.session is None: if Data.session is MISSING:
s = sessionmaker(bind=Data.engine, autoflush=True, autocommit=True) s = sessionmaker(Data.engine, expire_on_commit=False, class_=AsyncSession)
Data.session = scoped_session(s) Data.session = async_scoped_session(s, asyncio.current_task)
if Data.user is None: if Data.user is MISSING:
Data.user = UserData(self.config, self.session) Data.user = UserData(self.config, self.session)
if Data.arcade is None: if Data.arcade is MISSING:
Data.arcade = ArcadeData(self.config, self.session) Data.arcade = ArcadeData(self.config, self.session)
if Data.card is None: if Data.card is MISSING:
Data.card = CardData(self.config, self.session) Data.card = CardData(self.config, self.session)
if Data.base is None: if Data.base is MISSING:
Data.base = BaseData(self.config, self.session) Data.base = BaseData(self.config, self.session)
self.logger = logging.getLogger("database") self.logger = logging.getLogger("database")
@ -94,20 +105,29 @@ class Data:
alembic.config.main(argv=alembicArgs) alembic.config.main(argv=alembicArgs)
os.chdir(old_dir) os.chdir(old_dir)
def create_database(self): async def create_database(self):
self.logger.info("Creating databases...") self.logger.info("Creating databases...")
metadata.create_all(
self.engine, with warnings.catch_warnings():
checkfirst=True, # SQLAlchemy will generate a nice primary key constraint name, but in
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
# custom primary key name is generated, a warning is emitted from pymysql,
# which we don't care about. Other warnings may be helpful though, don't
# suppress everything.
warnings.filterwarnings(
action="ignore",
message=r"Name '(.+)' ignored for PRIMARY key\.",
category=pymysql.err.Warning,
) )
async with self.engine.begin() as conn:
await conn.run_sync(metadata.create_all, checkfirst=True)
for _, mod in Utils.get_all_titles().items(): for _, mod in Utils.get_all_titles().items():
if hasattr(mod, "database"): if hasattr(mod, "database"):
mod.database(self.config) mod.database(self.config)
metadata.create_all(
self.engine, await conn.run_sync(metadata.create_all, checkfirst=True)
checkfirst=True,
)
# Stamp the end revision as if alembic had created it, so it can take off after this. # Stamp the end revision as if alembic had created it, so it can take off after this.
self.__alembic_cmd( self.__alembic_cmd(
@ -115,19 +135,43 @@ class Data:
"head", "head",
) )
def schema_upgrade(self, ver: str = None): def schema_upgrade(self, ver: Optional[str] = None):
with warnings.catch_warnings():
# SQLAlchemy will generate a nice primary key constraint name, but in
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
# custom primary key name is generated, a warning is emitted from pymysql,
# which we don't care about. Other warnings may be helpful though, don't
# suppress everything.
warnings.filterwarnings(
action="ignore",
message=r"Name '(.+)' ignored for PRIMARY key\.",
category=pymysql.err.Warning,
)
self.__alembic_cmd( self.__alembic_cmd(
"upgrade", "upgrade",
"head" if not ver else ver, "head" if not ver else ver,
) )
def schema_downgrade(self, ver: str): def schema_downgrade(self, ver: str):
with warnings.catch_warnings():
# SQLAlchemy will generate a nice primary key constraint name, but in
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
# custom primary key name is generated, a warning is emitted from pymysql,
# which we don't care about. Other warnings may be helpful though, don't
# suppress everything.
warnings.filterwarnings(
action="ignore",
message=r"Name '(.+)' ignored for PRIMARY key\.",
category=pymysql.err.Warning,
)
self.__alembic_cmd( self.__alembic_cmd(
"downgrade", "downgrade",
ver, ver,
) )
async def create_owner(self, email: Optional[str] = None, code: Optional[str] = "00000000000000000000") -> None: async def create_owner(self, email: Optional[str] = None, code: str = "00000000000000000000") -> None:
pw = "".join( pw = "".join(
secrets.choice(string.ascii_letters + string.digits) for i in range(20) secrets.choice(string.ascii_letters + string.digits) for i in range(20)
) )
@ -150,12 +194,12 @@ class Data:
async def migrate(self) -> None: async def migrate(self) -> None:
exist = await self.base.execute("SELECT * FROM alembic_version") exist = await self.base.execute("SELECT * FROM alembic_version")
if exist is not None: if exist is not None:
self.logger.warn("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!") self.logger.warning("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!")
return return
self.logger.info("Upgrading to latest with legacy system") self.logger.info("Upgrading to latest with legacy system")
if not await self.legacy_upgrade(): if not await self.legacy_upgrade():
self.logger.warn("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!") self.logger.warning("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!")
return return
self.logger.info("Done") self.logger.info("Done")

View File

@ -1,16 +1,16 @@
from typing import Optional, Dict, List import re
from sqlalchemy import Table, Column, and_, or_ from typing import List, Optional
from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint
from sqlalchemy.types import Integer, String, Boolean, JSON from sqlalchemy import Column, Table, and_, or_
from sqlalchemy.sql import func, select
from sqlalchemy.dialects.mysql import insert from sqlalchemy.dialects.mysql import insert
from sqlalchemy.engine import Row from sqlalchemy.engine import Row
import re from sqlalchemy.sql import func, select
from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint
from sqlalchemy.types import JSON, Boolean, Integer, String
from core.data.schema.base import BaseData, metadata from core.data.schema.base import BaseData, metadata
from core.const import *
arcade = Table( arcade: Table = Table(
"arcade", "arcade",
metadata, metadata,
Column("id", Integer, primary_key=True, nullable=False), Column("id", Integer, primary_key=True, nullable=False),
@ -26,7 +26,7 @@ arcade = Table(
mysql_charset="utf8mb4", mysql_charset="utf8mb4",
) )
machine = Table( machine: Table = Table(
"machine", "machine",
metadata, metadata,
Column("id", Integer, primary_key=True, nullable=False), Column("id", Integer, primary_key=True, nullable=False),
@ -47,7 +47,7 @@ machine = Table(
mysql_charset="utf8mb4", mysql_charset="utf8mb4",
) )
arcade_owner = Table( arcade_owner: Table = Table(
"arcade_owner", "arcade_owner",
metadata, metadata,
Column( Column(
@ -69,7 +69,7 @@ arcade_owner = Table(
class ArcadeData(BaseData): class ArcadeData(BaseData):
async def get_machine(self, serial: str = None, id: int = None) -> Optional[Row]: async def get_machine(self, serial: Optional[str] = None, id: Optional[int] = None) -> Optional[Row]:
if serial is not None: if serial is not None:
serial = serial.replace("-", "") serial = serial.replace("-", "")
if len(serial) == 11: if len(serial) == 11:
@ -98,8 +98,8 @@ class ArcadeData(BaseData):
self, self,
arcade_id: int, arcade_id: int,
serial: str = "", serial: str = "",
board: str = None, board: Optional[str] = None,
game: str = None, game: Optional[str] = None,
is_cab: bool = False, is_cab: bool = False,
) -> Optional[int]: ) -> Optional[int]:
if not arcade_id: if not arcade_id:
@ -150,8 +150,8 @@ class ArcadeData(BaseData):
async def create_arcade( async def create_arcade(
self, self,
name: str = None, name: Optional[str] = None,
nickname: str = None, nickname: Optional[str] = None,
country: str = "JPN", country: str = "JPN",
country_id: int = 1, country_id: int = 1,
state: str = "", state: str = "",

View File

@ -1,22 +1,23 @@
import asyncio
import json import json
import logging import logging
from random import randrange from random import randrange
from typing import Any, Optional, Dict, List from typing import Any, Dict, List, Optional
from sqlalchemy import Column, MetaData, Table
from sqlalchemy.engine import Row from sqlalchemy.engine import Row
from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.base import Connection
from sqlalchemy.sql import text, func, select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import MetaData, Table, Column from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.types import Integer, String, TIMESTAMP, JSON, INTEGER, TEXT
from sqlalchemy.schema import ForeignKey from sqlalchemy.schema import ForeignKey
from sqlalchemy.dialects.mysql import insert from sqlalchemy.sql import func, text
from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String
from core.config import CoreConfig from core.config import CoreConfig
metadata = MetaData() metadata = MetaData()
event_log = Table( event_log: Table = Table(
"event_log", "event_log",
metadata, metadata,
Column("id", Integer, primary_key=True, nullable=False), Column("id", Integer, primary_key=True, nullable=False),
@ -37,7 +38,7 @@ event_log = Table(
class BaseData: class BaseData:
def __init__(self, cfg: CoreConfig, conn: Connection) -> None: def __init__(self, cfg: CoreConfig, conn: AsyncSession) -> None:
self.config = cfg self.config = cfg
self.conn = conn self.conn = conn
self.logger = logging.getLogger("database") self.logger = logging.getLogger("database")
@ -47,7 +48,7 @@ class BaseData:
try: try:
self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}") self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}")
res = self.conn.execute(text(sql), opts) res = await self.conn.execute(text(sql), opts)
except SQLAlchemyError as e: except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}") self.logger.error(f"SQLAlchemy error {e}")
@ -59,7 +60,7 @@ class BaseData:
except Exception: except Exception:
try: try:
res = self.conn.execute(sql, opts) res = await self.conn.execute(sql, opts)
except SQLAlchemyError as e: except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}") self.logger.error(f"SQLAlchemy error {e}")
@ -83,7 +84,7 @@ class BaseData:
async def log_event( async def log_event(
self, system: str, type: str, severity: int, message: str, details: Dict = {}, user: int = None, self, system: str, type: str, severity: int, message: str, details: Dict = {}, user: int = None,
arcade: int = None, machine: int = None, ip: str = None, game: str = None, version: str = None arcade: int = None, machine: int = None, ip: Optional[str] = None, game: Optional[str] = None, version: Optional[str] = None
) -> Optional[int]: ) -> Optional[int]:
sql = event_log.insert().values( sql = event_log.insert().values(
system=system, system=system,

View File

@ -1,13 +1,14 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
from sqlalchemy import Table, Column, UniqueConstraint
from sqlalchemy.types import Integer, String, Boolean, TIMESTAMP, BIGINT, VARCHAR from sqlalchemy import Column, Table, UniqueConstraint
from sqlalchemy.sql.schema import ForeignKey
from sqlalchemy.sql import func
from sqlalchemy.engine import Row from sqlalchemy.engine import Row
from sqlalchemy.sql import func
from sqlalchemy.sql.schema import ForeignKey
from sqlalchemy.types import BIGINT, TIMESTAMP, VARCHAR, Boolean, Integer, String
from core.data.schema.base import BaseData, metadata from core.data.schema.base import BaseData, metadata
aime_card = Table( aime_card: Table = Table(
"aime_card", "aime_card",
metadata, metadata,
Column("id", Integer, primary_key=True, nullable=False), Column("id", Integer, primary_key=True, nullable=False),

View File

@ -1,15 +1,15 @@
from typing import Optional, List from typing import List, Optional
from sqlalchemy import Table, Column
from sqlalchemy.types import Integer, String, TIMESTAMP
from sqlalchemy.sql import func
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.sql import func, select
from sqlalchemy.engine import Row
import bcrypt import bcrypt
from sqlalchemy import Column, Table
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.engine import Row
from sqlalchemy.sql import func, select
from sqlalchemy.types import TIMESTAMP, Integer, String
from core.data.schema.base import BaseData, metadata from core.data.schema.base import BaseData, metadata
aime_user = Table( aime_user: Table = Table(
"aime_user", "aime_user",
metadata, metadata,
Column("id", Integer, nullable=False, primary_key=True, autoincrement=True), Column("id", Integer, nullable=False, primary_key=True, autoincrement=True),
@ -26,10 +26,10 @@ aime_user = Table(
class UserData(BaseData): class UserData(BaseData):
async def create_user( async def create_user(
self, self,
id: int = None, id: Optional[int] = None,
username: str = None, username: Optional[str] = None,
email: str = None, email: Optional[str] = None,
password: str = None, password: Optional[str] = None,
permission: int = 1, permission: int = 1,
) -> Optional[int]: ) -> Optional[int]:
if id is None: if id is None:

View File

@ -1,18 +1,47 @@
from typing import Dict, Any, Optional
from types import ModuleType
from starlette.requests import Request
import logging
import importlib import importlib
from os import walk import logging
import jwt
from base64 import b64decode from base64 import b64decode
from datetime import datetime, timezone from datetime import datetime, timezone
from os import walk
from types import ModuleType
from typing import Any, Dict, Optional
import jwt
from starlette.requests import Request
from .config import CoreConfig from .config import CoreConfig
class _MissingSentinel:
__slots__: tuple[str, ...] = ()
def __eq__(self, other) -> bool:
return False
def __bool__(self) -> bool:
return False
def __hash__(self) -> int:
return 0
def __repr__(self):
return "..."
MISSING: Any = _MissingSentinel()
"""This is different from `None` in that its type is `Any`, and so it can be used
as a placeholder for values that are *definitely* going to be initialized,
so they don't have to be typed as `T | None`, which makes type checkers
angry when an attribute is accessed.
This can also be used for when `None` has actual meaning as a value, and so a
separate value is needed to mean "unset"."""
class Utils: class Utils:
real_title_port = None real_title_port = None
real_title_port_ssl = None real_title_port_ssl = None
@classmethod @classmethod
def get_all_titles(cls) -> Dict[str, ModuleType]: def get_all_titles(cls) -> Dict[str, ModuleType]:
ret: Dict[str, Any] = {} ret: Dict[str, Any] = {}
@ -39,24 +68,53 @@ class Utils:
@classmethod @classmethod
def get_title_port(cls, cfg: CoreConfig): def get_title_port(cls, cfg: CoreConfig):
if cls.real_title_port is not None: return cls.real_title_port if cls.real_title_port is not None:
return cls.real_title_port
cls.real_title_port = cfg.server.proxy_port if cfg.server.is_using_proxy and cfg.server.proxy_port else cfg.server.port cls.real_title_port = (
cfg.server.proxy_port
if cfg.server.is_using_proxy and cfg.server.proxy_port
else cfg.server.port
)
return cls.real_title_port return cls.real_title_port
@classmethod @classmethod
def get_title_port_ssl(cls, cfg: CoreConfig): def get_title_port_ssl(cls, cfg: CoreConfig):
if cls.real_title_port_ssl is not None: return cls.real_title_port_ssl if cls.real_title_port_ssl is not None:
return cls.real_title_port_ssl
cls.real_title_port_ssl = cfg.server.proxy_port_ssl if cfg.server.is_using_proxy and cfg.server.proxy_port_ssl else 443 cls.real_title_port_ssl = (
cfg.server.proxy_port_ssl
if cfg.server.is_using_proxy and cfg.server.proxy_port_ssl
else 443
)
return cls.real_title_port_ssl return cls.real_title_port_ssl
def create_sega_auth_key(aime_id: int, game: str, place_id: int, keychip_id: str, b64_secret: str, exp_seconds: int = 86400, err_logger: str = 'aimedb') -> Optional[str]:
def create_sega_auth_key(
aime_id: int,
game: str,
place_id: int,
keychip_id: str,
b64_secret: str,
exp_seconds: int = 86400,
err_logger: str = "aimedb",
) -> Optional[str]:
logger = logging.getLogger(err_logger) logger = logging.getLogger(err_logger)
try: try:
return jwt.encode({ "aime_id": aime_id, "game": game, "place_id": place_id, "keychip_id": keychip_id, "exp": int(datetime.now(tz=timezone.utc).timestamp()) + exp_seconds }, b64decode(b64_secret), algorithm="HS256") return jwt.encode(
{
"aime_id": aime_id,
"game": game,
"place_id": place_id,
"keychip_id": keychip_id,
"exp": int(datetime.now(tz=timezone.utc).timestamp()) + exp_seconds,
},
b64decode(b64_secret),
algorithm="HS256",
)
except jwt.InvalidKeyError: except jwt.InvalidKeyError:
logger.error("Failed to encode Sega Auth Key because the secret is invalid!") logger.error("Failed to encode Sega Auth Key because the secret is invalid!")
return None return None
@ -64,10 +122,19 @@ def create_sega_auth_key(aime_id: int, game: str, place_id: int, keychip_id: str
logger.error(f"Unknown exception occoured when encoding Sega Auth Key! {e}") logger.error(f"Unknown exception occoured when encoding Sega Auth Key! {e}")
return None return None
def decode_sega_auth_key(token: str, b64_secret: str, err_logger: str = 'aimedb') -> Optional[Dict]:
def decode_sega_auth_key(
token: str, b64_secret: str, err_logger: str = "aimedb"
) -> Optional[Dict]:
logger = logging.getLogger(err_logger) logger = logging.getLogger(err_logger)
try: try:
return jwt.decode(token, "secret", b64decode(b64_secret), algorithms=["HS256"], options={"verify_signature": True}) return jwt.decode(
token,
"secret",
b64decode(b64_secret),
algorithms=["HS256"],
options={"verify_signature": True},
)
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
logger.error("Sega Auth Key failed to validate due to an expired signature!") logger.error("Sega Auth Key failed to validate due to an expired signature!")
return None return None
@ -83,4 +150,3 @@ def decode_sega_auth_key(token: str, b64_secret: str, err_logger: str = 'aimedb'
except Exception as e: except Exception as e:
logger.error(f"Unknown exception occoured when decoding Sega Auth Key! {e}") logger.error(f"Unknown exception occoured when decoding Sega Auth Key! {e}")
return None return None

View File

@ -1,12 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import logging
from os import mkdir, path, access, W_OK, environ
import yaml
import asyncio import asyncio
import logging
from os import W_OK, access, environ, mkdir, path
import yaml
from core.data import Data
from core.config import CoreConfig from core.config import CoreConfig
from core.data import Data
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Database utilities") parser = argparse.ArgumentParser(description="Database utilities")
@ -46,7 +47,7 @@ if __name__ == "__main__":
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if args.action == "create": if args.action == "create":
data.create_database() loop.run_until_complete(data.create_database())
elif args.action == "upgrade": elif args.action == "upgrade":
data.schema_upgrade(args.version) data.schema_upgrade(args.version)

24
read.py
View File

@ -1,16 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import re
import os
import yaml
from os import path
import logging
import coloredlogs
import asyncio import asyncio
import logging
import os
import re
from logging.handlers import TimedRotatingFileHandler from logging.handlers import TimedRotatingFileHandler
from os import path
from typing import List, Optional from typing import List, Optional
import coloredlogs
import yaml
from core import CoreConfig, Utils from core import CoreConfig, Utils
@ -44,7 +44,7 @@ class BaseReader:
pass pass
if __name__ == "__main__": async def main():
parser = argparse.ArgumentParser(description="Import Game Information") parser = argparse.ArgumentParser(description="Import Game Information")
parser.add_argument( parser.add_argument(
"--game", "--game",
@ -140,8 +140,12 @@ if __name__ == "__main__":
for dir, mod in titles.items(): for dir, mod in titles.items():
if args.game in mod.game_codes: if args.game in mod.game_codes:
handler = mod.reader(config, args.version, bin_arg, opt_arg, args.extra) handler = mod.reader(config, args.version, bin_arg, opt_arg, args.extra)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.read()) await handler.read()
logger.info("Done") logger.info("Done")
if __name__ == "__main__":
asyncio.run(main())