1
0
mirror of synced 2024-12-04 19:17:55 +01:00

Loaders can now take in arguments

[eve] Add beat_snap argument to the loader
[eve] Simplify bpm truncation method
[eve] Fix bug in TimeMap.from_seconds
[eve] TimeMap.convert_to_timing_info now accepts a beat_snap
Add denominator argument to round_beats for coarser rounding
Parameterize hypothesis strategies in a less ugly way
This commit is contained in:
Stepland 2021-05-12 01:24:25 +02:00
parent 8322747a15
commit f72e875898
22 changed files with 256 additions and 185 deletions

View File

@ -1,9 +1,11 @@
from __future__ import annotations from __future__ import annotations
import math
from dataclasses import astuple, dataclass from dataclasses import astuple, dataclass
from enum import Enum from enum import Enum
from fractions import Fraction from fractions import Fraction
from typing import Union from itertools import count
from typing import Iterator, Union
from jubeatools import song from jubeatools import song
@ -64,15 +66,7 @@ class Event:
@classmethod @classmethod
def from_long_note(cls, note: song.LongNote, time_map: TimeMap) -> Event: def from_long_note(cls, note: song.LongNote, time_map: TimeMap) -> Event:
if not note.has_straight_tail(): eve_long = EveLong.from_jubeatools(note, time_map)
raise ValueError("Diagonal tails cannot be represented in eve format")
eve_long = EveLong(
duration=duration_in_ticks(note, time_map),
length=len(list(note.positions_covered())) - 1,
direction=DIRECTION_TO_VALUE[note.tail_direction()],
position=note.position.index,
)
ticks = ticks_at_beat(note.time, time_map) ticks = ticks_at_beat(note.time, time_map)
return Event(time=ticks, command=Command.LONG, value=eve_long.value) return Event(time=ticks, command=Command.LONG, value=eve_long.value)
@ -134,6 +128,18 @@ class EveLong:
"outside the screen" "outside the screen"
) )
@classmethod
def from_jubeatools(cls, note: song.LongNote, time_map: TimeMap) -> EveLong:
if not note.has_straight_tail():
raise ValueError("Diagonal tails cannot be represented in eve format")
return cls(
duration=duration_in_ticks(note, time_map),
length=len(list(note.positions_covered())) - 1,
direction=DIRECTION_TO_VALUE[note.tail_direction()],
position=note.position.index,
)
@classmethod @classmethod
def from_value(cls, value: int) -> EveLong: def from_value(cls, value: int) -> EveLong:
if value < 0: if value < 0:
@ -148,7 +154,10 @@ class EveLong:
@property @property
def value(self) -> int: def value(self) -> int:
return ( return (
self.duration << 8 + self.length << 6 + self.direction << 4 + self.position (self.duration << 8)
+ (self.length << 6)
+ (self.direction << 4)
+ self.position
) )
@ -172,3 +181,33 @@ def ticks_to_seconds(tick: int) -> Fraction:
def seconds_to_ticks(time: Fraction) -> int: def seconds_to_ticks(time: Fraction) -> int:
"""Convert fractional seconds to eve ticks (300 Hz)""" """Convert fractional seconds to eve ticks (300 Hz)"""
return round(time * 300) return round(time * 300)
def value_to_truncated_bpm(value: int) -> Fraction:
"""Only keeps enough significant digits to allow recovering the original
TEMPO line value from the bpm"""
exact_bpm = value_to_bpm(value)
truncated_bpms = iter_truncated(exact_bpm)
bpms_preserving_value = filter(
lambda b: bpm_to_value(b) < value + 1, truncated_bpms
)
return next(bpms_preserving_value)
def iter_truncated(f: Fraction) -> Iterator[Fraction]:
for places in count():
yield truncate_fraction(f, places)
def truncate_fraction(f: Fraction, places: int) -> Fraction:
"""Truncates a fraction to the given number of decimal places"""
exponent = Fraction(10) ** places
return Fraction(math.floor(f * exponent), exponent)
def value_to_bpm(value: int) -> Fraction:
return 6 * 10 ** 7 / Fraction(value)
def bpm_to_value(bpm: Fraction) -> Fraction:
return 6 * 10 ** 7 / bpm

View File

@ -12,7 +12,7 @@ from jubeatools import song
from jubeatools.formats.dump_tools import make_dumper_from_chart_file_dumper from jubeatools.formats.dump_tools import make_dumper_from_chart_file_dumper
from jubeatools.formats.filetypes import ChartFile from jubeatools.formats.filetypes import ChartFile
from .commons import AnyNote, Command, Event, ticks_at_beat from .commons import AnyNote, Command, Event, bpm_to_value, ticks_at_beat
from .timemap import TimeMap from .timemap import TimeMap
@ -71,14 +71,10 @@ def make_timing_events(
def make_bpm_event(bpm_change: song.BPMEvent, time_map: TimeMap) -> Event: def make_bpm_event(bpm_change: song.BPMEvent, time_map: TimeMap) -> Event:
ticks = ticks_at_beat(bpm_change.time, time_map) ticks = ticks_at_beat(bpm_change.time, time_map)
bpm_value = bpm_to_value(Fraction(bpm_change.BPM)) bpm_value = math.floor(bpm_to_value(Fraction(bpm_change.BPM)))
return Event(time=ticks, command=Command.TEMPO, value=bpm_value) return Event(time=ticks, command=Command.TEMPO, value=bpm_value)
def bpm_to_value(bpm: Fraction) -> int:
return math.floor(60 * 10 ** 6 / bpm)
def choose_end_beat(notes: List[AnyNote]) -> song.BeatsTime: def choose_end_beat(notes: List[AnyNote]) -> song.BeatsTime:
"""Leave 2 empty measures (4 beats) after the last event""" """Leave 2 empty measures (4 beats) after the last event"""
last_note_beat = compute_last_note_beat(notes) last_note_beat = compute_last_note_beat(notes)

View File

@ -1,9 +1,7 @@
import math
from decimal import Decimal from decimal import Decimal
from fractions import Fraction
from functools import reduce from functools import reduce
from pathlib import Path from pathlib import Path
from typing import Iterator, List, Optional from typing import Any, Iterator, List, Optional
from jubeatools import song from jubeatools import song
from jubeatools.formats.load_tools import make_folder_loader, round_beats from jubeatools.formats.load_tools import make_folder_loader, round_beats
@ -16,13 +14,14 @@ from .commons import (
EveLong, EveLong,
Event, Event,
ticks_to_seconds, ticks_to_seconds,
value_to_truncated_bpm,
) )
from .timemap import BPMAtSecond, TimeMap from .timemap import BPMAtSecond, TimeMap
def load_eve(path: Path) -> song.Song: def load_eve(path: Path, *, beat_snap: int = 240, **kwargs: Any) -> song.Song:
files = load_folder(path) files = load_folder(path)
charts = [_load_eve(l, p) for p, l in files.items()] charts = [_load_eve(l, p, beat_snap=beat_snap) for p, l in files.items()]
return reduce(song.Song.merge, charts) return reduce(song.Song.merge, charts)
@ -33,7 +32,7 @@ def load_file(path: Path) -> List[str]:
load_folder = make_folder_loader("*.eve", load_file) load_folder = make_folder_loader("*.eve", load_file)
def _load_eve(lines: List[str], file_path: Path) -> song.Song: def _load_eve(lines: List[str], file_path: Path, *, beat_snap: int = 240) -> song.Song:
events = list(iter_events(lines)) events = list(iter_events(lines))
events_by_command = group_by(events, lambda e: e.command) events_by_command = group_by(events, lambda e: e.command)
bpms = [ bpms = [
@ -52,7 +51,7 @@ def _load_eve(lines: List[str], file_path: Path) -> song.Song:
for e in events_by_command[Command.LONG] for e in events_by_command[Command.LONG]
] ]
all_notes = sorted(tap_notes + long_notes, key=lambda n: (n.time, n.position)) all_notes = sorted(tap_notes + long_notes, key=lambda n: (n.time, n.position))
timing = time_map.convert_to_timing_info() timing = time_map.convert_to_timing_info(beat_snap=beat_snap)
chart = song.Chart(level=Decimal(0), timing=timing, notes=all_notes) chart = song.Chart(level=Decimal(0), timing=timing, notes=all_notes)
dif = guess_difficulty(file_path.stem) or song.Difficulty.EXTREME dif = guess_difficulty(file_path.stem) or song.Difficulty.EXTREME
return song.Song(metadata=song.Metadata(), charts={dif: chart}) return song.Song(metadata=song.Metadata(), charts={dif: chart})
@ -103,34 +102,6 @@ def parse_event(line: str) -> Event:
return Event(tick, command, value) return Event(tick, command, value)
def value_to_truncated_bpm(value: int) -> Fraction:
"""Only keeps enough significant digits to allow recovering the original
TEMPO line value from the bpm"""
precision = tempo_precision(value)
places = significant_decimal_places(precision) + 1
raw_bpm = value_to_bpm(value)
return truncate_fraction(raw_bpm, places)
def value_to_bpm(value: int) -> Fraction:
return 6 * 10 ** 7 / Fraction(value)
def significant_decimal_places(max_error: float) -> int:
return int(-(math.log(max_error / 5) / math.log(10)))
def tempo_precision(value: int) -> float:
"""Maximum error on the bpm this tempo value corresponds to"""
return (6 * 10 ** 7) / (value * (value + 1))
def truncate_fraction(f: Fraction, places: int) -> Fraction:
"""Truncates a fraction to the given number of decimal places"""
exponent = Fraction(10) ** places
return Fraction(math.floor(f * exponent), exponent)
def make_tap_note(ticks: int, value: int, time_map: TimeMap) -> song.TapNote: def make_tap_note(ticks: int, value: int, time_map: TimeMap) -> song.TapNote:
seconds = ticks_to_seconds(ticks) seconds = ticks_to_seconds(ticks)
raw_beats = time_map.beats_at(seconds) raw_beats = time_map.beats_at(seconds)

View File

@ -3,8 +3,7 @@ import math
from hypothesis import given from hypothesis import given
from hypothesis import strategies as st from hypothesis import strategies as st
from ..dump import bpm_to_value from ..commons import bpm_to_value, value_to_truncated_bpm
from ..load import value_to_truncated_bpm
@given(st.integers(min_value=1, max_value=6 * 10 ** 7)) @given(st.integers(min_value=1, max_value=6 * 10 ** 7))

View File

@ -1,5 +1,6 @@
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from decimal import Decimal
from pathlib import Path from pathlib import Path
from typing import Iterator from typing import Iterator
@ -12,6 +13,10 @@ from jubeatools.testutils import strategies as jbst
from jubeatools.testutils.test_patterns import dump_and_load_then_compare from jubeatools.testutils.test_patterns import dump_and_load_then_compare
from jubeatools.testutils.typing import DrawFunc from jubeatools.testutils.typing import DrawFunc
simple_beat_strat = jbst.beat_time(
denominator_strat=st.sampled_from([4, 8, 3]), max_section=10
)
@st.composite @st.composite
def eve_compatible_song(draw: DrawFunc) -> song.Song: def eve_compatible_song(draw: DrawFunc) -> song.Song:
@ -21,10 +26,24 @@ def eve_compatible_song(draw: DrawFunc) -> song.Song:
jbst.chart( jbst.chart(
timing_strat=jbst.timing_info( timing_strat=jbst.timing_info(
with_bpm_changes=True, with_bpm_changes=True,
bpm_strat=st.decimals(min_value=1, max_value=1000, places=2), bpm_strat=st.decimals(min_value=1, max_value=500, places=2),
beat_zero_offset_strat=st.decimals(min_value=0, max_value=20, places=2), beat_zero_offset_strat=st.decimals(min_value=0, max_value=20, places=2),
time_strat=simple_beat_strat,
), ),
notes_strat=jbst.notes(jbst.NoteOption.LONGS), notes_strat=jbst.notes(
note_strat=st.one_of(
jbst.tap_note(time_start=simple_beat_strat),
jbst.long_note(
time_strat=simple_beat_strat,
duration_strat=jbst.beat_time(
min_numerator=1,
max_section=3,
denominator_strat=st.sampled_from([4, 8, 3]),
),
),
)
),
level_strat=st.just(Decimal(0)),
) )
) )
return song.Song( return song.Song(
@ -40,7 +59,7 @@ def open_temp_dir() -> Iterator[Path]:
@given(eve_compatible_song()) @given(eve_compatible_song())
@settings(verbosity=Verbosity.debug) @settings(verbosity=Verbosity.normal)
def test_that_full_chart_roundtrips(song: song.Song) -> None: def test_that_full_chart_roundtrips(song: song.Song) -> None:
dump_and_load_then_compare( dump_and_load_then_compare(
Format.EVE, Format.EVE,

View File

@ -0,0 +1,25 @@
from hypothesis import given
from hypothesis import strategies as st
from jubeatools import song
from jubeatools.testutils import strategies as jbst
from ..commons import EveLong
from ..timemap import TimeMap
@given(
jbst.long_note(),
jbst.timing_info(
with_bpm_changes=True,
bpm_strat=st.decimals(min_value=1, max_value=1000, places=2),
beat_zero_offset_strat=st.decimals(min_value=0, max_value=20, places=2),
),
)
def test_that_long_note_roundtrips(
long_note: song.LongNote, timing: song.Timing
) -> None:
time_map = TimeMap.from_timing(timing)
original = EveLong.from_jubeatools(long_note, time_map)
recovered = EveLong.from_value(original.value)
assert recovered == original

View File

@ -81,8 +81,8 @@ class TimeMap:
raise ValueError("No BPM defined") raise ValueError("No BPM defined")
grouped_by_time = group_by(events, key=lambda e: e.seconds) grouped_by_time = group_by(events, key=lambda e: e.seconds)
for time, events in grouped_by_time.items(): for time, events_at_time in grouped_by_time.items():
if len(events) > 1: if len(events_at_time) > 1:
raise ValueError(f"Multiple BPMs defined at {time} seconds : {events}") raise ValueError(f"Multiple BPMs defined at {time} seconds : {events}")
# take the first BPM change then compute from there # take the first BPM change then compute from there
@ -144,10 +144,13 @@ class TimeMap:
) )
return bpm_change.beats + beats_since_last_event return bpm_change.beats + beats_since_last_event
def convert_to_timing_info(self) -> song.Timing: def convert_to_timing_info(self, beat_snap: int = 240) -> song.Timing:
return song.Timing( return song.Timing(
events=[ events=[
song.BPMEvent(time=round_beats(e.beats), BPM=fraction_to_decimal(e.BPM)) song.BPMEvent(
time=round_beats(e.beats, beat_snap),
BPM=fraction_to_decimal(e.BPM),
)
for e in self.events_by_beats for e in self.events_by_beats
], ],
beat_zero_offset=self.beat_zero_offset, beat_zero_offset=self.beat_zero_offset,

View File

@ -97,7 +97,7 @@ def looks_like_eve(path: Path) -> bool:
line = f.readline() line = f.readline()
if line.strip(): if line.strip():
return looks_like_eve_line(next(f)) return looks_like_eve_line(next(f))
return False return False

View File

@ -5,7 +5,7 @@ from decimal import Decimal
from functools import reduce from functools import reduce
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Set, Tuple, Union from typing import Any, Dict, Iterator, List, Mapping, Set, Tuple, Union
from more_itertools import collapse, mark_ends from more_itertools import collapse, mark_ends
@ -347,7 +347,7 @@ def _load_memo_file(lines: List[str]) -> Song:
return Song(metadata=metadata, charts=charts) return Song(metadata=metadata, charts=charts)
def load_memo(path: Path) -> Song: def load_memo(path: Path, **kwargs: Any) -> Song:
files = load_folder(path) files = load_folder(path)
charts = [_load_memo_file(lines) for _, lines in files.items()] charts = [_load_memo_file(lines) for _, lines in files.items()]
return reduce(Song.merge, charts) return reduce(Song.merge, charts)

View File

@ -4,7 +4,7 @@ from decimal import Decimal
from functools import reduce from functools import reduce
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Set, Tuple, Union from typing import Any, Dict, Iterator, List, Mapping, Set, Tuple, Union
from more_itertools import mark_ends from more_itertools import mark_ends
@ -338,7 +338,7 @@ def _load_memo1_file(lines: List[str]) -> Song:
return Song(metadata=metadata, charts=charts) return Song(metadata=metadata, charts=charts)
def load_memo1(path: Path) -> Song: def load_memo1(path: Path, **kwargs: Any) -> Song:
files = load_folder(path) files = load_folder(path)
charts = [_load_memo1_file(lines) for _, lines in files.items()] charts = [_load_memo1_file(lines) for _, lines in files.items()]
return reduce(Song.merge, charts) return reduce(Song.merge, charts)

View File

@ -3,7 +3,7 @@ from decimal import Decimal
from functools import reduce from functools import reduce
from itertools import product, zip_longest from itertools import product, zip_longest
from pathlib import Path from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Optional, Set, Tuple, Union from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Union
from parsimonious import Grammar, NodeVisitor, ParseError from parsimonious import Grammar, NodeVisitor, ParseError
from parsimonious.nodes import Node from parsimonious.nodes import Node
@ -457,7 +457,7 @@ def _load_memo2_file(lines: List[str]) -> Song:
return Song(metadata=metadata, charts=charts) return Song(metadata=metadata, charts=charts)
def load_memo2(path: Path) -> Song: def load_memo2(path: Path, **kwargs: Any) -> Song:
files = load_folder(path) files = load_folder(path)
charts = [_load_memo2_file(lines) for _, lines in files.items()] charts = [_load_memo2_file(lines) for _, lines in files.items()]
return reduce(Song.merge, charts) return reduce(Song.merge, charts)

View File

@ -4,7 +4,7 @@ from decimal import Decimal
from functools import reduce from functools import reduce
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
from typing import Dict, Iterator, List, Set, Tuple, Union from typing import Any, Dict, Iterator, List, Set, Tuple, Union
from parsimonious import Grammar, NodeVisitor, ParseError from parsimonious import Grammar, NodeVisitor, ParseError
from parsimonious.nodes import Node from parsimonious.nodes import Node
@ -241,7 +241,7 @@ class MonoColumnParser(JubeatAnalyserParser):
section_starting_beat += section.length section_starting_beat += section.length
def load_mono_column(path: Path) -> Song: def load_mono_column(path: Path, **kwargs: Any) -> Song:
files = load_folder(path) files = load_folder(path)
charts = [_load_mono_column_file(lines) for _, lines in files.items()] charts = [_load_mono_column_file(lines) for _, lines in files.items()]
return reduce(Song.merge, charts) return reduce(Song.merge, charts)

View File

@ -17,7 +17,7 @@ from ..test_utils import memo_compatible_song, temp_file_named_txt
from . import example1, example2, example3 from . import example1, example2, example3
@given(jbst.notes(jbst.NoteOption.LONGS)) @given(jbst.notes())
@example(example1.notes) @example(example1.notes)
def test_that_notes_roundtrip(notes: Set[Union[song.TapNote, song.LongNote]]) -> None: def test_that_notes_roundtrip(notes: Set[Union[song.TapNote, song.LongNote]]) -> None:
timing = song.Timing( timing = song.Timing(

View File

@ -9,7 +9,6 @@ from jubeatools import song
from jubeatools.formats import Format from jubeatools.formats import Format
from jubeatools.formats.jubeat_analyser.memo1.dump import _dump_memo1_chart from jubeatools.formats.jubeat_analyser.memo1.dump import _dump_memo1_chart
from jubeatools.formats.jubeat_analyser.memo1.load import Memo1Parser from jubeatools.formats.jubeat_analyser.memo1.load import Memo1Parser
from jubeatools.testutils.strategies import NoteOption
from jubeatools.testutils.strategies import notes as notes_strat from jubeatools.testutils.strategies import notes as notes_strat
from jubeatools.testutils.test_patterns import dump_and_load_then_compare from jubeatools.testutils.test_patterns import dump_and_load_then_compare
@ -17,7 +16,7 @@ from ..test_utils import memo_compatible_song, temp_file_named_txt
from . import example1 from . import example1
@given(notes_strat(NoteOption.LONGS)) @given(notes_strat())
def test_that_notes_roundtrip(notes: List[Union[song.TapNote, song.LongNote]]) -> None: def test_that_notes_roundtrip(notes: List[Union[song.TapNote, song.LongNote]]) -> None:
timing = song.Timing( timing = song.Timing(
events=[song.BPMEvent(song.BeatsTime(0), Decimal(120))], events=[song.BPMEvent(song.BeatsTime(0), Decimal(120))],

View File

@ -19,7 +19,6 @@ from jubeatools.song import (
TapNote, TapNote,
Timing, Timing,
) )
from jubeatools.testutils.strategies import NoteOption
from jubeatools.testutils.strategies import notes as notes_strat from jubeatools.testutils.strategies import notes as notes_strat
from jubeatools.testutils.test_patterns import dump_and_load_then_compare from jubeatools.testutils.test_patterns import dump_and_load_then_compare
@ -27,7 +26,7 @@ from ..test_utils import memo_compatible_song, temp_file_named_txt
from . import example1, example2, example3 from . import example1, example2, example3
@given(notes_strat(NoteOption.LONGS)) @given(notes_strat())
def test_that_notes_roundtrip(notes: List[Union[TapNote, LongNote]]) -> None: def test_that_notes_roundtrip(notes: List[Union[TapNote, LongNote]]) -> None:
timing = Timing( timing = Timing(
events=[BPMEvent(BeatsTime(0), Decimal(120))], beat_zero_offset=SecondsTime(0) events=[BPMEvent(BeatsTime(0), Decimal(120))], beat_zero_offset=SecondsTime(0)

View File

@ -19,7 +19,7 @@ from jubeatools.song import (
TapNote, TapNote,
Timing, Timing,
) )
from jubeatools.testutils.strategies import NoteOption, long_note from jubeatools.testutils.strategies import long_note
from jubeatools.testutils.strategies import notes as notes_strat from jubeatools.testutils.strategies import notes as notes_strat
from jubeatools.testutils.strategies import tap_note from jubeatools.testutils.strategies import tap_note
from jubeatools.testutils.test_patterns import dump_and_load_then_compare from jubeatools.testutils.test_patterns import dump_and_load_then_compare
@ -63,7 +63,7 @@ def test_that_a_single_long_note_roundtrips(note: LongNote) -> None:
assert set([note]) == actual assert set([note]) == actual
@given(notes_strat(NoteOption.LONGS)) @given(notes_strat())
def test_that_many_notes_roundtrip(notes: List[Union[TapNote, LongNote]]) -> None: def test_that_many_notes_roundtrip(notes: List[Union[TapNote, LongNote]]) -> None:
timing = Timing( timing = Timing(
events=[BPMEvent(BeatsTime(0), Decimal(120))], beat_zero_offset=SecondsTime(0) events=[BPMEvent(BeatsTime(0), Decimal(120))], beat_zero_offset=SecondsTime(0)

View File

@ -14,7 +14,7 @@ from jubeatools.testutils.typing import DrawFunc
def memo_compatible_metadata(draw: DrawFunc) -> song.Metadata: def memo_compatible_metadata(draw: DrawFunc) -> song.Metadata:
text_strat = st.text(alphabet=st.characters(min_codepoint=0x20, max_codepoint=0x7E)) text_strat = st.text(alphabet=st.characters(min_codepoint=0x20, max_codepoint=0x7E))
metadata: song.Metadata = draw( metadata: song.Metadata = draw(
jbst.metadata(text_strat=text_strat, path_start=text_strat) jbst.metadata(text_strat=text_strat, path_strat=text_strat)
) )
metadata.preview = None metadata.preview = None
metadata.preview_file = None metadata.preview_file = None
@ -28,7 +28,7 @@ def memo_compatible_song(draw: DrawFunc) -> song.Song:
chart = draw( chart = draw(
jbst.chart( jbst.chart(
timing_strat=jbst.timing_info(with_bpm_changes=True), timing_strat=jbst.timing_info(with_bpm_changes=True),
notes_strat=jbst.notes(jbst.NoteOption.LONGS), notes_strat=jbst.notes(),
) )
) )
metadata: song.Metadata = draw(memo_compatible_metadata()) metadata: song.Metadata = draw(memo_compatible_metadata())

View File

@ -51,7 +51,8 @@ def make_folder_loader(glob_pattern: str, file_loader: FileLoader) -> FolderLoad
Number = Union[int, float, Decimal, Fraction] Number = Union[int, float, Decimal, Fraction]
def round_beats(beats: Number) -> song.BeatsTime: def round_beats(beats: Number, denominator: int = 240) -> song.BeatsTime:
"""Rounds a given beat value to the nearest 1/240th""" """Rounds a given beat value to the nearest given division (default to
nearest_240th = round(beats * 240) nearest 1/240th"""
return song.BeatsTime(nearest_240th, 240) nearest = round(beats * denominator)
return song.BeatsTime(nearest, denominator)

View File

@ -149,8 +149,8 @@ def _load_memon_note_v0(
return jbt.TapNote(time, position) return jbt.TapNote(time, position)
def load_memon_legacy(file: Path) -> jbt.Song: def load_memon_legacy(path: Path, **kwargs: Any) -> jbt.Song:
raw_memon = _load_raw_memon(file) raw_memon = _load_raw_memon(path)
schema = Memon_legacy() schema = Memon_legacy()
memon = schema.load(raw_memon) memon = schema.load(raw_memon)
metadata = jbt.Metadata( metadata = jbt.Metadata(
@ -179,8 +179,8 @@ def load_memon_legacy(file: Path) -> jbt.Song:
return jbt.Song(metadata=metadata, charts=charts, common_timing=common_timing) return jbt.Song(metadata=metadata, charts=charts, common_timing=common_timing)
def load_memon_0_1_0(file: Path) -> jbt.Song: def load_memon_0_1_0(path: Path, **kwargs: Any) -> jbt.Song:
raw_memon = _load_raw_memon(file) raw_memon = _load_raw_memon(path)
schema = Memon_0_1_0() schema = Memon_0_1_0()
memon = schema.load(raw_memon) memon = schema.load(raw_memon)
metadata = jbt.Metadata( metadata = jbt.Metadata(
@ -209,8 +209,8 @@ def load_memon_0_1_0(file: Path) -> jbt.Song:
return jbt.Song(metadata=metadata, charts=charts, common_timing=common_timing) return jbt.Song(metadata=metadata, charts=charts, common_timing=common_timing)
def load_memon_0_2_0(file: Path) -> jbt.Song: def load_memon_0_2_0(path: Path, **kwargs: Any) -> jbt.Song:
raw_memon = _load_raw_memon(file) raw_memon = _load_raw_memon(path)
schema = Memon_0_2_0() schema = Memon_0_2_0()
memon = schema.load(raw_memon) memon = schema.load(raw_memon)
preview = None preview = None

View File

@ -1,13 +1,13 @@
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Set
import hypothesis.strategies as st import hypothesis.strategies as st
from hypothesis import given from hypothesis import given
from jubeatools import song
from jubeatools.formats.typing import Dumper, Loader from jubeatools.formats.typing import Dumper, Loader
from jubeatools.song import Song from jubeatools.testutils import strategies as jbst
from jubeatools.testutils.strategies import NoteOption, TimingOption
from jubeatools.testutils.strategies import song as song_strat
from jubeatools.testutils.typing import DrawFunc from jubeatools.testutils.typing import DrawFunc
from . import ( from . import (
@ -21,7 +21,7 @@ from . import (
def dump_and_load( def dump_and_load(
expected_song: Song, dump_function: Dumper, load_function: Loader expected_song: song.Song, dump_function: Dumper, load_function: Loader
) -> None: ) -> None:
with tempfile.NamedTemporaryFile(mode="wb") as file: with tempfile.NamedTemporaryFile(mode="wb") as file:
files = dump_function(expected_song, Path(file.name)) files = dump_function(expected_song, Path(file.name))
@ -35,16 +35,37 @@ def dump_and_load(
@st.composite @st.composite
def memon_legacy_compatible_song(draw: DrawFunc) -> Song: def memon_diffs(draw: DrawFunc) -> Set[str]:
simple_diff_names = st.sampled_from(list(d.value for d in song.Difficulty))
diff_names = st.one_of(
simple_diff_names,
st.text(
alphabet=st.characters(min_codepoint=0x20, max_codepoint=0x7E),
min_size=1,
max_size=20,
),
)
s: Set[str] = draw(st.sets(diff_names, min_size=1, max_size=10))
return s
@st.composite
def memon_legacy_compatible_song(draw: DrawFunc) -> song.Song:
"""Memon versions below v0.2.0 do not support any preview metadata""" """Memon versions below v0.2.0 do not support any preview metadata"""
song: Song = draw(song_strat(TimingOption.GLOBAL, True, NoteOption.LONGS)) random_song: song.Song = draw(
song.metadata.preview = None jbst.song(
song.metadata.preview_file = None diffs_strat=memon_diffs(),
return song chart_strat=jbst.chart(timing_strat=st.none()),
common_timing_strat=jbst.timing_info(with_bpm_changes=False),
)
)
random_song.metadata.preview = None
random_song.metadata.preview_file = None
return random_song
@given(memon_legacy_compatible_song()) @given(memon_legacy_compatible_song())
def test_memon_legacy(song: Song) -> None: def test_memon_legacy(song: song.Song) -> None:
dump_and_load(song, dump_memon_legacy, load_memon_legacy) dump_and_load(song, dump_memon_legacy, load_memon_legacy)
@ -52,18 +73,24 @@ memon_0_1_0_compatible_song = memon_legacy_compatible_song
@given(memon_0_1_0_compatible_song()) @given(memon_0_1_0_compatible_song())
def test_memon_0_1_0(song: Song) -> None: def test_memon_0_1_0(song: song.Song) -> None:
dump_and_load(song, dump_memon_0_1_0, load_memon_0_1_0) dump_and_load(song, dump_memon_0_1_0, load_memon_0_1_0)
@st.composite @st.composite
def memon_0_2_0_compatible_song(draw: DrawFunc) -> Song: def memon_0_2_0_compatible_song(draw: DrawFunc) -> song.Song:
"""Memon v0.2.0 does not support preview_file""" """Memon v0.2.0 does not support preview_file"""
song: Song = draw(song_strat(TimingOption.GLOBAL, True, NoteOption.LONGS)) random_song: song.Song = draw(
song.metadata.preview_file = None jbst.song(
return song diffs_strat=memon_diffs(),
chart_strat=jbst.chart(timing_strat=st.none()),
common_timing_strat=jbst.timing_info(with_bpm_changes=False),
)
)
random_song.metadata.preview_file = None
return random_song
@given(memon_0_2_0_compatible_song()) @given(memon_0_2_0_compatible_song())
def test_memon_0_2_0(song: Song) -> None: def test_memon_0_2_0(song: song.Song) -> None:
dump_and_load(song, dump_memon_0_2_0, load_memon_0_2_0) dump_and_load(song, dump_memon_0_2_0, load_memon_0_2_0)

View File

@ -1,5 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Protocol from typing import Any, Dict, List, Protocol
from jubeatools.formats.filetypes import ChartFile from jubeatools.formats.filetypes import ChartFile
from jubeatools.song import Song from jubeatools.song import Song
@ -24,6 +24,10 @@ class ChartFileDumper(Protocol):
... ...
# Loaders deserialize a Path to a Song object class Loader(Protocol):
# The Path can be a file or a folder depending on the format """A Loader deserializes a Path to a Song object and possibly takes in
Loader = Callable[[Path], Song] some options via the kwargs.
The Path can be a file or a folder depending on the format"""
def __call__(self, path: Path, **kwargs: Any) -> Song:
...

View File

@ -6,7 +6,7 @@ from decimal import Decimal
from enum import Flag, auto from enum import Flag, auto
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Set, Union from typing import Dict, Iterable, Optional, Set, Union
import hypothesis.strategies as st import hypothesis.strategies as st
from multidict import MultiDict from multidict import MultiDict
@ -34,8 +34,9 @@ def beat_time(
max_section: Optional[int] = None, max_section: Optional[int] = None,
min_numerator: Optional[int] = None, min_numerator: Optional[int] = None,
max_numerator: Optional[int] = None, max_numerator: Optional[int] = None,
denominator_strat: st.SearchStrategy[int] = st.sampled_from([4, 8, 16, 3, 5]),
) -> BeatsTime: ) -> BeatsTime:
denominator = draw(st.sampled_from([4, 8, 16, 3, 5])) denominator = draw(denominator_strat)
if min_section is not None: if min_section is not None:
min_value = denominator * 4 * min_section min_value = denominator * 4 * min_section
@ -68,17 +69,25 @@ def note_position(draw: DrawFunc) -> NotePosition:
@st.composite @st.composite
def tap_note(draw: DrawFunc) -> TapNote: def tap_note(
time = draw(beat_time(max_section=10)) draw: DrawFunc, time_start: st.SearchStrategy[BeatsTime] = beat_time(max_section=10)
) -> TapNote:
time = draw(time_start)
position = draw(note_position()) position = draw(note_position())
return TapNote(time, position) return TapNote(time, position)
@st.composite @st.composite
def long_note(draw: DrawFunc) -> LongNote: def long_note(
time = draw(beat_time(max_section=10)) draw: DrawFunc,
time_strat: st.SearchStrategy[BeatsTime] = beat_time(max_section=10),
duration_strat: st.SearchStrategy[BeatsTime] = beat_time(
min_numerator=1, max_section=3
),
) -> LongNote:
time = draw(time_strat)
position = draw(note_position()) position = draw(note_position())
duration = draw(beat_time(min_numerator=1, max_section=3)) duration = draw(duration_strat)
tail_is_vertical = draw(st.booleans()) tail_is_vertical = draw(st.booleans())
tail_offset = draw(st.integers(min_value=1, max_value=3)) tail_offset = draw(st.integers(min_value=1, max_value=3))
if tail_is_vertical: if tail_is_vertical:
@ -91,15 +100,6 @@ def long_note(draw: DrawFunc) -> LongNote:
return LongNote(time, position, duration, tail_tip) return LongNote(time, position, duration, tail_tip)
class NoteOption(Flag):
"""What kind of notes are allowed to be generated"""
# Long notes
LONGS = auto()
# Intersections between longs and other notes on the same square
COLLISIONS = auto()
@st.composite @st.composite
def bad_notes(draw: DrawFunc, longs: bool) -> Set[Union[TapNote, LongNote]]: def bad_notes(draw: DrawFunc, longs: bool) -> Set[Union[TapNote, LongNote]]:
note_strat = tap_note() note_strat = tap_note()
@ -110,16 +110,20 @@ def bad_notes(draw: DrawFunc, longs: bool) -> Set[Union[TapNote, LongNote]]:
@st.composite @st.composite
def notes(draw: DrawFunc, options: NoteOption) -> Set[Union[TapNote, LongNote]]: def notes(
if (NoteOption.COLLISIONS in options) and (NoteOption.LONGS not in options): draw: DrawFunc,
raise ValueError("Can't ask for collisions without longs") collisions: bool = False,
note_strat: st.SearchStrategy[Union[TapNote, LongNote]] = st.one_of(
note_strat = tap_note() tap_note(), long_note()
if NoteOption.LONGS in options: ),
note_strat = st.one_of(note_strat, long_note()) beat_time_strat: st.SearchStrategy[BeatsTime] = beat_time(max_section=3),
beat_interval_strat: st.SearchStrategy[BeatsTime] = beat_time(
min_numerator=1, max_section=3
),
) -> Set[Union[TapNote, LongNote]]:
raw_notes: Set[Union[TapNote, LongNote]] = draw(st.sets(note_strat, max_size=32)) raw_notes: Set[Union[TapNote, LongNote]] = draw(st.sets(note_strat, max_size=32))
if NoteOption.COLLISIONS in options: if collisions:
return raw_notes return raw_notes
else: else:
last_notes: Dict[NotePosition, Optional[BeatsTime]] = { last_notes: Dict[NotePosition, Optional[BeatsTime]] = {
@ -129,11 +133,9 @@ def notes(draw: DrawFunc, options: NoteOption) -> Set[Union[TapNote, LongNote]]:
for note in sorted(raw_notes, key=lambda n: (n.time, n.position)): for note in sorted(raw_notes, key=lambda n: (n.time, n.position)):
last_note_time = last_notes[note.position] last_note_time = last_notes[note.position]
if last_note_time is None: if last_note_time is None:
new_time = draw(beat_time(max_section=3)) new_time = draw(beat_time_strat)
else: else:
new_time = last_note_time + draw( new_time = last_note_time + draw(beat_interval_strat)
beat_time(min_numerator=1, max_section=3)
)
if isinstance(note, LongNote): if isinstance(note, LongNote):
notes.add( notes.add(
LongNote( LongNote(
@ -158,9 +160,11 @@ def bpms(draw: DrawFunc) -> Decimal:
@st.composite @st.composite
def bpm_changes( def bpm_changes(
draw: DrawFunc, bpm_strat: st.SearchStrategy[Decimal] = bpms() draw: DrawFunc,
bpm_strat: st.SearchStrategy[Decimal] = bpms(),
time_strat: st.SearchStrategy[BeatsTime] = beat_time(min_section=1, max_section=10),
) -> BPMEvent: ) -> BPMEvent:
time = draw(beat_time(min_section=1, max_section=10)) time = draw(time_strat)
bpm = draw(bpm_strat) bpm = draw(bpm_strat)
return BPMEvent(time, bpm) return BPMEvent(time, bpm)
@ -173,13 +177,14 @@ def timing_info(
beat_zero_offset_strat: st.SearchStrategy[Decimal] = st.decimals( beat_zero_offset_strat: st.SearchStrategy[Decimal] = st.decimals(
min_value=0, max_value=20, places=3 min_value=0, max_value=20, places=3
), ),
time_strat: st.SearchStrategy[BeatsTime] = beat_time(min_section=1, max_section=10),
) -> Timing: ) -> Timing:
first_bpm = draw(bpm_strat) first_bpm = draw(bpm_strat)
first_event = BPMEvent(BeatsTime(0), first_bpm) first_event = BPMEvent(BeatsTime(0), first_bpm)
events = [first_event] events = [first_event]
if with_bpm_changes: if with_bpm_changes:
raw_bpm_changes = st.lists( raw_bpm_changes = st.lists(
bpm_changes(bpm_strat), unique_by=get_bpm_change_time bpm_changes(bpm_strat, time_strat), unique_by=get_bpm_change_time
) )
sorted_bpm_changes = raw_bpm_changes.map( sorted_bpm_changes = raw_bpm_changes.map(
lambda l: sorted(l, key=get_bpm_change_time) lambda l: sorted(l, key=get_bpm_change_time)
@ -195,15 +200,23 @@ def get_bpm_change_time(b: BPMEvent) -> BeatsTime:
@st.composite @st.composite
def chart(draw: DrawFunc, timing_strat: Any, notes_strat: Any) -> Chart: def level(draw: DrawFunc) -> Union[int, Decimal]:
level = Decimal( d: Union[int, Decimal] = draw(
draw( st.one_of(
st.one_of( st.integers(min_value=0), st.decimals(min_value=0, max_value=10.9, places=1)
st.integers(min_value=0),
st.decimals(min_value=0, max_value=10.9, places=1),
)
) )
) )
return d
@st.composite
def chart(
draw: DrawFunc,
timing_strat: st.SearchStrategy[Timing] = timing_info(),
notes_strat: st.SearchStrategy[Iterable[Union[TapNote, LongNote]]] = notes(),
level_strat: st.SearchStrategy[Union[int, Decimal]] = level(),
) -> Chart:
level = Decimal(draw(level_strat))
timing = draw(timing_strat) timing = draw(timing_strat)
notes = draw(notes_strat) notes = draw(notes_strat)
return Chart( return Chart(
@ -228,15 +241,15 @@ def preview(draw: DrawFunc) -> Preview:
def metadata( def metadata(
draw: DrawFunc, draw: DrawFunc,
text_strat: st.SearchStrategy[str] = st.text(), text_strat: st.SearchStrategy[str] = st.text(),
path_start: st.SearchStrategy[str] = st.text(), path_strat: st.SearchStrategy[str] = st.text(),
) -> Metadata: ) -> Metadata:
return Metadata( return Metadata(
title=draw(text_strat), title=draw(text_strat),
artist=draw(text_strat), artist=draw(text_strat),
audio=Path(draw(path_start)), audio=Path(draw(path_strat)),
cover=Path(draw(path_start)), cover=Path(draw(path_strat)),
preview=draw(st.one_of(st.none(), preview())), preview=draw(st.one_of(st.none(), preview())),
preview_file=draw(path_start), preview_file=draw(path_strat),
) )
@ -249,44 +262,20 @@ class TimingOption(Flag):
@st.composite @st.composite
def song( def song(
draw: DrawFunc, draw: DrawFunc,
timing_options: TimingOption, diffs_strat: st.SearchStrategy[Set[str]] = st.sets(
extra_diffs: bool, st.sampled_from(list(d.value for d in Difficulty)), min_size=1, max_size=3
notes_options: NoteOption, ),
common_timing_strat: st.SearchStrategy[Optional[Timing]] = timing_info(),
chart_strat: st.SearchStrategy[Chart] = chart(),
metadata_strat: st.SearchStrategy[Metadata] = metadata(),
) -> Song: ) -> Song:
if not ((TimingOption.GLOBAL | TimingOption.PER_CHART) & timing_options): diffs = draw(diffs_strat)
raise ValueError(
"Invalid timing options, at least one of the flags GLOBAL or PER_CHART must be set"
)
timing_strat = timing_info(TimingOption.BPM_CHANGES in timing_options)
note_strat = notes(notes_options)
diff_name_strat = st.sampled_from(list(d.value for d in Difficulty))
if extra_diffs:
# only go for ascii in extra diffs
# https://en.wikipedia.org/wiki/Basic_Latin_(Unicode_block)
diff_name_strat = st.one_of(
diff_name_strat,
st.text(
alphabet=st.characters(min_codepoint=0x20, max_codepoint=0x7E),
min_size=1,
max_size=20,
),
)
diffs = draw(st.sets(diff_name_strat, min_size=1, max_size=10))
charts: MultiDict[Chart] = MultiDict() charts: MultiDict[Chart] = MultiDict()
for diff_name in diffs: for diff_name in diffs:
chart_timing_strat = st.none() charts.add(diff_name, draw(chart_strat))
if TimingOption.PER_CHART in timing_options:
chart_timing_strat = st.one_of(st.none(), timing_strat)
_chart = draw(chart(chart_timing_strat, note_strat))
charts.add(diff_name, _chart)
common_timing_start: st.SearchStrategy[Optional[Timing]] = st.none()
if TimingOption.GLOBAL in timing_options:
common_timing_start = timing_strat
return Song( return Song(
metadata=draw(metadata()), metadata=draw(metadata_strat),
charts=charts, charts=charts,
common_timing=draw(common_timing_start), common_timing=draw(common_timing_strat),
) )