import json import random from typing import Dict, Any, Optional from bemani.common import Time from bemani.data.config import Config from sqlalchemy.engine.base import Connection # type: ignore from sqlalchemy.engine import CursorResult # type: ignore from sqlalchemy.sql import text # type: ignore from sqlalchemy.types import String, Integer # type: ignore from sqlalchemy import Table, Column, MetaData # type: ignore metadata = MetaData() """ Table for storing session IDs, so a session ID can be used to look up an arbitrary ID. This is currently used for user logins, user and arcade PASELI sessions. """ session = Table( 'session', metadata, Column('id', Integer, nullable=False), Column('type', String(32), nullable=False), Column('session', String(32), nullable=False, unique=True), Column('expiration', Integer), mysql_charset='utf8mb4', ) class _BytesEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: if isinstance(obj, bytes): # We're abusing lists here, we have a mixed type return ['__bytes__'] + [b for b in obj] # type: ignore return json.JSONEncoder.default(self, obj) class BaseData: SESSION_LENGTH = 32 def __init__(self, config: Config, conn: Connection) -> None: """ Initialize any DB singleton. Should only ever be called by Data. Parameters: config - config structure which is provided in case any function here needs to look up configuration. conn - An established connection to the DB which will be used for all queries. """ self.__config = config self.__conn = conn def execute(self, sql: str, params: Optional[Dict[str, Any]]=None, safe_write_operation: bool=False) -> CursorResult: """ Given a SQL string and some parameters, execute the query and return the result. Parameters: sql - The SQL statement to execute. params - Dictionary of parameters which will be substituted into the sql string. Returns: A SQLAlchemy CursorResult object. """ if self.__config.database.read_only: # See if this is an insert/update/delete for write_statement in [ "insert into ", "update ", "delete from ", ]: if write_statement in sql.lower() and not safe_write_operation: raise Exception('Read-only mode is active!') return self.__conn.execute( text(sql), params if params is not None else {}, ) def serialize(self, data: Dict[str, Any]) -> str: """ Given an arbitrary dict, serialize it to JSON. """ return json.dumps(data, cls=_BytesEncoder) def deserialize(self, data: Optional[str]) -> Dict[str, Any]: """ Given a string, deserialize it from JSON. """ if data is None: return {} def fix(jd: Any) -> Any: if type(jd) == dict: # Fix each element in the dictionary. for key in jd: jd[key] = fix(jd[key]) return jd if type(jd) == list: # Could be serialized by us, could be a normal list. if len(jd) >= 1 and jd[0] == '__bytes__': # This is a serialized bytestring return bytes(jd[1:]) # Possibly one of these is a dictionary/list/serialized. for i in range(len(jd)): jd[i] = fix(jd[i]) return jd # Normal value, its deserialized version is itself. return jd return fix(json.loads(data)) def _from_session(self, session: str, sesstype: str) -> Optional[int]: """ Given a previously-opened session, look up an ID. Parameters: session - String identifying a session that was opened by create_session. sesstype - Arbitrary string identifying the session type. Returns: ID as an integer if found, or None if the session is expired or doesn't exist. """ # Look up the user account, making sure to expire old sessions sql = "SELECT id FROM session WHERE session = :session AND type = :type AND expiration > :timestamp" cursor = self.execute(sql, {'session': session, 'type': sesstype, 'timestamp': Time.now()}) if cursor.rowcount != 1: # Couldn't find a user with this session return None result = cursor.fetchone() return result['id'] def _create_session(self, opid: int, optype: str, expiration: int=(30 * 86400)) -> str: """ Given an ID, create a session string. Parameters: opid - ID we wish to start a session for. expiration - Number of seconds before this session is invalid. Returns: A string that can be used as a session ID. """ # Create a new session that is unique while True: session = ''.join(random.choice('0123456789ABCDEF') for _ in range(BaseData.SESSION_LENGTH)) sql = "SELECT session FROM session WHERE session = :session" cursor = self.execute(sql, {'session': session}) if cursor.rowcount == 0: # Make sure sessions expire in a reasonable amount of time expiration = Time.now() + expiration # Use that session sql = ( "INSERT INTO session (id, session, type, expiration) " + "VALUES (:id, :session, :optype, :expiration)" ) cursor = self.execute( sql, {'id': opid, 'session': session, 'optype': optype, 'expiration': expiration}, safe_write_operation=True, ) if cursor.rowcount == 1: return session def _destroy_session(self, session: str, sesstype: str) -> None: """ Destroy a previously-created session. Parameters: session - A session string as returned from create_session. """ # Remove the session token sql = "DELETE FROM session WHERE session = :session AND type = :sesstype" self.execute(sql, {'session': session, 'sesstype': sesstype}, safe_write_operation=True) # Also weed out any other defunct sessions sql = "DELETE FROM session WHERE expiration < :timestamp" self.execute(sql, {'timestamp': Time.now()}, safe_write_operation=True)