1
0
mirror of synced 2024-12-04 19:17:55 +01:00

Loaders can now take in arguments

[eve] Add beat_snap argument to the loader
[eve] Simplify bpm truncation method
[eve] Fix bug in TimeMap.from_seconds
[eve] TimeMap.convert_to_timing_info now accepts a beat_snap
Add denominator argument to round_beats for coarser rounding
Parameterize hypothesis strategies in a less ugly way
This commit is contained in:
Stepland 2021-05-12 01:24:25 +02:00
parent 8322747a15
commit f72e875898
22 changed files with 256 additions and 185 deletions

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import math
from dataclasses import astuple, dataclass
from enum import Enum
from fractions import Fraction
from typing import Union
from itertools import count
from typing import Iterator, Union
from jubeatools import song
@ -64,15 +66,7 @@ class Event:
@classmethod
def from_long_note(cls, note: song.LongNote, time_map: TimeMap) -> Event:
if not note.has_straight_tail():
raise ValueError("Diagonal tails cannot be represented in eve format")
eve_long = EveLong(
duration=duration_in_ticks(note, time_map),
length=len(list(note.positions_covered())) - 1,
direction=DIRECTION_TO_VALUE[note.tail_direction()],
position=note.position.index,
)
eve_long = EveLong.from_jubeatools(note, time_map)
ticks = ticks_at_beat(note.time, time_map)
return Event(time=ticks, command=Command.LONG, value=eve_long.value)
@ -134,6 +128,18 @@ class EveLong:
"outside the screen"
)
@classmethod
def from_jubeatools(cls, note: song.LongNote, time_map: TimeMap) -> EveLong:
if not note.has_straight_tail():
raise ValueError("Diagonal tails cannot be represented in eve format")
return cls(
duration=duration_in_ticks(note, time_map),
length=len(list(note.positions_covered())) - 1,
direction=DIRECTION_TO_VALUE[note.tail_direction()],
position=note.position.index,
)
@classmethod
def from_value(cls, value: int) -> EveLong:
if value < 0:
@ -148,7 +154,10 @@ class EveLong:
@property
def value(self) -> int:
return (
self.duration << 8 + self.length << 6 + self.direction << 4 + self.position
(self.duration << 8)
+ (self.length << 6)
+ (self.direction << 4)
+ self.position
)
@ -172,3 +181,33 @@ def ticks_to_seconds(tick: int) -> Fraction:
def seconds_to_ticks(time: Fraction) -> int:
"""Convert fractional seconds to eve ticks (300 Hz)"""
return round(time * 300)
def value_to_truncated_bpm(value: int) -> Fraction:
"""Only keeps enough significant digits to allow recovering the original
TEMPO line value from the bpm"""
exact_bpm = value_to_bpm(value)
truncated_bpms = iter_truncated(exact_bpm)
bpms_preserving_value = filter(
lambda b: bpm_to_value(b) < value + 1, truncated_bpms
)
return next(bpms_preserving_value)
def iter_truncated(f: Fraction) -> Iterator[Fraction]:
for places in count():
yield truncate_fraction(f, places)
def truncate_fraction(f: Fraction, places: int) -> Fraction:
"""Truncates a fraction to the given number of decimal places"""
exponent = Fraction(10) ** places
return Fraction(math.floor(f * exponent), exponent)
def value_to_bpm(value: int) -> Fraction:
return 6 * 10 ** 7 / Fraction(value)
def bpm_to_value(bpm: Fraction) -> Fraction:
return 6 * 10 ** 7 / bpm

View File

@ -12,7 +12,7 @@ from jubeatools import song
from jubeatools.formats.dump_tools import make_dumper_from_chart_file_dumper
from jubeatools.formats.filetypes import ChartFile
from .commons import AnyNote, Command, Event, ticks_at_beat
from .commons import AnyNote, Command, Event, bpm_to_value, ticks_at_beat
from .timemap import TimeMap
@ -71,14 +71,10 @@ def make_timing_events(
def make_bpm_event(bpm_change: song.BPMEvent, time_map: TimeMap) -> Event:
ticks = ticks_at_beat(bpm_change.time, time_map)
bpm_value = bpm_to_value(Fraction(bpm_change.BPM))
bpm_value = math.floor(bpm_to_value(Fraction(bpm_change.BPM)))
return Event(time=ticks, command=Command.TEMPO, value=bpm_value)
def bpm_to_value(bpm: Fraction) -> int:
return math.floor(60 * 10 ** 6 / bpm)
def choose_end_beat(notes: List[AnyNote]) -> song.BeatsTime:
"""Leave 2 empty measures (4 beats) after the last event"""
last_note_beat = compute_last_note_beat(notes)

View File

@ -1,9 +1,7 @@
import math
from decimal import Decimal
from fractions import Fraction
from functools import reduce
from pathlib import Path
from typing import Iterator, List, Optional
from typing import Any, Iterator, List, Optional
from jubeatools import song
from jubeatools.formats.load_tools import make_folder_loader, round_beats
@ -16,13 +14,14 @@ from .commons import (
EveLong,
Event,
ticks_to_seconds,
value_to_truncated_bpm,
)
from .timemap import BPMAtSecond, TimeMap
def load_eve(path: Path) -> song.Song:
def load_eve(path: Path, *, beat_snap: int = 240, **kwargs: Any) -> song.Song:
files = load_folder(path)
charts = [_load_eve(l, p) for p, l in files.items()]
charts = [_load_eve(l, p, beat_snap=beat_snap) for p, l in files.items()]
return reduce(song.Song.merge, charts)
@ -33,7 +32,7 @@ def load_file(path: Path) -> List[str]:
load_folder = make_folder_loader("*.eve", load_file)
def _load_eve(lines: List[str], file_path: Path) -> song.Song:
def _load_eve(lines: List[str], file_path: Path, *, beat_snap: int = 240) -> song.Song:
events = list(iter_events(lines))
events_by_command = group_by(events, lambda e: e.command)
bpms = [
@ -52,7 +51,7 @@ def _load_eve(lines: List[str], file_path: Path) -> song.Song:
for e in events_by_command[Command.LONG]
]
all_notes = sorted(tap_notes + long_notes, key=lambda n: (n.time, n.position))
timing = time_map.convert_to_timing_info()
timing = time_map.convert_to_timing_info(beat_snap=beat_snap)
chart = song.Chart(level=Decimal(0), timing=timing, notes=all_notes)
dif = guess_difficulty(file_path.stem) or song.Difficulty.EXTREME
return song.Song(metadata=song.Metadata(), charts={dif: chart})
@ -103,34 +102,6 @@ def parse_event(line: str) -> Event:
return Event(tick, command, value)
def value_to_truncated_bpm(value: int) -> Fraction:
"""Only keeps enough significant digits to allow recovering the original
TEMPO line value from the bpm"""
precision = tempo_precision(value)
places = significant_decimal_places(precision) + 1
raw_bpm = value_to_bpm(value)
return truncate_fraction(raw_bpm, places)
def value_to_bpm(value: int) -> Fraction:
return 6 * 10 ** 7 / Fraction(value)
def significant_decimal_places(max_error: float) -> int:
return int(-(math.log(max_error / 5) / math.log(10)))
def tempo_precision(value: int) -> float:
"""Maximum error on the bpm this tempo value corresponds to"""
return (6 * 10 ** 7) / (value * (value + 1))
def truncate_fraction(f: Fraction, places: int) -> Fraction:
"""Truncates a fraction to the given number of decimal places"""
exponent = Fraction(10) ** places
return Fraction(math.floor(f * exponent), exponent)
def make_tap_note(ticks: int, value: int, time_map: TimeMap) -> song.TapNote:
seconds = ticks_to_seconds(ticks)
raw_beats = time_map.beats_at(seconds)

View File

@ -3,8 +3,7 @@ import math
from hypothesis import given
from hypothesis import strategies as st
from ..dump import bpm_to_value
from ..load import value_to_truncated_bpm
from ..commons import bpm_to_value, value_to_truncated_bpm
@given(st.integers(min_value=1, max_value=6 * 10 ** 7))

View File

@ -1,5 +1,6 @@
import tempfile
from contextlib import contextmanager
from decimal import Decimal
from pathlib import Path
from typing import Iterator
@ -12,6 +13,10 @@ from jubeatools.testutils import strategies as jbst
from jubeatools.testutils.test_patterns import dump_and_load_then_compare
from jubeatools.testutils.typing import DrawFunc
simple_beat_strat = jbst.beat_time(
denominator_strat=st.sampled_from([4, 8, 3]), max_section=10
)
@st.composite
def eve_compatible_song(draw: DrawFunc) -> song.Song:
@ -21,10 +26,24 @@ def eve_compatible_song(draw: DrawFunc) -> song.Song:
jbst.chart(
timing_strat=jbst.timing_info(
with_bpm_changes=True,
bpm_strat=st.decimals(min_value=1, max_value=1000, places=2),
bpm_strat=st.decimals(min_value=1, max_value=500, places=2),
beat_zero_offset_strat=st.decimals(min_value=0, max_value=20, places=2),
time_strat=simple_beat_strat,
),
notes_strat=jbst.notes(jbst.NoteOption.LONGS),
notes_strat=jbst.notes(
note_strat=st.one_of(
jbst.tap_note(time_start=simple_beat_strat),
jbst.long_note(
time_strat=simple_beat_strat,
duration_strat=jbst.beat_time(
min_numerator=1,
max_section=3,
denominator_strat=st.sampled_from([4, 8, 3]),
),
),
)
),
level_strat=st.just(Decimal(0)),
)
)
return song.Song(
@ -40,7 +59,7 @@ def open_temp_dir() -> Iterator[Path]:
@given(eve_compatible_song())
@settings(verbosity=Verbosity.debug)
@settings(verbosity=Verbosity.normal)
def test_that_full_chart_roundtrips(song: song.Song) -> None:
dump_and_load_then_compare(
Format.EVE,

View File

@ -0,0 +1,25 @@
from hypothesis import given
from hypothesis import strategies as st
from jubeatools import song
from jubeatools.testutils import strategies as jbst
from ..commons import EveLong
from ..timemap import TimeMap
@given(
jbst.long_note(),
jbst.timing_info(
with_bpm_changes=True,
bpm_strat=st.decimals(min_value=1, max_value=1000, places=2),
beat_zero_offset_strat=st.decimals(min_value=0, max_value=20, places=2),
),
)
def test_that_long_note_roundtrips(
long_note: song.LongNote, timing: song.Timing
) -> None:
time_map = TimeMap.from_timing(timing)
original = EveLong.from_jubeatools(long_note, time_map)
recovered = EveLong.from_value(original.value)
assert recovered == original

View File

@ -81,8 +81,8 @@ class TimeMap:
raise ValueError("No BPM defined")
grouped_by_time = group_by(events, key=lambda e: e.seconds)
for time, events in grouped_by_time.items():
if len(events) > 1:
for time, events_at_time in grouped_by_time.items():
if len(events_at_time) > 1:
raise ValueError(f"Multiple BPMs defined at {time} seconds : {events}")
# take the first BPM change then compute from there
@ -144,10 +144,13 @@ class TimeMap:
)
return bpm_change.beats + beats_since_last_event
def convert_to_timing_info(self) -> song.Timing:
def convert_to_timing_info(self, beat_snap: int = 240) -> song.Timing:
return song.Timing(
events=[
song.BPMEvent(time=round_beats(e.beats), BPM=fraction_to_decimal(e.BPM))
song.BPMEvent(
time=round_beats(e.beats, beat_snap),
BPM=fraction_to_decimal(e.BPM),
)
for e in self.events_by_beats
],
beat_zero_offset=self.beat_zero_offset,

View File

@ -5,7 +5,7 @@ from decimal import Decimal
from functools import reduce
from itertools import product
from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Set, Tuple, Union
from typing import Any, Dict, Iterator, List, Mapping, Set, Tuple, Union
from more_itertools import collapse, mark_ends
@ -347,7 +347,7 @@ def _load_memo_file(lines: List[str]) -> Song:
return Song(metadata=metadata, charts=charts)
def load_memo(path: Path) -> Song:
def load_memo(path: Path, **kwargs: Any) -> Song:
files = load_folder(path)
charts = [_load_memo_file(lines) for _, lines in files.items()]
return reduce(Song.merge, charts)

View File

@ -4,7 +4,7 @@ from decimal import Decimal
from functools import reduce
from itertools import product
from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Set, Tuple, Union
from typing import Any, Dict, Iterator, List, Mapping, Set, Tuple, Union
from more_itertools import mark_ends
@ -338,7 +338,7 @@ def _load_memo1_file(lines: List[str]) -> Song:
return Song(metadata=metadata, charts=charts)
def load_memo1(path: Path) -> Song:
def load_memo1(path: Path, **kwargs: Any) -> Song:
files = load_folder(path)
charts = [_load_memo1_file(lines) for _, lines in files.items()]
return reduce(Song.merge, charts)

View File

@ -3,7 +3,7 @@ from decimal import Decimal
from functools import reduce
from itertools import product, zip_longest
from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Union
from parsimonious import Grammar, NodeVisitor, ParseError
from parsimonious.nodes import Node
@ -457,7 +457,7 @@ def _load_memo2_file(lines: List[str]) -> Song:
return Song(metadata=metadata, charts=charts)
def load_memo2(path: Path) -> Song:
def load_memo2(path: Path, **kwargs: Any) -> Song:
files = load_folder(path)
charts = [_load_memo2_file(lines) for _, lines in files.items()]
return reduce(Song.merge, charts)

View File

@ -4,7 +4,7 @@ from decimal import Decimal
from functools import reduce
from itertools import product
from pathlib import Path
from typing import Dict, Iterator, List, Set, Tuple, Union
from typing import Any, Dict, Iterator, List, Set, Tuple, Union
from parsimonious import Grammar, NodeVisitor, ParseError
from parsimonious.nodes import Node
@ -241,7 +241,7 @@ class MonoColumnParser(JubeatAnalyserParser):
section_starting_beat += section.length
def load_mono_column(path: Path) -> Song:
def load_mono_column(path: Path, **kwargs: Any) -> Song:
files = load_folder(path)
charts = [_load_mono_column_file(lines) for _, lines in files.items()]
return reduce(Song.merge, charts)

View File

@ -17,7 +17,7 @@ from ..test_utils import memo_compatible_song, temp_file_named_txt
from . import example1, example2, example3
@given(jbst.notes(jbst.NoteOption.LONGS))
@given(jbst.notes())
@example(example1.notes)
def test_that_notes_roundtrip(notes: Set[Union[song.TapNote, song.LongNote]]) -> None:
timing = song.Timing(

View File

@ -9,7 +9,6 @@ from jubeatools import song
from jubeatools.formats import Format
from jubeatools.formats.jubeat_analyser.memo1.dump import _dump_memo1_chart
from jubeatools.formats.jubeat_analyser.memo1.load import Memo1Parser
from jubeatools.testutils.strategies import NoteOption
from jubeatools.testutils.strategies import notes as notes_strat
from jubeatools.testutils.test_patterns import dump_and_load_then_compare
@ -17,7 +16,7 @@ from ..test_utils import memo_compatible_song, temp_file_named_txt
from . import example1
@given(notes_strat(NoteOption.LONGS))
@given(notes_strat())
def test_that_notes_roundtrip(notes: List[Union[song.TapNote, song.LongNote]]) -> None:
timing = song.Timing(
events=[song.BPMEvent(song.BeatsTime(0), Decimal(120))],

View File

@ -19,7 +19,6 @@ from jubeatools.song import (
TapNote,
Timing,
)
from jubeatools.testutils.strategies import NoteOption
from jubeatools.testutils.strategies import notes as notes_strat
from jubeatools.testutils.test_patterns import dump_and_load_then_compare
@ -27,7 +26,7 @@ from ..test_utils import memo_compatible_song, temp_file_named_txt
from . import example1, example2, example3
@given(notes_strat(NoteOption.LONGS))
@given(notes_strat())
def test_that_notes_roundtrip(notes: List[Union[TapNote, LongNote]]) -> None:
timing = Timing(
events=[BPMEvent(BeatsTime(0), Decimal(120))], beat_zero_offset=SecondsTime(0)

View File

@ -19,7 +19,7 @@ from jubeatools.song import (
TapNote,
Timing,
)
from jubeatools.testutils.strategies import NoteOption, long_note
from jubeatools.testutils.strategies import long_note
from jubeatools.testutils.strategies import notes as notes_strat
from jubeatools.testutils.strategies import tap_note
from jubeatools.testutils.test_patterns import dump_and_load_then_compare
@ -63,7 +63,7 @@ def test_that_a_single_long_note_roundtrips(note: LongNote) -> None:
assert set([note]) == actual
@given(notes_strat(NoteOption.LONGS))
@given(notes_strat())
def test_that_many_notes_roundtrip(notes: List[Union[TapNote, LongNote]]) -> None:
timing = Timing(
events=[BPMEvent(BeatsTime(0), Decimal(120))], beat_zero_offset=SecondsTime(0)

View File

@ -14,7 +14,7 @@ from jubeatools.testutils.typing import DrawFunc
def memo_compatible_metadata(draw: DrawFunc) -> song.Metadata:
text_strat = st.text(alphabet=st.characters(min_codepoint=0x20, max_codepoint=0x7E))
metadata: song.Metadata = draw(
jbst.metadata(text_strat=text_strat, path_start=text_strat)
jbst.metadata(text_strat=text_strat, path_strat=text_strat)
)
metadata.preview = None
metadata.preview_file = None
@ -28,7 +28,7 @@ def memo_compatible_song(draw: DrawFunc) -> song.Song:
chart = draw(
jbst.chart(
timing_strat=jbst.timing_info(with_bpm_changes=True),
notes_strat=jbst.notes(jbst.NoteOption.LONGS),
notes_strat=jbst.notes(),
)
)
metadata: song.Metadata = draw(memo_compatible_metadata())

View File

@ -51,7 +51,8 @@ def make_folder_loader(glob_pattern: str, file_loader: FileLoader) -> FolderLoad
Number = Union[int, float, Decimal, Fraction]
def round_beats(beats: Number) -> song.BeatsTime:
"""Rounds a given beat value to the nearest 1/240th"""
nearest_240th = round(beats * 240)
return song.BeatsTime(nearest_240th, 240)
def round_beats(beats: Number, denominator: int = 240) -> song.BeatsTime:
"""Rounds a given beat value to the nearest given division (default to
nearest 1/240th"""
nearest = round(beats * denominator)
return song.BeatsTime(nearest, denominator)

View File

@ -149,8 +149,8 @@ def _load_memon_note_v0(
return jbt.TapNote(time, position)
def load_memon_legacy(file: Path) -> jbt.Song:
raw_memon = _load_raw_memon(file)
def load_memon_legacy(path: Path, **kwargs: Any) -> jbt.Song:
raw_memon = _load_raw_memon(path)
schema = Memon_legacy()
memon = schema.load(raw_memon)
metadata = jbt.Metadata(
@ -179,8 +179,8 @@ def load_memon_legacy(file: Path) -> jbt.Song:
return jbt.Song(metadata=metadata, charts=charts, common_timing=common_timing)
def load_memon_0_1_0(file: Path) -> jbt.Song:
raw_memon = _load_raw_memon(file)
def load_memon_0_1_0(path: Path, **kwargs: Any) -> jbt.Song:
raw_memon = _load_raw_memon(path)
schema = Memon_0_1_0()
memon = schema.load(raw_memon)
metadata = jbt.Metadata(
@ -209,8 +209,8 @@ def load_memon_0_1_0(file: Path) -> jbt.Song:
return jbt.Song(metadata=metadata, charts=charts, common_timing=common_timing)
def load_memon_0_2_0(file: Path) -> jbt.Song:
raw_memon = _load_raw_memon(file)
def load_memon_0_2_0(path: Path, **kwargs: Any) -> jbt.Song:
raw_memon = _load_raw_memon(path)
schema = Memon_0_2_0()
memon = schema.load(raw_memon)
preview = None

View File

@ -1,13 +1,13 @@
import tempfile
from pathlib import Path
from typing import Set
import hypothesis.strategies as st
from hypothesis import given
from jubeatools import song
from jubeatools.formats.typing import Dumper, Loader
from jubeatools.song import Song
from jubeatools.testutils.strategies import NoteOption, TimingOption
from jubeatools.testutils.strategies import song as song_strat
from jubeatools.testutils import strategies as jbst
from jubeatools.testutils.typing import DrawFunc
from . import (
@ -21,7 +21,7 @@ from . import (
def dump_and_load(
expected_song: Song, dump_function: Dumper, load_function: Loader
expected_song: song.Song, dump_function: Dumper, load_function: Loader
) -> None:
with tempfile.NamedTemporaryFile(mode="wb") as file:
files = dump_function(expected_song, Path(file.name))
@ -35,16 +35,37 @@ def dump_and_load(
@st.composite
def memon_legacy_compatible_song(draw: DrawFunc) -> Song:
def memon_diffs(draw: DrawFunc) -> Set[str]:
simple_diff_names = st.sampled_from(list(d.value for d in song.Difficulty))
diff_names = st.one_of(
simple_diff_names,
st.text(
alphabet=st.characters(min_codepoint=0x20, max_codepoint=0x7E),
min_size=1,
max_size=20,
),
)
s: Set[str] = draw(st.sets(diff_names, min_size=1, max_size=10))
return s
@st.composite
def memon_legacy_compatible_song(draw: DrawFunc) -> song.Song:
"""Memon versions below v0.2.0 do not support any preview metadata"""
song: Song = draw(song_strat(TimingOption.GLOBAL, True, NoteOption.LONGS))
song.metadata.preview = None
song.metadata.preview_file = None
return song
random_song: song.Song = draw(
jbst.song(
diffs_strat=memon_diffs(),
chart_strat=jbst.chart(timing_strat=st.none()),
common_timing_strat=jbst.timing_info(with_bpm_changes=False),
)
)
random_song.metadata.preview = None
random_song.metadata.preview_file = None
return random_song
@given(memon_legacy_compatible_song())
def test_memon_legacy(song: Song) -> None:
def test_memon_legacy(song: song.Song) -> None:
dump_and_load(song, dump_memon_legacy, load_memon_legacy)
@ -52,18 +73,24 @@ memon_0_1_0_compatible_song = memon_legacy_compatible_song
@given(memon_0_1_0_compatible_song())
def test_memon_0_1_0(song: Song) -> None:
def test_memon_0_1_0(song: song.Song) -> None:
dump_and_load(song, dump_memon_0_1_0, load_memon_0_1_0)
@st.composite
def memon_0_2_0_compatible_song(draw: DrawFunc) -> Song:
def memon_0_2_0_compatible_song(draw: DrawFunc) -> song.Song:
"""Memon v0.2.0 does not support preview_file"""
song: Song = draw(song_strat(TimingOption.GLOBAL, True, NoteOption.LONGS))
song.metadata.preview_file = None
return song
random_song: song.Song = draw(
jbst.song(
diffs_strat=memon_diffs(),
chart_strat=jbst.chart(timing_strat=st.none()),
common_timing_strat=jbst.timing_info(with_bpm_changes=False),
)
)
random_song.metadata.preview_file = None
return random_song
@given(memon_0_2_0_compatible_song())
def test_memon_0_2_0(song: Song) -> None:
def test_memon_0_2_0(song: song.Song) -> None:
dump_and_load(song, dump_memon_0_2_0, load_memon_0_2_0)

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Protocol
from typing import Any, Dict, List, Protocol
from jubeatools.formats.filetypes import ChartFile
from jubeatools.song import Song
@ -24,6 +24,10 @@ class ChartFileDumper(Protocol):
...
# Loaders deserialize a Path to a Song object
# The Path can be a file or a folder depending on the format
Loader = Callable[[Path], Song]
class Loader(Protocol):
"""A Loader deserializes a Path to a Song object and possibly takes in
some options via the kwargs.
The Path can be a file or a folder depending on the format"""
def __call__(self, path: Path, **kwargs: Any) -> Song:
...

View File

@ -6,7 +6,7 @@ from decimal import Decimal
from enum import Flag, auto
from itertools import product
from pathlib import Path
from typing import Any, Dict, Optional, Set, Union
from typing import Dict, Iterable, Optional, Set, Union
import hypothesis.strategies as st
from multidict import MultiDict
@ -34,8 +34,9 @@ def beat_time(
max_section: Optional[int] = None,
min_numerator: Optional[int] = None,
max_numerator: Optional[int] = None,
denominator_strat: st.SearchStrategy[int] = st.sampled_from([4, 8, 16, 3, 5]),
) -> BeatsTime:
denominator = draw(st.sampled_from([4, 8, 16, 3, 5]))
denominator = draw(denominator_strat)
if min_section is not None:
min_value = denominator * 4 * min_section
@ -68,17 +69,25 @@ def note_position(draw: DrawFunc) -> NotePosition:
@st.composite
def tap_note(draw: DrawFunc) -> TapNote:
time = draw(beat_time(max_section=10))
def tap_note(
draw: DrawFunc, time_start: st.SearchStrategy[BeatsTime] = beat_time(max_section=10)
) -> TapNote:
time = draw(time_start)
position = draw(note_position())
return TapNote(time, position)
@st.composite
def long_note(draw: DrawFunc) -> LongNote:
time = draw(beat_time(max_section=10))
def long_note(
draw: DrawFunc,
time_strat: st.SearchStrategy[BeatsTime] = beat_time(max_section=10),
duration_strat: st.SearchStrategy[BeatsTime] = beat_time(
min_numerator=1, max_section=3
),
) -> LongNote:
time = draw(time_strat)
position = draw(note_position())
duration = draw(beat_time(min_numerator=1, max_section=3))
duration = draw(duration_strat)
tail_is_vertical = draw(st.booleans())
tail_offset = draw(st.integers(min_value=1, max_value=3))
if tail_is_vertical:
@ -91,15 +100,6 @@ def long_note(draw: DrawFunc) -> LongNote:
return LongNote(time, position, duration, tail_tip)
class NoteOption(Flag):
"""What kind of notes are allowed to be generated"""
# Long notes
LONGS = auto()
# Intersections between longs and other notes on the same square
COLLISIONS = auto()
@st.composite
def bad_notes(draw: DrawFunc, longs: bool) -> Set[Union[TapNote, LongNote]]:
note_strat = tap_note()
@ -110,16 +110,20 @@ def bad_notes(draw: DrawFunc, longs: bool) -> Set[Union[TapNote, LongNote]]:
@st.composite
def notes(draw: DrawFunc, options: NoteOption) -> Set[Union[TapNote, LongNote]]:
if (NoteOption.COLLISIONS in options) and (NoteOption.LONGS not in options):
raise ValueError("Can't ask for collisions without longs")
note_strat = tap_note()
if NoteOption.LONGS in options:
note_strat = st.one_of(note_strat, long_note())
def notes(
draw: DrawFunc,
collisions: bool = False,
note_strat: st.SearchStrategy[Union[TapNote, LongNote]] = st.one_of(
tap_note(), long_note()
),
beat_time_strat: st.SearchStrategy[BeatsTime] = beat_time(max_section=3),
beat_interval_strat: st.SearchStrategy[BeatsTime] = beat_time(
min_numerator=1, max_section=3
),
) -> Set[Union[TapNote, LongNote]]:
raw_notes: Set[Union[TapNote, LongNote]] = draw(st.sets(note_strat, max_size=32))
if NoteOption.COLLISIONS in options:
if collisions:
return raw_notes
else:
last_notes: Dict[NotePosition, Optional[BeatsTime]] = {
@ -129,11 +133,9 @@ def notes(draw: DrawFunc, options: NoteOption) -> Set[Union[TapNote, LongNote]]:
for note in sorted(raw_notes, key=lambda n: (n.time, n.position)):
last_note_time = last_notes[note.position]
if last_note_time is None:
new_time = draw(beat_time(max_section=3))
new_time = draw(beat_time_strat)
else:
new_time = last_note_time + draw(
beat_time(min_numerator=1, max_section=3)
)
new_time = last_note_time + draw(beat_interval_strat)
if isinstance(note, LongNote):
notes.add(
LongNote(
@ -158,9 +160,11 @@ def bpms(draw: DrawFunc) -> Decimal:
@st.composite
def bpm_changes(
draw: DrawFunc, bpm_strat: st.SearchStrategy[Decimal] = bpms()
draw: DrawFunc,
bpm_strat: st.SearchStrategy[Decimal] = bpms(),
time_strat: st.SearchStrategy[BeatsTime] = beat_time(min_section=1, max_section=10),
) -> BPMEvent:
time = draw(beat_time(min_section=1, max_section=10))
time = draw(time_strat)
bpm = draw(bpm_strat)
return BPMEvent(time, bpm)
@ -173,13 +177,14 @@ def timing_info(
beat_zero_offset_strat: st.SearchStrategy[Decimal] = st.decimals(
min_value=0, max_value=20, places=3
),
time_strat: st.SearchStrategy[BeatsTime] = beat_time(min_section=1, max_section=10),
) -> Timing:
first_bpm = draw(bpm_strat)
first_event = BPMEvent(BeatsTime(0), first_bpm)
events = [first_event]
if with_bpm_changes:
raw_bpm_changes = st.lists(
bpm_changes(bpm_strat), unique_by=get_bpm_change_time
bpm_changes(bpm_strat, time_strat), unique_by=get_bpm_change_time
)
sorted_bpm_changes = raw_bpm_changes.map(
lambda l: sorted(l, key=get_bpm_change_time)
@ -195,15 +200,23 @@ def get_bpm_change_time(b: BPMEvent) -> BeatsTime:
@st.composite
def chart(draw: DrawFunc, timing_strat: Any, notes_strat: Any) -> Chart:
level = Decimal(
draw(
def level(draw: DrawFunc) -> Union[int, Decimal]:
d: Union[int, Decimal] = draw(
st.one_of(
st.integers(min_value=0),
st.decimals(min_value=0, max_value=10.9, places=1),
)
st.integers(min_value=0), st.decimals(min_value=0, max_value=10.9, places=1)
)
)
return d
@st.composite
def chart(
draw: DrawFunc,
timing_strat: st.SearchStrategy[Timing] = timing_info(),
notes_strat: st.SearchStrategy[Iterable[Union[TapNote, LongNote]]] = notes(),
level_strat: st.SearchStrategy[Union[int, Decimal]] = level(),
) -> Chart:
level = Decimal(draw(level_strat))
timing = draw(timing_strat)
notes = draw(notes_strat)
return Chart(
@ -228,15 +241,15 @@ def preview(draw: DrawFunc) -> Preview:
def metadata(
draw: DrawFunc,
text_strat: st.SearchStrategy[str] = st.text(),
path_start: st.SearchStrategy[str] = st.text(),
path_strat: st.SearchStrategy[str] = st.text(),
) -> Metadata:
return Metadata(
title=draw(text_strat),
artist=draw(text_strat),
audio=Path(draw(path_start)),
cover=Path(draw(path_start)),
audio=Path(draw(path_strat)),
cover=Path(draw(path_strat)),
preview=draw(st.one_of(st.none(), preview())),
preview_file=draw(path_start),
preview_file=draw(path_strat),
)
@ -249,44 +262,20 @@ class TimingOption(Flag):
@st.composite
def song(
draw: DrawFunc,
timing_options: TimingOption,
extra_diffs: bool,
notes_options: NoteOption,
) -> Song:
if not ((TimingOption.GLOBAL | TimingOption.PER_CHART) & timing_options):
raise ValueError(
"Invalid timing options, at least one of the flags GLOBAL or PER_CHART must be set"
)
timing_strat = timing_info(TimingOption.BPM_CHANGES in timing_options)
note_strat = notes(notes_options)
diff_name_strat = st.sampled_from(list(d.value for d in Difficulty))
if extra_diffs:
# only go for ascii in extra diffs
# https://en.wikipedia.org/wiki/Basic_Latin_(Unicode_block)
diff_name_strat = st.one_of(
diff_name_strat,
st.text(
alphabet=st.characters(min_codepoint=0x20, max_codepoint=0x7E),
min_size=1,
max_size=20,
diffs_strat: st.SearchStrategy[Set[str]] = st.sets(
st.sampled_from(list(d.value for d in Difficulty)), min_size=1, max_size=3
),
)
diffs = draw(st.sets(diff_name_strat, min_size=1, max_size=10))
common_timing_strat: st.SearchStrategy[Optional[Timing]] = timing_info(),
chart_strat: st.SearchStrategy[Chart] = chart(),
metadata_strat: st.SearchStrategy[Metadata] = metadata(),
) -> Song:
diffs = draw(diffs_strat)
charts: MultiDict[Chart] = MultiDict()
for diff_name in diffs:
chart_timing_strat = st.none()
if TimingOption.PER_CHART in timing_options:
chart_timing_strat = st.one_of(st.none(), timing_strat)
_chart = draw(chart(chart_timing_strat, note_strat))
charts.add(diff_name, _chart)
common_timing_start: st.SearchStrategy[Optional[Timing]] = st.none()
if TimingOption.GLOBAL in timing_options:
common_timing_start = timing_strat
charts.add(diff_name, draw(chart_strat))
return Song(
metadata=draw(metadata()),
metadata=draw(metadata_strat),
charts=charts,
common_timing=draw(common_timing_start),
common_timing=draw(common_timing_strat),
)