From 69bd438fe11dc7b60f5a9c15bc1d3dad6fd46a9d Mon Sep 17 00:00:00 2001 From: WerWolv Date: Sun, 30 Jan 2022 15:18:45 +0100 Subject: [PATCH] pattern: Added parameter packs --- .../include/hex/api/content_registry.hpp | 9 +-- .../include/hex/pattern_language/ast_node.hpp | 70 +++++++++++++++---- .../hex/pattern_language/evaluator.hpp | 19 ++++- .../include/hex/pattern_language/token.hpp | 1 + .../source/pattern_language/evaluator.cpp | 11 ++- .../source/pattern_language/log_console.cpp | 5 +- .../source/pattern_language/parser.cpp | 38 +++++----- 7 files changed, 115 insertions(+), 38 deletions(-) diff --git a/lib/libimhex/include/hex/api/content_registry.hpp b/lib/libimhex/include/hex/api/content_registry.hpp index 182db8c88..e008eb253 100644 --- a/lib/libimhex/include/hex/api/content_registry.hpp +++ b/lib/libimhex/include/hex/api/content_registry.hpp @@ -92,10 +92,11 @@ namespace hex { /* Pattern Language Function Registry. Allows adding of new functions that may be used inside the pattern language */ namespace PatternLanguage { - constexpr static u32 UnlimitedParameters = 0xFFFF'FFFF; - constexpr static u32 MoreParametersThan = 0x8000'0000; - constexpr static u32 LessParametersThan = 0x4000'0000; - constexpr static u32 NoParameters = 0x0000'0000; + constexpr static u32 UnlimitedParameters = 0xFFFF'FFFF; + constexpr static u32 MoreParametersThan = 0x8000'0000; + constexpr static u32 LessParametersThan = 0x4000'0000; + constexpr static u32 ExactlyOrMoreParametersThan = 0x2000'0000; + constexpr static u32 NoParameters = 0x0000'0000; using Namespace = std::vector; using Callback = std::function(hex::pl::Evaluator *, const std::vector &)>; diff --git a/lib/libimhex/include/hex/pattern_language/ast_node.hpp b/lib/libimhex/include/hex/pattern_language/ast_node.hpp index 425245e39..5133caafa 100644 --- a/lib/libimhex/include/hex/pattern_language/ast_node.hpp +++ b/lib/libimhex/include/hex/pattern_language/ast_node.hpp @@ -1527,6 +1527,22 @@ namespace hex::pl { std::vector> m_entries; }; + class ASTNodeParameterPack : public ASTNode { + public: + ASTNodeParameterPack(const std::vector &values) : m_values(values) {} + + [[nodiscard]] ASTNode *clone() const override { + return new ASTNodeParameterPack(*this); + } + + const std::vector &getValues() const { + return this->m_values; + } + + private: + std::vector m_values; + }; + class ASTNodeRValue : public ASTNode { public: using Path = std::vector>; @@ -1554,6 +1570,10 @@ namespace hex::pl { if (this->getPath().size() == 1) { if (auto name = std::get_if(&this->getPath().front()); name != nullptr) { if (*name == "$") return new ASTNodeLiteral(u128(evaluator->dataOffset())); + + auto ¶meterPack = evaluator->getScope(0).parameterPack; + if (parameterPack && *name == parameterPack->name) + return new ASTNodeParameterPack(parameterPack->values); } } @@ -1640,6 +1660,7 @@ namespace hex::pl { PatternData *currPattern = nullptr; i32 scopeIndex = 0; + if (!evaluator->isGlobalScope()) { auto globalScope = evaluator->getGlobalScope().scope; std::copy(globalScope->begin(), globalScope->end(), std::back_inserter(searchScope)); @@ -1697,8 +1718,9 @@ namespace hex::pl { if (name == "$") LogConsole::abortEvaluation("invalid use of placeholder operator in rvalue"); - if (!found) + if (!found) { LogConsole::abortEvaluation(hex::format("no variable named '{}' found", name), this); + } } } else { // Array indexing @@ -1973,10 +1995,16 @@ namespace hex::pl { auto expression = param->evaluate(evaluator); ON_SCOPE_EXIT { delete expression; }; - auto literal = dynamic_cast(expression->evaluate(evaluator)); - ON_SCOPE_EXIT { delete literal; }; + if (auto literal = dynamic_cast(expression->evaluate(evaluator))) { + evaluatedParams.push_back(literal->getValue()); + delete literal; + } else if (auto parameterPack = dynamic_cast(expression->evaluate(evaluator))) { + for (auto &value : parameterPack->getValues()) { + evaluatedParams.push_back(value); + } - evaluatedParams.push_back(literal->getValue()); + delete parameterPack; + } } auto &customFunctions = evaluator->getCustomFunctions(); @@ -1993,10 +2021,13 @@ namespace hex::pl { ; // Don't check parameter count } else if (function.parameterCount & ContentRegistry::PatternLanguage::LessParametersThan) { if (evaluatedParams.size() >= (function.parameterCount & ~ContentRegistry::PatternLanguage::LessParametersThan)) - LogConsole::abortEvaluation(hex::format("too many parameters for function '{0}'. Expected {1}", this->m_functionName, function.parameterCount & ~ContentRegistry::PatternLanguage::LessParametersThan), this); + LogConsole::abortEvaluation(hex::format("too many parameters for function '{0}'. Expected less than {1}", this->m_functionName, function.parameterCount & ~ContentRegistry::PatternLanguage::LessParametersThan), this); } else if (function.parameterCount & ContentRegistry::PatternLanguage::MoreParametersThan) { if (evaluatedParams.size() <= (function.parameterCount & ~ContentRegistry::PatternLanguage::MoreParametersThan)) - LogConsole::abortEvaluation(hex::format("too few parameters for function '{0}'. Expected {1}", this->m_functionName, function.parameterCount & ~ContentRegistry::PatternLanguage::MoreParametersThan), this); + LogConsole::abortEvaluation(hex::format("too few parameters for function '{0}'. Expected more than {1}", this->m_functionName, function.parameterCount & ~ContentRegistry::PatternLanguage::MoreParametersThan), this); + } else if (function.parameterCount & ContentRegistry::PatternLanguage::ExactlyOrMoreParametersThan) { + if (evaluatedParams.size() < (function.parameterCount & ~ContentRegistry::PatternLanguage::ExactlyOrMoreParametersThan)) + LogConsole::abortEvaluation(hex::format("too few parameters for function '{0}'. Expected more than {1}", this->m_functionName, (function.parameterCount - 1) & ~ContentRegistry::PatternLanguage::ExactlyOrMoreParametersThan), this); } else if (function.parameterCount != evaluatedParams.size()) { LogConsole::abortEvaluation(hex::format("invalid number of parameters for function '{0}'. Expected {1}", this->m_functionName, function.parameterCount), this); } @@ -2193,8 +2224,8 @@ namespace hex::pl { class ASTNodeFunctionDefinition : public ASTNode { public: - ASTNodeFunctionDefinition(std::string name, std::vector> params, std::vector body) - : m_name(std::move(name)), m_params(std::move(params)), m_body(std::move(body)) { + ASTNodeFunctionDefinition(std::string name, std::vector> params, std::vector body, std::optional parameterPack) + : m_name(std::move(name)), m_params(std::move(params)), m_body(std::move(body)), m_parameterPack(std::move(parameterPack)) { } ASTNodeFunctionDefinition(const ASTNodeFunctionDefinition &other) : ASTNode(other) { @@ -2235,7 +2266,12 @@ namespace hex::pl { [[nodiscard]] ASTNode *evaluate(Evaluator *evaluator) const override { - evaluator->addCustomFunction(this->m_name, this->m_params.size(), [this](Evaluator *ctx, const std::vector ¶ms) -> std::optional { + size_t paramCount = this->m_params.size(); + + if (this->m_parameterPack.has_value()) + paramCount |= ContentRegistry::PatternLanguage::ExactlyOrMoreParametersThan; + + evaluator->addCustomFunction(this->m_name, paramCount, [this](Evaluator *ctx, const std::vector ¶ms) -> std::optional { std::vector variables; ctx->pushScope(nullptr, variables); @@ -2246,12 +2282,19 @@ namespace hex::pl { ctx->popScope(); }; - u32 paramIndex = 0; - for (const auto &[name, type] : this->m_params) { + if (this->m_parameterPack.has_value()) { + std::vector parameterPackContent; + for (u32 paramIndex = this->m_params.size(); paramIndex < params.size(); paramIndex++) + parameterPackContent.push_back(params[paramIndex]); + + ctx->createParameterPack(this->m_parameterPack.value(), parameterPackContent); + } + + for (u32 paramIndex = 0; paramIndex < this->m_params.size(); paramIndex++) { + const auto &[name, type] = this->m_params[paramIndex]; + ctx->createVariable(name, type, params[paramIndex]); ctx->setVariable(name, params[paramIndex]); - - paramIndex++; } for (auto statement : this->m_body) { @@ -2283,6 +2326,7 @@ namespace hex::pl { std::string m_name; std::vector> m_params; std::vector m_body; + std::optional m_parameterPack; }; class ASTNodeCompoundStatement : public ASTNode { diff --git a/lib/libimhex/include/hex/pattern_language/evaluator.hpp b/lib/libimhex/include/hex/pattern_language/evaluator.hpp index cf99d9deb..62790c281 100644 --- a/lib/libimhex/include/hex/pattern_language/evaluator.hpp +++ b/lib/libimhex/include/hex/pattern_language/evaluator.hpp @@ -46,9 +46,15 @@ namespace hex::pl { return this->m_console; } + struct ParameterPack { + std::string name; + std::vector values; + }; + struct Scope { PatternData *parent; std::vector *scope; + std::optional parameterPack; }; void pushScope(PatternData *parent, std::vector &scope) { if (this->m_scopes.size() > this->getEvaluationDepth()) @@ -63,11 +69,19 @@ namespace hex::pl { this->m_scopes.pop_back(); } - const Scope &getScope(i32 index) { + Scope &getScope(i32 index) { return this->m_scopes[this->m_scopes.size() - 1 + index]; } - const Scope &getGlobalScope() { + const Scope &getScope(i32 index) const { + return this->m_scopes[this->m_scopes.size() - 1 + index]; + } + + Scope &getGlobalScope() { + return this->m_scopes.front(); + } + + const Scope &getGlobalScope() const { return this->m_scopes.front(); } @@ -167,6 +181,7 @@ namespace hex::pl { return this->m_stack; } + void createParameterPack(const std::string &name, const std::vector &values); void createVariable(const std::string &name, ASTNode *type, const std::optional &value = std::nullopt, bool outVariable = false); void setVariable(const std::string &name, const Token::Literal &value); diff --git a/lib/libimhex/include/hex/pattern_language/token.hpp b/lib/libimhex/include/hex/pattern_language/token.hpp index eaaadef3f..1422215fc 100644 --- a/lib/libimhex/include/hex/pattern_language/token.hpp +++ b/lib/libimhex/include/hex/pattern_language/token.hpp @@ -342,6 +342,7 @@ namespace hex::pl { #define VALUETYPE_UNSIGNED COMPONENT(ValueType, Unsigned) #define VALUETYPE_SIGNED COMPONENT(ValueType, Signed) #define VALUETYPE_FLOATINGPOINT COMPONENT(ValueType, FloatingPoint) +#define VALUETYPE_AUTO COMPONENT(ValueType, Auto) #define VALUETYPE_ANY COMPONENT(ValueType, Any) #define SEPARATOR_ROUNDBRACKETOPEN COMPONENT(Separator, RoundBracketOpen) diff --git a/lib/libimhex/source/pattern_language/evaluator.cpp b/lib/libimhex/source/pattern_language/evaluator.cpp index 7ce8b29fd..cc9f97961 100644 --- a/lib/libimhex/source/pattern_language/evaluator.cpp +++ b/lib/libimhex/source/pattern_language/evaluator.cpp @@ -6,6 +6,13 @@ namespace hex::pl { Evaluator *PatternCreationLimiter::s_evaluator = nullptr; + void Evaluator::createParameterPack(const std::string &name, const std::vector &values) { + this->getScope(0).parameterPack = ParameterPack { + name, + values + }; + } + void Evaluator::createVariable(const std::string &name, ASTNode *type, const std::optional &value, bool outVariable) { auto &variables = *this->getScope(0).scope; for (auto &variable : variables) { @@ -15,7 +22,7 @@ namespace hex::pl { } auto startOffset = this->dataOffset(); - auto pattern = type->createPatterns(this).front(); + auto pattern = type == nullptr ? nullptr : type->createPatterns(this).front(); this->dataOffset() = startOffset; if (pattern == nullptr) { @@ -38,7 +45,7 @@ namespace hex::pl { else if (std::get_if(&value.value()) != nullptr) pattern = new PatternDataString(0, 1); else - __builtin_unreachable(); + LogConsole::abortEvaluation("cannot determine type of auto variable", type); } pattern->setVariableName(name); diff --git a/lib/libimhex/source/pattern_language/log_console.cpp b/lib/libimhex/source/pattern_language/log_console.cpp index 76cf74db8..e5a2be189 100644 --- a/lib/libimhex/source/pattern_language/log_console.cpp +++ b/lib/libimhex/source/pattern_language/log_console.cpp @@ -13,7 +13,10 @@ namespace hex::pl { } [[noreturn]] void LogConsole::abortEvaluation(const std::string &message, const ASTNode *node) { - throw EvaluateError(static_cast(node)->getLineNumber(), message); + if (node == nullptr) + abortEvaluation(message); + else + throw EvaluateError(node->getLineNumber(), message); } void LogConsole::clear() { diff --git a/lib/libimhex/source/pattern_language/parser.cpp b/lib/libimhex/source/pattern_language/parser.cpp index 781fbb3a4..a4d237646 100644 --- a/lib/libimhex/source/pattern_language/parser.cpp +++ b/lib/libimhex/source/pattern_language/parser.cpp @@ -431,31 +431,37 @@ namespace hex::pl { ASTNode *Parser::parseFunctionDefinition() { const auto &functionName = getValue(-2).get(); std::vector> params; + std::optional parameterPack; // Parse parameter list bool hasParams = !peek(SEPARATOR_ROUNDBRACKETCLOSE); u32 unnamedParamCount = 0; while (hasParams) { - auto type = parseType(true); + if (MATCHES(sequence(VALUETYPE_AUTO, SEPARATOR_DOT, SEPARATOR_DOT, SEPARATOR_DOT, IDENTIFIER))) { + parameterPack = getValue(-1).get(); - if (MATCHES(sequence(IDENTIFIER))) - params.emplace_back(getValue(-1).get(), type); - else { - params.emplace_back(std::to_string(unnamedParamCount), type); - unnamedParamCount++; - } + if (MATCHES(sequence(SEPARATOR_COMMA))) + throwParseError("parameter pack can only appear at end of parameter list"); - if (!MATCHES(sequence(SEPARATOR_COMMA))) { - if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) + break; + } else { + auto type = parseType(true); + + if (MATCHES(sequence(IDENTIFIER))) + params.emplace_back(getValue(-1).get(), type); + else { + params.emplace_back(std::to_string(unnamedParamCount), type); + unnamedParamCount++; + } + + if (!MATCHES(sequence(SEPARATOR_COMMA))) { break; - else - throwParseError("expected closing ')' after parameter list"); + } } } - if (!hasParams) { - if (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) - throwParseError("expected closing ')' after parameter list"); - } + + if (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) + throwParseError("expected closing ')' after parameter list"); if (!MATCHES(sequence(SEPARATOR_CURLYBRACKETOPEN))) throwParseError("expected opening '{' after function definition"); @@ -473,7 +479,7 @@ namespace hex::pl { } bodyCleanup.release(); - return create(new ASTNodeFunctionDefinition(getNamespacePrefixedName(functionName), params, body)); + return create(new ASTNodeFunctionDefinition(getNamespacePrefixedName(functionName), params, body, parameterPack)); } ASTNode *Parser::parseFunctionVariableDecl() {