fix: make database async
This commit is contained in:
parent
1331d473c9
commit
bc7524c8fc
@ -1,8 +1,14 @@
|
|||||||
from __future__ import with_statement
|
from __future__ import with_statement
|
||||||
from alembic import context
|
|
||||||
from sqlalchemy import engine_from_config, pool
|
import asyncio
|
||||||
|
import threading
|
||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||||
|
|
||||||
from core.data.schema.base import metadata
|
from core.data.schema.base import metadata
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
@ -37,34 +43,16 @@ def run_migrations_offline():
|
|||||||
script output.
|
script output.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise Exception('Not implemented or configured!')
|
raise Exception("Not implemented or configured!")
|
||||||
|
|
||||||
url = config.get_main_option("sqlalchemy.url")
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
context.configure(
|
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
|
||||||
url=url, target_metadata=target_metadata, literal_binds=True)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
context.run_migrations()
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_online():
|
def do_run_migrations(connection: Connection) -> None:
|
||||||
"""Run migrations in 'online' mode.
|
|
||||||
|
|
||||||
In this scenario we need to create an Engine
|
|
||||||
and associate a connection with the context.
|
|
||||||
|
|
||||||
"""
|
|
||||||
ini_section = config.get_section(config.config_ini_section)
|
|
||||||
overrides = context.get_x_argument(as_dictionary=True)
|
|
||||||
for override in overrides:
|
|
||||||
ini_section[override] = overrides[override]
|
|
||||||
|
|
||||||
connectable = engine_from_config(
|
|
||||||
ini_section,
|
|
||||||
prefix='sqlalchemy.',
|
|
||||||
poolclass=pool.NullPool)
|
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
|
||||||
context.configure(
|
context.configure(
|
||||||
connection=connection,
|
connection=connection,
|
||||||
target_metadata=target_metadata,
|
target_metadata=target_metadata,
|
||||||
@ -75,6 +63,44 @@ def run_migrations_online():
|
|||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
context.run_migrations()
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_async_migrations() -> None:
|
||||||
|
"""In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
ini_section = config.get_section(config.config_ini_section)
|
||||||
|
overrides = context.get_x_argument(as_dictionary=True)
|
||||||
|
for override in overrides:
|
||||||
|
ini_section[override] = overrides[override]
|
||||||
|
|
||||||
|
connectable = async_engine_from_config(
|
||||||
|
ini_section, prefix="sqlalchemy.", poolclass=pool.NullPool
|
||||||
|
)
|
||||||
|
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online():
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
# there's no event loop
|
||||||
|
asyncio.run(run_async_migrations())
|
||||||
|
else:
|
||||||
|
# there's currently an event loop and trying to wait for a coroutine
|
||||||
|
# to finish without using `await` is pretty wormy. nested event loops
|
||||||
|
# are explicitly forbidden by asyncio.
|
||||||
|
#
|
||||||
|
# take the easy way out, spawn it in another thread.
|
||||||
|
thread = threading.Thread(target=asyncio.run, args=(run_async_migrations(),))
|
||||||
|
thread.start()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
|
||||||
if context.is_offline_mode():
|
if context.is_offline_mode():
|
||||||
run_migrations_offline()
|
run_migrations_offline()
|
||||||
else:
|
else:
|
||||||
|
@ -1,54 +1,65 @@
|
|||||||
import logging, coloredlogs
|
import asyncio
|
||||||
from typing import Optional
|
import logging
|
||||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from logging.handlers import TimedRotatingFileHandler
|
|
||||||
import os
|
import os
|
||||||
import secrets, string
|
import secrets
|
||||||
import bcrypt
|
import string
|
||||||
|
import warnings
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
from logging.handlers import TimedRotatingFileHandler
|
||||||
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
import alembic.config
|
import alembic.config
|
||||||
import glob
|
import bcrypt
|
||||||
|
import coloredlogs
|
||||||
|
import pymysql.err
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncEngine,
|
||||||
|
AsyncSession,
|
||||||
|
async_scoped_session,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.config import CoreConfig
|
from core.config import CoreConfig
|
||||||
from core.data.schema import *
|
from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata
|
||||||
from core.utils import Utils
|
from core.utils import MISSING, Utils
|
||||||
|
|
||||||
|
|
||||||
class Data:
|
class Data:
|
||||||
engine = None
|
engine: ClassVar[AsyncEngine] = MISSING
|
||||||
session = None
|
session: ClassVar[AsyncSession] = MISSING
|
||||||
user = None
|
user: ClassVar[UserData] = MISSING
|
||||||
arcade = None
|
arcade: ClassVar[ArcadeData] = MISSING
|
||||||
card = None
|
card: ClassVar[CardData] = MISSING
|
||||||
base = None
|
base: ClassVar[BaseData] = MISSING
|
||||||
|
|
||||||
def __init__(self, cfg: CoreConfig) -> None:
|
def __init__(self, cfg: CoreConfig) -> None:
|
||||||
self.config = cfg
|
self.config = cfg
|
||||||
|
|
||||||
if self.config.database.sha2_password:
|
if self.config.database.sha2_password:
|
||||||
passwd = sha256(self.config.database.password.encode()).digest()
|
passwd = sha256(self.config.database.password.encode()).digest()
|
||||||
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
||||||
else:
|
else:
|
||||||
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
|
||||||
|
|
||||||
if Data.engine is None:
|
if Data.engine is MISSING:
|
||||||
Data.engine = create_engine(self.__url, pool_recycle=3600)
|
Data.engine = create_async_engine(self.__url, pool_recycle=3600, isolation_level="AUTOCOMMIT")
|
||||||
self.__engine = Data.engine
|
self.__engine = Data.engine
|
||||||
|
|
||||||
if Data.session is None:
|
if Data.session is MISSING:
|
||||||
s = sessionmaker(bind=Data.engine, autoflush=True, autocommit=True)
|
s = sessionmaker(Data.engine, expire_on_commit=False, class_=AsyncSession)
|
||||||
Data.session = scoped_session(s)
|
Data.session = async_scoped_session(s, asyncio.current_task)
|
||||||
|
|
||||||
if Data.user is None:
|
if Data.user is MISSING:
|
||||||
Data.user = UserData(self.config, self.session)
|
Data.user = UserData(self.config, self.session)
|
||||||
|
|
||||||
if Data.arcade is None:
|
if Data.arcade is MISSING:
|
||||||
Data.arcade = ArcadeData(self.config, self.session)
|
Data.arcade = ArcadeData(self.config, self.session)
|
||||||
|
|
||||||
if Data.card is None:
|
if Data.card is MISSING:
|
||||||
Data.card = CardData(self.config, self.session)
|
Data.card = CardData(self.config, self.session)
|
||||||
|
|
||||||
if Data.base is None:
|
if Data.base is MISSING:
|
||||||
Data.base = BaseData(self.config, self.session)
|
Data.base = BaseData(self.config, self.session)
|
||||||
|
|
||||||
self.logger = logging.getLogger("database")
|
self.logger = logging.getLogger("database")
|
||||||
@ -94,20 +105,29 @@ class Data:
|
|||||||
alembic.config.main(argv=alembicArgs)
|
alembic.config.main(argv=alembicArgs)
|
||||||
os.chdir(old_dir)
|
os.chdir(old_dir)
|
||||||
|
|
||||||
def create_database(self):
|
async def create_database(self):
|
||||||
self.logger.info("Creating databases...")
|
self.logger.info("Creating databases...")
|
||||||
metadata.create_all(
|
|
||||||
self.engine,
|
with warnings.catch_warnings():
|
||||||
checkfirst=True,
|
# SQLAlchemy will generate a nice primary key constraint name, but in
|
||||||
|
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
|
||||||
|
# custom primary key name is generated, a warning is emitted from pymysql,
|
||||||
|
# which we don't care about. Other warnings may be helpful though, don't
|
||||||
|
# suppress everything.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
action="ignore",
|
||||||
|
message=r"Name '(.+)' ignored for PRIMARY key\.",
|
||||||
|
category=pymysql.err.Warning,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async with self.engine.begin() as conn:
|
||||||
|
await conn.run_sync(metadata.create_all, checkfirst=True)
|
||||||
|
|
||||||
for _, mod in Utils.get_all_titles().items():
|
for _, mod in Utils.get_all_titles().items():
|
||||||
if hasattr(mod, "database"):
|
if hasattr(mod, "database"):
|
||||||
mod.database(self.config)
|
mod.database(self.config)
|
||||||
metadata.create_all(
|
|
||||||
self.engine,
|
await conn.run_sync(metadata.create_all, checkfirst=True)
|
||||||
checkfirst=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stamp the end revision as if alembic had created it, so it can take off after this.
|
# Stamp the end revision as if alembic had created it, so it can take off after this.
|
||||||
self.__alembic_cmd(
|
self.__alembic_cmd(
|
||||||
@ -115,19 +135,43 @@ class Data:
|
|||||||
"head",
|
"head",
|
||||||
)
|
)
|
||||||
|
|
||||||
def schema_upgrade(self, ver: str = None):
|
def schema_upgrade(self, ver: Optional[str] = None):
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
# SQLAlchemy will generate a nice primary key constraint name, but in
|
||||||
|
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
|
||||||
|
# custom primary key name is generated, a warning is emitted from pymysql,
|
||||||
|
# which we don't care about. Other warnings may be helpful though, don't
|
||||||
|
# suppress everything.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
action="ignore",
|
||||||
|
message=r"Name '(.+)' ignored for PRIMARY key\.",
|
||||||
|
category=pymysql.err.Warning,
|
||||||
|
)
|
||||||
|
|
||||||
self.__alembic_cmd(
|
self.__alembic_cmd(
|
||||||
"upgrade",
|
"upgrade",
|
||||||
"head" if not ver else ver,
|
"head" if not ver else ver,
|
||||||
)
|
)
|
||||||
|
|
||||||
def schema_downgrade(self, ver: str):
|
def schema_downgrade(self, ver: str):
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
# SQLAlchemy will generate a nice primary key constraint name, but in
|
||||||
|
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
|
||||||
|
# custom primary key name is generated, a warning is emitted from pymysql,
|
||||||
|
# which we don't care about. Other warnings may be helpful though, don't
|
||||||
|
# suppress everything.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
action="ignore",
|
||||||
|
message=r"Name '(.+)' ignored for PRIMARY key\.",
|
||||||
|
category=pymysql.err.Warning,
|
||||||
|
)
|
||||||
|
|
||||||
self.__alembic_cmd(
|
self.__alembic_cmd(
|
||||||
"downgrade",
|
"downgrade",
|
||||||
ver,
|
ver,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_owner(self, email: Optional[str] = None, code: Optional[str] = "00000000000000000000") -> None:
|
async def create_owner(self, email: Optional[str] = None, code: str = "00000000000000000000") -> None:
|
||||||
pw = "".join(
|
pw = "".join(
|
||||||
secrets.choice(string.ascii_letters + string.digits) for i in range(20)
|
secrets.choice(string.ascii_letters + string.digits) for i in range(20)
|
||||||
)
|
)
|
||||||
@ -150,12 +194,12 @@ class Data:
|
|||||||
async def migrate(self) -> None:
|
async def migrate(self) -> None:
|
||||||
exist = await self.base.execute("SELECT * FROM alembic_version")
|
exist = await self.base.execute("SELECT * FROM alembic_version")
|
||||||
if exist is not None:
|
if exist is not None:
|
||||||
self.logger.warn("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!")
|
self.logger.warning("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.logger.info("Upgrading to latest with legacy system")
|
self.logger.info("Upgrading to latest with legacy system")
|
||||||
if not await self.legacy_upgrade():
|
if not await self.legacy_upgrade():
|
||||||
self.logger.warn("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!")
|
self.logger.warning("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!")
|
||||||
return
|
return
|
||||||
self.logger.info("Done")
|
self.logger.info("Done")
|
||||||
|
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
from typing import Optional, Dict, List
|
import re
|
||||||
from sqlalchemy import Table, Column, and_, or_
|
from typing import List, Optional
|
||||||
from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint
|
|
||||||
from sqlalchemy.types import Integer, String, Boolean, JSON
|
from sqlalchemy import Column, Table, and_, or_
|
||||||
from sqlalchemy.sql import func, select
|
|
||||||
from sqlalchemy.dialects.mysql import insert
|
from sqlalchemy.dialects.mysql import insert
|
||||||
from sqlalchemy.engine import Row
|
from sqlalchemy.engine import Row
|
||||||
import re
|
from sqlalchemy.sql import func, select
|
||||||
|
from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint
|
||||||
|
from sqlalchemy.types import JSON, Boolean, Integer, String
|
||||||
|
|
||||||
from core.data.schema.base import BaseData, metadata
|
from core.data.schema.base import BaseData, metadata
|
||||||
from core.const import *
|
|
||||||
|
|
||||||
arcade = Table(
|
arcade: Table = Table(
|
||||||
"arcade",
|
"arcade",
|
||||||
metadata,
|
metadata,
|
||||||
Column("id", Integer, primary_key=True, nullable=False),
|
Column("id", Integer, primary_key=True, nullable=False),
|
||||||
@ -26,7 +26,7 @@ arcade = Table(
|
|||||||
mysql_charset="utf8mb4",
|
mysql_charset="utf8mb4",
|
||||||
)
|
)
|
||||||
|
|
||||||
machine = Table(
|
machine: Table = Table(
|
||||||
"machine",
|
"machine",
|
||||||
metadata,
|
metadata,
|
||||||
Column("id", Integer, primary_key=True, nullable=False),
|
Column("id", Integer, primary_key=True, nullable=False),
|
||||||
@ -47,7 +47,7 @@ machine = Table(
|
|||||||
mysql_charset="utf8mb4",
|
mysql_charset="utf8mb4",
|
||||||
)
|
)
|
||||||
|
|
||||||
arcade_owner = Table(
|
arcade_owner: Table = Table(
|
||||||
"arcade_owner",
|
"arcade_owner",
|
||||||
metadata,
|
metadata,
|
||||||
Column(
|
Column(
|
||||||
@ -69,7 +69,7 @@ arcade_owner = Table(
|
|||||||
|
|
||||||
|
|
||||||
class ArcadeData(BaseData):
|
class ArcadeData(BaseData):
|
||||||
async def get_machine(self, serial: str = None, id: int = None) -> Optional[Row]:
|
async def get_machine(self, serial: Optional[str] = None, id: Optional[int] = None) -> Optional[Row]:
|
||||||
if serial is not None:
|
if serial is not None:
|
||||||
serial = serial.replace("-", "")
|
serial = serial.replace("-", "")
|
||||||
if len(serial) == 11:
|
if len(serial) == 11:
|
||||||
@ -98,8 +98,8 @@ class ArcadeData(BaseData):
|
|||||||
self,
|
self,
|
||||||
arcade_id: int,
|
arcade_id: int,
|
||||||
serial: str = "",
|
serial: str = "",
|
||||||
board: str = None,
|
board: Optional[str] = None,
|
||||||
game: str = None,
|
game: Optional[str] = None,
|
||||||
is_cab: bool = False,
|
is_cab: bool = False,
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
if not arcade_id:
|
if not arcade_id:
|
||||||
@ -150,8 +150,8 @@ class ArcadeData(BaseData):
|
|||||||
|
|
||||||
async def create_arcade(
|
async def create_arcade(
|
||||||
self,
|
self,
|
||||||
name: str = None,
|
name: Optional[str] = None,
|
||||||
nickname: str = None,
|
nickname: Optional[str] = None,
|
||||||
country: str = "JPN",
|
country: str = "JPN",
|
||||||
country_id: int = 1,
|
country_id: int = 1,
|
||||||
state: str = "",
|
state: str = "",
|
||||||
|
@ -1,22 +1,23 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from random import randrange
|
from random import randrange
|
||||||
from typing import Any, Optional, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Column, MetaData, Table
|
||||||
from sqlalchemy.engine import Row
|
from sqlalchemy.engine import Row
|
||||||
from sqlalchemy.engine.cursor import CursorResult
|
from sqlalchemy.engine.cursor import CursorResult
|
||||||
from sqlalchemy.engine.base import Connection
|
|
||||||
from sqlalchemy.sql import text, func, select
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy import MetaData, Table, Column
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.types import Integer, String, TIMESTAMP, JSON, INTEGER, TEXT
|
|
||||||
from sqlalchemy.schema import ForeignKey
|
from sqlalchemy.schema import ForeignKey
|
||||||
from sqlalchemy.dialects.mysql import insert
|
from sqlalchemy.sql import func, text
|
||||||
|
from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String
|
||||||
|
|
||||||
from core.config import CoreConfig
|
from core.config import CoreConfig
|
||||||
|
|
||||||
metadata = MetaData()
|
metadata = MetaData()
|
||||||
|
|
||||||
event_log = Table(
|
event_log: Table = Table(
|
||||||
"event_log",
|
"event_log",
|
||||||
metadata,
|
metadata,
|
||||||
Column("id", Integer, primary_key=True, nullable=False),
|
Column("id", Integer, primary_key=True, nullable=False),
|
||||||
@ -37,7 +38,7 @@ event_log = Table(
|
|||||||
|
|
||||||
|
|
||||||
class BaseData:
|
class BaseData:
|
||||||
def __init__(self, cfg: CoreConfig, conn: Connection) -> None:
|
def __init__(self, cfg: CoreConfig, conn: AsyncSession) -> None:
|
||||||
self.config = cfg
|
self.config = cfg
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
self.logger = logging.getLogger("database")
|
self.logger = logging.getLogger("database")
|
||||||
@ -47,7 +48,7 @@ class BaseData:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}")
|
self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}")
|
||||||
res = self.conn.execute(text(sql), opts)
|
res = await self.conn.execute(text(sql), opts)
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
self.logger.error(f"SQLAlchemy error {e}")
|
self.logger.error(f"SQLAlchemy error {e}")
|
||||||
@ -59,7 +60,7 @@ class BaseData:
|
|||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
res = self.conn.execute(sql, opts)
|
res = await self.conn.execute(sql, opts)
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
self.logger.error(f"SQLAlchemy error {e}")
|
self.logger.error(f"SQLAlchemy error {e}")
|
||||||
@ -83,7 +84,7 @@ class BaseData:
|
|||||||
|
|
||||||
async def log_event(
|
async def log_event(
|
||||||
self, system: str, type: str, severity: int, message: str, details: Dict = {}, user: int = None,
|
self, system: str, type: str, severity: int, message: str, details: Dict = {}, user: int = None,
|
||||||
arcade: int = None, machine: int = None, ip: str = None, game: str = None, version: str = None
|
arcade: int = None, machine: int = None, ip: Optional[str] = None, game: Optional[str] = None, version: Optional[str] = None
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
sql = event_log.insert().values(
|
sql = event_log.insert().values(
|
||||||
system=system,
|
system=system,
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
from sqlalchemy import Table, Column, UniqueConstraint
|
|
||||||
from sqlalchemy.types import Integer, String, Boolean, TIMESTAMP, BIGINT, VARCHAR
|
from sqlalchemy import Column, Table, UniqueConstraint
|
||||||
from sqlalchemy.sql.schema import ForeignKey
|
|
||||||
from sqlalchemy.sql import func
|
|
||||||
from sqlalchemy.engine import Row
|
from sqlalchemy.engine import Row
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from sqlalchemy.sql.schema import ForeignKey
|
||||||
|
from sqlalchemy.types import BIGINT, TIMESTAMP, VARCHAR, Boolean, Integer, String
|
||||||
|
|
||||||
from core.data.schema.base import BaseData, metadata
|
from core.data.schema.base import BaseData, metadata
|
||||||
|
|
||||||
aime_card = Table(
|
aime_card: Table = Table(
|
||||||
"aime_card",
|
"aime_card",
|
||||||
metadata,
|
metadata,
|
||||||
Column("id", Integer, primary_key=True, nullable=False),
|
Column("id", Integer, primary_key=True, nullable=False),
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
from typing import Optional, List
|
from typing import List, Optional
|
||||||
from sqlalchemy import Table, Column
|
|
||||||
from sqlalchemy.types import Integer, String, TIMESTAMP
|
|
||||||
from sqlalchemy.sql import func
|
|
||||||
from sqlalchemy.dialects.mysql import insert
|
|
||||||
from sqlalchemy.sql import func, select
|
|
||||||
from sqlalchemy.engine import Row
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
from sqlalchemy import Column, Table
|
||||||
|
from sqlalchemy.dialects.mysql import insert
|
||||||
|
from sqlalchemy.engine import Row
|
||||||
|
from sqlalchemy.sql import func, select
|
||||||
|
from sqlalchemy.types import TIMESTAMP, Integer, String
|
||||||
|
|
||||||
from core.data.schema.base import BaseData, metadata
|
from core.data.schema.base import BaseData, metadata
|
||||||
|
|
||||||
aime_user = Table(
|
aime_user: Table = Table(
|
||||||
"aime_user",
|
"aime_user",
|
||||||
metadata,
|
metadata,
|
||||||
Column("id", Integer, nullable=False, primary_key=True, autoincrement=True),
|
Column("id", Integer, nullable=False, primary_key=True, autoincrement=True),
|
||||||
@ -26,10 +26,10 @@ aime_user = Table(
|
|||||||
class UserData(BaseData):
|
class UserData(BaseData):
|
||||||
async def create_user(
|
async def create_user(
|
||||||
self,
|
self,
|
||||||
id: int = None,
|
id: Optional[int] = None,
|
||||||
username: str = None,
|
username: Optional[str] = None,
|
||||||
email: str = None,
|
email: Optional[str] = None,
|
||||||
password: str = None,
|
password: Optional[str] = None,
|
||||||
permission: int = 1,
|
permission: int = 1,
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
if id is None:
|
if id is None:
|
||||||
|
@ -1,18 +1,47 @@
|
|||||||
from typing import Dict, Any, Optional
|
|
||||||
from types import ModuleType
|
|
||||||
from starlette.requests import Request
|
|
||||||
import logging
|
|
||||||
import importlib
|
import importlib
|
||||||
from os import walk
|
import logging
|
||||||
import jwt
|
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from os import walk
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .config import CoreConfig
|
from .config import CoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _MissingSentinel:
|
||||||
|
__slots__: tuple[str, ...] = ()
|
||||||
|
|
||||||
|
def __eq__(self, other) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "..."
|
||||||
|
|
||||||
|
|
||||||
|
MISSING: Any = _MissingSentinel()
|
||||||
|
"""This is different from `None` in that its type is `Any`, and so it can be used
|
||||||
|
as a placeholder for values that are *definitely* going to be initialized,
|
||||||
|
so they don't have to be typed as `T | None`, which makes type checkers
|
||||||
|
angry when an attribute is accessed.
|
||||||
|
|
||||||
|
This can also be used for when `None` has actual meaning as a value, and so a
|
||||||
|
separate value is needed to mean "unset"."""
|
||||||
|
|
||||||
|
|
||||||
class Utils:
|
class Utils:
|
||||||
real_title_port = None
|
real_title_port = None
|
||||||
real_title_port_ssl = None
|
real_title_port_ssl = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_titles(cls) -> Dict[str, ModuleType]:
|
def get_all_titles(cls) -> Dict[str, ModuleType]:
|
||||||
ret: Dict[str, Any] = {}
|
ret: Dict[str, Any] = {}
|
||||||
@ -39,24 +68,53 @@ class Utils:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_title_port(cls, cfg: CoreConfig):
|
def get_title_port(cls, cfg: CoreConfig):
|
||||||
if cls.real_title_port is not None: return cls.real_title_port
|
if cls.real_title_port is not None:
|
||||||
|
return cls.real_title_port
|
||||||
|
|
||||||
cls.real_title_port = cfg.server.proxy_port if cfg.server.is_using_proxy and cfg.server.proxy_port else cfg.server.port
|
cls.real_title_port = (
|
||||||
|
cfg.server.proxy_port
|
||||||
|
if cfg.server.is_using_proxy and cfg.server.proxy_port
|
||||||
|
else cfg.server.port
|
||||||
|
)
|
||||||
|
|
||||||
return cls.real_title_port
|
return cls.real_title_port
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_title_port_ssl(cls, cfg: CoreConfig):
|
def get_title_port_ssl(cls, cfg: CoreConfig):
|
||||||
if cls.real_title_port_ssl is not None: return cls.real_title_port_ssl
|
if cls.real_title_port_ssl is not None:
|
||||||
|
return cls.real_title_port_ssl
|
||||||
|
|
||||||
cls.real_title_port_ssl = cfg.server.proxy_port_ssl if cfg.server.is_using_proxy and cfg.server.proxy_port_ssl else 443
|
cls.real_title_port_ssl = (
|
||||||
|
cfg.server.proxy_port_ssl
|
||||||
|
if cfg.server.is_using_proxy and cfg.server.proxy_port_ssl
|
||||||
|
else 443
|
||||||
|
)
|
||||||
|
|
||||||
return cls.real_title_port_ssl
|
return cls.real_title_port_ssl
|
||||||
|
|
||||||
def create_sega_auth_key(aime_id: int, game: str, place_id: int, keychip_id: str, b64_secret: str, exp_seconds: int = 86400, err_logger: str = 'aimedb') -> Optional[str]:
|
|
||||||
|
def create_sega_auth_key(
|
||||||
|
aime_id: int,
|
||||||
|
game: str,
|
||||||
|
place_id: int,
|
||||||
|
keychip_id: str,
|
||||||
|
b64_secret: str,
|
||||||
|
exp_seconds: int = 86400,
|
||||||
|
err_logger: str = "aimedb",
|
||||||
|
) -> Optional[str]:
|
||||||
logger = logging.getLogger(err_logger)
|
logger = logging.getLogger(err_logger)
|
||||||
try:
|
try:
|
||||||
return jwt.encode({ "aime_id": aime_id, "game": game, "place_id": place_id, "keychip_id": keychip_id, "exp": int(datetime.now(tz=timezone.utc).timestamp()) + exp_seconds }, b64decode(b64_secret), algorithm="HS256")
|
return jwt.encode(
|
||||||
|
{
|
||||||
|
"aime_id": aime_id,
|
||||||
|
"game": game,
|
||||||
|
"place_id": place_id,
|
||||||
|
"keychip_id": keychip_id,
|
||||||
|
"exp": int(datetime.now(tz=timezone.utc).timestamp()) + exp_seconds,
|
||||||
|
},
|
||||||
|
b64decode(b64_secret),
|
||||||
|
algorithm="HS256",
|
||||||
|
)
|
||||||
except jwt.InvalidKeyError:
|
except jwt.InvalidKeyError:
|
||||||
logger.error("Failed to encode Sega Auth Key because the secret is invalid!")
|
logger.error("Failed to encode Sega Auth Key because the secret is invalid!")
|
||||||
return None
|
return None
|
||||||
@ -64,10 +122,19 @@ def create_sega_auth_key(aime_id: int, game: str, place_id: int, keychip_id: str
|
|||||||
logger.error(f"Unknown exception occoured when encoding Sega Auth Key! {e}")
|
logger.error(f"Unknown exception occoured when encoding Sega Auth Key! {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def decode_sega_auth_key(token: str, b64_secret: str, err_logger: str = 'aimedb') -> Optional[Dict]:
|
|
||||||
|
def decode_sega_auth_key(
|
||||||
|
token: str, b64_secret: str, err_logger: str = "aimedb"
|
||||||
|
) -> Optional[Dict]:
|
||||||
logger = logging.getLogger(err_logger)
|
logger = logging.getLogger(err_logger)
|
||||||
try:
|
try:
|
||||||
return jwt.decode(token, "secret", b64decode(b64_secret), algorithms=["HS256"], options={"verify_signature": True})
|
return jwt.decode(
|
||||||
|
token,
|
||||||
|
"secret",
|
||||||
|
b64decode(b64_secret),
|
||||||
|
algorithms=["HS256"],
|
||||||
|
options={"verify_signature": True},
|
||||||
|
)
|
||||||
except jwt.ExpiredSignatureError:
|
except jwt.ExpiredSignatureError:
|
||||||
logger.error("Sega Auth Key failed to validate due to an expired signature!")
|
logger.error("Sega Auth Key failed to validate due to an expired signature!")
|
||||||
return None
|
return None
|
||||||
@ -83,4 +150,3 @@ def decode_sega_auth_key(token: str, b64_secret: str, err_logger: str = 'aimedb'
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unknown exception occoured when decoding Sega Auth Key! {e}")
|
logger.error(f"Unknown exception occoured when decoding Sega Auth Key! {e}")
|
||||||
return None
|
return None
|
||||||
|
|
11
dbutils.py
11
dbutils.py
@ -1,12 +1,13 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
from os import mkdir, path, access, W_OK, environ
|
|
||||||
import yaml
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from os import W_OK, access, environ, mkdir, path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
from core.data import Data
|
|
||||||
from core.config import CoreConfig
|
from core.config import CoreConfig
|
||||||
|
from core.data import Data
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Database utilities")
|
parser = argparse.ArgumentParser(description="Database utilities")
|
||||||
@ -46,7 +47,7 @@ if __name__ == "__main__":
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
if args.action == "create":
|
if args.action == "create":
|
||||||
data.create_database()
|
loop.run_until_complete(data.create_database())
|
||||||
|
|
||||||
elif args.action == "upgrade":
|
elif args.action == "upgrade":
|
||||||
data.schema_upgrade(args.version)
|
data.schema_upgrade(args.version)
|
||||||
|
24
read.py
24
read.py
@ -1,16 +1,16 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import argparse
|
import argparse
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import yaml
|
|
||||||
from os import path
|
|
||||||
import logging
|
|
||||||
import coloredlogs
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from logging.handlers import TimedRotatingFileHandler
|
from logging.handlers import TimedRotatingFileHandler
|
||||||
|
from os import path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import coloredlogs
|
||||||
|
import yaml
|
||||||
|
|
||||||
from core import CoreConfig, Utils
|
from core import CoreConfig, Utils
|
||||||
|
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ class BaseReader:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
async def main():
|
||||||
parser = argparse.ArgumentParser(description="Import Game Information")
|
parser = argparse.ArgumentParser(description="Import Game Information")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--game",
|
"--game",
|
||||||
@ -140,8 +140,12 @@ if __name__ == "__main__":
|
|||||||
for dir, mod in titles.items():
|
for dir, mod in titles.items():
|
||||||
if args.game in mod.game_codes:
|
if args.game in mod.game_codes:
|
||||||
handler = mod.reader(config, args.version, bin_arg, opt_arg, args.extra)
|
handler = mod.reader(config, args.version, bin_arg, opt_arg, args.extra)
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
loop.run_until_complete(handler.read())
|
await handler.read()
|
||||||
|
|
||||||
|
|
||||||
logger.info("Done")
|
logger.info("Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user