diff --git a/jubeatools/formats/jubeat_analyser/load_tools.py b/jubeatools/formats/jubeat_analyser/load_tools.py index 7499126..9f52e70 100644 --- a/jubeatools/formats/jubeat_analyser/load_tools.py +++ b/jubeatools/formats/jubeat_analyser/load_tools.py @@ -14,7 +14,7 @@ import constraint from parsimonious import Grammar, NodeVisitor, ParseError from parsimonious.nodes import Node -from jubeatools.formats.load_tools import make_folder_loader +from jubeatools.formats.load_tools import make_folder_loader, round_beats from jubeatools.song import BeatsTime, BPMEvent, Difficulty, LongNote, NotePosition from .symbols import ( @@ -185,11 +185,6 @@ def split_double_byte_line(line: str) -> List[str]: return symbols -def decimal_to_beats(decimal_time: Decimal) -> BeatsTime: - nearest_240th = round(decimal_time * 240) - return BeatsTime(nearest_240th, 240) - - @dataclass(frozen=True) class UnfinishedLongNote: time: BeatsTime @@ -331,7 +326,7 @@ class JubeatAnalyserParser: method() def do_b(self, value: str) -> None: - self.beats_per_section = decimal_to_beats(Decimal(value)) + self.beats_per_section = round_beats(Decimal(value)) def do_m(self, value: str) -> None: self.music = value @@ -432,7 +427,7 @@ class JubeatAnalyserParser: f"{self.beats_per_section} beats, a symbol cannot happen " f"afterwards at {timing}" ) - self.symbols[symbol] = decimal_to_beats(timing) + self.symbols[symbol] = round_beats(timing) def is_short_line(self, line: str) -> bool: return len(line.encode("shift-jis-2004")) < self.bytes_per_panel * 4 diff --git a/jubeatools/formats/jubeat_analyser/tests/memo/test_memo.py b/jubeatools/formats/jubeat_analyser/tests/memo/test_memo.py index 69ccea9..4fee60e 100644 --- a/jubeatools/formats/jubeat_analyser/tests/memo/test_memo.py +++ b/jubeatools/formats/jubeat_analyser/tests/memo/test_memo.py @@ -11,8 +11,9 @@ from jubeatools.formats.enum import Format from jubeatools.formats.jubeat_analyser.memo.dump import _dump_memo_chart from jubeatools.formats.jubeat_analyser.memo.load import MemoParser from jubeatools.testutils import strategies as jbst +from jubeatools.testutils.test_patterns import dump_and_load_then_compare -from ..test_utils import load_and_dump_then_check, memo_compatible_song +from ..test_utils import memo_compatible_song, temp_file_named_txt from . import example1, example2, example3 @@ -44,4 +45,10 @@ def test_that_notes_roundtrip(notes: Set[Union[song.TapNote, song.LongNote]]) -> @example(*example2.data) @example(*example3.data) def test_that_full_chart_roundtrips(song: song.Song, circle_free: bool) -> None: - load_and_dump_then_check(Format.MEMO, song, circle_free) + dump_and_load_then_compare( + Format.MEMO, + song, + temp_path=temp_file_named_txt(), + bytes_decoder=lambda b: b.decode("shift-jis-2004"), + dump_options={"circle_free": circle_free}, + ) diff --git a/jubeatools/formats/jubeat_analyser/tests/memo1/test_memo1.py b/jubeatools/formats/jubeat_analyser/tests/memo1/test_memo1.py index 29a64a0..ce3e977 100644 --- a/jubeatools/formats/jubeat_analyser/tests/memo1/test_memo1.py +++ b/jubeatools/formats/jubeat_analyser/tests/memo1/test_memo1.py @@ -11,8 +11,9 @@ from jubeatools.formats.jubeat_analyser.memo1.dump import _dump_memo1_chart from jubeatools.formats.jubeat_analyser.memo1.load import Memo1Parser from jubeatools.testutils.strategies import NoteOption from jubeatools.testutils.strategies import notes as notes_strat +from jubeatools.testutils.test_patterns import dump_and_load_then_compare -from ..test_utils import load_and_dump_then_check, memo_compatible_song +from ..test_utils import memo_compatible_song, temp_file_named_txt from . import example1 @@ -41,4 +42,10 @@ def test_that_notes_roundtrip(notes: List[Union[song.TapNote, song.LongNote]]) - @given(memo_compatible_song(), st.booleans()) @example(*example1.data) def test_that_full_chart_roundtrips(song: song.Song, circle_free: bool) -> None: - load_and_dump_then_check(Format.MEMO_1, song, circle_free) + dump_and_load_then_compare( + Format.MEMO_1, + song, + temp_path=temp_file_named_txt(), + bytes_decoder=lambda b: b.decode("shift-jis-2004"), + dump_options={"circle_free": circle_free}, + ) diff --git a/jubeatools/formats/jubeat_analyser/tests/memo2/test_memo2.py b/jubeatools/formats/jubeat_analyser/tests/memo2/test_memo2.py index 6097236..1e089b9 100644 --- a/jubeatools/formats/jubeat_analyser/tests/memo2/test_memo2.py +++ b/jubeatools/formats/jubeat_analyser/tests/memo2/test_memo2.py @@ -21,8 +21,9 @@ from jubeatools.song import ( ) from jubeatools.testutils.strategies import NoteOption from jubeatools.testutils.strategies import notes as notes_strat +from jubeatools.testutils.test_patterns import dump_and_load_then_compare -from ..test_utils import load_and_dump_then_check, memo_compatible_song +from ..test_utils import memo_compatible_song, temp_file_named_txt from . import example1, example2, example3 @@ -52,4 +53,10 @@ def test_that_notes_roundtrip(notes: List[Union[TapNote, LongNote]]) -> None: @example(*example2.data) @example(*example3.data) def test_that_full_chart_roundtrips(song: Song, circle_free: bool) -> None: - load_and_dump_then_check(Format.MEMO_2, song, circle_free) + dump_and_load_then_compare( + Format.MEMO_2, + song, + temp_path=temp_file_named_txt(), + bytes_decoder=lambda b: b.decode("shift-jis-2004"), + dump_options={"circle_free": circle_free}, + ) diff --git a/jubeatools/formats/jubeat_analyser/tests/mono_column/test_mono_column_hypothesis.py b/jubeatools/formats/jubeat_analyser/tests/mono_column/test_mono_column_hypothesis.py index bd9aa80..3a4c241 100644 --- a/jubeatools/formats/jubeat_analyser/tests/mono_column/test_mono_column_hypothesis.py +++ b/jubeatools/formats/jubeat_analyser/tests/mono_column/test_mono_column_hypothesis.py @@ -22,8 +22,9 @@ from jubeatools.song import ( from jubeatools.testutils.strategies import NoteOption, long_note from jubeatools.testutils.strategies import notes as notes_strat from jubeatools.testutils.strategies import tap_note +from jubeatools.testutils.test_patterns import dump_and_load_then_compare -from ..test_utils import load_and_dump_then_check, memo_compatible_song +from ..test_utils import memo_compatible_song, temp_file_named_txt @given(st.sets(tap_note(), min_size=1, max_size=100)) @@ -84,4 +85,10 @@ def test_that_many_notes_roundtrip(notes: List[Union[TapNote, LongNote]]) -> Non @given(memo_compatible_song(), st.booleans()) def test_that_full_chart_roundtrips(song: Song, circle_free: bool) -> None: - load_and_dump_then_check(Format.MONO_COLUMN, song, circle_free) + dump_and_load_then_compare( + Format.MONO_COLUMN, + song, + temp_path=temp_file_named_txt(), + bytes_decoder=lambda b: b.decode("shift-jis-2004"), + dump_options={"circle_free": circle_free}, + ) diff --git a/jubeatools/formats/jubeat_analyser/tests/test_utils.py b/jubeatools/formats/jubeat_analyser/tests/test_utils.py index d9ebe36..a1119cf 100644 --- a/jubeatools/formats/jubeat_analyser/tests/test_utils.py +++ b/jubeatools/formats/jubeat_analyser/tests/test_utils.py @@ -1,12 +1,11 @@ import tempfile +from contextlib import contextmanager from pathlib import Path +from typing import Iterator -from hypothesis import note as hypothesis_note from hypothesis import strategies as st from jubeatools import song -from jubeatools.formats import DUMPERS, LOADERS, Format -from jubeatools.formats.guess import guess_format from jubeatools.testutils import strategies as jbst from jubeatools.testutils.typing import DrawFunc @@ -28,7 +27,7 @@ def memo_compatible_song(draw: DrawFunc) -> song.Song: diff = draw(st.sampled_from(list(d.value for d in song.Difficulty))) chart = draw( jbst.chart( - timing_strat=jbst.timing_info(bpm_changes=True), + timing_strat=jbst.timing_info(with_bpm_changes=True), notes_strat=jbst.notes(jbst.NoteOption.LONGS), ) ) @@ -39,17 +38,7 @@ def memo_compatible_song(draw: DrawFunc) -> song.Song: ) -def load_and_dump_then_check(f: Format, song: song.Song, circle_free: bool) -> None: - loader = LOADERS[f] - dumper = DUMPERS[f] +@contextmanager +def temp_file_named_txt() -> Iterator[Path]: with tempfile.NamedTemporaryFile(suffix=".txt") as dst: - path = Path(dst.name) - files = dumper(song, path, circle_free=circle_free) - assert len(files) == 1 - bytes_ = files.popitem()[1] - hypothesis_note(f"Chart file :\n{bytes_.decode('shift-jis-2004')}") - dst.write(bytes_) - dst.flush() - assert guess_format(path) == f - recovered_song = loader(path) - assert recovered_song == song + yield Path(dst.name) diff --git a/jubeatools/formats/load_tools.py b/jubeatools/formats/load_tools.py index a662923..0481415 100644 --- a/jubeatools/formats/load_tools.py +++ b/jubeatools/formats/load_tools.py @@ -1,5 +1,9 @@ +from decimal import Decimal +from fractions import Fraction from pathlib import Path -from typing import Dict, Iterable, Optional, Protocol, TypeVar +from typing import Dict, Iterable, Protocol, TypeVar, Union + +from jubeatools import song T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) @@ -39,3 +43,15 @@ def make_folder_loader(glob_pattern: str, file_loader: FileLoader) -> FolderLoad return files return folder_loader + + +# TODO +# use numbers.Number instead when this mypy issue is finally fixed +# https://github.com/python/mypy/issues/3186 +Number = Union[int, float, Decimal, Fraction] + + +def round_beats(beats: Number) -> song.BeatsTime: + """Rounds a given beat value to the nearest 1/240th""" + nearest_240th = round(beats * 240) + return song.BeatsTime(nearest_240th, 240) diff --git a/jubeatools/formats/tests/__init__.py b/jubeatools/formats/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jubeatools/formats/jubeat_analyser/tests/test_decimal_to_beats.py b/jubeatools/formats/tests/test_decimal_to_beats.py similarity index 82% rename from jubeatools/formats/jubeat_analyser/tests/test_decimal_to_beats.py rename to jubeatools/formats/tests/test_decimal_to_beats.py index 90b3796..2d96a4a 100644 --- a/jubeatools/formats/jubeat_analyser/tests/test_decimal_to_beats.py +++ b/jubeatools/formats/tests/test_decimal_to_beats.py @@ -3,7 +3,7 @@ from fractions import Fraction import pytest -from ..load_tools import decimal_to_beats +from ..load_tools import round_beats @pytest.mark.parametrize("numerator", range(240)) @@ -13,5 +13,5 @@ def test_fraction_recovery_after_rounding_to_three_decimals(numerator: int) -> N rounded = round(decimal, 3) text_form = str(rounded) re_parsed_decimal = Decimal(text_form) - result = decimal_to_beats(re_parsed_decimal) + result = round_beats(re_parsed_decimal) assert fraction == result diff --git a/jubeatools/testutils/strategies.py b/jubeatools/testutils/strategies.py index 5d5988e..7fdfa51 100644 --- a/jubeatools/testutils/strategies.py +++ b/jubeatools/testutils/strategies.py @@ -151,34 +151,42 @@ def notes(draw: DrawFunc, options: NoteOption) -> Set[Union[TapNote, LongNote]]: @st.composite -def bpm_strat(draw: DrawFunc) -> Decimal: +def bpms(draw: DrawFunc) -> Decimal: d: Decimal = draw(st.decimals(min_value=1, max_value=1000, places=3)) return d @st.composite -def bpm_change(draw: DrawFunc) -> BPMEvent: +def bpm_changes( + draw: DrawFunc, bpm_strat: st.SearchStrategy[Decimal] = bpms() +) -> BPMEvent: time = draw(beat_time(min_section=1, max_section=10)) - bpm = draw(bpm_strat()) + bpm = draw(bpm_strat) return BPMEvent(time, bpm) @st.composite def timing_info( draw: DrawFunc, - bpm_changes: bool = True, + with_bpm_changes: bool = True, + bpm_strat: st.SearchStrategy[Decimal] = bpms(), + beat_zero_offset_strat: st.SearchStrategy[Decimal] = st.decimals( + min_value=0, max_value=20, places=3 + ), ) -> Timing: - first_bpm = draw(bpm_strat()) + first_bpm = draw(bpm_strat) first_event = BPMEvent(BeatsTime(0), first_bpm) events = [first_event] - if bpm_changes: - raw_bpm_changes = st.lists(bpm_change(), unique_by=get_bpm_change_time) + if with_bpm_changes: + raw_bpm_changes = st.lists( + bpm_changes(bpm_strat), unique_by=get_bpm_change_time + ) sorted_bpm_changes = raw_bpm_changes.map( lambda l: sorted(l, key=get_bpm_change_time) ) other_events = draw(sorted_bpm_changes) events += other_events - beat_zero_offset = draw(st.decimals(min_value=0, max_value=20, places=3)) + beat_zero_offset = draw(beat_zero_offset_strat) return Timing(events=events, beat_zero_offset=beat_zero_offset) diff --git a/jubeatools/testutils/test_patterns.py b/jubeatools/testutils/test_patterns.py new file mode 100644 index 0000000..7525ca8 --- /dev/null +++ b/jubeatools/testutils/test_patterns.py @@ -0,0 +1,29 @@ +from pathlib import Path +from typing import Callable, ContextManager, Optional + +from hypothesis import note + +from jubeatools import song +from jubeatools.formats import DUMPERS, LOADERS +from jubeatools.formats.enum import Format +from jubeatools.formats.guess import guess_format + + +def dump_and_load_then_compare( + format_: Format, + song: song.Song, + temp_path: ContextManager[Path], + bytes_decoder: Callable[[bytes], str], + dump_options: Optional[dict] = None, +) -> None: + dump_options = dump_options or {} + loader = LOADERS[format_] + dumper = DUMPERS[format_] + with temp_path as path: + files = dumper(song, path, **(dump_options or {})) + for path, bytes_ in files.items(): + path.write_bytes(bytes_) + note(f"Wrote to {path} :\n{bytes_decoder(bytes_)}") + assert guess_format(path) == format_ + recovered_song = loader(path) + assert recovered_song == song