1
0
mirror of synced 2024-11-27 22:40:49 +01:00

Convert TJA/Fumen classes into dataclasses (#57)

This branch was originally used to test out Pydantic. However, after
some profiling, I came to the conclusion that runtime type checking is
overkill for a project like this. So, instead, I'm sticking with static
typing via dataclasses.

Fixes #56.
This commit is contained in:
Viv 2023-07-30 12:56:06 -04:00 committed by GitHub
parent 03c8892243
commit d99c7f5984
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 189 additions and 223 deletions

View File

@ -27,7 +27,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.8.x'
python-version: '3.10.x'
- name: Install tja2fumen and its dev dependencies
run: |

View File

@ -3,7 +3,7 @@ name = "tja2fumen"
version = "0.0.dev0"
description = "Convert TJA chart files into fumen (.bin) chart files"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.10"
license = {file = "LICENSE.txt"}
keywords = ["taiko", "tatsujin", "fumen", "TJA"]
@ -19,7 +19,7 @@ tja2fumen = "tja2fumen:main"
[project.optional-dependencies]
dev = ["pytest", "build", "pyinstaller", "twine", "toml-cli",
"flake8", "pyproject-flake8"]
"flake8", "pyproject-flake8", "pydantic"]
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,6 +1,7 @@
import re
from tja2fumen.types import TJAMeasureProcessed, FumenCourse, FumenNote
from tja2fumen.types import (TJAMeasureProcessed,
FumenCourse, FumenHeader, FumenMeasure, FumenNote)
def process_tja_commands(tja):
@ -47,9 +48,16 @@ def process_tja_commands(tja):
# Handle commands that can only be placed between measures
# (i.e. no mid-measure variations)
elif data.name == 'delay':
measure_tja_processed.delay = data.value * 1000 # ms -> s
measure_tja_processed.delay = float(data.value) * 1000
elif data.name == 'branch_start':
measure_tja_processed.branch_start = data.value
branch_condition = data.value.split(',')
if branch_condition[0] == 'r': # r = drumRoll
branch_condition[1] = int(branch_condition[1])
branch_condition[2] = int(branch_condition[2])
elif branch_condition[0] == 'p': # p = Percentage
branch_condition[1] = float(branch_condition[1]) / 100
branch_condition[2] = float(branch_condition[2]) / 100
measure_tja_processed.branch_start = branch_condition
elif data.name == 'section':
measure_tja_processed.section = data.value
elif data.name == 'levelhold':
@ -76,7 +84,7 @@ def process_tja_commands(tja):
if data.name == 'bpm':
new_val = current_bpm = float(data.value)
elif data.name == 'scroll':
new_val = current_scroll = data.value
new_val = current_scroll = float(data.value)
elif data.name == 'gogo':
new_val = current_gogo = bool(int(data.value))
# Check for mid-measure commands
@ -164,7 +172,8 @@ def convert_tja_to_fumen(tja):
# Pre-allocate the measures for the converted TJA
n_measures = len(tja_branches_processed['normal'])
fumen = FumenCourse(
measures=n_measures,
measures=[FumenMeasure() for _ in range(n_measures)],
header=FumenHeader(),
score_init=tja.score_init,
score_diff=tja.score_diff,
)

View File

@ -3,10 +3,12 @@ import re
import struct
from copy import deepcopy
from tja2fumen.types import (TJASong, TJAMeasure, TJAData, FumenCourse,
FumenMeasure, FumenBranch, FumenNote, FumenHeader)
from tja2fumen.types import (TJASong, TJACourse, TJAMeasure, TJAData,
FumenCourse, FumenMeasure, FumenBranch, FumenNote,
FumenHeader)
from tja2fumen.constants import (NORMALIZE_COURSE, COURSE_NAMES, BRANCH_NAMES,
TJA_NOTE_TYPES, FUMEN_NOTE_TYPES)
TJA_COURSE_NAMES, TJA_NOTE_TYPES,
FUMEN_NOTE_TYPES)
###############################################################################
# TJA-parsing functions #
@ -59,11 +61,16 @@ def split_tja_lines_into_courses(lines):
if line.split("//")[0].strip()]
# Initialize song with BPM and OFFSET global metadata
bpm = [line.split(":")[1] for line in lines
if line.startswith("BPM")][0]
offset = [line.split(":")[1] for line in lines
if line.startswith("OFFSET")][0]
parsed_tja = TJASong(bpm, offset)
bpm = float([line.split(":")[1] for line in lines
if line.startswith("BPM")][0])
offset = float([line.split(":")[1] for line in lines
if line.startswith("OFFSET")][0])
parsed_tja = TJASong(
BPM=bpm,
offset=offset,
courses={course: TJACourse(BPM=bpm, offset=offset, course=course)
for course in TJA_COURSE_NAMES}
)
current_course = ''
current_course_basename = ''
@ -174,17 +181,18 @@ def parse_tja_course_data(course):
"""
has_branches = bool([d for d in course.data if d.startswith('#BRANCH')])
current_branch = 'all' if has_branches else 'normal'
branch_condition = None
branch_condition = ''
# Process course lines
idx_m = 0
idx_m_branchstart = 0
for idx_l, line in enumerate(course.data):
# 0. Check to see whether line is a command or note data
command, value, notes = None, None, None
command, name, value, notes = None, None, None, None
match_command = re.match(r"^#([A-Z]+)(?:\s+(.+))?", line)
if match_command:
command, value = match_command.groups()
value = '' if value is None else value
else:
notes = line # If not a command, then line must be note data
@ -217,54 +225,49 @@ def parse_tja_course_data(course):
# Parse event type
if command == 'GOGOSTART':
current_event = TJAData('gogo', '1', pos)
name, value = 'gogo', '1'
elif command == 'GOGOEND':
current_event = TJAData('gogo', '0', pos)
name, value = 'gogo', '0'
elif command == 'BARLINEON':
current_event = TJAData('barline', '1', pos)
name, value = 'barline', '1'
elif command == 'BARLINEOFF':
current_event = TJAData('barline', '0', pos)
name, value = 'barline', '0'
elif command == 'DELAY':
current_event = TJAData('delay', float(value), pos)
name = 'delay'
elif command == 'SCROLL':
current_event = TJAData('scroll', float(value), pos)
name = 'scroll'
elif command == 'BPMCHANGE':
current_event = TJAData('bpm', float(value), pos)
name = 'bpm'
elif command == 'MEASURE':
current_event = TJAData('measure', value, pos)
name = 'measure'
elif command == 'LEVELHOLD':
current_event = TJAData('levelhold', None, pos)
name = 'levelhold'
elif command == 'SECTION':
# If #SECTION occurs before a #BRANCHSTART, then ensure that
# it's present on every branch. Otherwise, #SECTION will only
# be present on the current branch, and so the `branch_info`
# values won't be correctly set for the other two branches.
if course.data[idx_l+1].startswith('#BRANCHSTART'):
current_event = TJAData('section', None, pos)
name = 'section'
current_branch = 'all'
# Otherwise, #SECTION exists in isolation. In this case, to
# reset the accuracy, we just repeat the previous #BRANCHSTART.
else:
current_event = TJAData('branch_start', branch_condition,
pos)
name, value = 'branch_start', branch_condition
elif command == 'BRANCHSTART':
# Ensure that the #BRANCHSTART command is added to all branches
current_branch = 'all'
branch_condition = value.split(',')
if branch_condition[0] == 'r': # r = drumRoll
branch_condition[1] = int(branch_condition[1]) # drumrolls
branch_condition[2] = int(branch_condition[2]) # drumrolls
elif branch_condition[0] == 'p': # p = Percentage
branch_condition[1] = float(branch_condition[1]) / 100 # %
branch_condition[2] = float(branch_condition[2]) / 100 # %
current_event = TJAData('branch_start', branch_condition, pos)
name = 'branch_start'
branch_condition = value
# Preserve the index of the BRANCHSTART command to re-use
idx_m_branchstart = idx_m
# Append event to the current measure's events
for branch in (course.branches.keys() if current_branch == 'all'
else [current_branch]):
course.branches[branch][idx_m].events.append(current_event)
course.branches[branch][idx_m].events.append(
TJAData(name=name, value=value, pos=pos)
)
# 3. Parse commands that don't create an event
# (e.g. simply changing the current branch)
@ -295,7 +298,7 @@ def parse_tja_course_data(course):
# Merge measure data and measure events in chronological order
for branch_name, branch in course.branches.items():
for measure in branch:
notes = [TJAData('note', TJA_NOTE_TYPES[note], i)
notes = [TJAData(name='note', value=TJA_NOTE_TYPES[note], pos=i)
for i, note in enumerate(measure.notes) if
TJA_NOTE_TYPES[note] != 'Blank']
events = measure.events
@ -353,9 +356,9 @@ def parse_fumen(fumen_file, exclude_empty_measures=False):
file = open(fumen_file, "rb")
size = os.fstat(file.fileno()).st_size
song = FumenCourse(
header=FumenHeader(raw_bytes=file.read(520))
)
header = FumenHeader()
header.parse_header_values(file.read(520))
song = FumenCourse(header=header)
for measure_number in range(song.header.b512_b515_number_of_measures):
# Parse the measure data using the following `format_string`:
@ -374,8 +377,8 @@ def parse_fumen(fumen_file, exclude_empty_measures=False):
measure = FumenMeasure(
bpm=measure_struct[0],
offset_start=measure_struct[1],
gogo=measure_struct[2],
barline=measure_struct[3],
gogo=bool(measure_struct[2]),
barline=bool(measure_struct[3]),
padding1=measure_struct[4],
branch_info=list(measure_struct[5:11]),
padding2=measure_struct[11]

View File

@ -1,65 +1,56 @@
import csv
import os
import struct
from typing import Dict, List
from tja2fumen.constants import TJA_COURSE_NAMES, BRANCH_NAMES
from dataclasses import dataclass, field
from tja2fumen.constants import BRANCH_NAMES
class DefaultObject:
"""Set default methods for all TJA/Fumen classes."""
def __repr__(self):
return str(self.__dict__)
@dataclass(slots=True)
class TJAData:
"""Contains the information for a single note or single command."""
name: str
value: str
# For TJA, 'pos' is stored as an integer rather than in milliseconds
pos: int
class TJASong(DefaultObject):
"""Contains all the data in a single TJA (`.tja`) chart file."""
def __init__(self, BPM=None, offset=None):
# Note: TJA song metadata (e.g. TITLE, SUBTITLE, WAVE) is not stored
# because it is not needed to convert a `.tja` to `.bin` files.
self.BPM = float(BPM)
self.offset = float(offset)
self.courses = {course: TJACourse(self.BPM, self.offset, course)
for course in TJA_COURSE_NAMES}
def __repr__(self):
# Show truncated version of courses to avoid long representation
return (f"{{'BPM': {self.BPM}, 'offset': {self.offset}, "
f"'courses': {list(self.courses.keys())}}}")
class TJACourse(DefaultObject):
"""Contains all the data in a single TJA `COURSE:` section."""
def __init__(self, BPM, offset, course, level=0, balloon=None,
score_init=0, score_diff=0):
self.level = level
self.balloon = [] if balloon is None else balloon
self.score_init = score_init
self.score_diff = score_diff
self.BPM = BPM
self.offset = offset
self.course = course
self.data = []
# A "TJA Branch" is just a list of measures
self.branches = {
'normal': [TJAMeasure()],
'professional': [TJAMeasure()],
'master': [TJAMeasure()]
}
def __repr__(self):
# Don't show default fields if the course contains no data
return str(self.__dict__) if self.data else "{'data': []}"
class TJAMeasure(DefaultObject):
@dataclass(slots=True)
class TJAMeasure:
"""Contains all the data in a single TJA measure (denoted by ',')."""
def __init__(self, notes=None, events=None):
self.notes = [] if notes is None else notes
self.events = [] if events is None else events
self.combined = []
notes: List[TJAData] = field(default_factory=list)
events: List[TJAData] = field(default_factory=list)
combined: List[TJAData] = field(default_factory=list)
class TJAMeasureProcessed(DefaultObject):
@dataclass(slots=True)
class TJACourse:
"""Contains all the data in a single TJA `COURSE:` section."""
BPM: float
offset: float
course: str
level: int = 0
balloon: list = field(default_factory=list)
score_init: int = 0
score_diff: int = 0
data: list = field(default_factory=list)
branches: Dict[str, List[TJAMeasure]] = field(
default_factory=lambda: {k: [TJAMeasure()] for k in BRANCH_NAMES}
)
@dataclass(slots=True)
class TJASong:
"""Contains all the data in a single TJA (`.tja`) chart file."""
BPM: float
offset: float
courses: Dict[str, TJACourse]
@dataclass(slots=True)
class TJAMeasureProcessed:
"""
Contains all the data in a single TJA measure (denoted by ','), but with
all `#COMMAND` lines processed, and their values stored as attributes.
@ -69,61 +60,62 @@ class TJAMeasureProcessed(DefaultObject):
number of `TJAMeasureProcessed` objects will often be greater than
the number of `TJAMeasure` objects for a given song.))
"""
def __init__(self, bpm, scroll, gogo, barline, time_sig, subdivisions,
pos_start=0, pos_end=0, delay=0, levelhold=False,
section=None, branch_start=None, data=None):
self.bpm = bpm
self.scroll = scroll
self.gogo = gogo
self.barline = barline
self.time_sig = time_sig
self.subdivisions = subdivisions
self.pos_start = pos_start
self.pos_end = pos_end
self.delay = delay
self.section = section
self.levelhold = levelhold
self.branch_start = branch_start
self.data = [] if data is None else data
bpm: float
scroll: float
gogo: bool
barline: bool
time_sig: List[int]
subdivisions: int
pos_start: int = 0
pos_end: int = 0
delay: float = 0.0
section: bool = False
levelhold: bool = False
branch_start: List = field(default_factory=list)
data: list = field(default_factory=list)
class TJAData(DefaultObject):
"""Contains the information for a single note or single command."""
def __init__(self, name, value, pos=None):
# For TJA, 'pos' is stored as an integer rather than in milliseconds
self.pos = pos
self.name = name
self.value = value
@dataclass(slots=True)
class FumenNote:
"""Contains all the byte values for a single Fumen note."""
note_type: str = ''
pos: float = 0.0
score_init: int = 0
score_diff: int = 0
padding: int = 0
item: int = 0
duration: float = 0.0
multimeasure: bool = False
hits: int = 0
hits_padding: int = 0
drumroll_bytes: bytes = b'\x00\x00\x00\x00\x00\x00\x00\x00'
class FumenCourse(DefaultObject):
"""Contains all the data in a single Fumen (`.bin`) chart file."""
def __init__(self, measures=None, header=None, score_init=0, score_diff=0):
if isinstance(measures, int):
self.measures = [FumenMeasure() for _ in range(measures)]
else:
self.measures = [] if measures is None else measures
self.header = FumenHeader() if header is None else header
self.score_init = score_init
self.score_diff = score_diff
@dataclass(slots=True)
class FumenBranch:
"""Contains all the data in a single Fumen branch."""
length: int = 0
speed: float = 0.0
padding: int = 0
notes: list = field(default_factory=list)
class FumenMeasure(DefaultObject):
@dataclass(slots=True)
class FumenMeasure:
"""Contains all the data in a single Fumen measure."""
def __init__(self, bpm=0.0, offset_start=0.0, offset_end=0.0,
duration=0.0, gogo=False, barline=True, branch_start=None,
branch_info=None, padding1=0, padding2=0):
self.bpm = bpm
self.offset_start = offset_start
self.offset_end = offset_end
self.duration = duration
self.gogo = gogo
self.barline = barline
self.branch_start = branch_start
self.branch_info = [-1] * 6 if branch_info is None else branch_info
self.branches = {b: FumenBranch() for b in BRANCH_NAMES}
self.padding1 = padding1
self.padding2 = padding2
bpm: float = 0.0
offset_start: float = 0.0
offset_end: float = 0.0
duration: float = 0.0
gogo: bool = False
barline: bool = True
branch_start: list = field(default_factory=list)
branch_info: List[int] = field(default_factory=lambda: [-1] * 6)
branches: Dict[str, FumenBranch] = field(
default_factory=lambda: {b: FumenBranch() for b in BRANCH_NAMES}
)
padding1: int = 0
padding2: int = 0
def set_duration(self, time_sig, measure_length, subdivisions):
"""Compute the millisecond duration of the measure."""
@ -212,79 +204,38 @@ class FumenMeasure(DefaultObject):
self.branch_info[4:6] = branch_condition[1:]
class FumenBranch(DefaultObject):
"""Contains all the data in a single Fumen branch."""
def __init__(self, length=0, speed=0.0, padding=0):
self.length = length
self.speed = speed
self.padding = padding
self.notes = []
class FumenNote(DefaultObject):
"""Contains all the byte values for a single Fumen note."""
def __init__(self, note_type='', pos=0.0, score_init=0, score_diff=0,
padding=0, item=0, duration=0.0, multimeasure=False,
hits=0, hits_padding=0,
drumroll_bytes=b'\x00\x00\x00\x00\x00\x00\x00\x00'):
self.note_type = note_type
self.pos = pos
self.score_init = score_init
self.score_diff = score_diff
self.padding = padding
# TODO: Determine how to properly set the item byte
# (https://github.com/vivaria/tja2fumen/issues/17)
self.item = item
# These attributes are only used for drumrolls/balloons
self.duration = duration
self.multimeasure = multimeasure
self.hits = hits
self.hits_padding = hits_padding
self.drumroll_bytes = drumroll_bytes
class FumenHeader(DefaultObject):
@dataclass(slots=True)
class FumenHeader:
"""Contains all the byte values for a Fumen chart file's header."""
def __init__(self, raw_bytes=None):
if raw_bytes is None:
self.order = "<"
self._assign_default_header_values()
else:
self.order = self._parse_order(raw_bytes)
self._parse_header_values(raw_bytes)
order: str = "<"
b000_b431_timing_windows: List[float] = field(default_factory=lambda:
[25.025, 75.075, 108.422]*36)
b432_b435_has_branches: int = 0
b436_b439_hp_max: int = 10000
b440_b443_hp_clear: int = 8000
b444_b447_hp_gain_good: int = 10
b448_b451_hp_gain_ok: int = 5
b452_b455_hp_loss_bad: int = -20
b456_b459_normal_normal_ratio: int = 65536
b460_b463_normal_professional_ratio: int = 65536
b464_b467_normal_master_ratio: int = 65536
b468_b471_branch_points_good: int = 20
b472_b475_branch_points_ok: int = 10
b476_b479_branch_points_bad: int = 0
b480_b483_branch_points_drumroll: int = 1
b484_b487_branch_points_good_big: int = 20
b488_b491_branch_points_ok_big: int = 10
b492_b495_branch_points_drumroll_big: int = 1
b496_b499_branch_points_balloon: int = 30
b500_b503_branch_points_kusudama: int = 30
b504_b507_branch_points_unknown: int = 20
b508_b511_dummy_data: int = 12345678
b512_b515_number_of_measures: int = 0
b516_b519_unknown_data: int = 0
def _assign_default_header_values(self):
"""Set the default header values."""
# This byte string corresponds to the timing windows for Hard/Oni
# ((When these bytes are parsed, you get roughly about
# (25.025, 75.075, 108.442), but repeated 36 times.))
timing_windows = self.up(b'43\xc8Ag&\x96B"\xe2\xd8B' * 36, "fff" * 36)
self.b000_b431_timing_windows = timing_windows
self.b432_b435_has_branches = 0
self.b436_b439_hp_max = 10000
self.b440_b443_hp_clear = 8000
self.b444_b447_hp_gain_good = 10
self.b448_b451_hp_gain_ok = 5
self.b452_b455_hp_loss_bad = -20
self.b456_b459_normal_normal_ratio = 65536
self.b460_b463_normal_professional_ratio = 65536
self.b464_b467_normal_master_ratio = 65536
self.b468_b471_branch_points_good = 20
self.b472_b475_branch_points_ok = 10
self.b476_b479_branch_points_bad = 0
self.b480_b483_branch_points_drumroll = 1
self.b484_b487_branch_points_good_big = 20
self.b488_b491_branch_points_ok_big = 10
self.b492_b495_branch_points_drumroll_big = 1
self.b496_b499_branch_points_balloon = 30
self.b500_b503_branch_points_kusudama = 30
self.b504_b507_branch_points_unknown = 20
self.b508_b511_dummy_data = 12345678
self.b512_b515_number_of_measures = 0
self.b516_b519_unknown_data = 0
def _parse_header_values(self, raw_bytes):
def parse_header_values(self, raw_bytes):
"""Parse a raw string of 520 bytes to get the header values."""
self.order = self._parse_order(raw_bytes)
rb = raw_bytes # We use a shortened form just for visual clarity:
self.b000_b431_timing_windows = self.up(rb, "f"*108, 0, 431)
self.b432_b435_has_branches = self.up(rb, "i", 432, 435)
@ -365,21 +316,24 @@ class FumenHeader(DefaultObject):
"""Represent the header values as a string of raw bytes."""
value_list = []
format_string = self.order
for key, val in self.__dict__.items():
for key in self.__slots__:
if key in ["order", "_raw_bytes"]:
pass
elif key == "b000_b431_timing_windows":
value_list.extend(list(val))
format_string += "f" * len(val)
value_list.extend(list(getattr(self, key)))
format_string += "f" * len(getattr(self, key))
else:
value_list.append(val)
value_list.append(getattr(self, key))
format_string += "i"
raw_bytes = struct.pack(format_string, *value_list)
assert len(raw_bytes) == 520
return raw_bytes
def __repr__(self):
# Display truncated version of timing windows
return str([v if not isinstance(v, tuple)
else [round(timing, 2) for timing in v[:3]]
for v in self.__dict__.values()])
@dataclass(slots=True)
class FumenCourse:
"""Contains all the data in a single Fumen (`.bin`) chart file."""
header: FumenHeader
measures: List[FumenMeasure] = field(default_factory=list)
score_init: int = 0
score_diff: int = 0