1
0
mirror of synced 2025-01-07 09:41:33 +01:00
bemaniutils/bemani/data/mysql/base.py

187 lines
6.6 KiB
Python
Raw Normal View History

import json
import random
from typing import Dict, Any, Optional
from bemani.common import Time
from bemani.data.config import Config
from sqlalchemy.engine.base import Connection # type: ignore
from sqlalchemy.engine import CursorResult # type: ignore
from sqlalchemy.sql import text # type: ignore
from sqlalchemy.types import String, Integer # type: ignore
from sqlalchemy import Table, Column, MetaData # type: ignore
metadata = MetaData()
"""
Table for storing session IDs, so a session ID can be used to look up an arbitrary ID.
This is currently used for user logins, user and arcade PASELI sessions.
"""
session = Table(
'session',
metadata,
Column('id', Integer, nullable=False),
Column('type', String(32), nullable=False),
Column('session', String(32), nullable=False, unique=True),
Column('expiration', Integer),
mysql_charset='utf8mb4',
)
class _BytesEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any:
if isinstance(obj, bytes):
# We're abusing lists here, we have a mixed type
return ['__bytes__'] + [b for b in obj] # type: ignore
return json.JSONEncoder.default(self, obj)
class BaseData:
SESSION_LENGTH = 32
def __init__(self, config: Config, conn: Connection) -> None:
"""
Initialize any DB singleton.
Should only ever be called by Data.
Parameters:
config - config structure which is provided in case any function here
needs to look up configuration.
conn - An established connection to the DB which will be used for all
queries.
"""
self.__config = config
self.__conn = conn
def execute(self, sql: str, params: Optional[Dict[str, Any]]=None, safe_write_operation: bool=False) -> CursorResult:
"""
Given a SQL string and some parameters, execute the query and return the result.
Parameters:
sql - The SQL statement to execute.
params - Dictionary of parameters which will be substituted into the sql string.
Returns:
A SQLAlchemy CursorResult object.
"""
if self.__config.database.read_only:
# See if this is an insert/update/delete
for write_statement in [
"insert into ",
"update ",
"delete from ",
]:
if write_statement in sql.lower() and not safe_write_operation:
raise Exception('Read-only mode is active!')
return self.__conn.execute(
text(sql),
params if params is not None else {},
)
def serialize(self, data: Dict[str, Any]) -> str:
"""
Given an arbitrary dict, serialize it to JSON.
"""
return json.dumps(data, cls=_BytesEncoder)
def deserialize(self, data: Optional[str]) -> Dict[str, Any]:
"""
Given a string, deserialize it from JSON.
"""
if data is None:
return {}
def fix(jd: Any) -> Any:
if type(jd) == dict:
# Fix each element in the dictionary.
for key in jd:
jd[key] = fix(jd[key])
return jd
if type(jd) == list:
# Could be serialized by us, could be a normal list.
if len(jd) >= 1 and jd[0] == '__bytes__':
# This is a serialized bytestring
return bytes(jd[1:])
# Possibly one of these is a dictionary/list/serialized.
for i in range(len(jd)):
jd[i] = fix(jd[i])
return jd
# Normal value, its deserialized version is itself.
return jd
return fix(json.loads(data))
def _from_session(self, session: str, sesstype: str) -> Optional[int]:
"""
Given a previously-opened session, look up an ID.
Parameters:
session - String identifying a session that was opened by create_session.
sesstype - Arbitrary string identifying the session type.
Returns:
ID as an integer if found, or None if the session is expired or doesn't exist.
"""
# Look up the user account, making sure to expire old sessions
sql = "SELECT id FROM session WHERE session = :session AND type = :type AND expiration > :timestamp"
cursor = self.execute(sql, {'session': session, 'type': sesstype, 'timestamp': Time.now()})
if cursor.rowcount != 1:
# Couldn't find a user with this session
return None
result = cursor.fetchone()
return result['id']
def _create_session(self, opid: int, optype: str, expiration: int=(30 * 86400)) -> str:
"""
Given an ID, create a session string.
Parameters:
opid - ID we wish to start a session for.
expiration - Number of seconds before this session is invalid.
Returns:
A string that can be used as a session ID.
"""
# Create a new session that is unique
while True:
session = ''.join(random.choice('0123456789ABCDEF') for _ in range(BaseData.SESSION_LENGTH))
sql = "SELECT session FROM session WHERE session = :session"
cursor = self.execute(sql, {'session': session})
if cursor.rowcount == 0:
# Make sure sessions expire in a reasonable amount of time
expiration = Time.now() + expiration
# Use that session
sql = (
"INSERT INTO session (id, session, type, expiration) " +
"VALUES (:id, :session, :optype, :expiration)"
)
cursor = self.execute(
sql,
{'id': opid, 'session': session, 'optype': optype, 'expiration': expiration},
safe_write_operation=True,
)
if cursor.rowcount == 1:
return session
def _destroy_session(self, session: str, sesstype: str) -> None:
"""
Destroy a previously-created session.
Parameters:
session - A session string as returned from create_session.
"""
# Remove the session token
sql = "DELETE FROM session WHERE session = :session AND type = :sesstype"
self.execute(sql, {'session': session, 'sesstype': sesstype}, safe_write_operation=True)
# Also weed out any other defunct sessions
sql = "DELETE FROM session WHERE expiration < :timestamp"
self.execute(sql, {'timestamp': Time.now()}, safe_write_operation=True)