diff --git a/.github/workflows/check_code_quality.yml b/.github/workflows/check_code_quality.yml new file mode 100644 index 0000000..49ea761 --- /dev/null +++ b/.github/workflows/check_code_quality.yml @@ -0,0 +1,32 @@ +#file: noinspection LongLine +name: "Check code quality" + +on: + push: + branches: + - main + pull_request: + branches: + - '*' + +jobs: + check-code-quality: + runs-on: windows-2019 + steps: + - name: Checkout tja2fumen + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10.x' + + - name: Install tja2fumen and its dev dependencies + run: | + pip install -e .[dev] + + - name: Lint project using flake8 + run: pflake8 + + - name: Type analysis using mypy + run: mypy src --strict diff --git a/.github/workflows/test_and_publish_release.yml b/.github/workflows/test_and_publish_release.yml index f2563c0..9259cd5 100644 --- a/.github/workflows/test_and_publish_release.yml +++ b/.github/workflows/test_and_publish_release.yml @@ -33,9 +33,6 @@ jobs: run: | pip install -e .[dev] - - name: Lint project - run: pflake8 - - name: Run tests (Python API) run: | pytest testing --entry-point python-api diff --git a/pyproject.toml b/pyproject.toml index 50a4c0d..aa771b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ tja2fumen = "tja2fumen:main" [project.optional-dependencies] dev = ["pytest", "build", "pyinstaller", "twine", "toml-cli", - "flake8", "pyproject-flake8", "pydantic"] + "flake8", "pyproject-flake8", "mypy"] [tool.setuptools.packages.find] where = ["src"] diff --git a/src/tja2fumen/__init__.py b/src/tja2fumen/__init__.py index d78d2e2..a6edd0d 100644 --- a/src/tja2fumen/__init__.py +++ b/src/tja2fumen/__init__.py @@ -1,14 +1,16 @@ import argparse import os import sys +from typing import Sequence from tja2fumen.parsers import parse_tja from tja2fumen.converters import convert_tja_to_fumen from tja2fumen.writers import write_fumen from tja2fumen.constants import COURSE_IDS +from tja2fumen.types import TJACourse -def main(argv=None): +def main(argv: Sequence[str] = ()) -> None: """ Main entry point for tja2fumen's command line interface. @@ -35,14 +37,16 @@ def main(argv=None): parsed_tja = parse_tja(fname_tja) # Convert parsed TJA courses and write each course to `.bin` files - for course in parsed_tja.courses.items(): - convert_and_write(course, base_name, + for course_name, course in parsed_tja.courses.items(): + convert_and_write(course, course_name, base_name, single_course=(len(parsed_tja.courses) == 1)) -def convert_and_write(parsed_course, base_name, single_course=False): +def convert_and_write(tja_data: TJACourse, + course_name: str, + base_name: str, + single_course: bool = False) -> None: """Process the parsed data for a single TJA course.""" - course_name, tja_data = parsed_course fumen_data = convert_tja_to_fumen(tja_data) # Add course ID (e.g. '_x', '_x_1', '_x_2') to the output file's base name output_name = base_name diff --git a/src/tja2fumen/converters.py b/src/tja2fumen/converters.py index 49a25c2..8e38ade 100644 --- a/src/tja2fumen/converters.py +++ b/src/tja2fumen/converters.py @@ -1,10 +1,11 @@ import re -from tja2fumen.types import (TJAMeasureProcessed, +from tja2fumen.types import (TJACourse, TJAMeasureProcessed, FumenCourse, FumenHeader, FumenMeasure, FumenNote) -def process_tja_commands(tja): +def process_tja_commands(tja: TJACourse) \ + -> dict[str, list[TJAMeasureProcessed]]: """ Process each #COMMAND present in a TJASong's measures, and assign their values as attributes to each measure. @@ -22,8 +23,9 @@ def process_tja_commands(tja): After this function is finished, all the #COMMANDS will be gone, and each measure will have attributes (e.g. measure.bpm, measure.scroll) instead. """ - tja_branches_processed = {branch_name: [] - for branch_name in tja.branches.keys()} + tja_branches_processed: dict[str, list[TJAMeasureProcessed]] = { + branch_name: [] for branch_name in tja.branches.keys() + } for branch_name, branch_measures_tja in tja.branches.items(): current_bpm = tja.BPM current_scroll = 1.0 @@ -50,16 +52,15 @@ def process_tja_commands(tja): elif data.name == 'delay': measure_tja_processed.delay = float(data.value) * 1000 elif data.name == 'branch_start': - branch_condition = data.value.split(',') - if branch_condition[0] == 'r': # r = drumRoll - branch_condition[1] = int(branch_condition[1]) - branch_condition[2] = int(branch_condition[2]) - elif branch_condition[0] == 'p': # p = Percentage - branch_condition[1] = float(branch_condition[1]) / 100 - branch_condition[2] = float(branch_condition[2]) / 100 - measure_tja_processed.branch_start = branch_condition + branch_type, val1, val2 = data.value.split(',') + if branch_type == 'r': # r = drumRoll + branch_cond = (float(val1), float(val2)) + elif branch_type == 'p': # p = Percentage + branch_cond = (float(val1)/100, float(val2)/100) + measure_tja_processed.branch_type = branch_type + measure_tja_processed.branch_cond = branch_cond elif data.name == 'section': - measure_tja_processed.section = data.value + measure_tja_processed.section = bool(data.value) elif data.name == 'levelhold': measure_tja_processed.levelhold = True elif data.name == 'barline': @@ -135,7 +136,7 @@ def process_tja_commands(tja): return tja_branches_processed -def convert_tja_to_fumen(tja): +def convert_tja_to_fumen(tja: TJACourse) -> FumenCourse: """ Convert TJA data to Fumen data by calculating Fumen-specific values. @@ -193,7 +194,8 @@ def convert_tja_to_fumen(tja): branch_points_measure = 0 current_drumroll = None current_levelhold = False - branch_conditions = [] + branch_types: list[str] = [] + branch_conditions: list[tuple[float, float]] = [] course_balloons = tja.balloon.copy() # Iterate through the measures within the branch @@ -215,12 +217,13 @@ def convert_tja_to_fumen(tja): ) # Compute the millisecond offsets for the start/end of each measure - measure_fumen.set_ms_offsets( - song_offset=tja.offset, - delay=measure_tja.delay, - prev_measure=(fumen.measures[idx_m-1] if idx_m else None), - first_measure=(idx_m == 0) - ) + if idx_m == 0: + measure_fumen.set_first_ms_offsets(song_offset=tja.offset) + else: + measure_fumen.set_ms_offsets( + delay=measure_tja.delay, + prev_measure=fumen.measures[idx_m-1], + ) # Handle whether barline should be hidden: # 1. Measures where #BARLINEOFF has been set @@ -232,11 +235,13 @@ def convert_tja_to_fumen(tja): measure_fumen.barline = False # Check to see if the measure contains a branching condition - branch_condition = measure_tja.branch_start - if branch_condition: + branch_type = measure_tja.branch_type + branch_cond = measure_tja.branch_cond + if branch_type and branch_cond: # Update the branch_info values for the measure measure_fumen.set_branch_info( - branch_condition, branch_points_total, current_branch, + branch_type, branch_cond, + branch_points_total, current_branch, first_branch_condition=(not branch_conditions), has_section=bool(measure_tja.section), has_levelhold=current_levelhold @@ -248,7 +253,8 @@ def convert_tja_to_fumen(tja): current_levelhold = False # Keep track of the branch conditions (to later determine how # to set the header bytes for branches) - branch_conditions.append(branch_condition) + branch_types.append(branch_type) + branch_conditions.append(branch_cond) # NB: We update the branch condition note counter *after* # we check the current measure's branch condition. @@ -276,6 +282,11 @@ def convert_tja_to_fumen(tja): # Handle '8' notes (end of a drumroll/balloon) if data.value == "EndDRB": + if not isinstance(current_drumroll, FumenNote): + raise ValueError( + "'8' note encountered without matching " + "drumroll/balloon/kusudama note." + ) # If a drumroll spans a single measure, then add the # difference between start/end position if not current_drumroll.multimeasure: @@ -311,7 +322,7 @@ def convert_tja_to_fumen(tja): note.hits = course_balloons.pop(0) except IndexError: raise ValueError(f"Not enough values for 'BALLOON: " - f"{','.join(course_balloons)}'") + f"{course_balloons}'") current_drumroll = note elif note.note_type in ["Drumroll", "DRUMROLL"]: current_drumroll = note @@ -355,11 +366,11 @@ def convert_tja_to_fumen(tja): # If song has only drumroll branching conditions (also allowing percentage # conditions that force a level up/level down), then set the header bytes # so that only drumrolls contribute to branching. - drumroll_only = branch_conditions != [] and all([ - (cond[0] == 'r') or - (cond[0] == 'p' and cond[1] == 0.0 and cond[2] == 0.0) or - (cond[0] == 'p' and cond[1] > 1.00 and cond[2] > 1.00) - for cond in branch_conditions + drumroll_only = branch_types != [] and branch_conditions != [] and all([ + (branch_type == 'r') or + (branch_type == 'p' and cond[0] == 0.0 and cond[1] == 0.0) or + (branch_type == 'p' and cond[0] > 1.00 and cond[1] > 1.00) + for branch_type, cond in zip(branch_types, branch_conditions) ]) if drumroll_only: fumen.header.b468_b471_branch_points_good = 0 @@ -371,9 +382,9 @@ def convert_tja_to_fumen(tja): # Alternatively, if the song has only percentage-based conditions, then set # the header bytes so that only notes and balloons contribute to branching. - percentage_only = branch_conditions != [] and all([ - (condition[0] != 'r') - for condition in branch_conditions + percentage_only = branch_types != [] and all([ + (branch_type != 'r') + for branch_type in branch_types ]) if percentage_only: fumen.header.b480_b483_branch_points_drumroll = 0 diff --git a/src/tja2fumen/parsers.py b/src/tja2fumen/parsers.py index 2b82b3f..5f4a806 100644 --- a/src/tja2fumen/parsers.py +++ b/src/tja2fumen/parsers.py @@ -2,6 +2,7 @@ import os import re import struct from copy import deepcopy +from typing import BinaryIO, Any from tja2fumen.types import (TJASong, TJACourse, TJAMeasure, TJAData, FumenCourse, FumenMeasure, FumenBranch, FumenNote, @@ -15,7 +16,7 @@ from tja2fumen.constants import (NORMALIZE_COURSE, COURSE_NAMES, BRANCH_NAMES, ############################################################################### -def parse_tja(fname_tja): +def parse_tja(fname_tja: str) -> TJASong: """Read in lines of a .tja file and load them into a TJASong object.""" try: tja_text = open(fname_tja, "r", encoding="utf-8-sig").read() @@ -30,7 +31,7 @@ def parse_tja(fname_tja): return tja -def split_tja_lines_into_courses(lines): +def split_tja_lines_into_courses(lines: list[str]) -> TJASong: """ Parse TJA metadata in order to split TJA lines into separate courses. @@ -152,7 +153,7 @@ def split_tja_lines_into_courses(lines): return parsed_tja -def parse_tja_course_data(course): +def parse_tja_course_data(course: TJACourse) -> None: """ Parse course data (notes, commands) into a nested song structure. @@ -188,40 +189,41 @@ def parse_tja_course_data(course): idx_m_branchstart = 0 for idx_l, line in enumerate(course.data): # 0. Check to see whether line is a command or note data - command, name, value, notes = None, None, None, None + command, name, value, note_data = '', '', '', '' match_command = re.match(r"^#([A-Z]+)(?:\s+(.+))?", line) if match_command: command, value = match_command.groups() value = '' if value is None else value else: - notes = line # If not a command, then line must be note data + note_data = line # If not a command, then line must be note data # 1. Parse measure notes - if notes: + if note_data: # If measure has ended, then add notes to the current measure, # then start a new measure by incrementing idx_m - if notes.endswith(','): - for branch in (course.branches.keys() - if current_branch == 'all' - else [current_branch]): - course.branches[branch][idx_m].notes += notes[0:-1] - course.branches[branch].append(TJAMeasure()) + if note_data.endswith(','): + for branch_name in (course.branches.keys() + if current_branch == 'all' + else [current_branch]): + course.branches[branch_name][idx_m].notes += note_data[:-1] + course.branches[branch_name].append(TJAMeasure()) idx_m += 1 # Otherwise, keep adding notes to the current measure ('idx_m') else: - for branch in (course.branches.keys() - if current_branch == 'all' - else [current_branch]): - course.branches[branch][idx_m].notes += notes + for branch_name in (course.branches.keys() + if current_branch == 'all' + else [current_branch]): + course.branches[branch_name][idx_m].notes += note_data # 2. Parse measure commands that produce an "event" elif command in ['GOGOSTART', 'GOGOEND', 'BARLINEON', 'BARLINEOFF', 'DELAY', 'SCROLL', 'BPMCHANGE', 'MEASURE', 'LEVELHOLD', 'SECTION', 'BRANCHSTART']: # Get position of the event - for branch in (course.branches.keys() if current_branch == 'all' - else [current_branch]): - pos = len(course.branches[branch][idx_m].notes) + for branch_name in (course.branches.keys() + if current_branch == 'all' + else [current_branch]): + pos = len(course.branches[branch_name][idx_m].notes) # Parse event type if command == 'GOGOSTART': @@ -263,9 +265,10 @@ def parse_tja_course_data(course): idx_m_branchstart = idx_m # Append event to the current measure's events - for branch in (course.branches.keys() if current_branch == 'all' - else [current_branch]): - course.branches[branch][idx_m].events.append( + for branch_name in (course.branches.keys() + if current_branch == 'all' + else [current_branch]): + course.branches[branch_name][idx_m].events.append( TJAData(name=name, value=value, pos=pos) ) @@ -328,7 +331,8 @@ def parse_tja_course_data(course): # Fumen-parsing functions # ############################################################################### -def parse_fumen(fumen_file, exclude_empty_measures=False): +def parse_fumen(fumen_file: str, + exclude_empty_measures: bool = False) -> FumenCourse: """ Parse bytes of a fumen .bin file into nested measures, branches, and notes. @@ -469,7 +473,10 @@ def parse_fumen(fumen_file, exclude_empty_measures=False): return song -def read_struct(file, order, format_string, seek=None): +def read_struct(file: BinaryIO, + order: str, + format_string: str, + seek: int = 0) -> tuple[Any, ...]: """ Interpret bytes as packed binary data. diff --git a/src/tja2fumen/types.py b/src/tja2fumen/types.py index b2f9b94..376bde2 100644 --- a/src/tja2fumen/types.py +++ b/src/tja2fumen/types.py @@ -1,9 +1,9 @@ import csv import os import struct -from typing import Dict, List +from typing import Any, Optional -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from tja2fumen.constants import BRANCH_NAMES @@ -20,9 +20,9 @@ class TJAData: @dataclass(slots=True) class TJAMeasure: """Contains all the data in a single TJA measure (denoted by ',').""" - notes: List[TJAData] = field(default_factory=list) - events: List[TJAData] = field(default_factory=list) - combined: List[TJAData] = field(default_factory=list) + notes: list[str] = field(default_factory=list) + events: list[TJAData] = field(default_factory=list) + combined: list[TJAData] = field(default_factory=list) @dataclass(slots=True) @@ -32,11 +32,11 @@ class TJACourse: offset: float course: str level: int = 0 - balloon: list = field(default_factory=list) + balloon: list[int] = field(default_factory=list) score_init: int = 0 score_diff: int = 0 - data: list = field(default_factory=list) - branches: Dict[str, List[TJAMeasure]] = field( + data: list[str] = field(default_factory=list) + branches: dict[str, list[TJAMeasure]] = field( default_factory=lambda: {k: [TJAMeasure()] for k in BRANCH_NAMES} ) @@ -46,7 +46,7 @@ class TJASong: """Contains all the data in a single TJA (`.tja`) chart file.""" BPM: float offset: float - courses: Dict[str, TJACourse] + courses: dict[str, TJACourse] @dataclass(slots=True) @@ -64,15 +64,16 @@ class TJAMeasureProcessed: scroll: float gogo: bool barline: bool - time_sig: List[int] + time_sig: list[int] subdivisions: int pos_start: int = 0 pos_end: int = 0 delay: float = 0.0 section: bool = False levelhold: bool = False - branch_start: List = field(default_factory=list) - data: list = field(default_factory=list) + branch_type: str = '' + branch_cond: tuple[float, float] = (0.0, 0.0) + data: list[TJAData] = field(default_factory=list) @dataclass(slots=True) @@ -97,7 +98,7 @@ class FumenBranch: length: int = 0 speed: float = 0.0 padding: int = 0 - notes: list = field(default_factory=list) + notes: list[FumenNote] = field(default_factory=list) @dataclass(slots=True) @@ -109,15 +110,17 @@ class FumenMeasure: duration: float = 0.0 gogo: bool = False barline: bool = True - branch_start: list = field(default_factory=list) - branch_info: List[int] = field(default_factory=lambda: [-1] * 6) - branches: Dict[str, FumenBranch] = field( + branch_info: list[int] = field(default_factory=lambda: [-1] * 6) + branches: dict[str, FumenBranch] = field( default_factory=lambda: {b: FumenBranch() for b in BRANCH_NAMES} ) padding1: int = 0 padding2: int = 0 - def set_duration(self, time_sig, measure_length, subdivisions): + def set_duration(self, + time_sig: list[int], + measure_length: int, + subdivisions: int) -> None: """Compute the millisecond duration of the measure.""" # First, we compute the duration for a full 4/4 measure. full_duration = (4 * 60_000 / self.bpm) @@ -135,26 +138,37 @@ class FumenMeasure: ) self.duration = (full_duration * measure_size * measure_ratio) - def set_ms_offsets(self, song_offset, delay, prev_measure, first_measure): - """Compute the millisecond offsets for the start/end of the measure.""" - if first_measure: - self.offset_start = (song_offset * -1000) - (4 * 60_000 / self.bpm) - else: - # First, start with the end timing of the previous measure - self.offset_start = prev_measure.offset_end - # Add any #DELAY commands - self.offset_start += delay - # Adjust the start timing to account for #BPMCHANGE commands - # (!!! Discovered by tana :3 !!!) - self.offset_start += (4 * 60_000 / prev_measure.bpm) - self.offset_start -= (4 * 60_000 / self.bpm) - + def set_first_ms_offsets(self, song_offset: float) -> None: + # First, start with song's OFFSET: metadata + self.offset_start = (song_offset * -1 * 1000) # s -> ms + # Then, subtract a full 4/4 measure for the current BPM + self.offset_start -= (4 * 60_000 / self.bpm) # Compute the end offset by adding the duration to the start offset self.offset_end = self.offset_start + self.duration - def set_branch_info(self, branch_condition, branch_points_total, - current_branch, first_branch_condition, - has_section, has_levelhold): + def set_ms_offsets(self, + delay: float, + prev_measure: 'FumenMeasure') -> None: + """Compute the millisecond offsets for the start/end of the measure.""" + # First, start with the end timing of the previous measure + self.offset_start = prev_measure.offset_end + # Add any #DELAY commands + self.offset_start += delay + # Adjust the start timing to account for #BPMCHANGE commands + # (!!! Discovered by tana :3 !!!) + self.offset_start += (4 * 60_000 / prev_measure.bpm) + self.offset_start -= (4 * 60_000 / self.bpm) + # Compute the end offset by adding the duration to the start offset + self.offset_end = self.offset_start + self.duration + + def set_branch_info(self, + branch_type: str, + branch_cond: tuple[float, float], + branch_points_total: int, + current_branch: str, + first_branch_condition: bool, + has_section: bool, + has_levelhold: bool) -> None: """Compute the values that represent branching/diverge conditions.""" # If levelhold is set, force the branch to stay the same, # regardless of the value of the current branch condition. @@ -171,9 +185,9 @@ class FumenMeasure: # 1. Percentage is between 0% and 100% # 2. Percentage is above 100% (guaranteed level down) # 3. Percentage is 0% (guaranteed level up) - elif branch_condition[0] == 'p': + elif branch_type == 'p': vals = [] - for percent in branch_condition[1:]: + for percent in branch_cond: if 0 < percent <= 1: vals.append(int(branch_points_total * percent)) elif percent > 1: @@ -195,20 +209,21 @@ class FumenMeasure: # 3. It's not the first branching condition, and it # doesn't have a #SECTION command. # TODO: Determine the behavior for these 3 conditions - elif branch_condition[0] == 'r': + elif branch_type == 'r': + vals = [int(v) for v in branch_cond] if current_branch == 'normal': - self.branch_info[0:2] = branch_condition[1:] + self.branch_info[0:2] = vals elif current_branch == 'professional': - self.branch_info[2:4] = branch_condition[1:] + self.branch_info[2:4] = vals elif current_branch == 'master': - self.branch_info[4:6] = branch_condition[1:] + self.branch_info[4:6] = vals @dataclass(slots=True) class FumenHeader: """Contains all the byte values for a Fumen chart file's header.""" order: str = "<" - b000_b431_timing_windows: List[float] = field(default_factory=lambda: + b000_b431_timing_windows: list[float] = field(default_factory=lambda: [25.025, 75.075, 108.422]*36) b432_b435_has_branches: int = 0 b436_b439_hp_max: int = 10000 @@ -233,7 +248,7 @@ class FumenHeader: b512_b515_number_of_measures: int = 0 b516_b519_unknown_data: int = 0 - def parse_header_values(self, raw_bytes): + def parse_header_values(self, raw_bytes: bytes) -> None: """Parse a raw string of 520 bytes to get the header values.""" self.order = self._parse_order(raw_bytes) rb = raw_bytes # We use a shortened form just for visual clarity: @@ -261,14 +276,15 @@ class FumenHeader: self.b512_b515_number_of_measures = self.up(rb, "i", 512, 515) self.b516_b519_unknown_data = self.up(rb, "i", 516, 519) - def up(self, raw_bytes, type_string, s=None, e=None): + def up(self, raw_bytes: bytes, type_string: str, + s: Optional[int] = None, e: Optional[int] = None) -> Any: """Unpack a raw byte string according to specific types.""" if s is not None and e is not None: raw_bytes = raw_bytes[s:e+1] vals = struct.unpack(self.order + type_string, raw_bytes) return vals[0] if len(vals) == 1 else vals - def _parse_order(self, raw_bytes): + def _parse_order(self, raw_bytes: bytes) -> str: """Parse the order of the song (little or big endian).""" self.order = '' # Bytes 512-515 are the number of measures. We check the values using @@ -279,7 +295,8 @@ class FumenHeader: else: return "<" - def set_hp_bytes(self, n_notes, difficulty, stars): + def set_hp_bytes(self, n_notes: int, difficulty: str, + stars: int) -> None: """Compute header bytes related to the soul gauge (HP) behavior.""" # Note: Ura Oni is equivalent to Oni for soul gauge behavior difficulty = 'Oni' if difficulty in ['Ura', 'Edit'] else difficulty @@ -287,7 +304,8 @@ class FumenHeader: self.b440_b443_hp_clear = {'Easy': 6000, 'Normal': 7000, 'Hard': 7000, 'Oni': 8000}[difficulty] - def _get_hp_from_LUTs(self, n_notes, difficulty, stars): + def _get_hp_from_LUTs(self, n_notes: int, difficulty: str, + stars: int) -> None: """Fetch pre-computed soul gauge values from lookup tables (LUTs).""" if not 0 < n_notes <= 2500: return @@ -312,18 +330,18 @@ class FumenHeader: break @property - def raw_bytes(self): + def raw_bytes(self) -> bytes: """Represent the header values as a string of raw bytes.""" value_list = [] format_string = self.order - for key in self.__slots__: - if key in ["order", "_raw_bytes"]: + for f in fields(self): + if f.name in ["order", "_raw_bytes"]: pass - elif key == "b000_b431_timing_windows": - value_list.extend(list(getattr(self, key))) - format_string += "f" * len(getattr(self, key)) + elif f.name == "b000_b431_timing_windows": + value_list.extend(list(getattr(self, f.name))) + format_string += "f" * len(getattr(self, f.name)) else: - value_list.append(getattr(self, key)) + value_list.append(getattr(self, f.name)) format_string += "i" raw_bytes = struct.pack(format_string, *value_list) assert len(raw_bytes) == 520 @@ -334,6 +352,6 @@ class FumenHeader: class FumenCourse: """Contains all the data in a single Fumen (`.bin`) chart file.""" header: FumenHeader - measures: List[FumenMeasure] = field(default_factory=list) + measures: list[FumenMeasure] = field(default_factory=list) score_init: int = 0 score_diff: int = 0 diff --git a/src/tja2fumen/writers.py b/src/tja2fumen/writers.py index f4af6bb..54f8884 100644 --- a/src/tja2fumen/writers.py +++ b/src/tja2fumen/writers.py @@ -1,9 +1,11 @@ import struct +from typing import BinaryIO, Any +from tja2fumen.types import FumenCourse from tja2fumen.constants import BRANCH_NAMES, FUMEN_TYPE_NOTES -def write_fumen(path_out, song): +def write_fumen(path_out: str, song: FumenCourse) -> None: """ Write the values in a FumenCourse object to a `.bin` file. @@ -47,7 +49,11 @@ def write_fumen(path_out, song): file.write(note.drumroll_bytes) -def write_struct(file, order, format_string, value_list, seek=None): +def write_struct(file: BinaryIO, + order: str, + format_string: str, + value_list: list[Any], + seek: int = 0) -> None: """Pack (int, float, etc.) values into a string of bytes, then write.""" if seek: file.seek(seek)