547 lines
19 KiB
Python
547 lines
19 KiB
Python
|
import copy
|
||
|
import struct
|
||
|
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||
|
|
||
|
from bemani.protocol.stream import InputStream
|
||
|
from bemani.protocol.node import Node
|
||
|
|
||
|
|
||
|
class XmlEncodingException(Exception):
|
||
|
"""
|
||
|
An exception that is thrown when we encounter an error encoding to or decoding from XML.
|
||
|
"""
|
||
|
|
||
|
|
||
|
class XmlDecoder:
|
||
|
"""
|
||
|
A hand-rolled XML parser, suitable for parsing old-style XML documents in
|
||
|
game data or from legacy game traffic. I did consider using lxml and other
|
||
|
data stores, but they insist on mangling data inside binary/string blobs
|
||
|
making them unsuitable for a protocol with exact specifications.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, data: bytes, encoding: str) -> None:
|
||
|
"""
|
||
|
Initialize the XML decoder.
|
||
|
|
||
|
Parameters:
|
||
|
data - String XML data which should be decoded into Nodes.
|
||
|
encoding - The expected encoding of the XML.
|
||
|
"""
|
||
|
self.stream = InputStream(data)
|
||
|
self.root: Optional[Node] = None
|
||
|
self.current: List[Node] = []
|
||
|
self.encoding = encoding
|
||
|
|
||
|
def __start_element(self, tag: bytes, attributes: Dict[str, str]) -> None:
|
||
|
"""
|
||
|
Called when we encounter an element open tag. Also called when we encounter
|
||
|
an empty element. Creates a new node with the specified name and attributes.
|
||
|
|
||
|
Parameters:
|
||
|
tag - The string tag name.
|
||
|
attributes - A dictionary keyed by attribute name and whose values are the string
|
||
|
attribute values. This attribute values should already be decoded from
|
||
|
the XML's encoding.
|
||
|
"""
|
||
|
data_type = attributes.get('__type')
|
||
|
|
||
|
array_str = attributes.get('__count')
|
||
|
if array_str is not None:
|
||
|
array = True
|
||
|
else:
|
||
|
array = False
|
||
|
|
||
|
if data_type is None:
|
||
|
# Special case for nodes that don't have a type
|
||
|
node = Node(name=tag.decode('ascii'), type=Node.NODE_TYPE_VOID)
|
||
|
else:
|
||
|
# Get the data value
|
||
|
type_int = Node.typename_to_type(data_type)
|
||
|
if type_int is None:
|
||
|
raise XmlEncodingException('Invalid node type {} for node {}'.format(data_type, tag.decode('ascii')))
|
||
|
|
||
|
node = Node(name=tag.decode('ascii'), type=type_int, array=array)
|
||
|
|
||
|
# Now, do the attributes
|
||
|
for attr in attributes:
|
||
|
if attr == '__type' or attr == '__count':
|
||
|
# Skip these, handled
|
||
|
continue
|
||
|
else:
|
||
|
node.set_attribute(attr, attributes.get(attr))
|
||
|
|
||
|
self.current.append(node)
|
||
|
|
||
|
def __end_element(self, tag: bytes) -> None:
|
||
|
"""
|
||
|
Called when we encounter an element close tag. Also called when we encounter an empty element,
|
||
|
after __start_element is called. Does bookkeeping related to element order.
|
||
|
|
||
|
Parameters:
|
||
|
tag - The string tag name.
|
||
|
"""
|
||
|
node = self.current.pop()
|
||
|
|
||
|
if node.name != tag.decode('ascii'):
|
||
|
raise Exception('Logic error, expected {} but got {}'.format(tag.decode('ascii'), node.name))
|
||
|
|
||
|
if len(self.current) == 0:
|
||
|
self.root = node
|
||
|
else:
|
||
|
parent = self.current[-1]
|
||
|
parent.add_child(node)
|
||
|
|
||
|
def __yield_values(self, text: str) -> Iterator[str]:
|
||
|
value = ''
|
||
|
|
||
|
for c in text:
|
||
|
if c.isspace():
|
||
|
if len(value) > 0:
|
||
|
yield value
|
||
|
value = ''
|
||
|
else:
|
||
|
value = value + c
|
||
|
|
||
|
if len(value) > 0:
|
||
|
yield value
|
||
|
|
||
|
def __text(self, text: bytes) -> None:
|
||
|
"""
|
||
|
Called when we finish parsing arbitrary non-element text. Note that the text passed in is in
|
||
|
the XML document's encoding and it is this function's responsibility to decode it.
|
||
|
|
||
|
Parameters:
|
||
|
text - String text value of the node, as encoded by the XML document's encoding.
|
||
|
"""
|
||
|
try:
|
||
|
value = text.decode(self.encoding)
|
||
|
except UnicodeDecodeError:
|
||
|
raise XmlEncodingException('Failed to decode text node with given encoding')
|
||
|
|
||
|
if len(self.current) > 0:
|
||
|
data_type = self.current[-1].data_type
|
||
|
composite = self.current[-1].is_composite
|
||
|
array = self.current[-1].is_array
|
||
|
|
||
|
if data_type == 'void':
|
||
|
# We can't handle this
|
||
|
return
|
||
|
|
||
|
if data_type == 'str':
|
||
|
# Do nothing, already fine
|
||
|
value = value.replace('&', '&')
|
||
|
value = value.replace('<', '<')
|
||
|
value = value.replace('>', '>')
|
||
|
value = value.replace(''', '\'')
|
||
|
value = value.replace('"', '\"')
|
||
|
if self.current[-1].value is None:
|
||
|
self.current[-1].set_value(value)
|
||
|
else:
|
||
|
self.current[-1].set_value(self.current[-1].value + value)
|
||
|
elif data_type == 'bin':
|
||
|
# Convert from a hex string
|
||
|
def hex_to_bin(hexval: str) -> bytes:
|
||
|
intval = int(hexval, 16)
|
||
|
return struct.pack('>B', intval)
|
||
|
|
||
|
# Remove any spaces first
|
||
|
value = ''.join([c for c in value if not c.isspace()])
|
||
|
if self.current[-1].value is None:
|
||
|
self.current[-1].set_value(b''.join([hex_to_bin(value[i:(i + 2)]) for i in range(0, len(value), 2)]))
|
||
|
else:
|
||
|
self.current[-1].set_value(self.current[-1].value + b''.join([hex_to_bin(value[i:(i + 2)]) for i in range(0, len(value), 2)]))
|
||
|
elif data_type == 'ip4':
|
||
|
# Do nothing, already fine
|
||
|
self.current[-1].set_value(value)
|
||
|
elif data_type == 'bool':
|
||
|
def conv_bool(val: str) -> bool:
|
||
|
if val and val.lower() in ['0', 'false']:
|
||
|
return False
|
||
|
else:
|
||
|
return True
|
||
|
|
||
|
if array or composite:
|
||
|
self.current[-1].set_value([conv_bool(v) for v in self.__yield_values(value)])
|
||
|
else:
|
||
|
self.current[-1].set_value(conv_bool(value))
|
||
|
elif data_type == 'float':
|
||
|
if array or composite:
|
||
|
self.current[-1].set_value([float(v) for v in self.__yield_values(value)])
|
||
|
else:
|
||
|
self.current[-1].set_value(float(value))
|
||
|
else:
|
||
|
if array or composite:
|
||
|
self.current[-1].set_value([int(v) for v in self.__yield_values(value)])
|
||
|
else:
|
||
|
self.current[-1].set_value(int(value))
|
||
|
|
||
|
def __parse_attributes(self, attributes: bytes) -> Dict[str, str]:
|
||
|
"""
|
||
|
Given a string representing zero or more possible attributes, parse them into
|
||
|
a dictionary.
|
||
|
|
||
|
Returns:
|
||
|
A dictionary keyed by the attribute name and who's values are unescaped strings.
|
||
|
If no attributes exist, this returns an empty dictionary.
|
||
|
"""
|
||
|
attr_stream = InputStream(attributes)
|
||
|
parsed_attrs: Dict[str, str] = {}
|
||
|
state = 'space'
|
||
|
attr = b''
|
||
|
val = b''
|
||
|
|
||
|
def unescape(value: bytes) -> str:
|
||
|
val = value.decode(self.encoding)
|
||
|
val = val.replace('&', '&')
|
||
|
val = val.replace('<', '<')
|
||
|
val = val.replace('>', '>')
|
||
|
val = val.replace(''', '\'')
|
||
|
val = val.replace('"', '\"')
|
||
|
val = val.replace(' ', '\r')
|
||
|
return val.replace(' ', '\n')
|
||
|
|
||
|
while True:
|
||
|
c = attr_stream.read_byte()
|
||
|
|
||
|
if c is None:
|
||
|
return parsed_attrs
|
||
|
if state == 'space':
|
||
|
if not c.isspace():
|
||
|
state = 'attr'
|
||
|
attr = c
|
||
|
elif state == 'attr':
|
||
|
if c == b'=':
|
||
|
attr = attr.strip()
|
||
|
state = 'valstart'
|
||
|
else:
|
||
|
attr = attr + c
|
||
|
elif state == 'valstart':
|
||
|
if c == b'"':
|
||
|
state = 'valdouble'
|
||
|
val = b''
|
||
|
elif c == b'\'':
|
||
|
state = 'valsingle'
|
||
|
val = b''
|
||
|
elif state == 'valdouble':
|
||
|
if c == b'"':
|
||
|
state = 'space'
|
||
|
parsed_attrs[attr.decode('ascii')] = unescape(val)
|
||
|
else:
|
||
|
val = val + c
|
||
|
elif state == 'valsingle':
|
||
|
if c == b'\'':
|
||
|
state = 'space'
|
||
|
parsed_attrs[attr.decode('ascii')] = unescape(val)
|
||
|
else:
|
||
|
val = val + c
|
||
|
|
||
|
def __split_node(self, content: bytes) -> Tuple[bytes, bytes]:
|
||
|
node_stream = InputStream(content)
|
||
|
tag = b''
|
||
|
attributes = b''
|
||
|
state = "tag"
|
||
|
|
||
|
while True:
|
||
|
c = node_stream.read_byte()
|
||
|
|
||
|
if c is None:
|
||
|
break
|
||
|
if state == "tag":
|
||
|
if c.isspace():
|
||
|
state = "space"
|
||
|
else:
|
||
|
tag = tag + c
|
||
|
elif state == "space":
|
||
|
if not c.isspace():
|
||
|
attributes = c
|
||
|
state = "attributes"
|
||
|
elif state == "attributes":
|
||
|
attributes = attributes + c
|
||
|
|
||
|
return (tag, attributes)
|
||
|
|
||
|
def __handle_node(self, content: bytes) -> None:
|
||
|
"""
|
||
|
Called whenever we encounter any node type. Filters out special nodes,
|
||
|
determines whether this is a start, end or empty node, and fires off
|
||
|
calls to the respective __start_element and __end_element functions.
|
||
|
|
||
|
Parameters:
|
||
|
The node contents, minus the < and > characters. This will be encoded
|
||
|
in the XML document's encoding.
|
||
|
"""
|
||
|
if content[:1] == b'?' and content[-1:] == b'?':
|
||
|
# Special node, parse to get the encoding.
|
||
|
tag, attributes = self.__split_node(content[1:-1])
|
||
|
if tag == b'xml':
|
||
|
attributes_dict = self.__parse_attributes(attributes)
|
||
|
if 'encoding' in attributes_dict:
|
||
|
self.encoding = attributes_dict['encoding']
|
||
|
return
|
||
|
|
||
|
if content[:1] == b'/':
|
||
|
# We got an element end
|
||
|
self.__end_element(content[1:])
|
||
|
else:
|
||
|
# We got a start element
|
||
|
if content[-1:] == b'/':
|
||
|
# This is an empty element
|
||
|
empty = True
|
||
|
content = content[:-1]
|
||
|
else:
|
||
|
# This node has subnodes or text
|
||
|
empty = False
|
||
|
|
||
|
tag, attributes = self.__split_node(content)
|
||
|
self.__start_element(tag, self.__parse_attributes(attributes))
|
||
|
if empty:
|
||
|
self.__end_element(tag)
|
||
|
|
||
|
def get_tree(self) -> Node:
|
||
|
"""
|
||
|
Walk the XML document and parse into nodes.
|
||
|
|
||
|
Returns:
|
||
|
A Node object representing the root of the XML document.
|
||
|
"""
|
||
|
state = 'text'
|
||
|
text = b''
|
||
|
node = b''
|
||
|
|
||
|
while True:
|
||
|
c = self.stream.read_byte()
|
||
|
|
||
|
if c is None:
|
||
|
return self.root
|
||
|
elif state == 'text':
|
||
|
if c == b'<':
|
||
|
self.__text(text)
|
||
|
state = 'node'
|
||
|
node = b''
|
||
|
else:
|
||
|
text = text + c
|
||
|
elif state == 'node':
|
||
|
if c == b'>':
|
||
|
self.__handle_node(node)
|
||
|
state = 'text'
|
||
|
text = b''
|
||
|
else:
|
||
|
node = node + c
|
||
|
|
||
|
|
||
|
class XmlEncoder:
|
||
|
def __init__(self, tree: Node, encoding: str) -> None:
|
||
|
"""
|
||
|
Initialize the XML encoder.
|
||
|
|
||
|
Parameters:
|
||
|
tree - A binary blob of data to be decoded
|
||
|
encoding - A string representing the text encoding for string elements. Should be either
|
||
|
'shift-jis', 'euc-jp' or 'utf-8'
|
||
|
"""
|
||
|
self.tree = tree
|
||
|
self.encoding = encoding
|
||
|
|
||
|
def get_data(self) -> bytes:
|
||
|
magic = '<?xml version="1.0" encoding="{}"?>'.format(self.encoding).encode('ascii')
|
||
|
payload = self.to_xml(self.tree)
|
||
|
|
||
|
return magic + payload
|
||
|
|
||
|
def to_xml(self, node: Node) -> bytes:
|
||
|
"""
|
||
|
Convert this node, attributes and all children to an XML-like representation of the tree.
|
||
|
|
||
|
Parameters:
|
||
|
node: A Node representing the root of the tree to be encoded.
|
||
|
|
||
|
Returns:
|
||
|
Bytes representing the XML-like data for this node and all children.
|
||
|
"""
|
||
|
attrs_dict = copy.deepcopy(node.attributes)
|
||
|
order = sorted(attrs_dict.keys())
|
||
|
if node.data_length != 0:
|
||
|
# Represent type and length
|
||
|
if node.is_array:
|
||
|
if node.value is None:
|
||
|
attrs_dict['__count'] = '0'
|
||
|
else:
|
||
|
attrs_dict['__count'] = str(len(node.value))
|
||
|
order.insert(0, '__count')
|
||
|
attrs_dict['__type'] = node.data_type
|
||
|
order.insert(0, '__type')
|
||
|
|
||
|
def escape(val: Any, attr: bool=False) -> bytes:
|
||
|
if isinstance(val, str):
|
||
|
val = val.replace('&', '&')
|
||
|
val = val.replace('<', '<')
|
||
|
val = val.replace('>', '>')
|
||
|
val = val.replace('\'', ''')
|
||
|
val = val.replace('\"', '"')
|
||
|
if attr:
|
||
|
val = val.replace('\r', ' ')
|
||
|
val = val.replace('\n', ' ')
|
||
|
|
||
|
return val.encode(self.encoding)
|
||
|
else:
|
||
|
return str(val).encode('ascii')
|
||
|
|
||
|
if attrs_dict:
|
||
|
attrs = b' ' + b' '.join([b''.join([attr.encode('ascii'), b'="', escape(attrs_dict[attr], attr=True), b'"']) for attr in order])
|
||
|
else:
|
||
|
attrs = b''
|
||
|
|
||
|
if node.children:
|
||
|
# Has children nodes
|
||
|
children = [self.to_xml(child) for child in node.children]
|
||
|
string = b''.join([
|
||
|
b'<',
|
||
|
node.name.encode('ascii'),
|
||
|
attrs,
|
||
|
b'>',
|
||
|
b''.join(children),
|
||
|
b'</',
|
||
|
node.name.encode('ascii'),
|
||
|
b'>',
|
||
|
])
|
||
|
else:
|
||
|
# Doesn't have children nodes
|
||
|
if node.data_length == 0:
|
||
|
# Void node
|
||
|
string = b''.join([
|
||
|
b'<',
|
||
|
node.name.encode('ascii'),
|
||
|
attrs,
|
||
|
b'/>',
|
||
|
])
|
||
|
else:
|
||
|
# Node with values
|
||
|
if node.is_array or node.is_composite:
|
||
|
if node.value is None:
|
||
|
vals = ''
|
||
|
else:
|
||
|
if node.data_type == 'bool':
|
||
|
vals = ' '.join([('1' if val else '0') for val in node.value])
|
||
|
else:
|
||
|
vals = ' '.join([str(val) for val in node.value])
|
||
|
binary = vals.encode('ascii')
|
||
|
elif node.data_type == 'str':
|
||
|
binary = escape(node.value)
|
||
|
elif node.data_type == 'bool':
|
||
|
binary = b'1' if node.value else b'0'
|
||
|
elif node.data_type == 'ip4':
|
||
|
vals = '.'.join([str(val) for val in node.value])
|
||
|
binary = vals.encode('ascii')
|
||
|
elif node.data_type == 'bin':
|
||
|
# Convert to a hex string
|
||
|
def bin_to_hex(binary: int) -> str:
|
||
|
val = hex(binary)[2:]
|
||
|
while len(val) < 2:
|
||
|
val = '0' + val
|
||
|
return val
|
||
|
|
||
|
vals = ''.join([bin_to_hex(v) for v in node.value])
|
||
|
binary = vals.encode('ascii')
|
||
|
else:
|
||
|
vals = str(node.value)
|
||
|
binary = vals.encode('ascii')
|
||
|
|
||
|
string = b''.join([
|
||
|
b'<',
|
||
|
node.name.encode('ascii'),
|
||
|
attrs,
|
||
|
b'>',
|
||
|
binary,
|
||
|
b'</',
|
||
|
node.name.encode('ascii'),
|
||
|
b'>',
|
||
|
])
|
||
|
|
||
|
return string
|
||
|
|
||
|
|
||
|
class XmlEncoding:
|
||
|
"""
|
||
|
Wrapper class representing an XML encoding.
|
||
|
"""
|
||
|
# The string values should match the constants in EAmuseProtocol.
|
||
|
# I have no better way to link these than to write this comment,
|
||
|
# as otherwise we would have a circular dependency.
|
||
|
ACCEPTED_ENCODINGS = ['shift-jis', 'euc-jp', 'utf-8', 'ascii']
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
"""
|
||
|
Initialize the encoding object.
|
||
|
"""
|
||
|
self.encoding: Optional[str] = None
|
||
|
|
||
|
def __fix_encoding(self, encoding: str) -> str:
|
||
|
"""
|
||
|
Given an encoding, try to normalize it, looking for specific ways that
|
||
|
older games might send it back.
|
||
|
|
||
|
Parameters:
|
||
|
encoding - The encoding we want to normalize.
|
||
|
|
||
|
Returns:
|
||
|
A new encoding string that is equivalent but normalized.
|
||
|
"""
|
||
|
encoding = encoding.lower()
|
||
|
encoding = encoding.replace('_', '-')
|
||
|
return encoding
|
||
|
|
||
|
def decode(self, data: bytes, skip_on_exceptions: bool=False) -> Optional[Node]:
|
||
|
"""
|
||
|
Given a data blob, decode the data with the current encoding. Will set
|
||
|
the class property value 'encoding' to the encoding used on the last
|
||
|
decode.
|
||
|
|
||
|
Parameters:
|
||
|
data - Blob of text representing the data to decode.
|
||
|
|
||
|
Returns:
|
||
|
Node object representing the root of the decoded tree, or None
|
||
|
if we couldn't decode the object for some reason.
|
||
|
"""
|
||
|
# Always assume this, unless we get told otherwise in the XML
|
||
|
self.encoding = 'shift-jis'
|
||
|
|
||
|
# Decode property/value
|
||
|
try:
|
||
|
xml = XmlDecoder(data, self.encoding)
|
||
|
tree = xml.get_tree()
|
||
|
self.encoding = xml.encoding
|
||
|
return tree
|
||
|
except XmlEncodingException:
|
||
|
if skip_on_exceptions:
|
||
|
return None
|
||
|
else:
|
||
|
raise
|
||
|
|
||
|
def encode(self, tree: Node, encoding: Optional[str]=None) -> bytes:
|
||
|
"""
|
||
|
Given a tree of Node objects, encode the data with the current encoding.
|
||
|
|
||
|
Parameters:
|
||
|
tree - Node tree representing the data to encode
|
||
|
encoding - The text encoding to use. If None, will try to use the encoding from
|
||
|
the last successful decode
|
||
|
|
||
|
Returns:
|
||
|
String blob representing encoded data as XML.
|
||
|
"""
|
||
|
# Ensure we got the right encoding
|
||
|
if encoding is None:
|
||
|
encoding = self.encoding
|
||
|
if encoding is None:
|
||
|
raise XmlEncodingException('Unknown encoding')
|
||
|
|
||
|
encoding = self.__fix_encoding(encoding)
|
||
|
if encoding not in XmlEncoding.ACCEPTED_ENCODINGS:
|
||
|
# XML pages only support shift-jis
|
||
|
raise XmlEncodingException("Invalid text encoding {}".format(encoding))
|
||
|
|
||
|
xml = XmlEncoder(tree, encoding)
|
||
|
return xml.get_data()
|