fix: make database async
This commit is contained in:
parent
1331d473c9
commit
bc7524c8fc
@ -1,8 +1,14 @@
|
||||
from __future__ import with_statement
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from core.data.schema.base import metadata
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
@ -37,34 +43,16 @@ def run_migrations_offline():
|
||||
script output.
|
||||
|
||||
"""
|
||||
raise Exception('Not implemented or configured!')
|
||||
raise Exception("Not implemented or configured!")
|
||||
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url, target_metadata=target_metadata, literal_binds=True)
|
||||
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online():
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
ini_section = config.get_section(config.config_ini_section)
|
||||
overrides = context.get_x_argument(as_dictionary=True)
|
||||
for override in overrides:
|
||||
ini_section[override] = overrides[override]
|
||||
|
||||
connectable = engine_from_config(
|
||||
ini_section,
|
||||
prefix='sqlalchemy.',
|
||||
poolclass=pool.NullPool)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
@ -75,6 +63,44 @@ def run_migrations_online():
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
ini_section = config.get_section(config.config_ini_section)
|
||||
overrides = context.get_x_argument(as_dictionary=True)
|
||||
for override in overrides:
|
||||
ini_section[override] = overrides[override]
|
||||
|
||||
connectable = async_engine_from_config(
|
||||
ini_section, prefix="sqlalchemy.", poolclass=pool.NullPool
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# there's no event loop
|
||||
asyncio.run(run_async_migrations())
|
||||
else:
|
||||
# there's currently an event loop and trying to wait for a coroutine
|
||||
# to finish without using `await` is pretty wormy. nested event loops
|
||||
# are explicitly forbidden by asyncio.
|
||||
#
|
||||
# take the easy way out, spawn it in another thread.
|
||||
thread = threading.Thread(target=asyncio.run, args=(run_async_migrations(),))
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
|
@ -1,54 +1,65 @@
|
||||
import logging, coloredlogs
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
from sqlalchemy import create_engine
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import secrets, string
|
||||
import bcrypt
|
||||
import secrets
|
||||
import string
|
||||
import warnings
|
||||
from hashlib import sha256
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import alembic.config
|
||||
import glob
|
||||
import bcrypt
|
||||
import coloredlogs
|
||||
import pymysql.err
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_scoped_session,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.config import CoreConfig
|
||||
from core.data.schema import *
|
||||
from core.utils import Utils
|
||||
from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata
|
||||
from core.utils import MISSING, Utils
|
||||
|
||||
|
||||
class Data:
|
||||
engine = None
|
||||
session = None
|
||||
user = None
|
||||
arcade = None
|
||||
card = None
|
||||
base = None
|
||||
engine: ClassVar[AsyncEngine] = MISSING
|
||||
session: ClassVar[AsyncSession] = MISSING
|
||||
user: ClassVar[UserData] = MISSING
|
||||
arcade: ClassVar[ArcadeData] = MISSING
|
||||
card: ClassVar[CardData] = MISSING
|
||||
base: ClassVar[BaseData] = MISSING
|
||||
|
||||
def __init__(self, cfg: CoreConfig) -> None:
|
||||
self.config = cfg
|
||||
|
||||
if self.config.database.sha2_password:
|
||||
passwd = sha256(self.config.database.password.encode()).digest()
|
||||
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
||||
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
||||
else:
|
||||
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
||||
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
||||
|
||||
if Data.engine is None:
|
||||
Data.engine = create_engine(self.__url, pool_recycle=3600)
|
||||
if Data.engine is MISSING:
|
||||
Data.engine = create_async_engine(self.__url, pool_recycle=3600, isolation_level="AUTOCOMMIT")
|
||||
self.__engine = Data.engine
|
||||
|
||||
if Data.session is None:
|
||||
s = sessionmaker(bind=Data.engine, autoflush=True, autocommit=True)
|
||||
Data.session = scoped_session(s)
|
||||
if Data.session is MISSING:
|
||||
s = sessionmaker(Data.engine, expire_on_commit=False, class_=AsyncSession)
|
||||
Data.session = async_scoped_session(s, asyncio.current_task)
|
||||
|
||||
if Data.user is None:
|
||||
if Data.user is MISSING:
|
||||
Data.user = UserData(self.config, self.session)
|
||||
|
||||
if Data.arcade is None:
|
||||
if Data.arcade is MISSING:
|
||||
Data.arcade = ArcadeData(self.config, self.session)
|
||||
|
||||
if Data.card is None:
|
||||
if Data.card is MISSING:
|
||||
Data.card = CardData(self.config, self.session)
|
||||
|
||||
if Data.base is None:
|
||||
if Data.base is MISSING:
|
||||
Data.base = BaseData(self.config, self.session)
|
||||
|
||||
self.logger = logging.getLogger("database")
|
||||
@ -94,20 +105,29 @@ class Data:
|
||||
alembic.config.main(argv=alembicArgs)
|
||||
os.chdir(old_dir)
|
||||
|
||||
def create_database(self):
|
||||
async def create_database(self):
|
||||
self.logger.info("Creating databases...")
|
||||
metadata.create_all(
|
||||
self.engine,
|
||||
checkfirst=True,
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# SQLAlchemy will generate a nice primary key constraint name, but in
|
||||
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
|
||||
# custom primary key name is generated, a warning is emitted from pymysql,
|
||||
# which we don't care about. Other warnings may be helpful though, don't
|
||||
# suppress everything.
|
||||
warnings.filterwarnings(
|
||||
action="ignore",
|
||||
message=r"Name '(.+)' ignored for PRIMARY key\.",
|
||||
category=pymysql.err.Warning,
|
||||
)
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(metadata.create_all, checkfirst=True)
|
||||
|
||||
for _, mod in Utils.get_all_titles().items():
|
||||
if hasattr(mod, "database"):
|
||||
mod.database(self.config)
|
||||
metadata.create_all(
|
||||
self.engine,
|
||||
checkfirst=True,
|
||||
)
|
||||
|
||||
await conn.run_sync(metadata.create_all, checkfirst=True)
|
||||
|
||||
# Stamp the end revision as if alembic had created it, so it can take off after this.
|
||||
self.__alembic_cmd(
|
||||
@ -115,19 +135,43 @@ class Data:
|
||||
"head",
|
||||
)
|
||||
|
||||
def schema_upgrade(self, ver: str = None):
|
||||
def schema_upgrade(self, ver: Optional[str] = None):
|
||||
with warnings.catch_warnings():
|
||||
# SQLAlchemy will generate a nice primary key constraint name, but in
|
||||
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
|
||||
# custom primary key name is generated, a warning is emitted from pymysql,
|
||||
# which we don't care about. Other warnings may be helpful though, don't
|
||||
# suppress everything.
|
||||
warnings.filterwarnings(
|
||||
action="ignore",
|
||||
message=r"Name '(.+)' ignored for PRIMARY key\.",
|
||||
category=pymysql.err.Warning,
|
||||
)
|
||||
|
||||
self.__alembic_cmd(
|
||||
"upgrade",
|
||||
"head" if not ver else ver,
|
||||
)
|
||||
|
||||
def schema_downgrade(self, ver: str):
|
||||
with warnings.catch_warnings():
|
||||
# SQLAlchemy will generate a nice primary key constraint name, but in
|
||||
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
|
||||
# custom primary key name is generated, a warning is emitted from pymysql,
|
||||
# which we don't care about. Other warnings may be helpful though, don't
|
||||
# suppress everything.
|
||||
warnings.filterwarnings(
|
||||
action="ignore",
|
||||
message=r"Name '(.+)' ignored for PRIMARY key\.",
|
||||
category=pymysql.err.Warning,
|
||||
)
|
||||
|
||||
self.__alembic_cmd(
|
||||
"downgrade",
|
||||
ver,
|
||||
)
|
||||
|
||||
async def create_owner(self, email: Optional[str] = None, code: Optional[str] = "00000000000000000000") -> None:
|
||||
async def create_owner(self, email: Optional[str] = None, code: str = "00000000000000000000") -> None:
|
||||
pw = "".join(
|
||||
secrets.choice(string.ascii_letters + string.digits) for i in range(20)
|
||||
)
|
||||
@ -150,12 +194,12 @@ class Data:
|
||||
async def migrate(self) -> None:
|
||||
exist = await self.base.execute("SELECT * FROM alembic_version")
|
||||
if exist is not None:
|
||||
self.logger.warn("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!")
|
||||
self.logger.warning("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!")
|
||||
return
|
||||
|
||||
self.logger.info("Upgrading to latest with legacy system")
|
||||
if not await self.legacy_upgrade():
|
||||
self.logger.warn("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!")
|
||||
self.logger.warning("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!")
|
||||
return
|
||||
self.logger.info("Done")
|
||||
|
||||
|
@ -1,16 +1,16 @@
|
||||
from typing import Optional, Dict, List
|
||||
from sqlalchemy import Table, Column, and_, or_
|
||||
from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint
|
||||
from sqlalchemy.types import Integer, String, Boolean, JSON
|
||||
from sqlalchemy.sql import func, select
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import Column, Table, and_, or_
|
||||
from sqlalchemy.dialects.mysql import insert
|
||||
from sqlalchemy.engine import Row
|
||||
import re
|
||||
from sqlalchemy.sql import func, select
|
||||
from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint
|
||||
from sqlalchemy.types import JSON, Boolean, Integer, String
|
||||
|
||||
from core.data.schema.base import BaseData, metadata
|
||||
from core.const import *
|
||||
|
||||
arcade = Table(
|
||||
arcade: Table = Table(
|
||||
"arcade",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, nullable=False),
|
||||
@ -26,7 +26,7 @@ arcade = Table(
|
||||
mysql_charset="utf8mb4",
|
||||
)
|
||||
|
||||
machine = Table(
|
||||
machine: Table = Table(
|
||||
"machine",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, nullable=False),
|
||||
@ -47,7 +47,7 @@ machine = Table(
|
||||
mysql_charset="utf8mb4",
|
||||
)
|
||||
|
||||
arcade_owner = Table(
|
||||
arcade_owner: Table = Table(
|
||||
"arcade_owner",
|
||||
metadata,
|
||||
Column(
|
||||
@ -69,7 +69,7 @@ arcade_owner = Table(
|
||||
|
||||
|
||||
class ArcadeData(BaseData):
|
||||
async def get_machine(self, serial: str = None, id: int = None) -> Optional[Row]:
|
||||
async def get_machine(self, serial: Optional[str] = None, id: Optional[int] = None) -> Optional[Row]:
|
||||
if serial is not None:
|
||||
serial = serial.replace("-", "")
|
||||
if len(serial) == 11:
|
||||
@ -98,8 +98,8 @@ class ArcadeData(BaseData):
|
||||
self,
|
||||
arcade_id: int,
|
||||
serial: str = "",
|
||||
board: str = None,
|
||||
game: str = None,
|
||||
board: Optional[str] = None,
|
||||
game: Optional[str] = None,
|
||||
is_cab: bool = False,
|
||||
) -> Optional[int]:
|
||||
if not arcade_id:
|
||||
@ -150,8 +150,8 @@ class ArcadeData(BaseData):
|
||||
|
||||
async def create_arcade(
|
||||
self,
|
||||
name: str = None,
|
||||
nickname: str = None,
|
||||
name: Optional[str] = None,
|
||||
nickname: Optional[str] = None,
|
||||
country: str = "JPN",
|
||||
country_id: int = 1,
|
||||
state: str = "",
|
||||
|
@ -1,22 +1,23 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from random import randrange
|
||||
from typing import Any, Optional, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import Column, MetaData, Table
|
||||
from sqlalchemy.engine import Row
|
||||
from sqlalchemy.engine.cursor import CursorResult
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.sql import text, func, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy import MetaData, Table, Column
|
||||
from sqlalchemy.types import Integer, String, TIMESTAMP, JSON, INTEGER, TEXT
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.schema import ForeignKey
|
||||
from sqlalchemy.dialects.mysql import insert
|
||||
from sqlalchemy.sql import func, text
|
||||
from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String
|
||||
|
||||
from core.config import CoreConfig
|
||||
|
||||
metadata = MetaData()
|
||||
|
||||
event_log = Table(
|
||||
event_log: Table = Table(
|
||||
"event_log",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, nullable=False),
|
||||
@ -37,7 +38,7 @@ event_log = Table(
|
||||
|
||||
|
||||
class BaseData:
|
||||
def __init__(self, cfg: CoreConfig, conn: Connection) -> None:
|
||||
def __init__(self, cfg: CoreConfig, conn: AsyncSession) -> None:
|
||||
self.config = cfg
|
||||
self.conn = conn
|
||||
self.logger = logging.getLogger("database")
|
||||
@ -47,7 +48,7 @@ class BaseData:
|
||||
|
||||
try:
|
||||
self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}")
|
||||
res = self.conn.execute(text(sql), opts)
|
||||
res = await self.conn.execute(text(sql), opts)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
self.logger.error(f"SQLAlchemy error {e}")
|
||||
@ -59,7 +60,7 @@ class BaseData:
|
||||
|
||||
except Exception:
|
||||
try:
|
||||
res = self.conn.execute(sql, opts)
|
||||
res = await self.conn.execute(sql, opts)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
self.logger.error(f"SQLAlchemy error {e}")
|
||||
@ -83,7 +84,7 @@ class BaseData:
|
||||
|
||||
async def log_event(
|
||||
self, system: str, type: str, severity: int, message: str, details: Dict = {}, user: int = None,
|
||||
arcade: int = None, machine: int = None, ip: str = None, game: str = None, version: str = None
|
||||
arcade: int = None, machine: int = None, ip: Optional[str] = None, game: Optional[str] = None, version: Optional[str] = None
|
||||
) -> Optional[int]:
|
||||
sql = event_log.insert().values(
|
||||
system=system,
|
||||
|
@ -1,13 +1,14 @@
|
||||
from typing import Dict, List, Optional
|
||||
from sqlalchemy import Table, Column, UniqueConstraint
|
||||
from sqlalchemy.types import Integer, String, Boolean, TIMESTAMP, BIGINT, VARCHAR
|
||||
from sqlalchemy.sql.schema import ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from sqlalchemy import Column, Table, UniqueConstraint
|
||||
from sqlalchemy.engine import Row
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.sql.schema import ForeignKey
|
||||
from sqlalchemy.types import BIGINT, TIMESTAMP, VARCHAR, Boolean, Integer, String
|
||||
|
||||
from core.data.schema.base import BaseData, metadata
|
||||
|
||||
aime_card = Table(
|
||||
aime_card: Table = Table(
|
||||
"aime_card",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, nullable=False),
|
||||
|
@ -1,15 +1,15 @@
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import Table, Column
|
||||
from sqlalchemy.types import Integer, String, TIMESTAMP
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.dialects.mysql import insert
|
||||
from sqlalchemy.sql import func, select
|
||||
from sqlalchemy.engine import Row
|
||||
from typing import List, Optional
|
||||
|
||||
import bcrypt
|
||||
from sqlalchemy import Column, Table
|
||||
from sqlalchemy.dialects.mysql import insert
|
||||
from sqlalchemy.engine import Row
|
||||
from sqlalchemy.sql import func, select
|
||||
from sqlalchemy.types import TIMESTAMP, Integer, String
|
||||
|
||||
from core.data.schema.base import BaseData, metadata
|
||||
|
||||
aime_user = Table(
|
||||
aime_user: Table = Table(
|
||||
"aime_user",
|
||||
metadata,
|
||||
Column("id", Integer, nullable=False, primary_key=True, autoincrement=True),
|
||||
@ -26,10 +26,10 @@ aime_user = Table(
|
||||
class UserData(BaseData):
|
||||
async def create_user(
|
||||
self,
|
||||
id: int = None,
|
||||
username: str = None,
|
||||
email: str = None,
|
||||
password: str = None,
|
||||
id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
permission: int = 1,
|
||||
) -> Optional[int]:
|
||||
if id is None:
|
||||
|
@ -1,18 +1,47 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from types import ModuleType
|
||||
from starlette.requests import Request
|
||||
import logging
|
||||
import importlib
|
||||
from os import walk
|
||||
import jwt
|
||||
import logging
|
||||
from base64 import b64decode
|
||||
from datetime import datetime, timezone
|
||||
from os import walk
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from starlette.requests import Request
|
||||
|
||||
from .config import CoreConfig
|
||||
|
||||
|
||||
class _MissingSentinel:
|
||||
__slots__: tuple[str, ...] = ()
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return False
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return 0
|
||||
|
||||
def __repr__(self):
|
||||
return "..."
|
||||
|
||||
|
||||
MISSING: Any = _MissingSentinel()
|
||||
"""This is different from `None` in that its type is `Any`, and so it can be used
|
||||
as a placeholder for values that are *definitely* going to be initialized,
|
||||
so they don't have to be typed as `T | None`, which makes type checkers
|
||||
angry when an attribute is accessed.
|
||||
|
||||
This can also be used for when `None` has actual meaning as a value, and so a
|
||||
separate value is needed to mean "unset"."""
|
||||
|
||||
|
||||
class Utils:
|
||||
real_title_port = None
|
||||
real_title_port_ssl = None
|
||||
|
||||
@classmethod
|
||||
def get_all_titles(cls) -> Dict[str, ModuleType]:
|
||||
ret: Dict[str, Any] = {}
|
||||
@ -39,24 +68,53 @@ class Utils:
|
||||
|
||||
@classmethod
|
||||
def get_title_port(cls, cfg: CoreConfig):
|
||||
if cls.real_title_port is not None: return cls.real_title_port
|
||||
if cls.real_title_port is not None:
|
||||
return cls.real_title_port
|
||||
|
||||
cls.real_title_port = cfg.server.proxy_port if cfg.server.is_using_proxy and cfg.server.proxy_port else cfg.server.port
|
||||
cls.real_title_port = (
|
||||
cfg.server.proxy_port
|
||||
if cfg.server.is_using_proxy and cfg.server.proxy_port
|
||||
else cfg.server.port
|
||||
)
|
||||
|
||||
return cls.real_title_port
|
||||
|
||||
@classmethod
|
||||
def get_title_port_ssl(cls, cfg: CoreConfig):
|
||||
if cls.real_title_port_ssl is not None: return cls.real_title_port_ssl
|
||||
if cls.real_title_port_ssl is not None:
|
||||
return cls.real_title_port_ssl
|
||||
|
||||
cls.real_title_port_ssl = cfg.server.proxy_port_ssl if cfg.server.is_using_proxy and cfg.server.proxy_port_ssl else 443
|
||||
cls.real_title_port_ssl = (
|
||||
cfg.server.proxy_port_ssl
|
||||
if cfg.server.is_using_proxy and cfg.server.proxy_port_ssl
|
||||
else 443
|
||||
)
|
||||
|
||||
return cls.real_title_port_ssl
|
||||
|
||||
def create_sega_auth_key(aime_id: int, game: str, place_id: int, keychip_id: str, b64_secret: str, exp_seconds: int = 86400, err_logger: str = 'aimedb') -> Optional[str]:
|
||||
|
||||
def create_sega_auth_key(
|
||||
aime_id: int,
|
||||
game: str,
|
||||
place_id: int,
|
||||
keychip_id: str,
|
||||
b64_secret: str,
|
||||
exp_seconds: int = 86400,
|
||||
err_logger: str = "aimedb",
|
||||
) -> Optional[str]:
|
||||
logger = logging.getLogger(err_logger)
|
||||
try:
|
||||
return jwt.encode({ "aime_id": aime_id, "game": game, "place_id": place_id, "keychip_id": keychip_id, "exp": int(datetime.now(tz=timezone.utc).timestamp()) + exp_seconds }, b64decode(b64_secret), algorithm="HS256")
|
||||
return jwt.encode(
|
||||
{
|
||||
"aime_id": aime_id,
|
||||
"game": game,
|
||||
"place_id": place_id,
|
||||
"keychip_id": keychip_id,
|
||||
"exp": int(datetime.now(tz=timezone.utc).timestamp()) + exp_seconds,
|
||||
},
|
||||
b64decode(b64_secret),
|
||||
algorithm="HS256",
|
||||
)
|
||||
except jwt.InvalidKeyError:
|
||||
logger.error("Failed to encode Sega Auth Key because the secret is invalid!")
|
||||
return None
|
||||
@ -64,10 +122,19 @@ def create_sega_auth_key(aime_id: int, game: str, place_id: int, keychip_id: str
|
||||
logger.error(f"Unknown exception occoured when encoding Sega Auth Key! {e}")
|
||||
return None
|
||||
|
||||
def decode_sega_auth_key(token: str, b64_secret: str, err_logger: str = 'aimedb') -> Optional[Dict]:
|
||||
|
||||
def decode_sega_auth_key(
|
||||
token: str, b64_secret: str, err_logger: str = "aimedb"
|
||||
) -> Optional[Dict]:
|
||||
logger = logging.getLogger(err_logger)
|
||||
try:
|
||||
return jwt.decode(token, "secret", b64decode(b64_secret), algorithms=["HS256"], options={"verify_signature": True})
|
||||
return jwt.decode(
|
||||
token,
|
||||
"secret",
|
||||
b64decode(b64_secret),
|
||||
algorithms=["HS256"],
|
||||
options={"verify_signature": True},
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.error("Sega Auth Key failed to validate due to an expired signature!")
|
||||
return None
|
||||
@ -83,4 +150,3 @@ def decode_sega_auth_key(token: str, b64_secret: str, err_logger: str = 'aimedb'
|
||||
except Exception as e:
|
||||
logger.error(f"Unknown exception occoured when decoding Sega Auth Key! {e}")
|
||||
return None
|
||||
|
11
dbutils.py
11
dbutils.py
@ -1,12 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
from os import mkdir, path, access, W_OK, environ
|
||||
import yaml
|
||||
import asyncio
|
||||
import logging
|
||||
from os import W_OK, access, environ, mkdir, path
|
||||
|
||||
import yaml
|
||||
|
||||
from core.data import Data
|
||||
from core.config import CoreConfig
|
||||
from core.data import Data
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Database utilities")
|
||||
@ -46,7 +47,7 @@ if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if args.action == "create":
|
||||
data.create_database()
|
||||
loop.run_until_complete(data.create_database())
|
||||
|
||||
elif args.action == "upgrade":
|
||||
data.schema_upgrade(args.version)
|
||||
|
24
read.py
24
read.py
@ -1,16 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import re
|
||||
import os
|
||||
import yaml
|
||||
from os import path
|
||||
import logging
|
||||
import coloredlogs
|
||||
import asyncio
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from os import path
|
||||
from typing import List, Optional
|
||||
|
||||
import coloredlogs
|
||||
import yaml
|
||||
|
||||
from core import CoreConfig, Utils
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ class BaseReader:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="Import Game Information")
|
||||
parser.add_argument(
|
||||
"--game",
|
||||
@ -140,8 +140,12 @@ if __name__ == "__main__":
|
||||
for dir, mod in titles.items():
|
||||
if args.game in mod.game_codes:
|
||||
handler = mod.reader(config, args.version, bin_arg, opt_arg, args.extra)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(handler.read())
|
||||
|
||||
await handler.read()
|
||||
|
||||
|
||||
logger.info("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
Loading…
x
Reference in New Issue
Block a user