1
0
mirror of synced 2024-12-04 19:17:55 +01:00

#memo loader untested version

This commit is contained in:
Stepland 2020-07-14 17:33:05 +02:00
parent 2a72284f09
commit 8abdf8edd6
3 changed files with 115 additions and 61 deletions

View File

@ -4,10 +4,11 @@ from copy import deepcopy
from dataclasses import dataclass
from decimal import Decimal
from functools import reduce
from itertools import chain, product
from itertools import chain, product, zip_longest
from typing import Mapping, Dict, Iterator, List, Optional, Set, Tuple, Union
import constraint
from more_itertools import mark_ends
from parsimonious import Grammar, NodeVisitor, ParseError
from path import Path
@ -26,6 +27,7 @@ from ..command import is_command, parse_command
from ..files import load_files
from ..parser import (
CIRCLE_FREE_TO_DECIMAL_TIME,
CIRCLE_FREE_TO_NOTE_SYMBOL,
LONG_ARROWS,
LONG_DIRECTION,
JubeatAnalyserParser,
@ -37,7 +39,7 @@ from ..parser import (
split_double_byte_line,
)
from ..symbol_definition import is_symbol_definition, parse_symbol_definition
from ..symbols import CIRCLE_FREE_SYMBOLS
from ..symbols import CIRCLE_FREE_SYMBOLS, NOTE_SYMBOLS
memo_chart_line_grammar = Grammar(
r"""
@ -104,6 +106,15 @@ class MemoFrame:
res += len(t)
return Decimal("0.25") * res
def __str__(self):
res = []
for pos, time in zip_longest(self.position_part, self.timing_part):
line = [f"{''.join(pos)}"]
if time is not None:
line += [f"|{''.join(time)}|"]
res += [" ".join(line)]
return "\n".join(res)
@dataclass
class MemoLoadedSection:
@ -112,6 +123,18 @@ class MemoLoadedSection:
length: Decimal
tempo: Decimal
def __str__(self):
res = []
if self.length != 4:
res += [f"b={self.length}", ""]
for symbol, time in self.symbols:
res += [f"*{symbol}:{time}", ""]
for _, is_last, frame in mark_ends(self.frames):
res += [str(frame)]
if not is_last:
res += [""]
return "\n".join(res)
# Any unicode character that's both :
# - confusable with a dash/hyphen
@ -138,7 +161,7 @@ class MemoParser(JubeatAnalyserParser):
self.frames: List[MemoFrame] = []
self.sections: List[MemoLoadedSection] = []
self.only_timingless_frames = False
def do_memo(self):
...
@ -149,15 +172,15 @@ class MemoParser(JubeatAnalyserParser):
def do_bpp(self, value):
if self.sections or self.frames:
raise ValueError(
"jubeatools does not handle changes of #bpp halfway"
)
raise ValueError("jubeatools does not handle changes of #bpp halfway")
else:
self._do_bpp(value)
def append_chart_line(self, line: MemoChartLine):
if len(line.position.encode("shift_jis_2004")) != 4 * self.bytes_per_panel:
raise SyntaxError(f"Invalid chart line for #bpp={self.bytes_per_panel} : {line}")
raise SyntaxError(
f"Invalid chart line for #bpp={self.bytes_per_panel} : {line}"
)
if line.timing is not None and self.bytes_per_panel == 2:
if len(line.timing.encode("shift_jis_2004")) % 2 != 0:
raise SyntaxError(f"Invalid chart line for #bpp=2 : {line}")
@ -170,7 +193,7 @@ class MemoParser(JubeatAnalyserParser):
return split_double_byte_line(line)
else:
return list(line)
def _frames_duration(self) -> Decimal:
return sum(frame.duration for frame in self.frames)
@ -191,7 +214,7 @@ class MemoParser(JubeatAnalyserParser):
if self._frames_duration() >= self.beats_per_section:
# then the current frame starts a new section
self._push_section()
self.frames.append(frame)
def _push_section(self):
@ -238,43 +261,67 @@ class MemoParser(JubeatAnalyserParser):
yield from self._iter_notes()
else:
yield from self._iter_notes_without_longs()
def _iter_frames(self) -> Iterator[Tuple[Decimal, Mapping[str, Decimal], MemoLoadedSection]]:
"""iterate over tuples of frame_starting_beat, frame, section_starting_beat, section"""
def _iter_frames(
self,
) -> Iterator[
Tuple[Mapping[str, Decimal], Decimal, MemoFrame, Decimal, MemoLoadedSection]
]:
"""iterate over tuples of
currently_defined_symbols, frame_starting_beat, frame, section_starting_beat, section"""
local_symbols: Dict[str, Decimal] = {}
section_starting_beat = Decimal(0)
for section in self.sections:
frame_starting_beat = Decimal(0)
for frame in section.frames:
yield frame_starting_beat, frame, section_starting_beat, section
if frame.timing_part:
local_symbols = {
symbol: Decimal("0.25") * i
for i, symbol in enumerate(chain(frame.timing_part))
if symbol not in EMPTY_BEAT_SYMBOLS
}
currently_defined_symbols = ChainMap(local_symbols, section.symbols)
yield currently_defined_symbols, frame_starting_beat, frame, section_starting_beat, section
frame_starting_beat += frame.duration
section_starting_beat += section.length
def _iter_notes(self) -> Iterator[Union[TapNote, LongNote]]:
unfinished_longs: Dict[NotePosition, UnfinishedLongNote] = {}
for section_starting_beat, section, bloc in self._iter_blocs():
for (
currently_defined_symbols,
frame_starting_beat,
frame,
section_starting_beat,
section,
) in self._iter_frames():
should_skip: Set[NotePosition] = set()
# 1/3 : look for ends to unfinished long notes
for pos, unfinished_long in unfinished_longs.items():
x, y = pos.as_tuple()
symbol = bloc[y][x]
if self.circle_free:
if symbol in CIRCLE_FREE_SYMBOLS:
should_skip.add(pos)
symbol_time = CIRCLE_FREE_TO_DECIMAL_TIME[symbol]
note_time = decimal_to_beats(section_starting_beat + symbol_time)
yield unfinished_long.ends_at(note_time)
elif symbol in section.symbols:
symbol = frame.position_part[y][x]
if self.circle_free and symbol in CIRCLE_FREE_SYMBOLS:
circled_symbol = CIRCLE_FREE_TO_NOTE_SYMBOL[symbol]
try:
symbol_time = currently_defined_symbols[circled_symbol]
except KeyError:
raise SyntaxError(
"Can't have a note symbol on the holding square of"
" an unfinished long note when #circlefree is on"
"Chart section positional part constains the circle free "
f"symbol '{symbol}' but the associated circled symbol "
f"'{circled_symbol}' could not be found in the timing part:\n"
f"{section}"
)
else:
if symbol in section.symbols:
should_skip.add(pos)
symbol_time = section.symbols[symbol]
note_time = decimal_to_beats(section_starting_beat + symbol_time)
yield unfinished_long.ends_at(note_time)
try:
symbol_time = currently_defined_symbols[symbol]
except KeyError:
continue
should_skip.add(pos)
note_time = decimal_to_beats(
section_starting_beat + frame_starting_beat + symbol_time
)
yield unfinished_long.ends_at(note_time)
unfinished_longs = {
k: unfinished_longs[k] for k in unfinished_longs.keys() - should_skip
@ -286,7 +333,7 @@ class MemoParser(JubeatAnalyserParser):
pos = NotePosition(x, y)
if pos in should_skip:
continue
symbol = bloc[y][x]
symbol = frame.position_part[y][x]
if symbol not in LONG_ARROWS:
continue
# at this point we are sure we have a long arrow
@ -297,8 +344,8 @@ class MemoParser(JubeatAnalyserParser):
while 0 <= candidate.x < 4 and 0 <= candidate.y < 4:
if candidate in should_skip:
continue
new_symbol = bloc[candidate.y][candidate.x]
if new_symbol in section.symbols:
new_symbol = frame.position_part[candidate.y][candidate.x]
if new_symbol in currently_defined_symbols:
note_candidates.add(candidate)
candidate += 𝛿pos
# if no notes have been crossed, we just ignore the arrow
@ -315,23 +362,20 @@ class MemoParser(JubeatAnalyserParser):
solutions = problem.getSolutions()
if not solutions:
raise SyntaxError(
"Invalid long note arrow pattern in bloc :\n"
+ "\n".join("".join(line) for line in bloc)
"Invalid long note arrow pattern in section :\n" + str(section)
)
solution = min(solutions, key=long_note_solution_heuristic)
if len(solutions) > 1 and not is_simple_solution(
solution, arrow_to_note_candidates
):
warnings.warn(
"Ambiguous arrow pattern in bloc :\n"
+ "\n".join("".join(line) for line in bloc)
+ "\n"
"The resulting long notes might not be what you expect"
"Ambiguous arrow pattern in section :\n" + str(section) + "\n"
"The chosen long notes might not be what you expect"
)
for arrow_pos, note_pos in solution.items():
should_skip.add(arrow_pos)
should_skip.add(note_pos)
symbol = bloc[note_pos.y][note_pos.x]
symbol = frame.position_part[note_pos.y][note_pos.x]
symbol_time = section.symbols[symbol]
note_time = decimal_to_beats(section_starting_beat + symbol_time)
unfinished_longs[note_pos] = UnfinishedLongNote(
@ -343,31 +387,34 @@ class MemoParser(JubeatAnalyserParser):
position = NotePosition(x, y)
if position in should_skip:
continue
symbol = bloc[y][x]
if symbol in section.symbols:
symbol_time = section.symbols[symbol]
note_time = decimal_to_beats(section_starting_beat + symbol_time)
yield TapNote(note_time, position)
symbol = frame.position_part[y][x]
try:
symbol_time = currently_defined_symbols[symbol]
except KeyError:
continue
note_time = decimal_to_beats(
section_starting_beat + frame_starting_beat + symbol_time
)
yield TapNote(note_time, position)
def _iter_notes_without_longs(self) -> Iterator[TapNote]:
local_symbols: Dict[str, Decimal] = {}
for frame_starting_beat, frame, section_starting_beat, section in self._iter_frames():
# define local note symbols according to what's found in the timing part
if frame.timing_part:
local_symbols = {
symbol: Decimal("0.25") * i
for i, symbol in enumerate(chain(frame.timing_part))
if symbol not in EMPTY_BEAT_SYMBOLS
}
currently_defined_symbols = ChainMap(local_symbols, section.symbols)
# cross compare with the position information
for (
currently_defined_symbols,
frame_starting_beat,
frame,
section_starting_beat,
_,
) in self._iter_frames():
# cross compare symbols with the position information
for y, x in product(range(4), range(4)):
symbol = frame.position_part[y][x]
try:
symbol_time = currently_defined_symbols[symbol]
except KeyError:
continue
note_time = decimal_to_beats(section_starting_beat + frame_starting_beat + symbol_time)
note_time = decimal_to_beats(
section_starting_beat + frame_starting_beat + symbol_time
)
position = NotePosition(x, y)
yield TapNote(note_time, position)
@ -382,7 +429,6 @@ def _load_memo_file(lines: List[str]) -> Song:
f"Error while parsing memo line {i} :\n" f"{type(e).__name__}: {e}"
) from None
# finish the current section
parser.finish_last_few_notes()
metadata = Metadata(
title=parser.title,
@ -410,4 +456,4 @@ def _load_memo_file(lines: List[str]) -> Song:
def load_memo(path: Path) -> Song:
files = load_files(path)
charts = [_load_memo_file(lines) for _, lines in files.items()]
return reduce(lambda a, b: a.merge(b), charts)
return reduce(Song.merge, charts)

View File

@ -204,7 +204,9 @@ class MonoColumnParser(JubeatAnalyserParser):
if symbol in CIRCLE_FREE_SYMBOLS:
should_skip.add(pos)
symbol_time = CIRCLE_FREE_TO_DECIMAL_TIME[symbol]
note_time = decimal_to_beats(section_starting_beat + symbol_time)
note_time = decimal_to_beats(
section_starting_beat + symbol_time
)
yield unfinished_long.ends_at(note_time)
elif symbol in section.symbols:
raise SyntaxError(
@ -215,7 +217,9 @@ class MonoColumnParser(JubeatAnalyserParser):
if symbol in section.symbols:
should_skip.add(pos)
symbol_time = section.symbols[symbol]
note_time = decimal_to_beats(section_starting_beat + symbol_time)
note_time = decimal_to_beats(
section_starting_beat + symbol_time
)
yield unfinished_long.ends_at(note_time)
unfinished_longs = {

View File

@ -25,6 +25,8 @@ CIRCLE_FREE_TO_DECIMAL_TIME = {
c: Decimal("0.25") * i for i, c in enumerate(CIRCLE_FREE_SYMBOLS)
}
CIRCLE_FREE_TO_NOTE_SYMBOL = dict(zip(CIRCLE_FREE_SYMBOLS, NOTE_SYMBOLS))
LONG_ARROWS = LONG_ARROW_LEFT | LONG_ARROW_DOWN | LONG_ARROW_UP | LONG_ARROW_RIGHT
LONG_DIRECTION = {
@ -136,7 +138,9 @@ class JubeatAnalyserParser:
def do_t(self, value):
self.current_tempo = Decimal(value)
self.timing_events.append(BPMEvent(self.section_starting_beat, BPM=self.current_tempo))
self.timing_events.append(
BPMEvent(self.section_starting_beat, BPM=self.current_tempo)
)
def do_o(self, value):
self.offset = int(value)