159 lines
5.8 KiB
159 lines
5.8 KiB
from typing import Tuple, Dict, List
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
import yaml
import logging, coloredlogs
from logging.handlers import TimedRotatingFileHandler
from os import path
from Crypto.Cipher import Blowfish
from hashlib import md5
import secrets
from core import CoreConfig, Utils
from core.title import BaseServlet
from titles.sao.config import SaoConfig
from titles.sao.const import SaoConstants
from titles.sao.base import SaoBase
from titles.sao.handlers.base import *
class SaoServlet(BaseServlet):
def __init__(self, core_cfg: CoreConfig, cfg_dir: str) -> None:
super().__init__(core_cfg, cfg_dir)
self.config_dir = cfg_dir
self.game_cfg = SaoConfig()
if path.exists(f"{cfg_dir}/sao.yaml"):
self.logger = logging.getLogger("sao")
if not hasattr(self.logger, "inited"):
log_fmt_str = "[%(asctime)s] SAO | %(levelname)s | %(message)s"
log_fmt = logging.Formatter(log_fmt_str)
fileHandler = TimedRotatingFileHandler(
"{0}/{1}.log".format(self.core_cfg.server.log_dir, "sao"),
consoleHandler = logging.StreamHandler()
level=self.game_cfg.server.loglevel, logger=self.logger, fmt=log_fmt_str
self.logger.inited = True
self.base = SaoBase(core_cfg, self.game_cfg)
self.static_hash = None
if self.game_cfg.hash.verify_hash:
self.static_hash = md5(self.game_cfg.hash.hash_base.encode()).digest() # Greate hashing guys, really validates the data
def get_routes(self) -> List[Route]:
return [
Route("/{datecode:int}/proto/if/{category:str}/{endpoint:str}", self.render_POST, methods=['POST'])
def is_game_enabled(cls, game_code: str, core_cfg: CoreConfig, cfg_dir: str) -> bool:
game_cfg = SaoConfig()
if path.exists(f"{cfg_dir}/{SaoConstants.CONFIG_NAME}"):
if not game_cfg.server.enable:
return False
return True
def get_allnet_info(self, game_code: str, game_ver: int, keychip: str) -> Tuple[str, str]:
port_ssl = Utils.get_title_port_ssl(self.core_cfg)
port_normal = Utils.get_title_port(self.core_cfg)
proto = "http"
port = f":{port_normal}" if port_normal != 80 else ""
if self.game_cfg.server.use_https:
proto = "https"
port = f":{port_ssl}" if port_ssl != 443 else ""
return (f"{proto}://{self.core_cfg.title.hostname}{port}/", "")
def get_mucha_info(self, core_cfg: CoreConfig, cfg_dir: str) -> Tuple[bool, str]:
if not self.game_cfg.server.enable:
return (False, [], [])
return (True, SaoConstants.GAME_CDS, SaoConstants.NETID_PREFIX)
async def render_POST(self, request: Request) -> bytes:
endpoint = request.path_params.get('endpoint', '')
iv = b""
req_raw = await request.body()
if len(req_raw) < 40:
self.logger.warn(f"Malformed request to {endpoint} - {req_raw.hex()}")
return Response()
req_header = SaoRequestHeader(req_raw)
cmd_str = f"{req_header.cmd:04x}"
if self.game_cfg.hash.verify_hash and self.static_hash != req_header.hash:
self.logger.error(f"Hash mismatch! Expecting {self.static_hash} but recieved {req_header.hash}")
return Response()
if self.game_cfg.crypt.enable:
iv = req_raw[40:48]
cipher = Blowfish.new(self.game_cfg.crypt.key.encode(), Blowfish.MODE_CBC, iv)
crypt_data = req_raw[48:]
req_data = cipher.decrypt(crypt_data)
self.logger.debug(f"Decrypted {req_data.hex()} with IV {iv.hex()}")
req_data = req_raw[40:]
handler = getattr(self.base, f"handle_{cmd_str}", self.base.handle_noop)
self.logger.info(f"{endpoint} - {cmd_str} request")
self.logger.debug(f"Request: {req_raw.hex()}")
resp = await handler(req_header, req_data)
if resp is None:
resp = SaoNoopResponse(req_header.cmd + 1).make()
if type(resp) == bytes:
elif issubclass(resp, SaoBaseResponse):
resp = resp.make()
self.logger.error(f"Unknown response type {type(resp)}")
return Response()
self.logger.debug(f"Response: {resp.hex()}")
if self.game_cfg.crypt.enable:
iv = secrets.token_bytes(8)
data_to_crypt = resp[24:]
while len(data_to_crypt) % 8 != 0:
data_to_crypt += b"\x00"
cipher = Blowfish.new(self.game_cfg.crypt.key.encode(), Blowfish.MODE_CBC, iv)
data_crypt = cipher.encrypt(data_to_crypt)
crypt_data_len = len(data_crypt) + len(iv)
tmp = struct.pack("!I", crypt_data_len) # does it want the length of the encrypted response??
resp = resp[:20] + tmp + iv + data_crypt
self.logger.debug(f"Encrypted Response: {resp.hex()}")
return Response(resp, media_type="text/html; charset=utf-8") |