1
0
mirror of synced 2025-01-22 11:33:39 +01:00

548 lines
19 KiB
Python

import copy
import struct
from typing import Any, Dict, Iterator, List, Optional, Tuple
from typing_extensions import Final
from bemani.protocol.stream import InputStream
from bemani.protocol.node import Node
class XmlEncodingException(Exception):
"""
An exception that is thrown when we encounter an error encoding to or decoding from XML.
"""
class XmlDecoder:
"""
A hand-rolled XML parser, suitable for parsing old-style XML documents in
game data or from legacy game traffic. I did consider using lxml and other
data stores, but they insist on mangling data inside binary/string blobs
making them unsuitable for a protocol with exact specifications.
"""
def __init__(self, data: bytes, encoding: str) -> None:
"""
Initialize the XML decoder.
Parameters:
data - String XML data which should be decoded into Nodes.
encoding - The expected encoding of the XML.
"""
self.stream = InputStream(data)
self.root: Optional[Node] = None
self.current: List[Node] = []
self.encoding = encoding
def __start_element(self, tag: bytes, attributes: Dict[str, str]) -> None:
"""
Called when we encounter an element open tag. Also called when we encounter
an empty element. Creates a new node with the specified name and attributes.
Parameters:
tag - The string tag name.
attributes - A dictionary keyed by attribute name and whose values are the string
attribute values. This attribute values should already be decoded from
the XML's encoding.
"""
data_type = attributes.get('__type')
array_str = attributes.get('__count')
if array_str is not None:
array = True
else:
array = False
if data_type is None:
# Special case for nodes that don't have a type
node = Node(name=tag.decode('ascii'), type=Node.NODE_TYPE_VOID)
else:
# Get the data value
type_int = Node.typename_to_type(data_type)
if type_int is None:
raise XmlEncodingException(f'Invalid node type {data_type} for node {tag.decode("ascii")}')
node = Node(name=tag.decode('ascii'), type=type_int, array=array)
# Now, do the attributes
for attr in attributes:
if attr == '__type' or attr == '__count':
# Skip these, handled
continue
else:
node.set_attribute(attr, attributes[attr])
self.current.append(node)
def __end_element(self, tag: bytes) -> None:
"""
Called when we encounter an element close tag. Also called when we encounter an empty element,
after __start_element is called. Does bookkeeping related to element order.
Parameters:
tag - The string tag name.
"""
node = self.current.pop()
if node.name != tag.decode('ascii'):
raise Exception(f'Logic error, expected {tag.decode("ascii")} but got {node.name}')
if len(self.current) == 0:
self.root = node
else:
parent = self.current[-1]
parent.add_child(node)
def __yield_values(self, text: str) -> Iterator[str]:
value = ''
for c in text:
if c.isspace():
if len(value) > 0:
yield value
value = ''
else:
value = value + c
if len(value) > 0:
yield value
def __text(self, text: bytes) -> None:
"""
Called when we finish parsing arbitrary non-element text. Note that the text passed in is in
the XML document's encoding and it is this function's responsibility to decode it.
Parameters:
text - String text value of the node, as encoded by the XML document's encoding.
"""
try:
value = text.decode(self.encoding)
except UnicodeDecodeError:
raise XmlEncodingException('Failed to decode text node with given encoding')
if len(self.current) > 0:
data_type = self.current[-1].data_type
composite = self.current[-1].is_composite
array = self.current[-1].is_array
if data_type == 'void':
# We can't handle this
return
if data_type == 'str':
# Do nothing, already fine
value = value.replace('&', '&')
value = value.replace('&lt;', '<')
value = value.replace('&gt;', '>')
value = value.replace('&apos;', '\'')
value = value.replace('&quot;', '\"')
if self.current[-1].value is None:
self.current[-1].set_value(value)
else:
self.current[-1].set_value(self.current[-1].value + value)
elif data_type == 'bin':
# Convert from a hex string
def hex_to_bin(hexval: str) -> bytes:
intval = int(hexval, 16)
return struct.pack('>B', intval)
# Remove any spaces first
value = ''.join([c for c in value if not c.isspace()])
if self.current[-1].value is None:
self.current[-1].set_value(b''.join([hex_to_bin(value[i:(i + 2)]) for i in range(0, len(value), 2)]))
else:
self.current[-1].set_value(self.current[-1].value + b''.join([hex_to_bin(value[i:(i + 2)]) for i in range(0, len(value), 2)]))
elif data_type == 'ip4':
# Do nothing, already fine
self.current[-1].set_value(value)
elif data_type == 'bool':
def conv_bool(val: str) -> bool:
if val and val.lower() in ['0', 'false']:
return False
else:
return True
if array or composite:
self.current[-1].set_value([conv_bool(v) for v in self.__yield_values(value)])
else:
self.current[-1].set_value(conv_bool(value))
elif data_type == 'float':
if array or composite:
self.current[-1].set_value([float(v) for v in self.__yield_values(value)])
else:
self.current[-1].set_value(float(value))
else:
if array or composite:
self.current[-1].set_value([int(v) for v in self.__yield_values(value)])
else:
self.current[-1].set_value(int(value))
def __parse_attributes(self, attributes: bytes) -> Dict[str, str]:
"""
Given a string representing zero or more possible attributes, parse them into
a dictionary.
Returns:
A dictionary keyed by the attribute name and who's values are unescaped strings.
If no attributes exist, this returns an empty dictionary.
"""
attr_stream = InputStream(attributes)
parsed_attrs: Dict[str, str] = {}
state = 'space'
attr = b''
val = b''
def unescape(value: bytes) -> str:
val = value.decode(self.encoding)
val = val.replace('&amp;', '&')
val = val.replace('&lt;', '<')
val = val.replace('&gt;', '>')
val = val.replace('&apos;', '\'')
val = val.replace('&quot;', '\"')
val = val.replace('&#13;', '\r')
return val.replace('&#10;', '\n')
while True:
c = attr_stream.read_byte()
if c is None:
return parsed_attrs
if state == 'space':
if not c.isspace():
state = 'attr'
attr = c
elif state == 'attr':
if c == b'=':
attr = attr.strip()
state = 'valstart'
else:
attr = attr + c
elif state == 'valstart':
if c == b'"':
state = 'valdouble'
val = b''
elif c == b'\'':
state = 'valsingle'
val = b''
elif state == 'valdouble':
if c == b'"':
state = 'space'
parsed_attrs[attr.decode('ascii')] = unescape(val)
else:
val = val + c
elif state == 'valsingle':
if c == b'\'':
state = 'space'
parsed_attrs[attr.decode('ascii')] = unescape(val)
else:
val = val + c
def __split_node(self, content: bytes) -> Tuple[bytes, bytes]:
node_stream = InputStream(content)
tag = b''
attributes = b''
state = "tag"
while True:
c = node_stream.read_byte()
if c is None:
break
if state == "tag":
if c.isspace():
state = "space"
else:
tag = tag + c
elif state == "space":
if not c.isspace():
attributes = c
state = "attributes"
elif state == "attributes":
attributes = attributes + c
return (tag, attributes)
def __handle_node(self, content: bytes) -> None:
"""
Called whenever we encounter any node type. Filters out special nodes,
determines whether this is a start, end or empty node, and fires off
calls to the respective __start_element and __end_element functions.
Parameters:
The node contents, minus the < and > characters. This will be encoded
in the XML document's encoding.
"""
if content[:1] == b'?' and content[-1:] == b'?':
# Special node, parse to get the encoding.
tag, attributes = self.__split_node(content[1:-1])
if tag == b'xml':
attributes_dict = self.__parse_attributes(attributes)
if 'encoding' in attributes_dict:
self.encoding = attributes_dict['encoding']
return
if content[:1] == b'/':
# We got an element end
self.__end_element(content[1:])
else:
# We got a start element
if content[-1:] == b'/':
# This is an empty element
empty = True
content = content[:-1]
else:
# This node has subnodes or text
empty = False
tag, attributes = self.__split_node(content)
self.__start_element(tag, self.__parse_attributes(attributes))
if empty:
self.__end_element(tag)
def get_tree(self) -> Optional[Node]:
"""
Walk the XML document and parse into nodes.
Returns:
A Node object representing the root of the XML document.
"""
state = 'text'
text = b''
node = b''
while True:
c = self.stream.read_byte()
if c is None:
return self.root
elif state == 'text':
if c == b'<':
self.__text(text)
state = 'node'
node = b''
else:
text = text + c
elif state == 'node':
if c == b'>':
self.__handle_node(node)
state = 'text'
text = b''
else:
node = node + c
class XmlEncoder:
def __init__(self, tree: Node, encoding: str) -> None:
"""
Initialize the XML encoder.
Parameters:
tree - A binary blob of data to be decoded
encoding - A string representing the text encoding for string elements. Should be either
'shift-jis', 'euc-jp', 'utf-8' or 'ascii'.
"""
self.tree = tree
self.encoding = encoding
def get_data(self) -> bytes:
magic = f'<?xml version="1.0" encoding="{self.encoding}"?>'.encode('ascii')
payload = self.to_xml(self.tree)
return magic + payload
def to_xml(self, node: Node) -> bytes:
"""
Convert this node, attributes and all children to an XML-like representation of the tree.
Parameters:
node: A Node representing the root of the tree to be encoded.
Returns:
Bytes representing the XML-like data for this node and all children.
"""
attrs_dict = copy.deepcopy(node.attributes)
order = sorted(attrs_dict.keys())
if node.data_length != 0:
# Represent type and length
if node.is_array:
if node.value is None:
attrs_dict['__count'] = '0'
else:
attrs_dict['__count'] = str(len(node.value))
order.insert(0, '__count')
attrs_dict['__type'] = node.data_type
order.insert(0, '__type')
def escape(val: Any, attr: bool=False) -> bytes:
if isinstance(val, str):
val = val.replace('&', '&amp;')
val = val.replace('<', '&lt;')
val = val.replace('>', '&gt;')
val = val.replace('\'', '&apos;')
val = val.replace('\"', '&quot;')
if attr:
val = val.replace('\r', '&#13;')
val = val.replace('\n', '&#10;')
return val.encode(self.encoding)
else:
return str(val).encode('ascii')
if attrs_dict:
attrs = b' ' + b' '.join([b''.join([attr.encode('ascii'), b'="', escape(attrs_dict[attr], attr=True), b'"']) for attr in order])
else:
attrs = b''
if node.children:
# Has children nodes
children = [self.to_xml(child) for child in node.children]
string = b''.join([
b'<',
node.name.encode('ascii'),
attrs,
b'>',
b''.join(children),
b'</',
node.name.encode('ascii'),
b'>',
])
else:
# Doesn't have children nodes
if node.data_length == 0:
# Void node
string = b''.join([
b'<',
node.name.encode('ascii'),
attrs,
b'/>',
])
else:
# Node with values
if node.is_array or node.is_composite:
if node.value is None:
vals = ''
else:
if node.data_type == 'bool':
vals = ' '.join([('1' if val else '0') for val in node.value])
else:
vals = ' '.join([str(val) for val in node.value])
binary = vals.encode('ascii')
elif node.data_type == 'str':
binary = escape(node.value)
elif node.data_type == 'bool':
binary = b'1' if node.value else b'0'
elif node.data_type == 'ip4':
vals = '.'.join([str(val) for val in node.value])
binary = vals.encode('ascii')
elif node.data_type == 'bin':
# Convert to a hex string
def bin_to_hex(binary: int) -> str:
val = hex(binary)[2:]
while len(val) < 2:
val = '0' + val
return val
vals = ''.join([bin_to_hex(v) for v in node.value])
binary = vals.encode('ascii')
else:
vals = str(node.value)
binary = vals.encode('ascii')
string = b''.join([
b'<',
node.name.encode('ascii'),
attrs,
b'>',
binary,
b'</',
node.name.encode('ascii'),
b'>',
])
return string
class XmlEncoding:
"""
Wrapper class representing an XML encoding.
"""
# The string values should match the constants in EAmuseProtocol.
# I have no better way to link these than to write this comment,
# as otherwise we would have a circular dependency.
ACCEPTED_ENCODINGS: Final[List[str]] = ['shift-jis', 'euc-jp', 'utf-8', 'ascii']
def __init__(self) -> None:
"""
Initialize the encoding object.
"""
self.encoding: Optional[str] = None
def __fix_encoding(self, encoding: str) -> str:
"""
Given an encoding, try to normalize it, looking for specific ways that
older games might send it back.
Parameters:
encoding - The encoding we want to normalize.
Returns:
A new encoding string that is equivalent but normalized.
"""
encoding = encoding.lower()
encoding = encoding.replace('_', '-')
return encoding
def decode(self, data: bytes, skip_on_exceptions: bool=False) -> Optional[Node]:
"""
Given a data blob, decode the data with the current encoding. Will set
the class property value 'encoding' to the encoding used on the last
decode.
Parameters:
data - Blob of text representing the data to decode.
Returns:
Node object representing the root of the decoded tree, or None
if we couldn't decode the object for some reason.
"""
# Always assume this, unless we get told otherwise in the XML
self.encoding = 'shift-jis'
# Decode property/value
try:
xml = XmlDecoder(data, self.encoding)
tree = xml.get_tree()
self.encoding = xml.encoding
return tree
except XmlEncodingException:
if skip_on_exceptions:
return None
else:
raise
def encode(self, tree: Node, encoding: Optional[str]=None) -> bytes:
"""
Given a tree of Node objects, encode the data with the current encoding.
Parameters:
tree - Node tree representing the data to encode
encoding - The text encoding to use. If None, will try to use the encoding from
the last successful decode
Returns:
String blob representing encoded data as XML.
"""
# Ensure we got the right encoding
if encoding is None:
encoding = self.encoding
if encoding is None:
raise XmlEncodingException('Unknown encoding')
encoding = self.__fix_encoding(encoding)
if encoding not in XmlEncoding.ACCEPTED_ENCODINGS:
# XML pages only support a few encodings.
raise XmlEncodingException(f"Invalid text encoding {encoding}")
xml = XmlEncoder(tree, encoding)
return xml.get_data()