1
0
mirror of synced 2024-11-28 07:50:51 +01:00
bemaniutils/bemani/protocol/lz77.py

390 lines
15 KiB
Python

from collections import defaultdict
from typing import Generator, List, Mapping, Optional, Set, Tuple
class LzException(Exception):
"""
An exception thrown when we encounter an error with Lz77 encoding/decoding.
"""
class Lz77Decompress:
"""
A class that can decompress an Lz77 stream of data. Notably, this is a different
variant to the Lz77 found in firebeat executables and BIOS. This is used for
over-the-wire compression of XML data, as well as compression inside a decent
amount of file formats found in various Konami games.
"""
RING_LENGTH = 0x1000
FLAG_COPY = 1
FLAG_BACKREF = 0
def __init__(self, data: bytes, backref: Optional[int] = None) -> None:
"""
Initialize the object.
Parameters:
data - Binary blob representing the data to be decompressed.
"""
self.eof: bool = False
self.data: bytes = data
self.read_pos: int = 0
self.left: int = len(self.data)
self.flags: int = 1
self.write_pos: int = 0
self.pending_copy_amount: int = 0
self.pending_copy_pos: int = 0
self.pending_copy_max: int = 0
self.ringlength: int = backref or self.RING_LENGTH
self.ring: bytes = b'\x00' * self.ringlength
def __ring_read(self, copy_pos: int, copy_len: int) -> Generator[bytes, None, None]:
"""
Read the next bytes from the backref ring at the current copy position.
Returns:
a generator which yields bytes.
"""
while copy_len > 0:
if copy_pos + copy_len > self.ringlength:
# Copy first chunk, let subsequent loop handle the next chunks
amount = self.ringlength - copy_pos
else:
# Copy the whole thing out, we have enough space to do so
amount = copy_len
ret = self.ring[copy_pos:(copy_pos + amount)]
self.__ring_write(ret)
yield ret
copy_pos = (copy_pos + amount) % self.ringlength
copy_len -= amount
def __ring_write(self, bytedata: bytes) -> None:
"""
Write bytes into the backref ring.
Parameters:
byte - A byte to be written at the current write offset
"""
while True:
amount = len(bytedata)
if amount == 0:
return
if amount > (self.ringlength - self.write_pos):
amount = self.ringlength - self.write_pos
self.ring = self.ring[:self.write_pos] + bytedata[:amount] + self.ring[(self.write_pos + amount):]
bytedata = bytedata[amount:]
self.write_pos = (self.write_pos + amount) % self.ringlength
def decompress_bytes(self) -> Generator[bytes, None, None]:
"""
Yield the next byte from the decompressed output. If we are
in a backref copy, read the next byte from the backref. If
we aren't, read the next flag to see if we should decode
a byte directly or if we should start another backref read.
If we don't have any flags, exit reporting EOF (None). If
we hit the end of the stream, stop yielding to signify EOF.
In all cases, whatever byte we read should be added back to
the backref buffer.
Returns:
a generator that yields bytes.
"""
while not self.eof:
if self.pending_copy_amount > 0:
# We had a backref that would have copied more data than we had available
# in the ringbuffer (because every read, even a backref, adds to the
# ringbuffer). So, since we read that last time and wrote it to the backbuffer
# we are safe to read again.
amount = min(self.pending_copy_amount, self.pending_copy_max)
yield from self.__ring_read(self.pending_copy_pos, amount)
# We read this many bytes and are about to write them to the ringbuffer,
# so bookkeep that.
self.pending_copy_amount -= amount
self.pending_copy_max = amount
else:
if self.flags == 1:
# Load the next byte for processing
if self.left == 0:
# We have nothing left to read, so skip the
# ringbuffer write below and exit early.
return
else:
self.flags = 0x100 | self.data[self.read_pos]
self.read_pos += 1
self.left -= 1
# Shift the lowest bit out to be retrieved as a flag
flag = self.flags & 1
self.flags >>= 1
if flag == self.FLAG_COPY:
# Figure out how much to pull at once
amount = 1
while self.flags != 1 and (self.flags & 1) == self.FLAG_COPY:
# We would do a copy next time, so pop that flag and just add to our read amount
self.flags >>= 1
amount += 1
# Grab chunk right out of the data source
b = self.data[self.read_pos:(self.read_pos + amount)]
self.__ring_write(b)
yield b
self.read_pos += amount
self.left -= amount
elif flag == self.FLAG_BACKREF:
yield from self.__read_backref()
else:
raise Exception("Logic error!")
def __read_backref(self) -> Generator[bytes, None, None]:
"""
Read a backref chunk. Grab the copy length and copy position
from the first two bytes and then read the first byte from
the backref. Sets up variables such that Lz77Decompress.__read()
can finish copying out of the backref on subsequent calls. Should
only be called by Lz77Decompress.__read(). If we discover the end
of stream, we don't generate any bytes and instead set the eof
flag which terminates the main decompression loop above.
Returns:
a generator that yields bytes.
"""
if self.left == 0:
self.eof = True
return
if self.left == 1:
raise LzException('Unexpected EOF mid-backref')
hi = self.data[self.read_pos]
lo = self.data[self.read_pos + 1]
self.read_pos += 2
self.left -= 2
copy_len = lo & 0xF
copy_pos = (hi << 4) | (lo >> 4)
if copy_pos > 0:
copy_len += 3
if copy_len > copy_pos:
# Remember what we have to do left, and the safe
# amount to copy next time (which is our length,
# since we are about to write that many butes to
# the ringbuffer right after reading them).
self.pending_copy_amount = copy_len - copy_pos
self.pending_copy_pos = self.write_pos
self.pending_copy_max = copy_pos
# Only copy the available bytes
copy_len = copy_pos
copy_pos = (self.write_pos - copy_pos)
while copy_pos < 0:
copy_pos += self.ringlength
copy_pos = copy_pos % self.ringlength
yield from self.__ring_read(copy_pos, copy_len)
else:
self.eof = True
return
class Lz77Compress:
"""
A class that can compress arbitrary binary data using the Lz77 protocol.
Note that this does support overlapped backtracks, so for instance the
string "abcabcabc" will be compressed properly (see unit tests for examples).
Great care has been taken in optimizing this and then we further optimize
by using Cython to build, netting us another 40% speed-up. This is important
because for any given packet we are decompressing and compressing at least
once, and if we use a proxy to direct traffic, possibly a second time.
"""
RING_LENGTH = 0x1000
FLAG_COPY = 1
FLAG_BACKREF = 0
def __init__(self, data: bytes, backref: Optional[int] = None) -> None:
"""
Initialize the object.
Parameters:
data - Binary blob representing the data to be decompressed.
"""
self.data: bytes = data
self.read_pos: int = 0
self.left: int = len(self.data)
self.eof: bool = False
self.bytes_written: int = 0
self.ringlength: int = backref or self.RING_LENGTH
self.locations: Mapping[int, Set[int]] = defaultdict(set)
self.starts: Mapping[bytes, Set[int]] = defaultdict(set)
self.last_start: Tuple[int, int, int] = (0, 0, 0)
def __ring_write(self, bytedata: bytes) -> None:
"""
Write bytes into the backref ring.
Parameters:
byte - A byte to be written at the current write offset
"""
for byte in bytedata:
# Update the start locations hashmap if we're past the beginning
self.last_start = (self.last_start[1], self.last_start[2], byte)
if self.bytes_written >= 2:
self.starts[bytes(self.last_start)].add(self.bytes_written - 2)
# Update the rest of the location hashmaps
self.locations[byte].add(self.bytes_written)
self.bytes_written += 1
def compress_bytes(self) -> Generator[bytes, None, None]:
"""
Given the current stream, go through and assemble the next flag byte
followed by the next chunk of compressed data.
"""
while not self.eof:
if self.left == 0:
# Output a dummy flag and an end of stream marker.
self.eof = True
yield b"\x00\x00\x00"
else:
# Need to assemble and return the next chunk, which is a flag
# byte and then 8 instructions.
flags = 0x0
flagpos = -1
data: List[bytes] = [b""] * 8
for _ in range(8):
# Track what flag we're generating data for
flagpos += 1
if self.left == 0:
# Output the end of stream marker, set EOF since we've succeeded
# in outputting all flags.
flags |= self.FLAG_BACKREF << flagpos
data[flagpos] = b"\x00\x00"
self.eof = True
break
elif self.left < 3 or self.bytes_written < 3:
# We either don't have enough data written to backref, or we
# don't have enough data in the stream that could be made into
# a backref.
flags |= self.FLAG_COPY << flagpos
chunk = self.data[self.read_pos:(self.read_pos + 1)]
data[flagpos] = chunk
self.__ring_write(chunk)
self.read_pos += 1
self.left -= 1
continue
# Figure out the maximum backref we can attempt to find
backref_amount = min(self.left, 18)
# Iterate over all spots where the first byte equals, and is in range.
earliest = max(0, self.bytes_written - (self.ringlength - 1))
possible_backref_locations: List[int] = [
absolute_pos for absolute_pos in self.starts[self.data[self.read_pos:(self.read_pos + 3)]]
if absolute_pos >= earliest
]
# Output the data as a copy if we couldn't find a backref
if not possible_backref_locations:
flags |= self.FLAG_COPY << flagpos
chunk = self.data[self.read_pos:(self.read_pos + 1)]
data[flagpos] = chunk
self.__ring_write(chunk)
self.read_pos += 1
self.left -= 1
continue
# Now, find the longest actual backref of our possibilities. We know
# we're going to write at least these three bytes, so append it to the
# output buffer.
start_write_size = self.bytes_written
self.__ring_write(self.data[self.read_pos:(self.read_pos + 3)])
copy_amount = 3
for _ in range(backref_amount - 3):
# Check our existing locations to figure out if we still have
# longest prefixes.
locations = self.locations[self.data[self.read_pos + copy_amount]]
new_backref_locations: List[int] = [
absolute_pos for absolute_pos in possible_backref_locations
if absolute_pos + copy_amount in locations
]
# If we have no longest prefixes, that means that any of the
# previous prefixes are good enough.
if not new_backref_locations:
break
# Mark that we're copying an extra byte from the backref.
self.__ring_write(self.data[(self.read_pos + copy_amount):(self.read_pos + copy_amount + 1)])
copy_amount += 1
possible_backref_locations = new_backref_locations
# Now that we have a list of candidates, arbitrarily pick the
# first one as our candidate and output it.
absolute_pos = possible_backref_locations[0]
backref_pos = start_write_size - absolute_pos
lo = (copy_amount - 3) & 0xF | ((backref_pos & 0xF) << 4)
hi = (backref_pos >> 4) & 0xFF
flags |= self.FLAG_BACKREF << flagpos
data[flagpos] = bytes([hi, lo])
self.read_pos += copy_amount
self.left -= copy_amount
yield bytes([flags]) + b"".join(data)
class Lz77:
"""
A wrapper class encapsulating Lz77 encoding and decoding.
"""
# The point at which we consider it better to trade off smaller data
# sent over the wire for a more computationally expensive compression.
REAL_COMPRESSION_THRESHOLD = 10 * 1024
def __init__(self, backref: Optional[int] = None) -> None:
"""
Initialize the object.
"""
self.backref = backref
def decompress(self, data: bytes) -> bytes:
"""
Given a binary blob, return a new binary blob representing the decompressed data.
Parameters:
data - Lz77-compressed binary data
Returns:
Raw binary data.
"""
lz = Lz77Decompress(data, backref=self.backref)
return b''.join(lz.decompress_bytes())
def compress(self, data: bytes) -> bytes:
"""
Given a binary blob, return a new binary blob representing the compressed data.
Parameters:
data - Raw binary data.
Returns:
L7zz-compressed binary data.
"""
lz = Lz77Compress(data, backref=self.backref)
return b''.join(lz.compress_bytes())