1
0
mirror of synced 2024-12-04 19:17:55 +01:00

Refactor and other fixes

This commit is contained in:
Stepland 2020-07-04 00:11:25 +02:00
parent cfc23c2f90
commit 875e8161ba
5 changed files with 262 additions and 183 deletions

View File

@ -11,15 +11,33 @@ from more_itertools import collapse, intersperse, mark_ends, windowed
from sortedcontainers import SortedDict, SortedKeyList, SortedSet
from jubeatools import __version__
from jubeatools.song import (BeatsTime, Chart, LongNote, Metadata,
NotePosition, Song, TapNote, Timing)
from jubeatools.song import (
BeatsTime,
Chart,
LongNote,
Metadata,
NotePosition,
Song,
TapNote,
Timing,
)
from ..command import dump_command
from .commons import NOTE_SYMBOLS, CIRCLE_FREE_SYMBOLS
from .commons import CIRCLE_FREE_SYMBOLS, NOTE_SYMBOLS
COMMAND_ORDER = SortedSet(
["b", "t", "m", "o", "r", "title", "artist", "lev", "dif", "jacket", "prevpos"]
)
COMMAND_ORDER = [
"b",
"t",
"m",
"o",
"r",
"title",
"artist",
"lev",
"dif",
"jacket",
"prevpos",
]
BEATS_TIME_TO_SYMBOL = {
BeatsTime(1, 4) * index: symbol for index, symbol in enumerate(NOTE_SYMBOLS)
@ -30,124 +48,19 @@ BEATS_TIME_TO_CIRCLE_FREE = {
}
DIRECTION_TO_ARROW = {
NotePosition(-1, 0): "", # U+FF1E : FULLWIDTH GREATER-THAN SIGN
NotePosition( 1, 0): "", # U+FF1C : FULLWIDTH LESS-THAN SIGN
NotePosition( 0, -1): "", # U+2228 : LOGICAL OR
NotePosition( 0, 1): "", # U+2227 : LOGICAL AND
NotePosition(-1, 0): "", # U+FF1E : FULLWIDTH GREATER-THAN SIGN
NotePosition(1, 0): "", # U+FF1C : FULLWIDTH LESS-THAN SIGN
NotePosition(0, -1): "", # U+2228 : LOGICAL OR
NotePosition(0, 1): "", # U+2227 : LOGICAL AND
}
DIRECTION_TO_LINE = {
NotePosition(-1, 0): "", # U+2015 : HORIZONTAL BAR
NotePosition( 1, 0): "",
NotePosition( 0, -1): "", # U+FF5C : FULLWIDTH VERTICAL LINE
NotePosition( 0, 1): "",
NotePosition(-1, 0): "", # U+2015 : HORIZONTAL BAR
NotePosition(1, 0): "",
NotePosition(0, -1): "", # U+FF5C : FULLWIDTH VERTICAL LINE
NotePosition(0, 1): "",
}
def fraction_to_decimal(frac: Fraction):
"Thanks stackoverflow ! https://stackoverflow.com/a/40468867/10768117"
return frac.numerator / Decimal(frac.denominator)
@dataclass
class MonoColumnDumpedSection:
commands: Dict[str, Optional[str]] = field(default_factory=dict)
symbol_definitions: Dict[BeatsTime, str] = field(default_factory=dict)
notes: List[TapNote] = field(default_factory=list)
def render(
self,
current_beat: BeatsTime,
extra_symbols: Dict[BeatsTime, str],
circle_free: bool = False
) -> str:
blocs = []
commands = list(self._dump_commands())
if commands:
blocs.append(commands)
symbols = list(self._dump_symbol_definitions())
if symbols:
blocs.append(symbols)
notes = list(self._dump_notes(current_beat, extra_symbols, circle_free))
if notes:
blocs.append(notes)
return "\n".join(collapse([intersperse("", blocs), "--"]))
def _dump_commands(self) -> Iterator[str]:
keys = chain(COMMAND_ORDER, self.commands.keys() - COMMAND_ORDER)
for key in keys:
try:
value = self.commands[key]
except KeyError:
continue
yield dump_command(key, value)
def _dump_symbol_definitions(self) -> Iterator[str]:
for time, symbol in self.symbol_definitions.items():
decimal_time = fraction_to_decimal(time)
yield f"*{symbol}:{decimal_time:.6f}"
def _dump_notes(
self,
current_beat: BeatsTime,
extra_symbols: Dict[BeatsTime, str],
circle_free: bool = False
) -> Iterator[str]:
frames: List[Dict[NotePosition, str]] = []
frame: Dict[NotePosition, str] = {}
symbols: Dict[BeatsTime, str] = ChainMap(
extra_symbols, BEATS_TIME_TO_SYMBOL
)
for note in self.notes:
if isinstance(note, LongNote):
needed_positions = set(note.positions_covered())
if needed_positions & frame.keys():
frames.append(frame)
frame = {}
direction = note.tail_direction()
arrow = DIRECTION_TO_ARROW[direction]
line = DIRECTION_TO_LINE[direction]
for is_first, is_last, pos in note.positions_covered():
if is_first:
time_in_section = note.time - current_beat
symbol = symbols[time_in_section]
frame[pos] = symbol
if is_last:
frame[pos] = arrow
else:
frame[pos] = line
elif isinstance(note, TapNote):
if note.position in frame:
frames.append(frame)
frame = {}
time_in_section = note.time - current_beat
symbol = symbols[time_in_section]
frame[note.position] = symbol
elif isinstance(note, LongNoteEnd):
if note.position in frame:
frames.append(frame)
frame = {}
time_in_section = note.time - current_beat
if circle_free:
symbol = CIRCLE_FREE_SYMBOLS[time_in_section]
else:
symbol = symbols[time_in_section]
frame[note.position] = symbol
frames.append(frame)
dumped_frames = map(self._dump_frame, frames)
yield from collapse(intersperse("", dumped_frames))
@staticmethod
def _dump_frame(frame: Dict[NotePosition, str]) -> Iterator[str]:
for y in range(4):
yield "".join(frame.get(NotePosition(x, y), "") for x in range(4))
DIFFICULTIES = {"BSC": 1, "ADV": 2, "EXT": 3}
# I put a FUCKTON of extra characters just in case some insane chart uses
@ -159,43 +72,151 @@ DEFAULT_EXTRA_SYMBOLS = (
"アイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨラリルレロワヲン"
)
def _raise_if_unfit_for_mono_column(chart: Chart, timing: Timing, circle_free: bool = False):
def fraction_to_decimal(frac: Fraction):
"Thanks stackoverflow ! https://stackoverflow.com/a/40468867/10768117"
return frac.numerator / Decimal(frac.denominator)
@dataclass
class MonoColumnDumpedSection:
current_beat: BeatsTime
commands: Dict[str, Optional[str]] = field(default_factory=dict)
symbol_definitions: Dict[BeatsTime, str] = field(default_factory=dict)
symbols: Dict[BeatsTime, str] = field(default_factory=dict)
notes: List[TapNote] = field(default_factory=list)
def render(self, circle_free: bool = False,) -> str:
blocs = []
commands = list(self._dump_commands())
if commands:
blocs.append(commands)
symbols = list(self._dump_symbol_definitions())
if symbols:
blocs.append(symbols)
notes = list(self._dump_notes(circle_free))
if notes:
blocs.append(notes)
return "\n".join(collapse([intersperse("", blocs), "--"]))
def _dump_commands(self) -> Iterator[str]:
keys = chain(COMMAND_ORDER, self.commands.keys() - set(COMMAND_ORDER))
for key in keys:
try:
value = self.commands[key]
except KeyError:
continue
yield dump_command(key, value)
def _dump_symbol_definitions(self) -> Iterator[str]:
for time, symbol in self.symbol_definitions.items():
decimal_time = fraction_to_decimal(time)
yield f"*{symbol}:{decimal_time:.6f}"
def _dump_notes(self, circle_free: bool = False,) -> Iterator[str]:
frames: List[Dict[NotePosition, str]] = []
frame: Dict[NotePosition, str] = {}
for note in self.notes:
if isinstance(note, LongNote):
needed_positions = set(note.positions_covered())
if needed_positions & frame.keys():
frames.append(frame)
frame = {}
direction = note.tail_direction()
arrow = DIRECTION_TO_ARROW[direction]
line = DIRECTION_TO_LINE[direction]
for is_first, is_last, pos in mark_ends(note.positions_covered()):
if is_first:
time_in_section = note.time - self.current_beat
symbol = self.symbols[time_in_section]
frame[pos] = symbol
elif is_last:
frame[pos] = arrow
else:
frame[pos] = line
elif isinstance(note, TapNote):
if note.position in frame:
frames.append(frame)
frame = {}
time_in_section = note.time - self.current_beat
symbol = self.symbols[time_in_section]
frame[note.position] = symbol
elif isinstance(note, LongNoteEnd):
if note.position in frame:
frames.append(frame)
frame = {}
time_in_section = note.time - self.current_beat
if circle_free:
symbol = CIRCLE_FREE_SYMBOLS[time_in_section]
else:
symbol = self.symbols[time_in_section]
frame[note.position] = symbol
frames.append(frame)
dumped_frames = map(self._dump_frame, frames)
yield from collapse(intersperse("", dumped_frames))
@staticmethod
def _dump_frame(frame: Dict[NotePosition, str]) -> Iterator[str]:
for y in range(4):
yield "".join(frame.get(NotePosition(x, y), "") for x in range(4))
class Sections(SortedDict):
"""Custom SortedDict that also acts as a defaultdict of
MonoColumnDumpedSection"""
def add_section(self, beat):
if beat not in self:
section = MonoColumnDumpedSection(beat)
self.__setitem__(beat, section)
def __missing__(self, beat):
section = MonoColumnDumpedSection(beat)
self.__setitem__(beat, section)
return section
@dataclass(frozen=True)
class LongNoteEnd:
time: BeatsTime
position: NotePosition
def _raise_if_unfit_for_mono_column(
chart: Chart, timing: Timing, circle_free: bool = False
):
if len(timing.events) < 1:
raise ValueError("No BPM found in file") from None
first_bpm = min(timing.events, key=lambda e: e.time)
if first_bpm.time != 0:
raise ValueError("First BPM event does not happen on beat zero")
if any(not note.tail_is_straight() for note in chart.notes if isinstance(note, LongNote)):
if any(
not note.tail_is_straight()
for note in chart.notes
if isinstance(note, LongNote)
):
raise ValueError(
"Chart contains diagonal long notes, reprensenting these in"
" mono_column format is not supported by jubeatools"
)
if circle_free and any(
(note.time + note.duration) % BeatsTime(1, 4) != 0
for note in chart.notes if isinstance(note, LongNote)
for note in chart.notes
if isinstance(note, LongNote)
):
raise ValueError(
"Chart contains long notes whose ending timing aren't"
" representable in #circlefree mode"
)
class SortedDefaultDict(SortedDict):
def __init__(self, default_factory, *args, **kwargs):
super().__init__(*args, **kwargs)
self.default_factory = default_factory
def __missing__(self, key):
value = self.default_factory()
self.__setitem__(key, value)
return value
@dataclass(frozen=True)
class LongNoteEnd:
time: BeatsTime
position: NotePosition
def _dump_mono_column_chart(
difficulty: str, chart: Chart, metadata: Metadata, timing: Timing
@ -205,36 +226,45 @@ def _dump_mono_column_chart(
timing_events = sorted(timing.events, key=lambda e: e.time)
notes = SortedKeyList(set(chart.notes), key=lambda n: n.time)
# Add long note ends
for note in chart.notes:
if isinstance(note, LongNote):
notes.add(LongNoteEnd(note.time+note.duration, note.position))
last_event_time = max(timing_events[-1].time, notes[-1].time)
last_measure = last_event_time // 4
sections = SortedDefaultDict(
MonoColumnDumpedSection,
{BeatsTime(4) * i: MonoColumnDumpedSection() for i in range(last_measure + 1)},
)
sections[0].commands.update(
o=int(timing.beat_zero_offset * 1000),
m=metadata.audio,
title=metadata.title,
artist=metadata.artist,
lev=int(chart.level),
dif=DIFFICULTIES.get(difficulty, 1),
jacket=metadata.cover,
prevpos=int(metadata.preview_start * 1000),
)
notes.add(LongNoteEnd(note.time + note.duration, note.position))
all_events = SortedKeyList(timing_events + notes, key=lambda n: n.time)
last_event = all_events[-1]
last_measure = last_event.time // 4
sections = Sections()
for i in range(last_measure + 1):
beat = BeatsTime(4) * i
sections.add_section(beat)
header = sections[0].commands
header["o"] = int(timing.beat_zero_offset * 1000)
header["lev"] = int(chart.level)
header["dif"] = DIFFICULTIES.get(difficulty, 1)
if metadata.audio:
header["m"] = metadata.audio
if metadata.title:
header["title"] = metadata.title
if metadata.artist:
header["artist"] = metadata.artist
if metadata.cover:
header["jacket"] = metadata.cover
if metadata.preview is not None:
header["prevpos"] = int(metadata.preview.start * 1000)
if any(isinstance(note, LongNote) for note in chart.notes):
header["holdbyarrow"] = 1
# Potentially create sub-sections for bpm changes
for event in timing_events:
sections[event.time].commands["t"] = event.BPM
# First, Set every single b=... value
for key, next_key in windowed(chain(sections.keys(),[None]), 2):
# First, Set every single b=… value
for key, next_key in windowed(chain(sections.keys(), [None]), 2):
if next_key is None:
sections[key].commands["b"] = 4
else:
sections[key].commands["b"] = fraction_to_decimal(next_key - key)
# Then, trim all the redundant b=...
# Then, trim all the redundant b=
last_b = 4
for section in sections.values():
current_b = section.commands["b"]
@ -243,21 +273,22 @@ def _dump_mono_column_chart(
else:
last_b = current_b
# Fill sections with notes
for key, next_key in windowed(chain(sections.keys(),[None]), 2):
for key, next_key in windowed(chain(sections.keys(), [None]), 2):
sections[key].notes = list(
notes.irange_key(min_key=key, max_key=next_key, inclusive=(True, False))
)
# Define extra symbols
existing_symbols = deepcopy(BEATS_TIME_TO_SYMBOL)
extra_symbols = iter(DEFAULT_EXTRA_SYMBOLS)
all_extra_symbols = {}
for section_start, section in sections.items():
# intentionally not a copy : at the end of this loop every section
# holds a reference to a dict containing every defined symbol
section.symbols = existing_symbols
for note in section.notes:
time_in_section = note.time - section_start
if time_in_section not in existing_symbols:
new_symbol = next(extra_symbols)
section.symbol_definitions[time_in_section] = new_symbol
all_extra_symbols[time_in_section] = new_symbol
existing_symbols[time_in_section] = new_symbol
# Actual output to file
@ -265,7 +296,7 @@ def _dump_mono_column_chart(
file.write(f"// Converted using jubeatools {__version__}\n")
file.write(f"// https://github.com/Stepland/jubeatools\n\n")
for section_start, section in sections.items():
file.write(section.render(section_start, all_extra_symbols) + "\n")
file.write(section.render() + "\n")
return file

View File

@ -28,7 +28,7 @@ from jubeatools.song import (
from ..command import is_command, parse_command
from ..symbol import is_symbol_definition, parse_symbol_definition
from .commons import NOTE_SYMBOLS, CIRCLE_FREE_SYMBOLS
from .commons import CIRCLE_FREE_SYMBOLS, NOTE_SYMBOLS
mono_column_chart_line_grammar = Grammar(
r"""

View File

@ -330,4 +330,4 @@ def test_circle_free():
expected = [
LongNote(BeatsTime(0), NotePosition(3, 3), BeatsTime(7), NotePosition(0, 3))
]
compare_chart_notes(chart, expected)
compare_chart_notes(chart, expected)

View File

@ -1,17 +1,30 @@
from decimal import Decimal
from hypothesis import given
import hypothesis.strategies as st
from hypothesis import given
from jubeatools.song import BeatsTime, LongNote, NotePosition, TapNote, Chart, Timing, BPMEvent, SecondsTime, BeatsTime, Metadata
from jubeatools.song import (
BeatsTime,
BPMEvent,
Chart,
LongNote,
Metadata,
NotePosition,
SecondsTime,
TapNote,
Timing,
)
from ..mono_column.load import MonoColumnParser
from ..mono_column.dump import _dump_mono_column_chart
from ..mono_column.load import MonoColumnParser
@st.composite
def beat_time(draw):
numerator = draw(st.integers(min_value=0, max_value=240*4*10))
return BeatsTime(numerator, 240)
denominator = draw(st.sampled_from([4, 8, 16, 3, 5]))
numerator = draw(st.integers(min_value=0, max_value=denominator * 4 * 10))
return BeatsTime(numerator, denominator)
@st.composite
def note_position(draw):
@ -19,6 +32,7 @@ def note_position(draw):
y = draw(st.integers(min_value=0, max_value=3))
return NotePosition(x, y)
@st.composite
def tap_note(draw):
time = draw(beat_time())
@ -29,13 +43,10 @@ def tap_note(draw):
@given(st.sets(tap_note(), min_size=1, max_size=2000))
def test_tap_notes(notes):
timing = Timing(
events=[BPMEvent(BeatsTime(0), Decimal(120))],
beat_zero_offset=SecondsTime(0)
events=[BPMEvent(BeatsTime(0), Decimal(120))], beat_zero_offset=SecondsTime(0)
)
chart = Chart(
level=0,
timing=timing,
notes=sorted(notes, key=lambda n: (n.time, n.position))
level=0, timing=timing, notes=sorted(notes, key=lambda n: (n.time, n.position))
)
metadata = Metadata("", "", "", "")
string_io = _dump_mono_column_chart("", chart, metadata, timing)
@ -45,3 +56,34 @@ def test_tap_notes(notes):
parser.load_line(line)
actual = set(parser.notes())
assert notes == actual
@st.composite
def long_note(draw):
time = draw(beat_time())
duration = draw(st.integers(min_value=1, max_value=time.denominator * 4 * 3))
position = draw(note_position())
tail_index = draw(st.integers(min_value=0, max_value=5))
if tail_index >= 3:
y = sorted(set(range(4)).difference([position.y]))[tail_index - 3]
tail_tip = NotePosition(position.x, y)
else:
x = sorted(set(range(4)).difference([position.x]))[tail_index]
tail_tip = NotePosition(x, position.y)
return LongNote(time, position, BeatsTime(duration, time.denominator), tail_tip)
@given(long_note())
def test_single_long_note(note):
timing = Timing(
events=[BPMEvent(BeatsTime(0), Decimal(120))], beat_zero_offset=SecondsTime(0)
)
chart = Chart(level=0, timing=timing, notes=[note])
metadata = Metadata("", "", "", "")
string_io = _dump_mono_column_chart("", chart, metadata, timing)
chart = string_io.getvalue()
parser = MonoColumnParser()
for line in chart.split("\n"):
parser.load_line(line)
actual = set(parser.notes())
assert set([note]) == actual

View File

@ -105,19 +105,20 @@ class LongNote:
def tail_direction(self) -> NotePosition:
if not self.tail_is_straight():
raise ValueError("Can't get tail direction when it's not straight")
diff = self.tail_tip - self.position
if diff.x == 0:
diff.y //= abs(diff.y)
x, y = (self.tail_tip - self.position).as_tuple()
if x == 0:
y //= abs(y)
else:
diff.x //= abs(diff.x)
return diff
x //= abs(x)
return NotePosition(x, y)
def positions_covered(self) -> Iterator[NotePosition]:
direction = self.tail_direction()
position = self.position
yield position
while position != self.tail_tip:
yield position
position = position + direction
yield position
@dataclass
@ -139,14 +140,19 @@ class Chart:
notes: List[Union[TapNote, LongNote]] = field(default_factory=list)
@dataclass
class Preview:
start: SecondsTime
length: SecondsTime
@dataclass
class Metadata:
title: str
artist: str
audio: Path
cover: Path
preview_start: SecondsTime = SecondsTime(0)
preview_length: SecondsTime = SecondsTime(0)
preview: Optional[Preview] = None
@dataclass