Make project mypy
-compliant (and add mypy
check to test suite) (#58)
This PR adopts static type hints as per `mypy` recommendations. This will hopefully clean up the code a little and squash a few type-related bugs. Additionally, this will hopefully open up the door to compiling the project using `mypyc`, which may score us some performance gains without needing to rewrite the project in an entirely new language.
This commit is contained in:
parent
d99c7f5984
commit
5229db4aab
32
.github/workflows/check_code_quality.yml
vendored
Normal file
32
.github/workflows/check_code_quality.yml
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
#file: noinspection LongLine
|
||||
name: "Check code quality"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
check-code-quality:
|
||||
runs-on: windows-2019
|
||||
steps:
|
||||
- name: Checkout tja2fumen
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10.x'
|
||||
|
||||
- name: Install tja2fumen and its dev dependencies
|
||||
run: |
|
||||
pip install -e .[dev]
|
||||
|
||||
- name: Lint project using flake8
|
||||
run: pflake8
|
||||
|
||||
- name: Type analysis using mypy
|
||||
run: mypy src --strict
|
@ -33,9 +33,6 @@ jobs:
|
||||
run: |
|
||||
pip install -e .[dev]
|
||||
|
||||
- name: Lint project
|
||||
run: pflake8
|
||||
|
||||
- name: Run tests (Python API)
|
||||
run: |
|
||||
pytest testing --entry-point python-api
|
||||
|
@ -19,7 +19,7 @@ tja2fumen = "tja2fumen:main"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest", "build", "pyinstaller", "twine", "toml-cli",
|
||||
"flake8", "pyproject-flake8", "pydantic"]
|
||||
"flake8", "pyproject-flake8", "mypy"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
@ -1,14 +1,16 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from typing import Sequence
|
||||
|
||||
from tja2fumen.parsers import parse_tja
|
||||
from tja2fumen.converters import convert_tja_to_fumen
|
||||
from tja2fumen.writers import write_fumen
|
||||
from tja2fumen.constants import COURSE_IDS
|
||||
from tja2fumen.types import TJACourse
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
def main(argv: Sequence[str] = ()) -> None:
|
||||
"""
|
||||
Main entry point for tja2fumen's command line interface.
|
||||
|
||||
@ -35,14 +37,16 @@ def main(argv=None):
|
||||
parsed_tja = parse_tja(fname_tja)
|
||||
|
||||
# Convert parsed TJA courses and write each course to `.bin` files
|
||||
for course in parsed_tja.courses.items():
|
||||
convert_and_write(course, base_name,
|
||||
for course_name, course in parsed_tja.courses.items():
|
||||
convert_and_write(course, course_name, base_name,
|
||||
single_course=(len(parsed_tja.courses) == 1))
|
||||
|
||||
|
||||
def convert_and_write(parsed_course, base_name, single_course=False):
|
||||
def convert_and_write(tja_data: TJACourse,
|
||||
course_name: str,
|
||||
base_name: str,
|
||||
single_course: bool = False) -> None:
|
||||
"""Process the parsed data for a single TJA course."""
|
||||
course_name, tja_data = parsed_course
|
||||
fumen_data = convert_tja_to_fumen(tja_data)
|
||||
# Add course ID (e.g. '_x', '_x_1', '_x_2') to the output file's base name
|
||||
output_name = base_name
|
||||
|
@ -1,10 +1,11 @@
|
||||
import re
|
||||
|
||||
from tja2fumen.types import (TJAMeasureProcessed,
|
||||
from tja2fumen.types import (TJACourse, TJAMeasureProcessed,
|
||||
FumenCourse, FumenHeader, FumenMeasure, FumenNote)
|
||||
|
||||
|
||||
def process_tja_commands(tja):
|
||||
def process_tja_commands(tja: TJACourse) \
|
||||
-> dict[str, list[TJAMeasureProcessed]]:
|
||||
"""
|
||||
Process each #COMMAND present in a TJASong's measures, and assign their
|
||||
values as attributes to each measure.
|
||||
@ -22,8 +23,9 @@ def process_tja_commands(tja):
|
||||
After this function is finished, all the #COMMANDS will be gone, and each
|
||||
measure will have attributes (e.g. measure.bpm, measure.scroll) instead.
|
||||
"""
|
||||
tja_branches_processed = {branch_name: []
|
||||
for branch_name in tja.branches.keys()}
|
||||
tja_branches_processed: dict[str, list[TJAMeasureProcessed]] = {
|
||||
branch_name: [] for branch_name in tja.branches.keys()
|
||||
}
|
||||
for branch_name, branch_measures_tja in tja.branches.items():
|
||||
current_bpm = tja.BPM
|
||||
current_scroll = 1.0
|
||||
@ -50,16 +52,15 @@ def process_tja_commands(tja):
|
||||
elif data.name == 'delay':
|
||||
measure_tja_processed.delay = float(data.value) * 1000
|
||||
elif data.name == 'branch_start':
|
||||
branch_condition = data.value.split(',')
|
||||
if branch_condition[0] == 'r': # r = drumRoll
|
||||
branch_condition[1] = int(branch_condition[1])
|
||||
branch_condition[2] = int(branch_condition[2])
|
||||
elif branch_condition[0] == 'p': # p = Percentage
|
||||
branch_condition[1] = float(branch_condition[1]) / 100
|
||||
branch_condition[2] = float(branch_condition[2]) / 100
|
||||
measure_tja_processed.branch_start = branch_condition
|
||||
branch_type, val1, val2 = data.value.split(',')
|
||||
if branch_type == 'r': # r = drumRoll
|
||||
branch_cond = (float(val1), float(val2))
|
||||
elif branch_type == 'p': # p = Percentage
|
||||
branch_cond = (float(val1)/100, float(val2)/100)
|
||||
measure_tja_processed.branch_type = branch_type
|
||||
measure_tja_processed.branch_cond = branch_cond
|
||||
elif data.name == 'section':
|
||||
measure_tja_processed.section = data.value
|
||||
measure_tja_processed.section = bool(data.value)
|
||||
elif data.name == 'levelhold':
|
||||
measure_tja_processed.levelhold = True
|
||||
elif data.name == 'barline':
|
||||
@ -135,7 +136,7 @@ def process_tja_commands(tja):
|
||||
return tja_branches_processed
|
||||
|
||||
|
||||
def convert_tja_to_fumen(tja):
|
||||
def convert_tja_to_fumen(tja: TJACourse) -> FumenCourse:
|
||||
"""
|
||||
Convert TJA data to Fumen data by calculating Fumen-specific values.
|
||||
|
||||
@ -193,7 +194,8 @@ def convert_tja_to_fumen(tja):
|
||||
branch_points_measure = 0
|
||||
current_drumroll = None
|
||||
current_levelhold = False
|
||||
branch_conditions = []
|
||||
branch_types: list[str] = []
|
||||
branch_conditions: list[tuple[float, float]] = []
|
||||
course_balloons = tja.balloon.copy()
|
||||
|
||||
# Iterate through the measures within the branch
|
||||
@ -215,12 +217,13 @@ def convert_tja_to_fumen(tja):
|
||||
)
|
||||
|
||||
# Compute the millisecond offsets for the start/end of each measure
|
||||
measure_fumen.set_ms_offsets(
|
||||
song_offset=tja.offset,
|
||||
delay=measure_tja.delay,
|
||||
prev_measure=(fumen.measures[idx_m-1] if idx_m else None),
|
||||
first_measure=(idx_m == 0)
|
||||
)
|
||||
if idx_m == 0:
|
||||
measure_fumen.set_first_ms_offsets(song_offset=tja.offset)
|
||||
else:
|
||||
measure_fumen.set_ms_offsets(
|
||||
delay=measure_tja.delay,
|
||||
prev_measure=fumen.measures[idx_m-1],
|
||||
)
|
||||
|
||||
# Handle whether barline should be hidden:
|
||||
# 1. Measures where #BARLINEOFF has been set
|
||||
@ -232,11 +235,13 @@ def convert_tja_to_fumen(tja):
|
||||
measure_fumen.barline = False
|
||||
|
||||
# Check to see if the measure contains a branching condition
|
||||
branch_condition = measure_tja.branch_start
|
||||
if branch_condition:
|
||||
branch_type = measure_tja.branch_type
|
||||
branch_cond = measure_tja.branch_cond
|
||||
if branch_type and branch_cond:
|
||||
# Update the branch_info values for the measure
|
||||
measure_fumen.set_branch_info(
|
||||
branch_condition, branch_points_total, current_branch,
|
||||
branch_type, branch_cond,
|
||||
branch_points_total, current_branch,
|
||||
first_branch_condition=(not branch_conditions),
|
||||
has_section=bool(measure_tja.section),
|
||||
has_levelhold=current_levelhold
|
||||
@ -248,7 +253,8 @@ def convert_tja_to_fumen(tja):
|
||||
current_levelhold = False
|
||||
# Keep track of the branch conditions (to later determine how
|
||||
# to set the header bytes for branches)
|
||||
branch_conditions.append(branch_condition)
|
||||
branch_types.append(branch_type)
|
||||
branch_conditions.append(branch_cond)
|
||||
|
||||
# NB: We update the branch condition note counter *after*
|
||||
# we check the current measure's branch condition.
|
||||
@ -276,6 +282,11 @@ def convert_tja_to_fumen(tja):
|
||||
|
||||
# Handle '8' notes (end of a drumroll/balloon)
|
||||
if data.value == "EndDRB":
|
||||
if not isinstance(current_drumroll, FumenNote):
|
||||
raise ValueError(
|
||||
"'8' note encountered without matching "
|
||||
"drumroll/balloon/kusudama note."
|
||||
)
|
||||
# If a drumroll spans a single measure, then add the
|
||||
# difference between start/end position
|
||||
if not current_drumroll.multimeasure:
|
||||
@ -311,7 +322,7 @@ def convert_tja_to_fumen(tja):
|
||||
note.hits = course_balloons.pop(0)
|
||||
except IndexError:
|
||||
raise ValueError(f"Not enough values for 'BALLOON: "
|
||||
f"{','.join(course_balloons)}'")
|
||||
f"{course_balloons}'")
|
||||
current_drumroll = note
|
||||
elif note.note_type in ["Drumroll", "DRUMROLL"]:
|
||||
current_drumroll = note
|
||||
@ -355,11 +366,11 @@ def convert_tja_to_fumen(tja):
|
||||
# If song has only drumroll branching conditions (also allowing percentage
|
||||
# conditions that force a level up/level down), then set the header bytes
|
||||
# so that only drumrolls contribute to branching.
|
||||
drumroll_only = branch_conditions != [] and all([
|
||||
(cond[0] == 'r') or
|
||||
(cond[0] == 'p' and cond[1] == 0.0 and cond[2] == 0.0) or
|
||||
(cond[0] == 'p' and cond[1] > 1.00 and cond[2] > 1.00)
|
||||
for cond in branch_conditions
|
||||
drumroll_only = branch_types != [] and branch_conditions != [] and all([
|
||||
(branch_type == 'r') or
|
||||
(branch_type == 'p' and cond[0] == 0.0 and cond[1] == 0.0) or
|
||||
(branch_type == 'p' and cond[0] > 1.00 and cond[1] > 1.00)
|
||||
for branch_type, cond in zip(branch_types, branch_conditions)
|
||||
])
|
||||
if drumroll_only:
|
||||
fumen.header.b468_b471_branch_points_good = 0
|
||||
@ -371,9 +382,9 @@ def convert_tja_to_fumen(tja):
|
||||
|
||||
# Alternatively, if the song has only percentage-based conditions, then set
|
||||
# the header bytes so that only notes and balloons contribute to branching.
|
||||
percentage_only = branch_conditions != [] and all([
|
||||
(condition[0] != 'r')
|
||||
for condition in branch_conditions
|
||||
percentage_only = branch_types != [] and all([
|
||||
(branch_type != 'r')
|
||||
for branch_type in branch_types
|
||||
])
|
||||
if percentage_only:
|
||||
fumen.header.b480_b483_branch_points_drumroll = 0
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
import re
|
||||
import struct
|
||||
from copy import deepcopy
|
||||
from typing import BinaryIO, Any
|
||||
|
||||
from tja2fumen.types import (TJASong, TJACourse, TJAMeasure, TJAData,
|
||||
FumenCourse, FumenMeasure, FumenBranch, FumenNote,
|
||||
@ -15,7 +16,7 @@ from tja2fumen.constants import (NORMALIZE_COURSE, COURSE_NAMES, BRANCH_NAMES,
|
||||
###############################################################################
|
||||
|
||||
|
||||
def parse_tja(fname_tja):
|
||||
def parse_tja(fname_tja: str) -> TJASong:
|
||||
"""Read in lines of a .tja file and load them into a TJASong object."""
|
||||
try:
|
||||
tja_text = open(fname_tja, "r", encoding="utf-8-sig").read()
|
||||
@ -30,7 +31,7 @@ def parse_tja(fname_tja):
|
||||
return tja
|
||||
|
||||
|
||||
def split_tja_lines_into_courses(lines):
|
||||
def split_tja_lines_into_courses(lines: list[str]) -> TJASong:
|
||||
"""
|
||||
Parse TJA metadata in order to split TJA lines into separate courses.
|
||||
|
||||
@ -152,7 +153,7 @@ def split_tja_lines_into_courses(lines):
|
||||
return parsed_tja
|
||||
|
||||
|
||||
def parse_tja_course_data(course):
|
||||
def parse_tja_course_data(course: TJACourse) -> None:
|
||||
"""
|
||||
Parse course data (notes, commands) into a nested song structure.
|
||||
|
||||
@ -188,40 +189,41 @@ def parse_tja_course_data(course):
|
||||
idx_m_branchstart = 0
|
||||
for idx_l, line in enumerate(course.data):
|
||||
# 0. Check to see whether line is a command or note data
|
||||
command, name, value, notes = None, None, None, None
|
||||
command, name, value, note_data = '', '', '', ''
|
||||
match_command = re.match(r"^#([A-Z]+)(?:\s+(.+))?", line)
|
||||
if match_command:
|
||||
command, value = match_command.groups()
|
||||
value = '' if value is None else value
|
||||
else:
|
||||
notes = line # If not a command, then line must be note data
|
||||
note_data = line # If not a command, then line must be note data
|
||||
|
||||
# 1. Parse measure notes
|
||||
if notes:
|
||||
if note_data:
|
||||
# If measure has ended, then add notes to the current measure,
|
||||
# then start a new measure by incrementing idx_m
|
||||
if notes.endswith(','):
|
||||
for branch in (course.branches.keys()
|
||||
if current_branch == 'all'
|
||||
else [current_branch]):
|
||||
course.branches[branch][idx_m].notes += notes[0:-1]
|
||||
course.branches[branch].append(TJAMeasure())
|
||||
if note_data.endswith(','):
|
||||
for branch_name in (course.branches.keys()
|
||||
if current_branch == 'all'
|
||||
else [current_branch]):
|
||||
course.branches[branch_name][idx_m].notes += note_data[:-1]
|
||||
course.branches[branch_name].append(TJAMeasure())
|
||||
idx_m += 1
|
||||
# Otherwise, keep adding notes to the current measure ('idx_m')
|
||||
else:
|
||||
for branch in (course.branches.keys()
|
||||
if current_branch == 'all'
|
||||
else [current_branch]):
|
||||
course.branches[branch][idx_m].notes += notes
|
||||
for branch_name in (course.branches.keys()
|
||||
if current_branch == 'all'
|
||||
else [current_branch]):
|
||||
course.branches[branch_name][idx_m].notes += note_data
|
||||
|
||||
# 2. Parse measure commands that produce an "event"
|
||||
elif command in ['GOGOSTART', 'GOGOEND', 'BARLINEON', 'BARLINEOFF',
|
||||
'DELAY', 'SCROLL', 'BPMCHANGE', 'MEASURE',
|
||||
'LEVELHOLD', 'SECTION', 'BRANCHSTART']:
|
||||
# Get position of the event
|
||||
for branch in (course.branches.keys() if current_branch == 'all'
|
||||
else [current_branch]):
|
||||
pos = len(course.branches[branch][idx_m].notes)
|
||||
for branch_name in (course.branches.keys()
|
||||
if current_branch == 'all'
|
||||
else [current_branch]):
|
||||
pos = len(course.branches[branch_name][idx_m].notes)
|
||||
|
||||
# Parse event type
|
||||
if command == 'GOGOSTART':
|
||||
@ -263,9 +265,10 @@ def parse_tja_course_data(course):
|
||||
idx_m_branchstart = idx_m
|
||||
|
||||
# Append event to the current measure's events
|
||||
for branch in (course.branches.keys() if current_branch == 'all'
|
||||
else [current_branch]):
|
||||
course.branches[branch][idx_m].events.append(
|
||||
for branch_name in (course.branches.keys()
|
||||
if current_branch == 'all'
|
||||
else [current_branch]):
|
||||
course.branches[branch_name][idx_m].events.append(
|
||||
TJAData(name=name, value=value, pos=pos)
|
||||
)
|
||||
|
||||
@ -328,7 +331,8 @@ def parse_tja_course_data(course):
|
||||
# Fumen-parsing functions #
|
||||
###############################################################################
|
||||
|
||||
def parse_fumen(fumen_file, exclude_empty_measures=False):
|
||||
def parse_fumen(fumen_file: str,
|
||||
exclude_empty_measures: bool = False) -> FumenCourse:
|
||||
"""
|
||||
Parse bytes of a fumen .bin file into nested measures, branches, and notes.
|
||||
|
||||
@ -469,7 +473,10 @@ def parse_fumen(fumen_file, exclude_empty_measures=False):
|
||||
return song
|
||||
|
||||
|
||||
def read_struct(file, order, format_string, seek=None):
|
||||
def read_struct(file: BinaryIO,
|
||||
order: str,
|
||||
format_string: str,
|
||||
seek: int = 0) -> tuple[Any, ...]:
|
||||
"""
|
||||
Interpret bytes as packed binary data.
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
import csv
|
||||
import os
|
||||
import struct
|
||||
from typing import Dict, List
|
||||
from typing import Any, Optional
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, fields
|
||||
|
||||
from tja2fumen.constants import BRANCH_NAMES
|
||||
|
||||
@ -20,9 +20,9 @@ class TJAData:
|
||||
@dataclass(slots=True)
|
||||
class TJAMeasure:
|
||||
"""Contains all the data in a single TJA measure (denoted by ',')."""
|
||||
notes: List[TJAData] = field(default_factory=list)
|
||||
events: List[TJAData] = field(default_factory=list)
|
||||
combined: List[TJAData] = field(default_factory=list)
|
||||
notes: list[str] = field(default_factory=list)
|
||||
events: list[TJAData] = field(default_factory=list)
|
||||
combined: list[TJAData] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -32,11 +32,11 @@ class TJACourse:
|
||||
offset: float
|
||||
course: str
|
||||
level: int = 0
|
||||
balloon: list = field(default_factory=list)
|
||||
balloon: list[int] = field(default_factory=list)
|
||||
score_init: int = 0
|
||||
score_diff: int = 0
|
||||
data: list = field(default_factory=list)
|
||||
branches: Dict[str, List[TJAMeasure]] = field(
|
||||
data: list[str] = field(default_factory=list)
|
||||
branches: dict[str, list[TJAMeasure]] = field(
|
||||
default_factory=lambda: {k: [TJAMeasure()] for k in BRANCH_NAMES}
|
||||
)
|
||||
|
||||
@ -46,7 +46,7 @@ class TJASong:
|
||||
"""Contains all the data in a single TJA (`.tja`) chart file."""
|
||||
BPM: float
|
||||
offset: float
|
||||
courses: Dict[str, TJACourse]
|
||||
courses: dict[str, TJACourse]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -64,15 +64,16 @@ class TJAMeasureProcessed:
|
||||
scroll: float
|
||||
gogo: bool
|
||||
barline: bool
|
||||
time_sig: List[int]
|
||||
time_sig: list[int]
|
||||
subdivisions: int
|
||||
pos_start: int = 0
|
||||
pos_end: int = 0
|
||||
delay: float = 0.0
|
||||
section: bool = False
|
||||
levelhold: bool = False
|
||||
branch_start: List = field(default_factory=list)
|
||||
data: list = field(default_factory=list)
|
||||
branch_type: str = ''
|
||||
branch_cond: tuple[float, float] = (0.0, 0.0)
|
||||
data: list[TJAData] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -97,7 +98,7 @@ class FumenBranch:
|
||||
length: int = 0
|
||||
speed: float = 0.0
|
||||
padding: int = 0
|
||||
notes: list = field(default_factory=list)
|
||||
notes: list[FumenNote] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -109,15 +110,17 @@ class FumenMeasure:
|
||||
duration: float = 0.0
|
||||
gogo: bool = False
|
||||
barline: bool = True
|
||||
branch_start: list = field(default_factory=list)
|
||||
branch_info: List[int] = field(default_factory=lambda: [-1] * 6)
|
||||
branches: Dict[str, FumenBranch] = field(
|
||||
branch_info: list[int] = field(default_factory=lambda: [-1] * 6)
|
||||
branches: dict[str, FumenBranch] = field(
|
||||
default_factory=lambda: {b: FumenBranch() for b in BRANCH_NAMES}
|
||||
)
|
||||
padding1: int = 0
|
||||
padding2: int = 0
|
||||
|
||||
def set_duration(self, time_sig, measure_length, subdivisions):
|
||||
def set_duration(self,
|
||||
time_sig: list[int],
|
||||
measure_length: int,
|
||||
subdivisions: int) -> None:
|
||||
"""Compute the millisecond duration of the measure."""
|
||||
# First, we compute the duration for a full 4/4 measure.
|
||||
full_duration = (4 * 60_000 / self.bpm)
|
||||
@ -135,26 +138,37 @@ class FumenMeasure:
|
||||
)
|
||||
self.duration = (full_duration * measure_size * measure_ratio)
|
||||
|
||||
def set_ms_offsets(self, song_offset, delay, prev_measure, first_measure):
|
||||
"""Compute the millisecond offsets for the start/end of the measure."""
|
||||
if first_measure:
|
||||
self.offset_start = (song_offset * -1000) - (4 * 60_000 / self.bpm)
|
||||
else:
|
||||
# First, start with the end timing of the previous measure
|
||||
self.offset_start = prev_measure.offset_end
|
||||
# Add any #DELAY commands
|
||||
self.offset_start += delay
|
||||
# Adjust the start timing to account for #BPMCHANGE commands
|
||||
# (!!! Discovered by tana :3 !!!)
|
||||
self.offset_start += (4 * 60_000 / prev_measure.bpm)
|
||||
self.offset_start -= (4 * 60_000 / self.bpm)
|
||||
|
||||
def set_first_ms_offsets(self, song_offset: float) -> None:
|
||||
# First, start with song's OFFSET: metadata
|
||||
self.offset_start = (song_offset * -1 * 1000) # s -> ms
|
||||
# Then, subtract a full 4/4 measure for the current BPM
|
||||
self.offset_start -= (4 * 60_000 / self.bpm)
|
||||
# Compute the end offset by adding the duration to the start offset
|
||||
self.offset_end = self.offset_start + self.duration
|
||||
|
||||
def set_branch_info(self, branch_condition, branch_points_total,
|
||||
current_branch, first_branch_condition,
|
||||
has_section, has_levelhold):
|
||||
def set_ms_offsets(self,
|
||||
delay: float,
|
||||
prev_measure: 'FumenMeasure') -> None:
|
||||
"""Compute the millisecond offsets for the start/end of the measure."""
|
||||
# First, start with the end timing of the previous measure
|
||||
self.offset_start = prev_measure.offset_end
|
||||
# Add any #DELAY commands
|
||||
self.offset_start += delay
|
||||
# Adjust the start timing to account for #BPMCHANGE commands
|
||||
# (!!! Discovered by tana :3 !!!)
|
||||
self.offset_start += (4 * 60_000 / prev_measure.bpm)
|
||||
self.offset_start -= (4 * 60_000 / self.bpm)
|
||||
# Compute the end offset by adding the duration to the start offset
|
||||
self.offset_end = self.offset_start + self.duration
|
||||
|
||||
def set_branch_info(self,
|
||||
branch_type: str,
|
||||
branch_cond: tuple[float, float],
|
||||
branch_points_total: int,
|
||||
current_branch: str,
|
||||
first_branch_condition: bool,
|
||||
has_section: bool,
|
||||
has_levelhold: bool) -> None:
|
||||
"""Compute the values that represent branching/diverge conditions."""
|
||||
# If levelhold is set, force the branch to stay the same,
|
||||
# regardless of the value of the current branch condition.
|
||||
@ -171,9 +185,9 @@ class FumenMeasure:
|
||||
# 1. Percentage is between 0% and 100%
|
||||
# 2. Percentage is above 100% (guaranteed level down)
|
||||
# 3. Percentage is 0% (guaranteed level up)
|
||||
elif branch_condition[0] == 'p':
|
||||
elif branch_type == 'p':
|
||||
vals = []
|
||||
for percent in branch_condition[1:]:
|
||||
for percent in branch_cond:
|
||||
if 0 < percent <= 1:
|
||||
vals.append(int(branch_points_total * percent))
|
||||
elif percent > 1:
|
||||
@ -195,20 +209,21 @@ class FumenMeasure:
|
||||
# 3. It's not the first branching condition, and it
|
||||
# doesn't have a #SECTION command.
|
||||
# TODO: Determine the behavior for these 3 conditions
|
||||
elif branch_condition[0] == 'r':
|
||||
elif branch_type == 'r':
|
||||
vals = [int(v) for v in branch_cond]
|
||||
if current_branch == 'normal':
|
||||
self.branch_info[0:2] = branch_condition[1:]
|
||||
self.branch_info[0:2] = vals
|
||||
elif current_branch == 'professional':
|
||||
self.branch_info[2:4] = branch_condition[1:]
|
||||
self.branch_info[2:4] = vals
|
||||
elif current_branch == 'master':
|
||||
self.branch_info[4:6] = branch_condition[1:]
|
||||
self.branch_info[4:6] = vals
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FumenHeader:
|
||||
"""Contains all the byte values for a Fumen chart file's header."""
|
||||
order: str = "<"
|
||||
b000_b431_timing_windows: List[float] = field(default_factory=lambda:
|
||||
b000_b431_timing_windows: list[float] = field(default_factory=lambda:
|
||||
[25.025, 75.075, 108.422]*36)
|
||||
b432_b435_has_branches: int = 0
|
||||
b436_b439_hp_max: int = 10000
|
||||
@ -233,7 +248,7 @@ class FumenHeader:
|
||||
b512_b515_number_of_measures: int = 0
|
||||
b516_b519_unknown_data: int = 0
|
||||
|
||||
def parse_header_values(self, raw_bytes):
|
||||
def parse_header_values(self, raw_bytes: bytes) -> None:
|
||||
"""Parse a raw string of 520 bytes to get the header values."""
|
||||
self.order = self._parse_order(raw_bytes)
|
||||
rb = raw_bytes # We use a shortened form just for visual clarity:
|
||||
@ -261,14 +276,15 @@ class FumenHeader:
|
||||
self.b512_b515_number_of_measures = self.up(rb, "i", 512, 515)
|
||||
self.b516_b519_unknown_data = self.up(rb, "i", 516, 519)
|
||||
|
||||
def up(self, raw_bytes, type_string, s=None, e=None):
|
||||
def up(self, raw_bytes: bytes, type_string: str,
|
||||
s: Optional[int] = None, e: Optional[int] = None) -> Any:
|
||||
"""Unpack a raw byte string according to specific types."""
|
||||
if s is not None and e is not None:
|
||||
raw_bytes = raw_bytes[s:e+1]
|
||||
vals = struct.unpack(self.order + type_string, raw_bytes)
|
||||
return vals[0] if len(vals) == 1 else vals
|
||||
|
||||
def _parse_order(self, raw_bytes):
|
||||
def _parse_order(self, raw_bytes: bytes) -> str:
|
||||
"""Parse the order of the song (little or big endian)."""
|
||||
self.order = ''
|
||||
# Bytes 512-515 are the number of measures. We check the values using
|
||||
@ -279,7 +295,8 @@ class FumenHeader:
|
||||
else:
|
||||
return "<"
|
||||
|
||||
def set_hp_bytes(self, n_notes, difficulty, stars):
|
||||
def set_hp_bytes(self, n_notes: int, difficulty: str,
|
||||
stars: int) -> None:
|
||||
"""Compute header bytes related to the soul gauge (HP) behavior."""
|
||||
# Note: Ura Oni is equivalent to Oni for soul gauge behavior
|
||||
difficulty = 'Oni' if difficulty in ['Ura', 'Edit'] else difficulty
|
||||
@ -287,7 +304,8 @@ class FumenHeader:
|
||||
self.b440_b443_hp_clear = {'Easy': 6000, 'Normal': 7000,
|
||||
'Hard': 7000, 'Oni': 8000}[difficulty]
|
||||
|
||||
def _get_hp_from_LUTs(self, n_notes, difficulty, stars):
|
||||
def _get_hp_from_LUTs(self, n_notes: int, difficulty: str,
|
||||
stars: int) -> None:
|
||||
"""Fetch pre-computed soul gauge values from lookup tables (LUTs)."""
|
||||
if not 0 < n_notes <= 2500:
|
||||
return
|
||||
@ -312,18 +330,18 @@ class FumenHeader:
|
||||
break
|
||||
|
||||
@property
|
||||
def raw_bytes(self):
|
||||
def raw_bytes(self) -> bytes:
|
||||
"""Represent the header values as a string of raw bytes."""
|
||||
value_list = []
|
||||
format_string = self.order
|
||||
for key in self.__slots__:
|
||||
if key in ["order", "_raw_bytes"]:
|
||||
for f in fields(self):
|
||||
if f.name in ["order", "_raw_bytes"]:
|
||||
pass
|
||||
elif key == "b000_b431_timing_windows":
|
||||
value_list.extend(list(getattr(self, key)))
|
||||
format_string += "f" * len(getattr(self, key))
|
||||
elif f.name == "b000_b431_timing_windows":
|
||||
value_list.extend(list(getattr(self, f.name)))
|
||||
format_string += "f" * len(getattr(self, f.name))
|
||||
else:
|
||||
value_list.append(getattr(self, key))
|
||||
value_list.append(getattr(self, f.name))
|
||||
format_string += "i"
|
||||
raw_bytes = struct.pack(format_string, *value_list)
|
||||
assert len(raw_bytes) == 520
|
||||
@ -334,6 +352,6 @@ class FumenHeader:
|
||||
class FumenCourse:
|
||||
"""Contains all the data in a single Fumen (`.bin`) chart file."""
|
||||
header: FumenHeader
|
||||
measures: List[FumenMeasure] = field(default_factory=list)
|
||||
measures: list[FumenMeasure] = field(default_factory=list)
|
||||
score_init: int = 0
|
||||
score_diff: int = 0
|
||||
|
@ -1,9 +1,11 @@
|
||||
import struct
|
||||
from typing import BinaryIO, Any
|
||||
|
||||
from tja2fumen.types import FumenCourse
|
||||
from tja2fumen.constants import BRANCH_NAMES, FUMEN_TYPE_NOTES
|
||||
|
||||
|
||||
def write_fumen(path_out, song):
|
||||
def write_fumen(path_out: str, song: FumenCourse) -> None:
|
||||
"""
|
||||
Write the values in a FumenCourse object to a `.bin` file.
|
||||
|
||||
@ -47,7 +49,11 @@ def write_fumen(path_out, song):
|
||||
file.write(note.drumroll_bytes)
|
||||
|
||||
|
||||
def write_struct(file, order, format_string, value_list, seek=None):
|
||||
def write_struct(file: BinaryIO,
|
||||
order: str,
|
||||
format_string: str,
|
||||
value_list: list[Any],
|
||||
seek: int = 0) -> None:
|
||||
"""Pack (int, float, etc.) values into a string of bytes, then write."""
|
||||
if seek:
|
||||
file.seek(seek)
|
||||
|
Loading…
Reference in New Issue
Block a user