diff --git a/include/views/view_pattern_editor.hpp b/include/views/view_pattern_editor.hpp index 6b13369bb..d53ece472 100644 --- a/include/views/view_pattern_editor.hpp +++ b/include/views/view_pattern_editor.hpp @@ -1,9 +1,8 @@ #pragma once #include -#include #include - +#include #include #include diff --git a/libstdpl/math.pl b/libstdpl/math.pl new file mode 100644 index 000000000..f1e55d933 --- /dev/null +++ b/libstdpl/math.pl @@ -0,0 +1,47 @@ +namespace std::math { + + fn min_u(u128 a, u128 b) { + return (a < b) ? a : b; + }; + + fn min_i(s128 a, s128 b) { + return (a < b) ? a : b; + }; + + fn min_f(double a, double b) { + return (a < b) ? a : b; + }; + + + fn max_u(u128 a, u128 b) { + return (a > b) ? a : b; + }; + + fn max_i(s128 a, s128 b) { + return (a > b) ? a : b; + }; + + fn max_f(double a, double b) { + return (a > b) ? a : b; + }; + + + fn abs_i(s128 value) { + return value < 0 ? -value : value; + }; + + fn abs_d(double value) { + return value < 0 ? -value : value; + }; + + + fn ceil(double value) { + s128 cast; + cast = value; + + return cast + 1; + }; + +} + +std::print("{}", std::math::ceil(123.6)); \ No newline at end of file diff --git a/plugins/builtin/source/content/pl_builtin_functions.cpp b/plugins/builtin/source/content/pl_builtin_functions.cpp index a755b734c..b365656b2 100644 --- a/plugins/builtin/source/content/pl_builtin_functions.cpp +++ b/plugins/builtin/source/content/pl_builtin_functions.cpp @@ -3,14 +3,43 @@ #include #include -#include +#include #include #include +#include #include +#include + namespace hex::plugin::builtin { + std::string format(pl::Evaluator *, auto params) { + auto format = pl::Token::literalToString(params[0], true); + std::string message; + + fmt::dynamic_format_arg_store formatArgs; + + for (u32 i = 1; i < params.size(); i++) { + auto ¶m = params[i]; + + std::visit(overloaded { + [&](pl::PatternData* value) { + formatArgs.push_back(hex::format("{} {} @ 0x{:X}", value->getTypeName(), value->getVariableName(), value->getOffset())); + }, + [&](auto &&value) { + formatArgs.push_back(value); + } + }, param); + } + + try { + return fmt::vformat(format, formatArgs); + } catch (fmt::format_error &error) { + hex::pl::LogConsole::abortEvaluation(hex::format("format error: {}", error.what())); + } + } + void registerPatternLanguageFunctions() { using namespace hex::pl; @@ -18,55 +47,37 @@ namespace hex::plugin::builtin { { /* assert(condition, message) */ - ContentRegistry::PatternLanguageFunctions::add(nsStd, "assert", 2, [](auto &ctx, auto params) { - auto condition = AS_TYPE(ASTNodeIntegerLiteral, params[0])->getValue(); - auto message = AS_TYPE(ASTNodeStringLiteral, params[1])->getString(); + ContentRegistry::PatternLanguageFunctions::add(nsStd, "assert", 2, [](Evaluator *ctx, auto params) -> std::optional { + auto condition = Token::literalToBoolean(params[0]); + auto message = std::get(params[1]); - if (LITERAL_COMPARE(condition, condition == 0)) - ctx.getConsole().abortEvaluation(hex::format("assertion failed \"{0}\"", message.data())); + if (!condition) + LogConsole::abortEvaluation(hex::format("assertion failed \"{0}\"", message)); - return nullptr; + return std::nullopt; }); /* assert_warn(condition, message) */ - ContentRegistry::PatternLanguageFunctions::add(nsStd, "assert_warn", 2, [](auto ctx, auto params) { - auto condition = AS_TYPE(ASTNodeIntegerLiteral, params[0])->getValue(); - auto message = AS_TYPE(ASTNodeStringLiteral, params[1])->getString(); + ContentRegistry::PatternLanguageFunctions::add(nsStd, "assert_warn", 2, [](auto *ctx, auto params) -> std::optional { + auto condition = Token::literalToBoolean(params[0]); + auto message = std::get(params[1]); - if (LITERAL_COMPARE(condition, condition == 0)) - ctx.getConsole().log(LogConsole::Level::Warning, hex::format("assertion failed \"{0}\"", message)); + if (!condition) + ctx->getConsole().log(LogConsole::Level::Warning, hex::format("assertion failed \"{0}\"", message)); - return nullptr; + return std::nullopt; }); - /* print(values...) */ - ContentRegistry::PatternLanguageFunctions::add(nsStd, "print", ContentRegistry::PatternLanguageFunctions::MoreParametersThan | 0, [](auto &ctx, auto params) { - std::string message; - for (auto& param : params) { - if (auto integerLiteral = dynamic_cast(param); integerLiteral != nullptr) { - std::visit([&](auto &&value) { - using Type = std::remove_cvref_t; - if constexpr (std::is_same_v) - message += (char)value; - else if constexpr (std::is_same_v) - message += value == 0 ? "false" : "true"; - else if constexpr (std::is_unsigned_v) - message += std::to_string(static_cast(value)); - else if constexpr (std::is_signed_v) - message += std::to_string(static_cast(value)); - else if constexpr (std::is_floating_point_v) - message += std::to_string(value); - else - message += "< Custom Type >"; - }, integerLiteral->getValue()); - } - else if (auto stringLiteral = dynamic_cast(param); stringLiteral != nullptr) - message += stringLiteral->getString(); - } + /* print(format, args...) */ + ContentRegistry::PatternLanguageFunctions::add(nsStd, "print", ContentRegistry::PatternLanguageFunctions::MoreParametersThan | 0, [](Evaluator *ctx, auto params) -> std::optional { + ctx->getConsole().log(LogConsole::Level::Info, format(ctx, params)); - ctx.getConsole().log(LogConsole::Level::Info, message); + return std::nullopt; + }); - return nullptr; + /* format(format, args...) */ + ContentRegistry::PatternLanguageFunctions::add(nsStd, "format", ContentRegistry::PatternLanguageFunctions::MoreParametersThan | 0, [](Evaluator *ctx, auto params) -> std::optional { + return format(ctx, params); }); } @@ -75,109 +86,82 @@ namespace hex::plugin::builtin { { /* align_to(alignment, value) */ - ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "align_to", 2, [](auto &ctx, auto params) -> ASTNode* { - auto alignment = AS_TYPE(ASTNodeIntegerLiteral, params[0])->getValue(); - auto value = AS_TYPE(ASTNodeIntegerLiteral, params[1])->getValue(); + ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "align_to", 2, [](Evaluator *ctx, auto params) -> std::optional { + auto alignment = Token::literalToUnsigned(params[0]); + auto value = Token::literalToUnsigned(params[1]); - auto result = std::visit([](auto &&alignment, auto &&value) { - u64 remainder = u64(value) % u64(alignment); - return remainder != 0 ? u64(value) + (u64(alignment) - remainder) : u64(value); - }, alignment, value); + u128 remainder = value % alignment; - return new ASTNodeIntegerLiteral(u64(result)); + return remainder != 0 ? value + (alignment - remainder) : value; }); /* base_address() */ - ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "base_address", ContentRegistry::PatternLanguageFunctions::NoParameters, [](auto &ctx, auto params) -> ASTNode* { - return new ASTNodeIntegerLiteral(u64(ImHexApi::Provider::get()->getBaseAddress())); + ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "base_address", ContentRegistry::PatternLanguageFunctions::NoParameters, [](Evaluator *ctx, auto params) -> std::optional { + return u128(ctx->getProvider()->getBaseAddress()); }); /* size() */ - ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "size", ContentRegistry::PatternLanguageFunctions::NoParameters, [](auto &ctx, auto params) -> ASTNode* { - return new ASTNodeIntegerLiteral(u64(ImHexApi::Provider::get()->getActualSize())); + ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "size", ContentRegistry::PatternLanguageFunctions::NoParameters, [](Evaluator *ctx, auto params) -> std::optional { + return u128(ctx->getProvider()->getActualSize()); }); /* find_sequence(occurrence_index, bytes...) */ - ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "find_sequence", ContentRegistry::PatternLanguageFunctions::MoreParametersThan | 1, [](auto &ctx, auto params) { - auto& occurrenceIndex = AS_TYPE(ASTNodeIntegerLiteral, params[0])->getValue(); + ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "find_sequence", ContentRegistry::PatternLanguageFunctions::MoreParametersThan | 1, [](Evaluator *ctx, auto params) -> std::optional { + auto occurrenceIndex = Token::literalToUnsigned(params[0]); + std::vector sequence; for (u32 i = 1; i < params.size(); i++) { - sequence.push_back(std::visit([&](auto &&value) -> u8 { - if (value <= 0xFF) - return value; - else - ctx.getConsole().abortEvaluation("sequence bytes need to fit into 1 byte"); - }, AS_TYPE(ASTNodeIntegerLiteral, params[i])->getValue())); + auto byte = Token::literalToUnsigned(params[i]); + + if (byte > 0xFF) + LogConsole::abortEvaluation(hex::format("byte #{} value out of range: {} > 0xFF", i, u64(byte))); + + sequence.push_back(u8(byte & 0xFF)); } std::vector bytes(sequence.size(), 0x00); u32 occurrences = 0; - for (u64 offset = 0; offset < ImHexApi::Provider::get()->getSize() - sequence.size(); offset++) { - ImHexApi::Provider::get()->read(offset, bytes.data(), bytes.size()); + for (u64 offset = 0; offset < ctx->getProvider()->getSize() - sequence.size(); offset++) { + ctx->getProvider()->read(offset, bytes.data(), bytes.size()); if (bytes == sequence) { - if (LITERAL_COMPARE(occurrenceIndex, occurrences < occurrenceIndex)) { + if (occurrences < occurrenceIndex) { occurrences++; continue; } - return new ASTNodeIntegerLiteral(offset); + return u128(offset); } } - ctx.getConsole().abortEvaluation("failed to find sequence"); + LogConsole::abortEvaluation("failed to find sequence"); }); /* read_unsigned(address, size) */ - ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "read_unsigned", 2, [](auto &ctx, auto params) { - auto address = AS_TYPE(ASTNodeIntegerLiteral, params[0])->getValue(); - auto size = AS_TYPE(ASTNodeIntegerLiteral, params[1])->getValue(); + ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "read_unsigned", 2, [](Evaluator *ctx, auto params) -> std::optional { + auto address = Token::literalToUnsigned(params[0]); + auto size = Token::literalToUnsigned(params[1]); - if (LITERAL_COMPARE(address, address >= ImHexApi::Provider::get()->getActualSize())) - ctx.getConsole().abortEvaluation("address out of range"); + if (size > 16) + LogConsole::abortEvaluation("read size out of range"); - return std::visit([&](auto &&address, auto &&size) { - if (size <= 0 || size > 16) - ctx.getConsole().abortEvaluation("invalid read size"); + u128 result = 0; + ctx->getProvider()->read(address, &result, size); - u8 value[(u8)size]; - ImHexApi::Provider::get()->read(address, value, size); - - switch ((u8)size) { - case 1: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); - case 2: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); - case 4: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); - case 8: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); - case 16: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); - default: ctx.getConsole().abortEvaluation("invalid read size"); - } - }, address, size); + return result; }); /* read_signed(address, size) */ - ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "read_signed", 2, [](auto &ctx, auto params) { - auto address = AS_TYPE(ASTNodeIntegerLiteral, params[0])->getValue(); - auto size = AS_TYPE(ASTNodeIntegerLiteral, params[1])->getValue(); + ContentRegistry::PatternLanguageFunctions::add(nsStdMem, "read_signed", 2, [](Evaluator *ctx, auto params) -> std::optional { + auto address = Token::literalToUnsigned(params[0]); + auto size = Token::literalToUnsigned(params[1]); - if (LITERAL_COMPARE(address, address >= ImHexApi::Provider::get()->getActualSize())) - ctx.getConsole().abortEvaluation("address out of range"); + if (size > 16) + LogConsole::abortEvaluation("read size out of range"); - return std::visit([&](auto &&address, auto &&size) { - if (size <= 0 || size > 16) - ctx.getConsole().abortEvaluation("invalid read size"); - - u8 value[(u8)size]; - ImHexApi::Provider::get()->read(address, value, size); - - switch ((u8)size) { - case 1: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); - case 2: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); - case 4: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); - case 8: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); - case 16: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); - default: ctx.getConsole().abortEvaluation("invalid read size"); - } - }, address, size); + s128 value; + ctx->getProvider()->read(address, &value, size); + return hex::signExtend(size * 8, value); }); } @@ -185,31 +169,39 @@ namespace hex::plugin::builtin { ContentRegistry::PatternLanguageFunctions::Namespace nsStdStr = { "std", "str" }; { /* length(string) */ - ContentRegistry::PatternLanguageFunctions::add(nsStdStr, "length", 1, [](auto &ctx, auto params) { - auto string = AS_TYPE(ASTNodeStringLiteral, params[1])->getString(); + ContentRegistry::PatternLanguageFunctions::add(nsStdStr, "length", 1, [](Evaluator *ctx, auto params) -> std::optional { + auto string = Token::literalToString(params[0], false); - return new ASTNodeIntegerLiteral(u32(string.length())); + return u128(string.length()); }); /* at(string, index) */ - ContentRegistry::PatternLanguageFunctions::add(nsStdStr, "at", 2, [](auto &ctx, auto params) { - auto string = AS_TYPE(ASTNodeStringLiteral, params[0])->getString(); - auto index = AS_TYPE(ASTNodeIntegerLiteral, params[1])->getValue(); + ContentRegistry::PatternLanguageFunctions::add(nsStdStr, "at", 2, [](Evaluator *ctx, auto params) -> std::optional { + auto string = Token::literalToString(params[0], false); + auto index = Token::literalToSigned(params[1]); - if (LITERAL_COMPARE(index, index >= string.length() || index < 0)) - ctx.getConsole().abortEvaluation("character index out of bounds"); + if (std::abs(index) >= string.length()) + LogConsole::abortEvaluation("character index out of range"); - return std::visit([&](auto &&value) { return new ASTNodeIntegerLiteral(char(string[u32(value)])); }, index); + if (index >= 0) + return char(string[index]); + else + return char(string[string.length() - -index]); }); - /* compare(left, right) */ - ContentRegistry::PatternLanguageFunctions::add(nsStdStr, "compare", 2, [](auto &ctx, auto params) { - auto left = AS_TYPE(ASTNodeStringLiteral, params[0])->getString(); - auto right = AS_TYPE(ASTNodeStringLiteral, params[1])->getString(); + /* substr(string, pos, count) */ + ContentRegistry::PatternLanguageFunctions::add(nsStdStr, "substr", 3, [](Evaluator *ctx, auto params) -> std::optional { + auto string = Token::literalToString(params[0], false); + auto pos = Token::literalToUnsigned(params[1]); + auto size = Token::literalToUnsigned(params[2]); - return new ASTNodeIntegerLiteral(bool(left == right)); + if (pos > size) + LogConsole::abortEvaluation("character index out of range"); + + return string.substr(pos, size); }); + } } -} \ No newline at end of file +} diff --git a/plugins/libimhex/include/hex/api/content_registry.hpp b/plugins/libimhex/include/hex/api/content_registry.hpp index 3cef73e00..fbf51f974 100644 --- a/plugins/libimhex/include/hex/api/content_registry.hpp +++ b/plugins/libimhex/include/hex/api/content_registry.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -15,7 +16,7 @@ namespace hex { class View; class LanguageDefinition; - namespace pl { class ASTNode; class Evaluator; } + namespace pl { class Evaluator; } namespace dp { class Node; } /* @@ -90,7 +91,7 @@ namespace hex { constexpr static u32 NoParameters = 0x0000'0000; using Namespace = std::vector; - using Callback = std::function&)>; + using Callback = std::function(hex::pl::Evaluator*, const std::vector&)>; struct Function { u32 parameterCount; diff --git a/plugins/libimhex/include/hex/helpers/utils.hpp b/plugins/libimhex/include/hex/helpers/utils.hpp index 430170f48..c266ba9ba 100644 --- a/plugins/libimhex/include/hex/helpers/utils.hpp +++ b/plugins/libimhex/include/hex/helpers/utils.hpp @@ -57,6 +57,11 @@ namespace hex { return (value & mask) >> to; } + constexpr inline s128 signExtend(size_t numBits, s128 value) { + s128 mask = 1U << (numBits - 1); + return (value ^ mask) - mask; + } + template struct overloaded : Ts... { using Ts::operator()...; }; template overloaded(Ts...) -> overloaded; diff --git a/plugins/libimhex/include/hex/pattern_language/ast_node.hpp b/plugins/libimhex/include/hex/pattern_language/ast_node.hpp index 0031dff45..da1e0725e 100644 --- a/plugins/libimhex/include/hex/pattern_language/ast_node.hpp +++ b/plugins/libimhex/include/hex/pattern_language/ast_node.hpp @@ -1,6 +1,8 @@ #pragma once -#include "token.hpp" +#include +#include +#include #include #include @@ -8,51 +10,49 @@ #include #include +#include + namespace hex::pl { - class ASTNodeAttribute; + class PatternData; - class Attributable { - protected: - Attributable() = default; - Attributable(const Attributable&) = default; + class ASTNodeAttribute : public ASTNode { public: + explicit ASTNodeAttribute(std::string attribute, std::optional value = std::nullopt) + : ASTNode(), m_attribute(std::move(attribute)), m_value(std::move(value)) { } - void addAttribute(ASTNodeAttribute *attribute) { - this->m_attributes.push_back(attribute); + ~ASTNodeAttribute() override = default; + + ASTNodeAttribute(const ASTNodeAttribute &other) : ASTNode(other) { + this->m_attribute = other.m_attribute; + this->m_value = other.m_value; } - [[nodiscard]] const auto& getAttributes() const { - return this->m_attributes; - } - - private: - std::vector m_attributes; - }; - - class ASTNode { - public: - constexpr ASTNode() = default; - constexpr virtual ~ASTNode() = default; - constexpr ASTNode(const ASTNode &) = default; - - [[nodiscard]] constexpr u32 getLineNumber() const { return this->m_lineNumber; } - [[maybe_unused]] constexpr void setLineNumber(u32 lineNumber) { this->m_lineNumber = lineNumber; } - - [[nodiscard]] virtual ASTNode* clone() const = 0; - - private: - u32 m_lineNumber = 1; - }; - - class ASTNodeIntegerLiteral : public ASTNode { - public: - explicit ASTNodeIntegerLiteral(Token::IntegerLiteral literal) : ASTNode(), m_literal(std::move(literal)) { } - - ASTNodeIntegerLiteral(const ASTNodeIntegerLiteral&) = default; - [[nodiscard]] ASTNode* clone() const override { - return new ASTNodeIntegerLiteral(*this); + return new ASTNodeAttribute(*this); + } + + [[nodiscard]] const std::string& getAttribute() const { + return this->m_attribute; + } + + [[nodiscard]] const std::optional& getValue() const { + return this->m_value; + } + + private: + std::string m_attribute; + std::optional m_value; + }; + + class ASTNodeLiteral : public ASTNode { + public: + explicit ASTNodeLiteral(Token::Literal literal) : ASTNode(), m_literal(literal) { } + + ASTNodeLiteral(const ASTNodeLiteral&) = default; + + [[nodiscard]] ASTNode* clone() const override { + return new ASTNodeLiteral(*this); } [[nodiscard]] const auto& getValue() const { @@ -60,32 +60,177 @@ namespace hex::pl { } private: - Token::IntegerLiteral m_literal; + Token::Literal m_literal; }; - class ASTNodeNumericExpression : public ASTNode { + class ASTNodeMathematicalExpression : public ASTNode { + #define FLOAT_BIT_OPERATION(name) \ + auto name(hex::floating_point auto left, auto right) const { LogConsole::abortEvaluation("invalid floating point operation", this); return 0; } \ + auto name(auto left, hex::floating_point auto right) const { LogConsole::abortEvaluation("invalid floating point operation", this); return 0; } \ + auto name(hex::floating_point auto left, hex::floating_point auto right) const { LogConsole::abortEvaluation("invalid floating point operation", this); return 0; } \ + auto name(hex::integral auto left, hex::integral auto right) const + + FLOAT_BIT_OPERATION(shiftLeft) { + return left << right; + } + + FLOAT_BIT_OPERATION(shiftRight) { + return left >> right; + } + + FLOAT_BIT_OPERATION(bitAnd) { + return left & right; + } + + FLOAT_BIT_OPERATION(bitOr) { + return left | right; + } + + FLOAT_BIT_OPERATION(bitXor) { + return left ^ right; + } + + FLOAT_BIT_OPERATION(bitNot) { + return ~right; + } + + FLOAT_BIT_OPERATION(modulus) { + return left % right; + } + + #undef FLOAT_BIT_OPERATION public: - ASTNodeNumericExpression(ASTNode *left, ASTNode *right, Token::Operator op) + ASTNodeMathematicalExpression(ASTNode *left, ASTNode *right, Token::Operator op) : ASTNode(), m_left(left), m_right(right), m_operator(op) { } - ~ASTNodeNumericExpression() override { + ~ASTNodeMathematicalExpression() override { delete this->m_left; delete this->m_right; } - ASTNodeNumericExpression(const ASTNodeNumericExpression &other) : ASTNode(other) { + ASTNodeMathematicalExpression(const ASTNodeMathematicalExpression &other) : ASTNode(other) { this->m_operator = other.m_operator; this->m_left = other.m_left->clone(); this->m_right = other.m_right->clone(); } [[nodiscard]] ASTNode* clone() const override { - return new ASTNodeNumericExpression(*this); + return new ASTNodeMathematicalExpression(*this); } - ASTNode *getLeftOperand() { return this->m_left; } - ASTNode *getRightOperand() { return this->m_right; } - Token::Operator getOperator() { return this->m_operator; } + [[nodiscard]] ASTNode* evaluate(Evaluator *evaluator) const override { + if (this->getLeftOperand() == nullptr || this->getRightOperand() == nullptr) + LogConsole::abortEvaluation("attempted to use void expression in mathematical expression", this); + + auto *left = dynamic_cast(this->getLeftOperand()->evaluate(evaluator)); + auto *right = dynamic_cast(this->getRightOperand()->evaluate(evaluator)); + ON_SCOPE_EXIT { delete left; delete right; }; + + return std::visit(overloaded { + // TODO: :notlikethis: + [this](u128 left, PatternData * const &right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](s128 left, PatternData * const &right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](double left, PatternData * const &right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](char left, PatternData * const &right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](bool left, PatternData * const &right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](std::string left, PatternData * const &right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](PatternData * const &left, u128 right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](PatternData * const &left, s128 right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](PatternData * const &left, double right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](PatternData * const &left, char right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](PatternData * const &left, bool right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](PatternData * const &left, std::string right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](PatternData * const &left, PatternData *right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + + [this](auto&& left, std::string right) -> ASTNode* { LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); }, + [this](std::string left, auto&& right) -> ASTNode* { + switch (this->getOperator()) { + case Token::Operator::Star: { + std::string result; + for (auto i = 0; i < right; i++) + result += left; + return new ASTNodeLiteral(result); + } + default: + LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); + } + }, + [this](std::string left, std::string right) -> ASTNode* { + switch (this->getOperator()) { + case Token::Operator::Plus: + return new ASTNodeLiteral(left + right); + case Token::Operator::BoolEquals: + return new ASTNodeLiteral(left == right); + case Token::Operator::BoolNotEquals: + return new ASTNodeLiteral(left != right); + case Token::Operator::BoolGreaterThan: + return new ASTNodeLiteral(left > right); + case Token::Operator::BoolLessThan: + return new ASTNodeLiteral(left < right); + case Token::Operator::BoolGreaterThanOrEquals: + return new ASTNodeLiteral(left >= right); + case Token::Operator::BoolLessThanOrEquals: + return new ASTNodeLiteral(left <= right); + default: + LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); + } + }, + [this](auto &&left, auto &&right) -> ASTNode* { + switch (this->getOperator()) { + case Token::Operator::Plus: + return new ASTNodeLiteral(left + right); + case Token::Operator::Minus: + return new ASTNodeLiteral(left - right); + case Token::Operator::Star: + return new ASTNodeLiteral(left * right); + case Token::Operator::Slash: + if (right == 0) LogConsole::abortEvaluation("division by zero!", this); + return new ASTNodeLiteral(left / right); + case Token::Operator::Percent: + if (right == 0) LogConsole::abortEvaluation("division by zero!", this); + return new ASTNodeLiteral(modulus(left, right)); + case Token::Operator::ShiftLeft: + return new ASTNodeLiteral(shiftLeft(left, right)); + case Token::Operator::ShiftRight: + return new ASTNodeLiteral(shiftRight(left, right)); + case Token::Operator::BitAnd: + return new ASTNodeLiteral(bitAnd(left, right)); + case Token::Operator::BitXor: + return new ASTNodeLiteral(bitXor(left, right)); + case Token::Operator::BitOr: + return new ASTNodeLiteral(bitOr(left, right)); + case Token::Operator::BitNot: + return new ASTNodeLiteral(bitNot(left, right)); + case Token::Operator::BoolEquals: + return new ASTNodeLiteral(left == right); + case Token::Operator::BoolNotEquals: + return new ASTNodeLiteral(left != right); + case Token::Operator::BoolGreaterThan: + return new ASTNodeLiteral(left > right); + case Token::Operator::BoolLessThan: + return new ASTNodeLiteral(left < right); + case Token::Operator::BoolGreaterThanOrEquals: + return new ASTNodeLiteral(left >= right); + case Token::Operator::BoolLessThanOrEquals: + return new ASTNodeLiteral(left <= right); + case Token::Operator::BoolAnd: + return new ASTNodeLiteral(left && right); + case Token::Operator::BoolXor: + return new ASTNodeLiteral(left && !right || !left && right); + case Token::Operator::BoolOr: + return new ASTNodeLiteral(left || right); + case Token::Operator::BoolNot: + return new ASTNodeLiteral(!right); + default: + LogConsole::abortEvaluation("invalid operand used in mathematical expression", this); + } + } + }, left->getValue(), right->getValue()); + } + + [[nodiscard]] ASTNode *getLeftOperand() const { return this->m_left; } + [[nodiscard]] ASTNode *getRightOperand() const { return this->m_right; } + [[nodiscard]] Token::Operator getOperator() const { return this->m_operator; } private: ASTNode *m_left, *m_right; @@ -114,10 +259,31 @@ namespace hex::pl { return new ASTNodeTernaryExpression(*this); } - ASTNode *getFirstOperand() { return this->m_first; } - ASTNode *getSecondOperand() { return this->m_second; } - ASTNode *getThirdOperand() { return this->m_third; } - Token::Operator getOperator() { return this->m_operator; } + [[nodiscard]] ASTNode* evaluate(Evaluator *evaluator) const override { + if (this->getFirstOperand() == nullptr || this->getSecondOperand() == nullptr || this->getThirdOperand() == nullptr) + LogConsole::abortEvaluation("attempted to use void expression in mathematical expression", this); + + auto *first = dynamic_cast(this->getFirstOperand()->evaluate(evaluator)); + auto *second = dynamic_cast(this->getSecondOperand()->evaluate(evaluator)); + auto *third = dynamic_cast(this->getThirdOperand()->evaluate(evaluator)); + ON_SCOPE_EXIT { delete first; delete second; delete third; }; + + auto condition = std::visit(overloaded { + [this](std::string value) -> bool { return !value.empty(); }, + [this](PatternData * const &) -> bool { LogConsole::abortEvaluation("cannot cast custom type to bool", this); }, + [](auto &&value) -> bool { return bool(value); } + }, first->getValue()); + + return std::visit(overloaded { + [condition](const T &second, const T &third) -> ASTNode* { return new ASTNodeLiteral(condition ? second : third); }, + [this](auto &&second, auto &&third) -> ASTNode* { LogConsole::abortEvaluation("operands to ternary expression have different types", this); } + }, second->getValue(), third->getValue()); + } + + [[nodiscard]] ASTNode *getFirstOperand() const { return this->m_first; } + [[nodiscard]] ASTNode *getSecondOperand() const { return this->m_second; } + [[nodiscard]] ASTNode *getThirdOperand() const { return this->m_third; } + [[nodiscard]] Token::Operator getOperator() const { return this->m_operator; } private: ASTNode *m_first, *m_second, *m_third; @@ -135,6 +301,37 @@ namespace hex::pl { return new ASTNodeBuiltinType(*this); } + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + auto offset = evaluator->dataOffset(); + auto size = Token::getTypeSize(this->m_type); + + evaluator->dataOffset() += size; + + PatternData *pattern; + if (Token::isUnsigned(this->m_type)) + pattern = new PatternDataUnsigned(offset, size); + else if (Token::isSigned(this->m_type)) + pattern = new PatternDataSigned(offset, size); + else if (Token::isFloatingPoint(this->m_type)) + pattern = new PatternDataFloat(offset, size); + else if (this->m_type == Token::ValueType::Boolean) + pattern = new PatternDataBoolean(offset, size); + else if (this->m_type == Token::ValueType::Character) + pattern = new PatternDataCharacter(offset); + else if (this->m_type == Token::ValueType::Character16) + pattern = new PatternDataCharacter16(offset); + else if (this->m_type == Token::ValueType::Padding) + pattern = new PatternDataPadding(offset, 1); + else if (this->m_type == Token::ValueType::String) + pattern = new PatternDataString(offset, 1); + else + LogConsole::abortEvaluation("invalid built-in type", this); + + pattern->setTypeName(Token::getTypeName(this->m_type)); + + return { pattern }; + } + private: const Token::ValueType m_type; }; @@ -166,12 +363,216 @@ namespace hex::pl { [[nodiscard]] ASTNode* getType() { return this->m_type; } [[nodiscard]] std::optional getEndian() const { return this->m_endian; } + [[nodiscard]] ASTNode *evaluate(Evaluator *evaluator) const override { + return this->m_type->evaluate(evaluator); + } + + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + auto patterns = this->m_type->createPatterns(evaluator); + + for (auto &pattern : patterns) { + if (!this->m_name.empty()) + pattern->setTypeName(this->m_name); + pattern->setEndian(this->m_endian.value_or(evaluator->getDefaultEndian())); + } + + return patterns; + } + private: std::string m_name; ASTNode *m_type; std::optional m_endian; }; + class ASTNodeCast : public ASTNode { + public: + ASTNodeCast(ASTNode *value, ASTNode *type) : m_value(value), m_type(type) { } + ASTNodeCast(const ASTNodeCast &other) { + this->m_value = other.m_value->clone(); + this->m_type = other.m_type->clone(); + } + + ~ASTNodeCast() override { + delete this->m_value; + delete this->m_type; + } + + [[nodiscard]] ASTNode* clone() const override { + return new ASTNodeCast(*this); + } + + [[nodiscard]] ASTNode* evaluate(Evaluator *evaluator) const override { + auto literal = dynamic_cast(this->m_value->evaluate(evaluator)); + auto type = dynamic_cast(this->m_type->evaluate(evaluator))->getType(); + + auto typePattern = this->m_type->createPatterns(evaluator).front(); + ON_SCOPE_EXIT { delete typePattern; }; + + return std::visit(overloaded { + [&, this](PatternData * value) -> ASTNode* { LogConsole::abortEvaluation(hex::format("cannot cast custom type '{}' to '{}'", value->getTypeName(), Token::getTypeName(type)), this); }, + [&, this](const std::string&) -> ASTNode* { LogConsole::abortEvaluation(hex::format("cannot cast string to '{}'", Token::getTypeName(type)), this); }, + [&, this](auto &&value) -> ASTNode* { + auto endianAdjustedValue = hex::changeEndianess(value, typePattern->getSize(), typePattern->getEndian()); + switch (type) { + case Token::ValueType::Unsigned8Bit: + return new ASTNodeLiteral(u128(u8(endianAdjustedValue))); + case Token::ValueType::Unsigned16Bit: + return new ASTNodeLiteral(u128(u16(endianAdjustedValue))); + case Token::ValueType::Unsigned32Bit: + return new ASTNodeLiteral(u128(u32(endianAdjustedValue))); + case Token::ValueType::Unsigned64Bit: + return new ASTNodeLiteral(u128(u64(endianAdjustedValue))); + case Token::ValueType::Unsigned128Bit: + return new ASTNodeLiteral(u128(endianAdjustedValue)); + case Token::ValueType::Signed8Bit: + return new ASTNodeLiteral(s128(s8(endianAdjustedValue))); + case Token::ValueType::Signed16Bit: + return new ASTNodeLiteral(s128(s16(endianAdjustedValue))); + case Token::ValueType::Signed32Bit: + return new ASTNodeLiteral(s128(s32(endianAdjustedValue))); + case Token::ValueType::Signed64Bit: + return new ASTNodeLiteral(s128(s64(endianAdjustedValue))); + case Token::ValueType::Signed128Bit: + return new ASTNodeLiteral(s128(endianAdjustedValue)); + case Token::ValueType::Float: + return new ASTNodeLiteral(double(float(endianAdjustedValue))); + case Token::ValueType::Double: + return new ASTNodeLiteral(double(endianAdjustedValue)); + case Token::ValueType::Character: + return new ASTNodeLiteral(char(endianAdjustedValue)); + case Token::ValueType::Character16: + return new ASTNodeLiteral(u128(char16_t(endianAdjustedValue))); + case Token::ValueType::Boolean: + return new ASTNodeLiteral(bool(endianAdjustedValue)); + default: + LogConsole::abortEvaluation(hex::format("cannot cast value to '{}'", Token::getTypeName(type)), this); + } + }, + }, literal->getValue()); + } + + private: + ASTNode *m_value; + ASTNode *m_type; + }; + + class ASTNodeWhileStatement : public ASTNode { + public: + explicit ASTNodeWhileStatement(ASTNode *condition, std::vector body) + : ASTNode(), m_condition(condition), m_body(std::move(body)) { } + + ~ASTNodeWhileStatement() override { + delete this->m_condition; + + for (auto &statement : this->m_body) + delete statement; + } + + ASTNodeWhileStatement(const ASTNodeWhileStatement &other) : ASTNode(other) { + this->m_condition = other.m_condition->clone(); + } + + [[nodiscard]] ASTNode* clone() const override { + return new ASTNodeWhileStatement(*this); + } + + [[nodiscard]] ASTNode* getCondition() { + return this->m_condition; + } + + [[nodiscard]] const std::vector& getBody() { + return this->m_body; + } + + FunctionResult execute(Evaluator *evaluator) override { + + while (evaluateCondition(evaluator)) { + auto variables = *evaluator->getScope(0).scope; + u32 startVariableCount = variables.size(); + ON_SCOPE_EXIT { + s64 stackSize = evaluator->getStack().size(); + for (u32 i = startVariableCount; i < variables.size(); i++) { + stackSize--; + delete variables[i]; + } + if (stackSize < 0) LogConsole::abortEvaluation("stack pointer underflow!", this); + evaluator->getStack().resize(stackSize); + }; + + evaluator->pushScope(nullptr, variables); + ON_SCOPE_EXIT { evaluator->popScope(); }; + + for (auto &statement : this->m_body) { + auto [executionStopped, result] = statement->execute(evaluator); + if (executionStopped) { + return { true, result }; + } + } + } + + return { false, { } }; + } + + [[nodiscard]] + bool evaluateCondition(Evaluator *evaluator) const { + auto literal = dynamic_cast(this->m_condition->evaluate(evaluator)); + ON_SCOPE_EXIT { delete literal; }; + + return std::visit(overloaded { + [](std::string value) -> bool { return !value.empty(); }, + [this](PatternData * const &) -> bool { LogConsole::abortEvaluation("cannot cast custom type to bool", this); }, + [](auto &&value) -> bool { return value != 0; } + }, literal->getValue()); + } + + private: + ASTNode *m_condition; + std::vector m_body; + }; + + inline void applyVariableAttributes(Evaluator *evaluator, const Attributable *attributable, PatternData *pattern) { + for (ASTNodeAttribute *attribute : attributable->getAttributes()) { + auto &name = attribute->getAttribute(); + auto value = attribute->getValue(); + + auto node = reinterpret_cast(attributable); + + auto requiresValue = [&]() { + if (!value.has_value()) + LogConsole::abortEvaluation(hex::format("used attribute '{}' without providing a value", name), node); + return true; + }; + + auto noValue = [&]() { + if (value.has_value()) + LogConsole::abortEvaluation(hex::format("provided a value to attribute '{}' which doesn't take one", name), node); + return true; + }; + + if (name == "color" && requiresValue()) { + u32 color = strtoul(value->c_str(), nullptr, 16); + pattern->setColor(color); + } else if (name == "name" && requiresValue()) { + pattern->setVariableName(*value); + } else if (name == "comment" && requiresValue()) { + pattern->setComment(*value); + } else if (name == "hidden" && noValue()) { + pattern->setHidden(true); + } else if (name == "format" && requiresValue()) { + auto functions = evaluator->getCustomFunctions(); + if (!functions.contains(*value)) + LogConsole::abortEvaluation(hex::format("cannot find formatter function '{}'", *value), node); + + const auto &function = functions[*value]; + if (function.parameterCount != 1) + LogConsole::abortEvaluation("formatter function needs exactly one parameter", node); + + pattern->setFormatterFunction(function, evaluator); + } + } + } + class ASTNodeVariableDecl : public ASTNode, public Attributable { public: ASTNodeVariableDecl(std::string name, ASTNode *type, ASTNode *placementOffset = nullptr) @@ -200,6 +601,32 @@ namespace hex::pl { [[nodiscard]] constexpr ASTNode* getType() const { return this->m_type; } [[nodiscard]] constexpr auto getPlacementOffset() const { return this->m_placementOffset; } + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + if (this->m_placementOffset != nullptr) { + auto offset = dynamic_cast(this->m_placementOffset->evaluate(evaluator)); + ON_SCOPE_EXIT { delete offset; }; + + evaluator->dataOffset() = std::visit(overloaded { + [this](std::string) -> u64 { LogConsole::abortEvaluation("placement offset cannot be a string", this); }, + [this](PatternData * const &) -> u64 { LogConsole::abortEvaluation("placement offset cannot be a custom type", this); }, + [](auto &&offset) -> u64 { return offset; } + }, offset->getValue()); + } + + auto pattern = this->m_type->createPatterns(evaluator).front(); + pattern->setVariableName(this->m_name); + + applyVariableAttributes(evaluator, this, pattern); + + return { pattern }; + } + + FunctionResult execute(Evaluator *evaluator) override { + evaluator->createVariable(this->getName(), this->getType()); + + return { false, { } }; + } + private: std::string m_name; ASTNode *m_type; @@ -235,6 +662,43 @@ namespace hex::pl { return new ASTNodeArrayVariableDecl(*this); } + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + if (this->m_placementOffset != nullptr) { + auto offset = dynamic_cast(this->m_placementOffset->evaluate(evaluator)); + ON_SCOPE_EXIT { delete offset; }; + + evaluator->dataOffset() = std::visit(overloaded { + [this](std::string) -> u64 { LogConsole::abortEvaluation("placement offset cannot be a string", this); }, + [this](PatternData * const &) -> u64 { LogConsole::abortEvaluation("placement offset cannot be a custom type", this); }, + [](auto &&offset) -> u64 { return offset; } + }, offset->getValue()); + } + + auto type = this->m_type->evaluate(evaluator); + ON_SCOPE_EXIT { delete type; }; + + PatternData *pattern; + if (dynamic_cast(type)) + pattern = createStaticArray(evaluator); + else if (auto attributable = dynamic_cast(type)) { + auto &attributes = attributable->getAttributes(); + + bool isStaticType = std::any_of(attributes.begin(), attributes.end(), [](ASTNodeAttribute *attribute) { + return attribute->getAttribute() == "static" && !attribute->getValue().has_value(); + }); + + if (isStaticType) + pattern = createStaticArray(evaluator); + else + pattern = createDynamicArray(evaluator); + } else { + LogConsole::abortEvaluation("invalid type used in array", this); + } + + applyVariableAttributes(evaluator, this, pattern); + return { pattern }; + } + [[nodiscard]] const std::string& getName() const { return this->m_name; } [[nodiscard]] constexpr ASTNode* getType() const { return this->m_type; } [[nodiscard]] constexpr ASTNode* getSize() const { return this->m_size; } @@ -245,6 +709,166 @@ namespace hex::pl { ASTNode *m_type; ASTNode *m_size; ASTNode *m_placementOffset; + + PatternData* createStaticArray(Evaluator *evaluator) const { + u64 startOffset = evaluator->dataOffset(); + + PatternData *templatePattern = this->m_type->createPatterns(evaluator).front(); + ON_SCOPE_EXIT { delete templatePattern; }; + + evaluator->dataOffset() = startOffset; + + u128 entryCount = 0; + + if (this->m_size != nullptr) { + auto sizeNode = this->m_size->evaluate(evaluator); + ON_SCOPE_EXIT { delete sizeNode; }; + + if (auto literal = dynamic_cast(sizeNode)) { + entryCount = std::visit(overloaded { + [this](std::string) -> u128 { LogConsole::abortEvaluation("cannot use string to index array", this); }, + [this](PatternData*) -> u128 { LogConsole::abortEvaluation("cannot use custom type to index array", this); }, + [](auto &&size) -> u128 { return size; } + }, literal->getValue()); + } else if (auto whileStatement = dynamic_cast(sizeNode)) { + while (whileStatement->evaluateCondition(evaluator)) { + entryCount++; + evaluator->dataOffset() += templatePattern->getSize(); + } + } + } else { + std::vector buffer(templatePattern->getSize()); + while (true) { + if (evaluator->dataOffset() >= evaluator->getProvider()->getActualSize() - buffer.size()) + LogConsole::abortEvaluation("reached end of file before finding end of unsized array", this); + + evaluator->getProvider()->read(evaluator->dataOffset(), buffer.data(), buffer.size()); + evaluator->dataOffset() += buffer.size(); + + entryCount++; + + bool reachedEnd = true; + for (u8 &byte : buffer) { + if (byte != 0x00) { + reachedEnd = false; + break; + } + } + + if (reachedEnd) break; + } + } + + PatternData *outputPattern; + if (dynamic_cast(templatePattern)) { + outputPattern = new PatternDataPadding(startOffset, 0); + } else if (dynamic_cast(templatePattern)) { + outputPattern = new PatternDataString(startOffset, 0); + } else if (dynamic_cast(templatePattern)) { + outputPattern = new PatternDataString16(startOffset, 0); + } else { + auto arrayPattern = new PatternDataStaticArray(startOffset, 0); + arrayPattern->setEntries(templatePattern->clone(), entryCount); + outputPattern = arrayPattern; + } + + outputPattern->setVariableName(this->m_name); + outputPattern->setEndian(templatePattern->getEndian()); + outputPattern->setColor(templatePattern->getColor()); + outputPattern->setTypeName(templatePattern->getTypeName()); + outputPattern->setSize(templatePattern->getSize() * entryCount); + + evaluator->dataOffset() = startOffset + outputPattern->getSize(); + + return outputPattern; + } + + PatternData* createDynamicArray(Evaluator *evaluator) const { + auto arrayPattern = new PatternDataDynamicArray(evaluator->dataOffset(), 0); + arrayPattern->setVariableName(this->m_name); + + std::vector entries; + size_t size = 0; + u64 entryCount = 0; + + if (this->m_size != nullptr) { + auto sizeNode = this->m_size->evaluate(evaluator); + ON_SCOPE_EXIT { delete sizeNode; }; + + { + auto templatePattern = this->m_type->createPatterns(evaluator).front(); + ON_SCOPE_EXIT { delete templatePattern; }; + + arrayPattern->setTypeName(templatePattern->getTypeName()); + evaluator->dataOffset() -= templatePattern->getSize(); + } + + if (auto literal = dynamic_cast(sizeNode)) { + entryCount = std::visit(overloaded{ + [this](std::string) -> u128 { LogConsole::abortEvaluation("cannot use string to index array", this); }, + [this](PatternData*) -> u128 { LogConsole::abortEvaluation("cannot use custom type to index array", this); }, + [](auto &&size) -> u128 { return size; } + }, literal->getValue()); + + for (u64 i = 0; i < entryCount; i++) { + auto pattern = this->m_type->createPatterns(evaluator).front(); + + pattern->setVariableName(hex::format("[{}]", i)); + pattern->setEndian(arrayPattern->getEndian()); + pattern->setColor(arrayPattern->getColor()); + entries.push_back(pattern); + + size += pattern->getSize(); + } + } else if (auto whileStatement = dynamic_cast(sizeNode)) { + while (whileStatement->evaluateCondition(evaluator)) { + auto pattern = this->m_type->createPatterns(evaluator).front(); + + pattern->setVariableName(hex::format("[{}]", entryCount)); + pattern->setEndian(arrayPattern->getEndian()); + pattern->setColor(arrayPattern->getColor()); + entries.push_back(pattern); + + entryCount++; + size += pattern->getSize(); + } + } + } else { + while (true) { + auto pattern = this->m_type->createPatterns(evaluator).front(); + std::vector buffer(pattern->getSize()); + + if (evaluator->dataOffset() >= evaluator->getProvider()->getActualSize() - buffer.size()) { + delete pattern; + LogConsole::abortEvaluation("reached end of file before finding end of unsized array", this); + } + + pattern->setVariableName(hex::format("[{}]", entryCount)); + pattern->setEndian(arrayPattern->getEndian()); + pattern->setColor(arrayPattern->getColor()); + entries.push_back(pattern); + + size += pattern->getSize(); + entryCount++; + + evaluator->getProvider()->read(evaluator->dataOffset() - pattern->getSize(), buffer.data(), buffer.size()); + bool reachedEnd = true; + for (u8 &byte : buffer) { + if (byte != 0x00) { + reachedEnd = false; + break; + } + } + + if (reachedEnd) break; + } + } + + arrayPattern->setEntries(entries); + arrayPattern->setSize(size); + + return arrayPattern; + } }; class ASTNodePointerVariableDecl : public ASTNode, public Attributable { @@ -278,6 +902,30 @@ namespace hex::pl { [[nodiscard]] constexpr ASTNode* getSizeType() const { return this->m_sizeType; } [[nodiscard]] constexpr auto getPlacementOffset() const { return this->m_placementOffset; } + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + if (this->m_placementOffset != nullptr) { + auto offset = dynamic_cast(this->m_placementOffset->evaluate(evaluator)); + ON_SCOPE_EXIT { delete offset; }; + + evaluator->dataOffset() = std::visit(overloaded { + [this](std::string) -> u64 { LogConsole::abortEvaluation("placement offset cannot be a string", this); }, + [this](PatternData*) -> u64 { LogConsole::abortEvaluation("placement offset cannot be a custom type", this); }, + [](auto &&offset) -> u64 { return u64(offset); } + }, offset->getValue()); + } + + auto sizePattern = this->m_sizeType->createPatterns(evaluator).front(); + ON_SCOPE_EXIT { delete sizePattern; }; + + auto pattern = new PatternDataPointer(evaluator->dataOffset(), sizePattern->getSize()); + pattern->setPointedAtPattern(this->m_type->createPatterns(evaluator).front()); + pattern->setVariableName(this->m_name); + + applyVariableAttributes(evaluator, this, pattern); + + return { pattern }; + } + private: std::string m_name; ASTNode *m_type; @@ -307,6 +955,27 @@ namespace hex::pl { return this->m_variables; } + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + std::vector patterns; + + for (auto &node : this->m_variables) { + auto newPatterns = node->createPatterns(evaluator); + patterns.insert(patterns.end(), newPatterns.begin(), newPatterns.end()); + } + + return patterns; + } + + FunctionResult execute(Evaluator *evaluator) override { + for (auto &variable : this->m_variables) { + auto variableDecl = dynamic_cast(variable); + + evaluator->createVariable(variableDecl->getName(), variableDecl->getType()->evaluate(evaluator)); + } + + return { false, { } }; + } + private: std::vector m_variables; }; @@ -329,6 +998,26 @@ namespace hex::pl { return new ASTNodeStruct(*this); } + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + auto pattern = new PatternDataStruct(evaluator->dataOffset(), 0); + + u64 startOffset = evaluator->dataOffset(); + std::vector memberPatterns; + + evaluator->pushScope(pattern, memberPatterns); + for (auto member : this->m_members) { + for (auto &memberPattern : member->createPatterns(evaluator)) { + memberPatterns.push_back(memberPattern); + } + } + evaluator->popScope(); + + pattern->setMembers(memberPatterns); + pattern->setSize(evaluator->dataOffset() - startOffset); + + return { pattern }; + } + [[nodiscard]] const std::vector& getMembers() const { return this->m_members; } void addMember(ASTNode *node) { this->m_members.push_back(node); } @@ -354,6 +1043,30 @@ namespace hex::pl { return new ASTNodeUnion(*this); } + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + auto pattern = new PatternDataUnion(evaluator->dataOffset(), 0); + + size_t size = 0; + std::vector memberPatterns; + u64 startOffset = evaluator->dataOffset(); + + evaluator->pushScope(pattern, memberPatterns); + for (auto member : this->m_members) { + for (auto &memberPattern : member->createPatterns(evaluator)) { + memberPattern->setOffset(startOffset); + memberPatterns.push_back(memberPattern); + size = std::max(memberPattern->getSize(), size); + } + } + evaluator->popScope(); + + evaluator->dataOffset() = startOffset + size; + pattern->setMembers(memberPatterns); + pattern->setSize(size); + + return { pattern }; + } + [[nodiscard]] const std::vector& getMembers() const { return this->m_members; } void addMember(ASTNode *node) { this->m_members.push_back(node); } @@ -367,7 +1080,7 @@ namespace hex::pl { ASTNodeEnum(const ASTNodeEnum &other) : ASTNode(other), Attributable(other) { for (const auto &[name, entry] : other.getEntries()) - this->m_entries.insert({ name, entry->clone() }); + this->m_entries.emplace(name, entry->clone()); this->m_underlyingType = other.m_underlyingType->clone(); } @@ -381,6 +1094,27 @@ namespace hex::pl { return new ASTNodeEnum(*this); } + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + auto pattern = new PatternDataEnum(evaluator->dataOffset(), 0); + + std::vector> enumEntries; + for (const auto &[name, value] : this->m_entries) { + auto literal = dynamic_cast(value->evaluate(evaluator)); + ON_SCOPE_EXIT { delete literal; }; + + enumEntries.emplace_back(literal->getValue(), name); + } + + pattern->setEnumValues(enumEntries); + + auto underlying = this->m_underlyingType->createPatterns(evaluator).front(); + ON_SCOPE_EXIT { delete underlying; }; + pattern->setSize(underlying->getSize()); + pattern->setEndian(underlying->getEndian()); + + return { pattern }; + } + [[nodiscard]] const std::map& getEntries() const { return this->m_entries; } void addEntry(const std::string &name, ASTNode* expression) { this->m_entries.insert({ name, expression }); } @@ -412,6 +1146,38 @@ namespace hex::pl { [[nodiscard]] const std::vector>& getEntries() const { return this->m_entries; } void addEntry(const std::string &name, ASTNode* size) { this->m_entries.emplace_back(name, size); } + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + auto pattern = new PatternDataBitfield(evaluator->dataOffset(), 0); + + size_t bitOffset = 0; + std::vector fields; + evaluator->pushScope(pattern, fields); + for (auto [name, bitSizeNode] : this->m_entries) { + auto literal = bitSizeNode->evaluate(evaluator); + ON_SCOPE_EXIT { delete literal; }; + + u8 bitSize = std::visit(overloaded { + [this](std::string) -> u8 { LogConsole::abortEvaluation("bitfield field size cannot be a string", this); }, + [this](PatternData*) -> u8 { LogConsole::abortEvaluation("bitfield field size cannot be a custom type", this); }, + [](auto &&offset) -> u8 { return static_cast(offset); } + }, dynamic_cast(literal)->getValue()); + + auto field = new PatternDataBitfieldField(evaluator->dataOffset(), bitOffset, bitSize); + field->setVariableName(name); + + bitOffset += bitSize; + fields.push_back(field); + } + evaluator->popScope(); + + pattern->setSize((bitOffset + 7) / 8); + pattern->setFields(fields); + + evaluator->dataOffset() += pattern->getSize(); + + return { pattern }; + } + private: std::vector> m_entries; }; @@ -435,10 +1201,187 @@ namespace hex::pl { return new ASTNodeRValue(*this); } - const Path& getPath() { + [[nodiscard]] + const Path& getPath() const { return this->m_path; } + [[nodiscard]] ASTNode* evaluate(Evaluator *evaluator) const override { + if (this->getPath().size() == 1) { + if (auto name = std::get_if(&this->getPath().front()); name != nullptr) { + if (*name == "$") return new ASTNodeLiteral(u128(evaluator->dataOffset())); + } + } + + auto pattern = this->createPatterns(evaluator).front(); + ON_SCOPE_EXIT { delete pattern; }; + + auto readValue = [&evaluator](auto &value, PatternData *pattern) { + if (pattern->isLocal()) { + auto &literal = evaluator->getStack()[pattern->getOffset()]; + + std::visit(overloaded { + [&](std::string &assignmentValue) { }, + [&](auto &&assignmentValue) { std::memcpy(&value, &assignmentValue, pattern->getSize()); } + }, literal); + } + else + evaluator->getProvider()->read(pattern->getOffset(), &value, pattern->getSize()); + }; + + Token::Literal literal; + if (dynamic_cast(pattern)) { + u128 value = 0; + readValue(value, pattern); + literal = value; + } else if (dynamic_cast(pattern)) { + s128 value = 0; + readValue(value, pattern); + literal = value; + } else if (dynamic_cast(pattern)) { + if (pattern->getSize() == sizeof(u16)) { + u16 value = 0; + readValue(value, pattern); + literal = double(float16ToFloat32(value)); + } else if (pattern->getSize() == sizeof(float)) { + float value = 0; + readValue(value, pattern); + literal = double(value); + } else if (pattern->getSize() == sizeof(double)) { + double value = 0; + readValue(value, pattern); + literal = value; + } else LogConsole::abortEvaluation("invalid floating point type access", this); + } else if (dynamic_cast(pattern)) { + char value = 0; + readValue(value, pattern); + literal = value; + } else if (dynamic_cast(pattern)) { + bool value = false; + readValue(value, pattern); + literal = value; + } else if (dynamic_cast(pattern)) { + std::string value; + + if (pattern->isLocal()) { + auto &literal = evaluator->getStack()[pattern->getOffset()]; + + std::visit(overloaded { + [&](std::string &assignmentValue) { value = assignmentValue; }, + [&, this](auto &&assignmentValue) { LogConsole::abortEvaluation(hex::format("cannot assign '{}' to string", pattern->getTypeName()), this); } + }, literal); + } + else + evaluator->getProvider()->read(pattern->getOffset(), value.data(), pattern->getSize()); + + literal = value; + } else if (auto bitfieldFieldPattern = dynamic_cast(pattern)) { + u64 value = 0; + readValue(value, pattern); + literal = u128(hex::extract(bitfieldFieldPattern->getBitOffset() + (bitfieldFieldPattern->getBitSize() - 1), bitfieldFieldPattern->getBitOffset(), value)); + } else { + literal = pattern->clone(); + } + + return new ASTNodeLiteral(literal); + } + + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + s32 scopeIndex = 0; + + auto searchScope = *evaluator->getScope(scopeIndex).scope; + PatternData *currPattern = nullptr; + + for (const auto &part : this->getPath()) { + + if (part.index() == 0) { + // Variable access + auto name = std::get(part); + + if (name == "parent") { + scopeIndex--; + searchScope = *evaluator->getScope(scopeIndex).scope; + currPattern = searchScope.front()->getParent(); + } else if (name == "this") { + searchScope = *evaluator->getScope(scopeIndex).scope; + + auto currParent = evaluator->getScope(0).parent; + + if (currParent == nullptr) + LogConsole::abortEvaluation("invalid use of 'this' outside of struct-like type", this); + + currPattern = currParent->clone(); + continue; + } else { + bool found = false; + for (const auto &variable : searchScope) { + if (variable->getVariableName() == name) { + auto newPattern = variable->clone(); + delete currPattern; + currPattern = newPattern; + found = true; + break; + } + } + + if (name == "$") + LogConsole::abortEvaluation("invalid use of placeholder operator in rvalue"); + + if (!found) + LogConsole::abortEvaluation(hex::format("no variable named '{}' found", name), this); + } + } else { + // Array indexing + auto index = dynamic_cast(std::get(part)->evaluate(evaluator)); + ON_SCOPE_EXIT { delete index; }; + + std::visit(overloaded { + [](std::string) { throw std::string("cannot use string to index array"); }, + [](PatternData * const &) { throw std::string("cannot use custom type to index array"); }, + [&, this](auto &&index) { + if (auto dynamicArrayPattern = dynamic_cast(currPattern)) { + if (index >= searchScope.size() || index < 0) + LogConsole::abortEvaluation("array index out of bounds", this); + + auto newPattern = searchScope[index]->clone(); + delete currPattern; + currPattern = newPattern; + } + else if (auto staticArrayPattern = dynamic_cast(currPattern)) { + if (index >= staticArrayPattern->getEntryCount() || index < 0) + LogConsole::abortEvaluation("array index out of bounds", this); + + auto newPattern = searchScope.front()->clone(); + delete currPattern; + currPattern = newPattern; + currPattern->setOffset(staticArrayPattern->getOffset() + index * staticArrayPattern->getSize()); + } + } + }, index->getValue()); + } + + if (auto pointerPattern = dynamic_cast(currPattern)) { + auto newPattern = pointerPattern->getPointedAtPattern()->clone(); + delete currPattern; + currPattern = newPattern; + } + + if (auto structPattern = dynamic_cast(currPattern)) + searchScope = structPattern->getMembers(); + else if (auto unionPattern = dynamic_cast(currPattern)) + searchScope = unionPattern->getMembers(); + else if (auto bitfieldPattern = dynamic_cast(currPattern)) + searchScope = bitfieldPattern->getFields(); + else if (auto dynamicArrayPattern = dynamic_cast(currPattern)) + searchScope = dynamicArrayPattern->getEntries(); + else if (auto staticArrayPattern = dynamic_cast(currPattern)) + searchScope = { staticArrayPattern->getTemplate() }; + + } + + return { currPattern }; + } + private: Path m_path; }; @@ -488,6 +1431,19 @@ namespace hex::pl { return new ASTNodeConditionalStatement(*this); } + [[nodiscard]] std::vector createPatterns(Evaluator *evaluator) const override { + std::vector patterns; + + auto &body = evaluateCondition(evaluator) ? this->m_trueBody : this->m_falseBody; + + for (auto &node : body) { + auto newPatterns = node->createPatterns(evaluator); + patterns.insert(patterns.end(), newPatterns.begin(), newPatterns.end()); + } + + return patterns; + } + [[nodiscard]] ASTNode* getCondition() { return this->m_condition; } @@ -500,44 +1456,50 @@ namespace hex::pl { return this->m_falseBody; } + FunctionResult execute(Evaluator *evaluator) override { + auto &body = evaluateCondition(evaluator) ? this->m_trueBody : this->m_falseBody; + + auto variables = *evaluator->getScope(0).scope; + u32 startVariableCount = variables.size(); + ON_SCOPE_EXIT { + s64 stackSize = evaluator->getStack().size(); + for (u32 i = startVariableCount; i < variables.size(); i++) { + stackSize--; + delete variables[i]; + } + if (stackSize < 0) LogConsole::abortEvaluation("stack pointer underflow!", this); + evaluator->getStack().resize(stackSize); + }; + + evaluator->pushScope(nullptr, variables); + ON_SCOPE_EXIT { evaluator->popScope(); }; + for (auto &statement : body) { + auto [executionStopped, result] = statement->execute(evaluator); + if (executionStopped) { + return { true, result }; + } + } + + return { false, { } }; + } + private: + [[nodiscard]] + bool evaluateCondition(Evaluator *evaluator) const { + auto literal = dynamic_cast(this->m_condition->evaluate(evaluator)); + ON_SCOPE_EXIT { delete literal; }; + + return std::visit(overloaded { + [](std::string value) -> bool { return !value.empty(); }, + [this](PatternData * const &) -> bool { LogConsole::abortEvaluation("cannot cast custom type to bool", this); }, + [](auto &&value) -> bool { return value != 0; } + }, literal->getValue()); + } + ASTNode *m_condition; std::vector m_trueBody, m_falseBody; }; - class ASTNodeWhileStatement : public ASTNode { - public: - explicit ASTNodeWhileStatement(ASTNode *condition, std::vector body) - : ASTNode(), m_condition(condition), m_body(std::move(body)) { } - - ~ASTNodeWhileStatement() override { - delete this->m_condition; - - for (auto &statement : this->m_body) - delete statement; - } - - ASTNodeWhileStatement(const ASTNodeWhileStatement &other) : ASTNode(other) { - this->m_condition = other.m_condition->clone(); - } - - [[nodiscard]] ASTNode* clone() const override { - return new ASTNodeWhileStatement(*this); - } - - [[nodiscard]] ASTNode* getCondition() { - return this->m_condition; - } - - [[nodiscard]] const std::vector& getBody() { - return this->m_body; - } - - private: - ASTNode *m_condition; - std::vector m_body; - }; - class ASTNodeFunctionCall : public ASTNode { public: explicit ASTNodeFunctionCall(std::string functionName, std::vector params) @@ -567,62 +1529,66 @@ namespace hex::pl { return this->m_params; } + [[nodiscard]] ASTNode* evaluate(Evaluator *evaluator) const override { + std::vector evaluatedParams; + for (auto param : this->m_params) { + auto expression = param->evaluate(evaluator); + ON_SCOPE_EXIT { delete expression; }; + + auto literal = dynamic_cast(expression->evaluate(evaluator)); + ON_SCOPE_EXIT { delete literal; }; + + evaluatedParams.push_back(literal->getValue()); + } + + auto &customFunctions = evaluator->getCustomFunctions(); + auto functions = ContentRegistry::PatternLanguageFunctions::getEntries(); + + for (auto &func : customFunctions) + functions.insert(func); + + if (!functions.contains(this->m_functionName)) + LogConsole::abortEvaluation(hex::format("call to unknown function '{}'", this->m_functionName), this); + + auto function = functions[this->m_functionName]; + if (function.parameterCount == ContentRegistry::PatternLanguageFunctions::UnlimitedParameters) { + ; // Don't check parameter count + } + else if (function.parameterCount & ContentRegistry::PatternLanguageFunctions::LessParametersThan) { + if (evaluatedParams.size() >= (function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan)) + LogConsole::abortEvaluation(hex::format("too many parameters for function '{0}'. Expected {1}", this->m_functionName, function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan), this); + } else if (function.parameterCount & ContentRegistry::PatternLanguageFunctions::MoreParametersThan) { + if (evaluatedParams.size() <= (function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan)) + LogConsole::abortEvaluation(hex::format("too few parameters for function '{0}'. Expected {1}", this->m_functionName, function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan), this); + } else if (function.parameterCount != evaluatedParams.size()) { + LogConsole::abortEvaluation(hex::format("invalid number of parameters for function '{0}'. Expected {1}", this->m_functionName, function.parameterCount), this); + } + + try { + auto result = functions[this->m_functionName].func(evaluator, evaluatedParams); + + if (result.has_value()) + return new ASTNodeLiteral(result.value()); + else + return new ASTNodeMathematicalExpression(nullptr, nullptr, Token::Operator::Plus); + } catch (std::string &error) { + LogConsole::abortEvaluation(error, this); + } + + return nullptr; + } + + FunctionResult execute(Evaluator *evaluator) override { + delete this->evaluate(evaluator); + + return { false, { } }; + } + private: std::string m_functionName; std::vector m_params; }; - class ASTNodeStringLiteral : public ASTNode { - public: - explicit ASTNodeStringLiteral(std::string string) : ASTNode(), m_string(std::move(string)) { } - - ~ASTNodeStringLiteral() override = default; - - ASTNodeStringLiteral(const ASTNodeStringLiteral &other) : ASTNode(other) { - this->m_string = other.m_string; - } - - [[nodiscard]] ASTNode* clone() const override { - return new ASTNodeStringLiteral(*this); - } - - [[nodiscard]] const std::string& getString() { - return this->m_string; - } - - private: - std::string m_string; - }; - - class ASTNodeAttribute : public ASTNode { - public: - explicit ASTNodeAttribute(std::string attribute, std::optional value = std::nullopt) - : ASTNode(), m_attribute(std::move(attribute)), m_value(std::move(value)) { } - - ~ASTNodeAttribute() override = default; - - ASTNodeAttribute(const ASTNodeAttribute &other) : ASTNode(other) { - this->m_attribute = other.m_attribute; - this->m_value = other.m_value; - } - - [[nodiscard]] ASTNode* clone() const override { - return new ASTNodeAttribute(*this); - } - - [[nodiscard]] const std::string& getAttribute() const { - return this->m_attribute; - } - - [[nodiscard]] const std::optional& getValue() const { - return this->m_value; - } - - private: - std::string m_attribute; - std::optional m_value; - }; - class ASTNodeTypeOperator : public ASTNode { public: ASTNodeTypeOperator(Token::Operator op, ASTNode *expression) : m_op(op), m_expression(expression) { @@ -642,64 +1608,41 @@ namespace hex::pl { delete this->m_expression; } + [[nodiscard]] Token::Operator getOperator() const { return this->m_op; } + [[nodiscard]] ASTNode* getExpression() const { return this->m_expression; } + [[nodiscard]] + ASTNode* evaluate(Evaluator *evaluator) const override { + auto pattern = this->m_expression->createPatterns(evaluator).front(); + ON_SCOPE_EXIT { delete pattern; }; + + switch (this->getOperator()) { + case Token::Operator::AddressOf: + return new ASTNodeLiteral(u128(pattern->getOffset())); + case Token::Operator::SizeOf: + return new ASTNodeLiteral(u128(pattern->getSize())); + default: + LogConsole::abortEvaluation("invalid type operator", this); + } + } + + private: Token::Operator m_op; ASTNode *m_expression; }; - class ASTNodeFunctionDefinition : public ASTNode { - public: - ASTNodeFunctionDefinition(std::string name, std::vector params, std::vector body) - : m_name(std::move(name)), m_params(std::move(params)), m_body(std::move(body)) { - - } - - ASTNodeFunctionDefinition(const ASTNodeFunctionDefinition &other) : ASTNode(other) { - this->m_name = other.m_name; - this->m_params = other.m_params; - - for (auto statement : other.m_body) { - this->m_body.push_back(statement->clone()); - } - } - - [[nodiscard]] ASTNode* clone() const override { - return new ASTNodeFunctionDefinition(*this); - } - - ~ASTNodeFunctionDefinition() override { - for (auto statement : this->m_body) - delete statement; - } - - [[nodiscard]] const std::string& getName() const { - return this->m_name; - } - - [[nodiscard]] const auto& getParams() const { - return this->m_params; - } - - [[nodiscard]] const auto& getBody() const { - return this->m_body; - } - - private: - std::string m_name; - std::vector m_params; - std::vector m_body; - }; class ASTNodeAssignment : public ASTNode { public: + // TODO: Implement this ASTNodeAssignment(std::string lvalueName, ASTNode *rvalue) : m_lvalueName(std::move(lvalueName)), m_rvalue(rvalue) { } @@ -725,6 +1668,15 @@ namespace hex::pl { return this->m_rvalue; } + FunctionResult execute(Evaluator *evaluator) override { + auto literal = dynamic_cast(this->getRValue()->evaluate(evaluator)); + ON_SCOPE_EXIT { delete literal; }; + + evaluator->setVariable(this->getLValueName(), literal->getValue()); + + return { false, { } }; + } + private: std::string m_lvalueName; ASTNode *m_rvalue; @@ -732,7 +1684,8 @@ namespace hex::pl { class ASTNodeReturnStatement : public ASTNode { public: - ASTNodeReturnStatement(ASTNode *rvalue) : m_rvalue(rvalue) { + // TODO: Implement this + explicit ASTNodeReturnStatement(ASTNode *rvalue) : m_rvalue(rvalue) { } @@ -748,11 +1701,106 @@ namespace hex::pl { delete this->m_rvalue; } - [[nodiscard]] ASTNode* getRValue() const { + [[nodiscard]] ASTNode* getReturnValue() const { return this->m_rvalue; } + FunctionResult execute(Evaluator *evaluator) override { + auto returnValue = this->getReturnValue(); + + if (returnValue == nullptr) + return { true, std::nullopt }; + else { + auto literal = dynamic_cast(returnValue->evaluate(evaluator)); + ON_SCOPE_EXIT { delete literal; }; + + return { true, literal->getValue() }; + } + } + private: ASTNode *m_rvalue; }; + + class ASTNodeFunctionDefinition : public ASTNode { + public: + // TODO: Implement this + ASTNodeFunctionDefinition(std::string name, std::map params, std::vector body) + : m_name(std::move(name)), m_params(std::move(params)), m_body(std::move(body)) { + + } + + ASTNodeFunctionDefinition(const ASTNodeFunctionDefinition &other) : ASTNode(other) { + this->m_name = other.m_name; + this->m_params = other.m_params; + + for (const auto &[name, type] : other.m_params) { + this->m_params.emplace(name, type->clone()); + } + + for (auto statement : other.m_body) { + this->m_body.push_back(statement->clone()); + } + } + + [[nodiscard]] ASTNode* clone() const override { + return new ASTNodeFunctionDefinition(*this); + } + + ~ASTNodeFunctionDefinition() override { + for (auto &[name, type] : this->m_params) + delete type; + for (auto statement : this->m_body) + delete statement; + } + + [[nodiscard]] const std::string& getName() const { + return this->m_name; + } + + [[nodiscard]] const auto& getParams() const { + return this->m_params; + } + + [[nodiscard]] const auto& getBody() const { + return this->m_body; + } + + [[nodiscard]] ASTNode* evaluate(Evaluator *evaluator) const override { + + evaluator->addCustomFunction(this->m_name, this->m_params.size(), [this](Evaluator *ctx, const std::vector& params) -> std::optional { + std::vector variables; + + ctx->pushScope(nullptr, variables); + ON_SCOPE_EXIT { ctx->popScope(); }; + + u32 paramIndex = 0; + for (const auto &[name, type] : this->m_params) { + ctx->createVariable(name, type); + ctx->setVariable(name, params[paramIndex]); + + paramIndex++; + } + + for (auto statement : this->m_body) { + auto [executionStopped, result] = statement->execute(ctx); + + if (executionStopped) { + return result; + } + } + + return { }; + }); + + return nullptr; + } + + + private: + std::string m_name; + std::map m_params; + std::vector m_body; + }; + } \ No newline at end of file diff --git a/plugins/libimhex/include/hex/pattern_language/ast_node_base.hpp b/plugins/libimhex/include/hex/pattern_language/ast_node_base.hpp new file mode 100644 index 000000000..624488f04 --- /dev/null +++ b/plugins/libimhex/include/hex/pattern_language/ast_node_base.hpp @@ -0,0 +1,65 @@ +#pragma once + +#include +#include + +#include + +namespace hex::pl { + + class ASTNode; + class ASTNodeAttribute; + + class PatternData; + class Evaluator; + + class Attributable { + protected: + Attributable() = default; + + Attributable(const Attributable &) = default; + + public: + + void addAttribute(ASTNodeAttribute *attribute) { + this->m_attributes.push_back(attribute); + } + + [[nodiscard]] const auto &getAttributes() const { + return this->m_attributes; + } + + private: + std::vector m_attributes; + }; + + class Clonable { + public: + [[nodiscard]] + virtual ASTNode* clone() const = 0; + }; + + class ASTNode : public Clonable { + public: + constexpr ASTNode() = default; + + constexpr virtual ~ASTNode() = default; + + constexpr ASTNode(const ASTNode &) = default; + + [[nodiscard]] constexpr u32 getLineNumber() const { return this->m_lineNumber; } + + [[maybe_unused]] constexpr void setLineNumber(u32 lineNumber) { this->m_lineNumber = lineNumber; } + + [[nodiscard]] virtual ASTNode *evaluate(Evaluator *evaluator) const { return this->clone(); } + + [[nodiscard]] virtual std::vector createPatterns(Evaluator *evaluator) const { return {}; } + + using FunctionResult = std::pair>; + virtual FunctionResult execute(Evaluator *evaluator) { throw std::pair(this->getLineNumber(), "cannot execute non-function statement"); } + + private: + u32 m_lineNumber = 1; + }; + +} \ No newline at end of file diff --git a/plugins/libimhex/include/hex/pattern_language/evaluator.hpp b/plugins/libimhex/include/hex/pattern_language/evaluator.hpp index 7487626c1..e0241b9bd 100644 --- a/plugins/libimhex/include/hex/pattern_language/evaluator.hpp +++ b/plugins/libimhex/include/hex/pattern_language/evaluator.hpp @@ -1,93 +1,96 @@ #pragma once -#include - -#include -#include -#include - #include -#include #include +#include #include -#define LITERAL_COMPARE(literal, cond) std::visit([&](auto &&literal) { return (cond) != 0; }, literal) -#define AS_TYPE(type, value) ctx.template asType(value) +#include +#include namespace hex::prv { class Provider; } namespace hex::pl { class PatternData; + class ASTNode; class Evaluator { public: Evaluator() = default; - std::optional> evaluate(const std::vector& ast); + std::optional> evaluate(const std::vector &ast); - LogConsole& getConsole() { return this->m_console; } - - void setDefaultEndian(std::endian endian) { this->m_defaultDataEndian = endian; } - void setRecursionLimit(u32 limit) { this->m_recursionLimit = limit; } - void setProvider(prv::Provider *provider) { this->m_provider = provider; } - [[nodiscard]] std::endian getCurrentEndian() const { return this->m_endianStack.back(); } - - PatternData* patternFromName(const ASTNodeRValue::Path &name); - - template - T* asType(ASTNode *param) { - if (auto evaluatedParam = dynamic_cast(param); evaluatedParam != nullptr) - return evaluatedParam; - else - this->getConsole().abortEvaluation("function got wrong type of parameter"); + [[nodiscard]] + LogConsole& getConsole() { + return this->m_console; } + struct Scope { PatternData *parent; std::vector* scope; }; + void pushScope(PatternData *parent, std::vector &scope) { this->m_scopes.push_back({ parent, &scope }); } + void popScope() { this->m_scopes.pop_back(); } + const Scope& getScope(s32 index) { + static Scope empty; + + if (index > 0 || -index >= this->m_scopes.size()) return empty; + return this->m_scopes[this->m_scopes.size() - 1 + index]; + } + + const Scope& getGlobalScope() { + return this->m_scopes.front(); + } + + void setProvider(prv::Provider *provider) { + this->m_provider = provider; + } + + [[nodiscard]] + prv::Provider *getProvider() const { + return this->m_provider; + } + + void setDefaultEndian(std::endian endian) { + this->m_defaultEndian = endian; + } + + [[nodiscard]] + std::endian getDefaultEndian() const { + return this->m_defaultEndian; + } + + u64& dataOffset() { return this->m_currOffset; } + + bool addCustomFunction(const std::string &name, u32 numParams, const ContentRegistry::PatternLanguageFunctions::Callback &function) { + const auto [iter, inserted] = this->m_customFunctions.insert({ name, { numParams, function } }); + + return inserted; + } + + [[nodiscard]] + const std::map& getCustomFunctions() const { + return this->m_customFunctions; + } + + [[nodiscard]] + std::vector& getStack() { + return this->m_stack; + } + + void createVariable(const std::string &name, ASTNode *type); + + void setVariable(const std::string &name, const Token::Literal& value); + private: - std::map m_types; - prv::Provider* m_provider = nullptr; - std::endian m_defaultDataEndian = std::endian::native; - u64 m_currOffset = 0; - std::vector m_endianStack; - std::vector m_globalMembers; - std::vector*> m_currMembers; - std::vector*> m_localVariables; - std::vector m_currMemberScope; - std::vector m_localStack; - std::map m_definedFunctions; + u64 m_currOffset; + prv::Provider *m_provider = nullptr; LogConsole m_console; - u32 m_recursionLimit; - u32 m_currRecursionDepth; + std::endian m_defaultEndian = std::endian::native; - void createLocalVariable(const std::string &varName, PatternData *pattern); - void setLocalVariableValue(const std::string &varName, const void *value, size_t size); - - ASTNodeIntegerLiteral* evaluateScopeResolution(ASTNodeScopeResolution *node); - ASTNodeIntegerLiteral* evaluateRValue(ASTNodeRValue *node); - ASTNode* evaluateFunctionCall(ASTNodeFunctionCall *node); - ASTNodeIntegerLiteral* evaluateTypeOperator(ASTNodeTypeOperator *typeOperatorNode); - ASTNodeIntegerLiteral* evaluateOperator(ASTNodeIntegerLiteral *left, ASTNodeIntegerLiteral *right, Token::Operator op); - ASTNodeIntegerLiteral* evaluateOperand(ASTNode *node); - ASTNodeIntegerLiteral* evaluateTernaryExpression(ASTNodeTernaryExpression *node); - ASTNodeIntegerLiteral* evaluateMathematicalExpression(ASTNodeNumericExpression *node); - void evaluateFunctionDefinition(ASTNodeFunctionDefinition *node); - std::optional evaluateFunctionBody(const std::vector &body); - - PatternData* findPattern(std::vector currMembers, const ASTNodeRValue::Path &path); - PatternData* evaluateAttributes(ASTNode *currNode, PatternData *currPattern); - PatternData* evaluateBuiltinType(ASTNodeBuiltinType *node); - void evaluateMember(ASTNode *node, std::vector &currMembers, bool increaseOffset); - PatternData* evaluateStruct(ASTNodeStruct *node); - PatternData* evaluateUnion(ASTNodeUnion *node); - PatternData* evaluateEnum(ASTNodeEnum *node); - PatternData* evaluateBitfield(ASTNodeBitfield *node); - PatternData* evaluateType(ASTNodeTypeDecl *node); - PatternData* evaluateVariable(ASTNodeVariableDecl *node); - PatternData* evaluateArray(ASTNodeArrayVariableDecl *node); - PatternData* evaluateStaticArray(ASTNodeArrayVariableDecl *node); - PatternData* evaluateDynamicArray(ASTNodeArrayVariableDecl *node); - PatternData* evaluatePointer(ASTNodePointerVariableDecl *node); + std::vector m_scopes; + std::map m_customFunctions; + std::vector m_customFunctionDefinitions; + std::vector m_stack; }; } \ No newline at end of file diff --git a/plugins/libimhex/include/hex/pattern_language/log_console.hpp b/plugins/libimhex/include/hex/pattern_language/log_console.hpp index d0fbc4404..d86470177 100644 --- a/plugins/libimhex/include/hex/pattern_language/log_console.hpp +++ b/plugins/libimhex/include/hex/pattern_language/log_console.hpp @@ -7,8 +7,12 @@ #include #include +#include + namespace hex::pl { + class ASTNode; + class LogConsole { public: enum Level { @@ -18,9 +22,10 @@ namespace hex::pl { Error }; - const auto& getLog() { return this->m_consoleLog; } + [[nodiscard]] + const auto& getLog() const { return this->m_consoleLog; } - using EvaluateError = std::string; + using EvaluateError = std::pair; void log(Level level, const std::string &message) { switch (level) { @@ -32,16 +37,29 @@ namespace hex::pl { } } - [[noreturn]] void abortEvaluation(const std::string &message) { - throw EvaluateError(message); + [[noreturn]] + static void abortEvaluation(const std::string &message) { + throw EvaluateError(0, message); + } + + [[noreturn]] + static void abortEvaluation(const std::string &message, const auto *node) { + throw EvaluateError(static_cast(node)->getLineNumber(), message); } void clear() { this->m_consoleLog.clear(); + this->m_lastHardError = { }; } + void setHardError(const EvaluateError &error) { this->m_lastHardError = error; } + + [[nodiscard]] + const LogConsole::EvaluateError& getLastHardError() { return this->m_lastHardError; }; + private: std::vector> m_consoleLog; + EvaluateError m_lastHardError; }; } \ No newline at end of file diff --git a/plugins/libimhex/include/hex/pattern_language/parser.hpp b/plugins/libimhex/include/hex/pattern_language/parser.hpp index 8e79ffdc7..a24e156c1 100644 --- a/plugins/libimhex/include/hex/pattern_language/parser.hpp +++ b/plugins/libimhex/include/hex/pattern_language/parser.hpp @@ -37,6 +37,11 @@ namespace hex::pl { return this->m_curr[index].lineNumber; } + auto* create(auto *node) { + node->setLineNumber(this->getLineNumber(-1)); + return node; + } + template const T& getValue(s32 index) const { auto value = std::get_if(&this->m_curr[index].value); @@ -68,6 +73,7 @@ namespace hex::pl { ASTNode* parseScopeResolution(); ASTNode* parseRValue(ASTNodeRValue::Path &path); ASTNode* parseFactor(); + ASTNode* parseCastExpression(); ASTNode* parseUnaryExpression(); ASTNode* parseMultiplicativeExpression(); ASTNode* parseAdditiveExpression(); @@ -83,7 +89,7 @@ namespace hex::pl { ASTNode* parseTernaryConditional(); ASTNode* parseMathematicalExpression(); - ASTNode* parseFunctionDefintion(); + ASTNode* parseFunctionDefinition(); ASTNode* parseFunctionStatement(); ASTNode* parseFunctionVariableAssignment(); ASTNode* parseFunctionReturnStatement(); @@ -93,7 +99,7 @@ namespace hex::pl { void parseAttribute(Attributable *currNode); ASTNode* parseConditional(); ASTNode* parseWhileStatement(); - ASTNodeTypeDecl* parseType(); + ASTNodeTypeDecl* parseType(bool allowString = false); ASTNode* parseUsingDeclaration(); ASTNode* parsePadding(); ASTNode* parseMemberVariable(ASTNodeTypeDecl *type); @@ -238,12 +244,6 @@ namespace hex::pl { return this->m_curr[index].type == type && this->m_curr[index] == value; } - bool peekOptional(Token::Type type, auto value, u32 index = 0) { - if (index >= this->m_matchedOptionals.size()) - return false; - return peek(type, value, std::distance(this->m_curr, this->m_matchedOptionals[index])); - } - }; } \ No newline at end of file diff --git a/plugins/libimhex/include/hex/pattern_language/pattern_data.hpp b/plugins/libimhex/include/hex/pattern_language/pattern_data.hpp index 2eaa60807..fc08bc045 100644 --- a/plugins/libimhex/include/hex/pattern_language/pattern_data.hpp +++ b/plugins/libimhex/include/hex/pattern_language/pattern_data.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -87,6 +88,13 @@ namespace hex::pl { [[nodiscard]] PatternData* getParent() const { return this->m_parent; } void setParent(PatternData *parent) { this->m_parent = parent; } + [[nodiscard]] std::string getDisplayName() const { return this->m_displayName.value_or(this->m_variableName); } + void setDisplayName(const std::string &name) { this->m_displayName = name; } + + void setFormatterFunction(const ContentRegistry::PatternLanguageFunctions::Function &function, Evaluator *evaluator) { + this->m_formatterFunction = { function, evaluator }; + } + virtual void createEntry(prv::Provider* &provider) = 0; [[nodiscard]] virtual std::string getFormattedName() const = 0; @@ -117,9 +125,9 @@ namespace hex::pl { static bool sortPatternDataTable(ImGuiTableSortSpecs *sortSpecs, prv::Provider *provider, pl::PatternData* left, pl::PatternData* right) { if (sortSpecs->Specs->ColumnUserID == ImGui::GetID("name")) { if (sortSpecs->Specs->SortDirection == ImGuiSortDirection_Ascending) - return left->getVariableName() > right->getVariableName(); + return left->getDisplayName() > right->getDisplayName(); else - return left->getVariableName() < right->getVariableName(); + return left->getDisplayName() < right->getDisplayName(); } else if (sortSpecs->Specs->ColumnUserID == ImGui::GetID("offset")) { if (sortSpecs->Specs->SortDirection == ImGuiSortDirection_Ascending) @@ -208,16 +216,16 @@ namespace hex::pl { } protected: - void createDefaultEntry(const std::string &value) const { + void createDefaultEntry(const std::string &value, const Token::Literal &literal) const { ImGui::TableNextRow(); - ImGui::TreeNodeEx(this->getVariableName().c_str(), ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); + ImGui::TreeNodeEx(this->getDisplayName().c_str(), ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); ImGui::TableNextColumn(); if (ImGui::Selectable(("##PatternDataLine"s + std::to_string(this->getOffset())).c_str(), false, ImGuiSelectableFlags_SpanAllColumns | ImGuiSelectableFlags_AllowItemOverlap)) { EventManager::post(Region { this->getOffset(), this->getSize() }); } this->drawCommentTooltip(); ImGui::SameLine(); - ImGui::Text("%s", this->getVariableName().c_str()); + ImGui::Text("%s", this->getDisplayName().c_str()); ImGui::TableNextColumn(); ImGui::ColorButton("color", ImColor(this->getColor()), ImGuiColorEditFlags_NoTooltip, ImVec2(ImGui::GetColumnWidth(), ImGui::GetTextLineHeight())); ImGui::TableNextColumn(); @@ -227,7 +235,20 @@ namespace hex::pl { ImGui::TableNextColumn(); ImGui::TextColored(ImColor(0xFF9BC64D), "%s", this->getFormattedName().c_str()); ImGui::TableNextColumn(); - ImGui::Text("%s", value.c_str()); + + if (!this->m_formatterFunction.has_value()) + ImGui::Text("%s", value.c_str()); + else { + auto &[func, evaluator] = this->m_formatterFunction.value(); + auto result = func.func(evaluator, { literal }); + + if (result.has_value()) { + if (auto displayValue = std::get_if(&result.value()); displayValue != nullptr) + ImGui::Text("%s", displayValue->c_str()); + } else { + ImGui::Text("???"); + } + } } void drawCommentTooltip() const { @@ -248,10 +269,13 @@ namespace hex::pl { size_t m_size; u32 m_color; + std::optional m_displayName; std::string m_variableName; std::optional m_comment; std::string m_typeName; + std::optional> m_formatterFunction; + PatternData *m_parent; bool m_local = false; }; @@ -299,7 +323,7 @@ namespace hex::pl { ImGui::TableNextRow(); ImGui::TableNextColumn(); - bool open = ImGui::TreeNodeEx(this->getVariableName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); + bool open = ImGui::TreeNodeEx(this->getDisplayName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); this->drawCommentTooltip(); ImGui::TableNextColumn(); ImGui::ColorButton("color", ImColor(this->getColor()), ImGuiColorEditFlags_NoTooltip, ImVec2(ImGui::GetColumnWidth(), ImGui::GetTextLineHeight())); @@ -354,7 +378,7 @@ namespace hex::pl { void setPointedAtPattern(PatternData *pattern) { this->m_pointedAt = pattern; - this->m_pointedAt->setVariableName("*" + this->getVariableName()); + this->m_pointedAt->setVariableName("*" + this->getDisplayName()); } [[nodiscard]] PatternData* getPointedAtPattern() { @@ -380,11 +404,11 @@ namespace hex::pl { } void createEntry(prv::Provider* &provider) override { - u64 data = 0; + u128 data = 0; provider->read(this->getOffset(), &data, this->getSize()); data = hex::changeEndianess(data, this->getSize(), this->getEndian()); - this->createDefaultEntry(hex::format("{:d} (0x{:0{}X})", data, data, this->getSize() * 2)); + this->createDefaultEntry(hex::format("{:d} (0x{:0{}X})", data, data, this->getSize() * 2), data); } [[nodiscard]] std::string getFormattedName() const override { @@ -411,42 +435,12 @@ namespace hex::pl { } void createEntry(prv::Provider* &provider) override { - u128 data = 0; + s128 data = 0; provider->read(this->getOffset(), &data, this->getSize()); data = hex::changeEndianess(data, this->getSize(), this->getEndian()); - switch (this->getSize()) { - case 1: { - s8 signedData; - std::memcpy(&signedData, &data, 1); - this->createDefaultEntry(hex::format("{:d} (0x{:0{}X})", signedData, data, 1 * 2)); - } - break; - case 2: { - s16 signedData; - std::memcpy(&signedData, &data, 2); - this->createDefaultEntry(hex::format("{:d} (0x{:0{}X})", signedData, data, 2 * 2)); - } - break; - case 4: { - s32 signedData; - std::memcpy(&signedData, &data, 4); - this->createDefaultEntry(hex::format("{:d} (0x{:0{}X})", signedData, data, 4 * 2)); - } - break; - case 8: { - s64 signedData; - std::memcpy(&signedData, &data, 8); - this->createDefaultEntry(hex::format("{:d} (0x{:0{}X})", signedData, data, 8 * 2)); - } - break; - case 16: { - s128 signedData; - std::memcpy(&signedData, &data, 16); - this->createDefaultEntry(hex::format("{:d} (0x{:0{}X})", signedData, data, 16 * 2)); - } - break; - } + data = hex::signExtend(this->getSize() * 8, data); + this->createDefaultEntry(hex::format("{:d} (0x{:0{}X})", data, data, 1 * 2), data); } [[nodiscard]] std::string getFormattedName() const override { @@ -478,13 +472,13 @@ namespace hex::pl { provider->read(this->getOffset(), &data, 4); data = hex::changeEndianess(data, 4, this->getEndian()); - this->createDefaultEntry(hex::format("{:e} (0x{:0{}X})", *reinterpret_cast(&data), data, this->getSize() * 2)); + this->createDefaultEntry(hex::format("{:e} (0x{:0{}X})", *reinterpret_cast(&data), data, this->getSize() * 2), *reinterpret_cast(&data)); } else if (this->getSize() == 8) { u64 data = 0; provider->read(this->getOffset(), &data, 8); data = hex::changeEndianess(data, 8, this->getEndian()); - this->createDefaultEntry(hex::format("{:e} (0x{:0{}X})", *reinterpret_cast(&data), data, this->getSize() * 2)); + this->createDefaultEntry(hex::format("{:e} (0x{:0{}X})", *reinterpret_cast(&data), data, this->getSize() * 2), *reinterpret_cast(&data)); } } @@ -513,11 +507,11 @@ namespace hex::pl { provider->read(this->getOffset(), &boolean, 1); if (boolean == 0) - this->createDefaultEntry("false"); + this->createDefaultEntry("false", false); else if (boolean == 1) - this->createDefaultEntry("true"); + this->createDefaultEntry("true", true); else - this->createDefaultEntry("true*"); + this->createDefaultEntry("true*", true); } [[nodiscard]] std::string getFormattedName() const override { @@ -540,7 +534,7 @@ namespace hex::pl { char character; provider->read(this->getOffset(), &character, 1); - this->createDefaultEntry(hex::format("'{0}'", character)); + this->createDefaultEntry(hex::format("'{0}'", character), character); } [[nodiscard]] std::string getFormattedName() const override { @@ -564,7 +558,8 @@ namespace hex::pl { provider->read(this->getOffset(), &character, 2); character = hex::changeEndianess(character, this->getEndian()); - this->createDefaultEntry(hex::format("'{0}'", std::wstring_convert, char16_t>{}.to_bytes(character))); + u128 literal = character; + this->createDefaultEntry(hex::format("'{0}'", std::wstring_convert, char16_t>{}.to_bytes(character)), literal); } [[nodiscard]] std::string getFormattedName() const override { @@ -587,7 +582,7 @@ namespace hex::pl { std::string buffer(this->getSize(), 0x00); provider->read(this->getOffset(), buffer.data(), this->getSize()); - this->createDefaultEntry(hex::format("\"{0}\"", makeDisplayable(buffer.data(), this->getSize()).c_str())); + this->createDefaultEntry(hex::format("\"{0}\"", makeDisplayable(buffer.data(), this->getSize()).c_str()), buffer); } [[nodiscard]] std::string getFormattedName() const override { @@ -615,7 +610,7 @@ namespace hex::pl { auto utf8String = std::wstring_convert, char16_t>{}.to_bytes(buffer); - this->createDefaultEntry(hex::format("\"{0}\"", utf8String)) ; + this->createDefaultEntry(hex::format("\"{0}\"", utf8String), utf8String); } [[nodiscard]] std::string getFormattedName() const override { @@ -662,7 +657,7 @@ namespace hex::pl { ImGui::TableNextRow(); ImGui::TableNextColumn(); - bool open = ImGui::TreeNodeEx(this->getVariableName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); + bool open = ImGui::TreeNodeEx(this->getDisplayName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); this->drawCommentTooltip(); ImGui::TableNextColumn(); ImGui::ColorButton("color", ImColor(this->getColor()), ImGuiColorEditFlags_NoTooltip, ImVec2(ImGui::GetColumnWidth(), ImGui::GetTextLineHeight())); @@ -777,7 +772,7 @@ namespace hex::pl { ImGui::TableNextRow(); ImGui::TableNextColumn(); - bool open = ImGui::TreeNodeEx(this->getVariableName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); + bool open = ImGui::TreeNodeEx(this->getDisplayName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); this->drawCommentTooltip(); ImGui::TableNextColumn(); ImGui::ColorButton("color", ImColor(this->getColor()), ImGuiColorEditFlags_NoTooltip, ImVec2(ImGui::GetColumnWidth(), ImGui::GetTextLineHeight())); @@ -866,7 +861,8 @@ namespace hex::pl { this->m_template = templ; this->m_entryCount = count; - this->m_template->setColor(this->getColor()); + this->setColor(this->m_template->getColor()); + this->m_template->setEndian(templ->getEndian()); this->m_template->setParent(this); } @@ -914,7 +910,7 @@ namespace hex::pl { void createEntry(prv::Provider* &provider) override { ImGui::TableNextRow(); ImGui::TableNextColumn(); - bool open = ImGui::TreeNodeEx(this->getVariableName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth); + bool open = ImGui::TreeNodeEx(this->getDisplayName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth); this->drawCommentTooltip(); ImGui::TableNextColumn(); ImGui::TableNextColumn(); @@ -1045,7 +1041,7 @@ namespace hex::pl { void createEntry(prv::Provider* &provider) override { ImGui::TableNextRow(); ImGui::TableNextColumn(); - bool open = ImGui::TreeNodeEx(this->getVariableName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); + bool open = ImGui::TreeNodeEx(this->getDisplayName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); this->drawCommentTooltip(); ImGui::TableNextColumn(); ImGui::TableNextColumn(); @@ -1163,14 +1159,18 @@ namespace hex::pl { bool foundValue = false; for (auto &[entryValueLiteral, entryName] : this->m_enumValues) { - bool matches = std::visit([&, name = entryName](auto &&entryValue) { - if (value == entryValue) { - valueString += name; - foundValue = true; - return true; - } + bool matches = std::visit(overloaded { + [&, name = entryName](auto &&entryValue) { + if (value == entryValue) { + valueString += name; + foundValue = true; + return true; + } - return false; + return false; + }, + [](std::string) { return false; }, + [](PatternData*) { return false; } }, entryValueLiteral); if (matches) break; @@ -1180,14 +1180,14 @@ namespace hex::pl { valueString += "???"; ImGui::TableNextRow(); - ImGui::TreeNodeEx(this->getVariableName().c_str(), ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); + ImGui::TreeNodeEx(this->getDisplayName().c_str(), ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); this->drawCommentTooltip(); ImGui::TableNextColumn(); if (ImGui::Selectable(("##PatternDataLine"s + std::to_string(this->getOffset())).c_str(), false, ImGuiSelectableFlags_SpanAllColumns)) { EventManager::post(Region { this->getOffset(), this->getSize() }); } ImGui::SameLine(); - ImGui::Text("%s", this->getVariableName().c_str()); + ImGui::Text("%s", this->getDisplayName().c_str()); ImGui::TableNextColumn(); ImGui::ColorButton("color", ImColor(this->getColor()), ImGuiColorEditFlags_NoTooltip, ImVec2(ImGui::GetColumnWidth(), ImGui::GetTextLineHeight())); ImGui::TableNextColumn(); @@ -1208,7 +1208,7 @@ namespace hex::pl { return this->m_enumValues; } - void setEnumValues(const std::vector> &enumValues) { + void setEnumValues(const std::vector> &enumValues) { this->m_enumValues = enumValues; } @@ -1229,7 +1229,7 @@ namespace hex::pl { } private: - std::vector> m_enumValues; + std::vector> m_enumValues; }; @@ -1252,9 +1252,9 @@ namespace hex::pl { std::reverse(value.begin(), value.end()); ImGui::TableNextRow(); - ImGui::TreeNodeEx(this->getVariableName().c_str(), ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); + ImGui::TreeNodeEx(this->getDisplayName().c_str(), ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); ImGui::TableNextColumn(); - ImGui::Text("%s", this->getVariableName().c_str()); + ImGui::Text("%s", this->getDisplayName().c_str()); ImGui::TableNextColumn(); ImGui::ColorButton("color", ImColor(this->getColor()), ImGuiColorEditFlags_NoTooltip, ImVec2(ImGui::GetColumnWidth(), ImGui::GetTextLineHeight())); ImGui::TableNextColumn(); @@ -1309,6 +1309,11 @@ namespace hex::pl { } + PatternDataBitfield(const PatternDataBitfield &other) : PatternData(other) { + for (auto &field : other.m_fields) + this->m_fields.push_back(field->clone()); + } + ~PatternDataBitfield() override { for (auto field : this->m_fields) delete field; @@ -1327,7 +1332,7 @@ namespace hex::pl { ImGui::TableNextRow(); ImGui::TableNextColumn(); - bool open = ImGui::TreeNodeEx(this->getVariableName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); + bool open = ImGui::TreeNodeEx(this->getDisplayName().c_str(), ImGuiTreeNodeFlags_SpanFullWidth | ImGuiTreeNodeFlags_AllowItemOverlap); this->drawCommentTooltip(); ImGui::TableNextColumn(); ImGui::TableNextColumn(); diff --git a/plugins/libimhex/include/hex/pattern_language/pattern_language.hpp b/plugins/libimhex/include/hex/pattern_language/pattern_language.hpp index e3649a051..0b1814efe 100644 --- a/plugins/libimhex/include/hex/pattern_language/pattern_language.hpp +++ b/plugins/libimhex/include/hex/pattern_language/pattern_language.hpp @@ -21,6 +21,8 @@ namespace hex::pl { class Evaluator; class PatternData; + class ASTNode; + class PatternLanguage { public: PatternLanguage(); @@ -39,6 +41,8 @@ namespace hex::pl { Validator *m_validator; Evaluator *m_evaluator; + std::vector m_currAST; + prv::Provider *m_provider = nullptr; std::endian m_defaultEndian = std::endian::native; u32 m_recursionLimit = 32; diff --git a/plugins/libimhex/include/hex/pattern_language/token.hpp b/plugins/libimhex/include/hex/pattern_language/token.hpp index ac94a0976..990b0f4ca 100644 --- a/plugins/libimhex/include/hex/pattern_language/token.hpp +++ b/plugins/libimhex/include/hex/pattern_language/token.hpp @@ -6,8 +6,12 @@ #include #include +#include + namespace hex::pl { + class PatternData; + class Token { public: enum class Type : u64 { @@ -31,6 +35,7 @@ namespace hex::pl { If, Else, Parent, + This, While, Function, Return, @@ -85,6 +90,7 @@ namespace hex::pl { Boolean = 0x14, Float = 0x42, Double = 0x82, + String = 0x15, CustomType = 0x00, Padding = 0x1F, @@ -108,8 +114,21 @@ namespace hex::pl { EndOfProgram }; - using IntegerLiteral = std::variant; - using ValueTypes = std::variant; + struct Identifier { + explicit Identifier(std::string identifier) : m_identifier(std::move(identifier)) { } + + [[nodiscard]] + const std::string &get() const { return this->m_identifier; } + + auto operator<=>(const Identifier&) const = default; + bool operator==(const Identifier&) const = default; + + private: + std::string m_identifier; + }; + + using Literal = std::variant; + using ValueTypes = std::variant; Token(Type type, auto value, u32 lineNumber) : type(type), value(value), lineNumber(lineNumber) { @@ -131,6 +150,58 @@ namespace hex::pl { return static_cast(type) >> 4; } + static u128 literalToUnsigned(const pl::Token::Literal &literal) { + return std::visit(overloaded { + [](std::string) -> u128 { throw std::string("expected integral type, got string"); }, + [](PatternData*) -> u128 { throw std::string("expected integral type, got custom type"); }, + [](auto &&value) -> u128 { return value; } + }, + literal); + } + + static s128 literalToSigned(const pl::Token::Literal &literal) { + return std::visit(overloaded { + [](std::string) -> s128 { throw std::string("expected integral type, got string"); }, + [](PatternData*) -> s128 { throw std::string("expected integral type, got custom type"); }, + [](auto &&value) -> s128 { return value; } + }, + literal); + } + + static double literalToFloatingPoint(const pl::Token::Literal &literal) { + return std::visit(overloaded { + [](std::string) -> double { throw std::string("expected integral type, got string"); }, + [](PatternData*) -> double { throw std::string("expected integral type, got custom type"); }, + [](auto &&value) -> double { return value; } + }, + literal); + } + + static bool literalToBoolean(const pl::Token::Literal &literal) { + return std::visit(overloaded { + [](std::string) -> bool { throw std::string("expected integral type, got string"); }, + [](PatternData*) -> bool { throw std::string("expected integral type, got custom type"); }, + [](auto &&value) -> bool { return value != 0; } + }, + literal); + } + + static std::string literalToString(const pl::Token::Literal &literal, bool cast) { + if (!cast && std::get_if(&literal) == nullptr) + throw std::string("expected string type, got integral"); + + return std::visit(overloaded { + [](std::string value) -> std::string { return value; }, + [](u128 value) -> std::string { return std::to_string(u64(value)); }, + [](s128 value) -> std::string { return std::to_string(s64(value)); }, + [](bool value) -> std::string { return value ? "true" : "false"; }, + [](char value) -> std::string { return std::string() + value; }, + [](PatternData*) -> std::string { throw std::string("expected integral type, got custom type"); }, + [](auto &&value) -> std::string { return std::to_string(value); } + }, + literal); + } + [[nodiscard]] constexpr static auto getTypeName(const pl::Token::ValueType type) { switch (type) { case ValueType::Signed8Bit: return "s8"; @@ -147,6 +218,7 @@ namespace hex::pl { case ValueType::Double: return "double"; case ValueType::Character: return "char"; case ValueType::Character16: return "char16"; + case ValueType::Padding: return "padding"; default: return "< ??? >"; } } @@ -204,14 +276,15 @@ namespace hex::pl { #define KEYWORD_IF COMPONENT(Keyword, If) #define KEYWORD_ELSE COMPONENT(Keyword, Else) #define KEYWORD_PARENT COMPONENT(Keyword, Parent) +#define KEYWORD_THIS COMPONENT(Keyword, This) #define KEYWORD_WHILE COMPONENT(Keyword, While) #define KEYWORD_FUNCTION COMPONENT(Keyword, Function) #define KEYWORD_RETURN COMPONENT(Keyword, Return) #define KEYWORD_NAMESPACE COMPONENT(Keyword, Namespace) -#define INTEGER hex::pl::Token::Type::Integer, hex::pl::Token::IntegerLiteral(u64(0)) +#define INTEGER hex::pl::Token::Type::Integer, hex::pl::Token::Literal(u128(0)) #define IDENTIFIER hex::pl::Token::Type::Identifier, "" -#define STRING hex::pl::Token::Type::String, "" +#define STRING hex::pl::Token::Type::String, hex::pl::Token::Literal("") #define OPERATOR_AT COMPONENT(Operator, AtDeclaration) #define OPERATOR_ASSIGNMENT COMPONENT(Operator, Assignment) diff --git a/plugins/libimhex/source/api/content_registry.cpp b/plugins/libimhex/source/api/content_registry.cpp index 23da41cc4..8e30cbafa 100644 --- a/plugins/libimhex/source/api/content_registry.cpp +++ b/plugins/libimhex/source/api/content_registry.cpp @@ -167,7 +167,7 @@ namespace hex { /* Pattern Language Functions */ - void ContentRegistry::PatternLanguageFunctions::add(const Namespace &ns, const std::string &name, u32 parameterCount, const std::function&)> &func) { + void ContentRegistry::PatternLanguageFunctions::add(const Namespace &ns, const std::string &name, u32 parameterCount, const ContentRegistry::PatternLanguageFunctions::Callback &func) { std::string functionName; for (auto &scope : ns) functionName += scope + "::"; diff --git a/plugins/libimhex/source/pattern_language/evaluator.cpp b/plugins/libimhex/source/pattern_language/evaluator.cpp index 8919132f3..241c8efd1 100644 --- a/plugins/libimhex/source/pattern_language/evaluator.cpp +++ b/plugins/libimhex/source/pattern_language/evaluator.cpp @@ -1,1309 +1,123 @@ #include - -#include -#include - -#include -#include - -#include -#include - -#include +#include namespace hex::pl { - ASTNodeIntegerLiteral* Evaluator::evaluateScopeResolution(ASTNodeScopeResolution *node) { - ASTNode *currScope = nullptr; - for (const auto &identifier : node->getPath()) { - if (currScope == nullptr) { - if (!this->m_types.contains(identifier)) - break; - - currScope = this->m_types[identifier.data()]; - } else if (auto enumNode = dynamic_cast(currScope); enumNode != nullptr) { - if (!enumNode->getEntries().contains(identifier)) - break; - else - return evaluateMathematicalExpression(static_cast(enumNode->getEntries().at(identifier))); + void Evaluator::createVariable(const std::string &name, ASTNode *type) { + auto &variables = *this->getScope(0).scope; + for (auto &variable : variables) { + if (variable->getVariableName() == name) { + LogConsole::abortEvaluation(hex::format("variable with name '{}' already exists", name)); } } - this->getConsole().abortEvaluation("failed to find identifier"); - } + auto pattern = type->createPatterns(this).front(); - PatternData* Evaluator::findPattern(std::vector currMembers, const ASTNodeRValue::Path &path) { - PatternData *currPattern = nullptr; - for (const auto &part : path) { - if (auto stringPart = std::get_if(&part); stringPart != nullptr) { - if (*stringPart == "parent") { - if (currPattern == nullptr) { - if (!this->m_currMemberScope.empty()) - currPattern = this->m_currMemberScope.back(); - - if (currPattern == nullptr) - this->getConsole().abortEvaluation("attempted to get parent of global namespace"); - } - - auto parent = currPattern->getParent(); - - if (parent == nullptr) { - this->getConsole().abortEvaluation("no parent available for identifier"); - } else { - currPattern = parent; - } - } else { - if (currPattern != nullptr) { - if (auto structPattern = dynamic_cast(currPattern); structPattern != nullptr) - currMembers = structPattern->getMembers(); - else if (auto unionPattern = dynamic_cast(currPattern); unionPattern != nullptr) - currMembers = unionPattern->getMembers(); - else if (auto bitfieldPattern = dynamic_cast(currPattern); bitfieldPattern != nullptr) { - currMembers = bitfieldPattern->getFields(); - } - else if (auto dynamicArrayPattern = dynamic_cast(currPattern); dynamicArrayPattern != nullptr) { - currMembers = dynamicArrayPattern->getEntries(); - continue; - } - else if (auto staticArrayPattern = dynamic_cast(currPattern); staticArrayPattern != nullptr) { - currMembers = { staticArrayPattern->getTemplate() }; - continue; - } - else - this->getConsole().abortEvaluation("tried to access member of a non-struct/union type"); - } - - auto candidate = std::find_if(currMembers.begin(), currMembers.end(), [&](auto member) { - return member->getVariableName() == *stringPart; - }); - - if (candidate != currMembers.end()) - currPattern = *candidate; - else - return nullptr; - } - } else if (auto nodePart = std::get_if(&part); nodePart != nullptr) { - if (auto numericalExpressionNode = dynamic_cast(*nodePart)) { - auto arrayIndexNode = evaluateMathematicalExpression(numericalExpressionNode); - ON_SCOPE_EXIT { delete arrayIndexNode; }; - - if (currPattern != nullptr) { - if (auto dynamicArrayPattern = dynamic_cast(currPattern); dynamicArrayPattern != nullptr) { - std::visit([this](auto &&arrayIndex) { - if (std::is_floating_point_v) - this->getConsole().abortEvaluation("cannot use float to index into array"); - }, arrayIndexNode->getValue()); - - std::visit([&](auto &&arrayIndex){ - if (arrayIndex >= 0 && arrayIndex < dynamicArrayPattern->getEntries().size()) - currPattern = dynamicArrayPattern->getEntries()[arrayIndex]; - else - this->getConsole().abortEvaluation(hex::format("tried to access out of bounds index {} of '{}'", arrayIndex, currPattern->getVariableName())); - }, arrayIndexNode->getValue()); - - } else if (auto staticArrayPattern = dynamic_cast(currPattern); staticArrayPattern != nullptr) { - std::visit([this](auto &&arrayIndex) { - if (std::is_floating_point_v) - this->getConsole().abortEvaluation("cannot use float to index into array"); - }, arrayIndexNode->getValue()); - - std::visit([&](auto &&arrayIndex){ - if (arrayIndex >= 0 && arrayIndex < staticArrayPattern->getEntryCount()) { - currPattern = staticArrayPattern->getTemplate(); - currPattern->setOffset(staticArrayPattern->getOffset() + arrayIndex * staticArrayPattern->getSize()); - } - else - this->getConsole().abortEvaluation(hex::format("tried to access out of bounds index {} of '{}'", arrayIndex, currPattern->getVariableName())); - }, arrayIndexNode->getValue()); - - } - else - this->getConsole().abortEvaluation("tried to index into non-array type"); - } - } else { - this->getConsole().abortEvaluation(hex::format("invalid node in rvalue path. This is a bug!'")); - } - } - - if (auto pointerPattern = dynamic_cast(currPattern); pointerPattern != nullptr) - currPattern = pointerPattern->getPointedAtPattern(); - } - - return currPattern; - } - - PatternData* Evaluator::patternFromName(const ASTNodeRValue::Path &path) { - - PatternData *currPattern = nullptr; - - // Local variable access - if (!this->m_localVariables.empty()) - currPattern = this->findPattern(*this->m_localVariables.back(), path); - - // If no local variable was found try local structure members - if (!this->m_currMembers.empty()) { - currPattern = this->findPattern(*this->m_currMembers.back(), path); - } - - // If no local member was found, try globally - if (currPattern == nullptr) { - currPattern = this->findPattern(this->m_globalMembers, path); - } - - // If still no pattern was found, the path is invalid - if (currPattern == nullptr) { - std::string identifier; - for (const auto& part : path) { - if (part.index() == 0) { - // Path part is a identifier - identifier += std::get(part); - } else if (part.index() == 1) { - // Path part is a array index - identifier += "[..]"; - } - - identifier += "."; - } - identifier.pop_back(); - this->getConsole().abortEvaluation(hex::format("no identifier with name '{}' found", identifier)); - } - - return currPattern; - } - - ASTNodeIntegerLiteral* Evaluator::evaluateRValue(ASTNodeRValue *node) { - if (node->getPath().size() == 1) { - if (auto part = std::get_if(&node->getPath()[0]); part != nullptr && *part == "$") - return new ASTNodeIntegerLiteral(this->m_currOffset); - } - - auto currPattern = this->patternFromName(node->getPath()); - - if (auto unsignedPattern = dynamic_cast(currPattern); unsignedPattern != nullptr) { - - u8 value[unsignedPattern->getSize()]; - if (currPattern->isLocal()) - std::memcpy(value, this->m_localStack.data() + unsignedPattern->getOffset(), unsignedPattern->getSize()); - else - this->m_provider->read(unsignedPattern->getOffset(), value, unsignedPattern->getSize()); - - switch (unsignedPattern->getSize()) { - case 1: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 1, unsignedPattern->getEndian())); - case 2: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 2, unsignedPattern->getEndian())); - case 4: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 4, unsignedPattern->getEndian())); - case 8: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 8, unsignedPattern->getEndian())); - case 16: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 16, unsignedPattern->getEndian())); - default: this->getConsole().abortEvaluation("invalid rvalue size"); - } - } else if (auto signedPattern = dynamic_cast(currPattern); signedPattern != nullptr) { - u8 value[signedPattern->getSize()]; - if (currPattern->isLocal()) - std::memcpy(value, this->m_localStack.data() + signedPattern->getOffset(), signedPattern->getSize()); - else - this->m_provider->read(signedPattern->getOffset(), value, signedPattern->getSize()); - - switch (signedPattern->getSize()) { - case 1: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 1, signedPattern->getEndian())); - case 2: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 2, signedPattern->getEndian())); - case 4: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 4, signedPattern->getEndian())); - case 8: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 8, signedPattern->getEndian())); - case 16: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 16, signedPattern->getEndian())); - default: this->getConsole().abortEvaluation("invalid rvalue size"); - } - } else if (auto boolPattern = dynamic_cast(currPattern); boolPattern != nullptr) { - u8 value[boolPattern->getSize()]; - if (currPattern->isLocal()) - std::memcpy(value, this->m_localStack.data() + boolPattern->getOffset(), boolPattern->getSize()); - else - this->m_provider->read(boolPattern->getOffset(), value, boolPattern->getSize()); - - return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 1, boolPattern->getEndian())); - } else if (auto charPattern = dynamic_cast(currPattern); charPattern != nullptr) { - u8 value[charPattern->getSize()]; - if (currPattern->isLocal()) - std::memcpy(value, this->m_localStack.data() + charPattern->getOffset(), charPattern->getSize()); - else - this->m_provider->read(charPattern->getOffset(), value, charPattern->getSize()); - - return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 1, charPattern->getEndian())); - } else if (auto char16Pattern = dynamic_cast(currPattern); char16Pattern != nullptr) { - u8 value[char16Pattern->getSize()]; - if (currPattern->isLocal()) - std::memcpy(value, this->m_localStack.data() + char16Pattern->getOffset(), char16Pattern->getSize()); - else - this->m_provider->read(char16Pattern->getOffset(), value, char16Pattern->getSize()); - - return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 1, char16Pattern->getEndian())); - } else if (auto enumPattern = dynamic_cast(currPattern); enumPattern != nullptr) { - u8 value[enumPattern->getSize()]; - if (currPattern->isLocal()) - std::memcpy(value, this->m_localStack.data() + enumPattern->getOffset(), enumPattern->getSize()); - else - this->m_provider->read(enumPattern->getOffset(), value, enumPattern->getSize()); - - switch (enumPattern->getSize()) { - case 1: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 1, enumPattern->getEndian())); - case 2: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 2, enumPattern->getEndian())); - case 4: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 4, enumPattern->getEndian())); - case 8: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 8, enumPattern->getEndian())); - case 16: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 16, enumPattern->getEndian())); - default: this->getConsole().abortEvaluation("invalid rvalue size"); - } - } else if (auto bitfieldFieldPattern = dynamic_cast(currPattern); bitfieldFieldPattern != nullptr) { - std::vector value(bitfieldFieldPattern->getSize()); - if (currPattern->isLocal()) - std::memcpy(value.data(), this->m_localStack.data() + bitfieldFieldPattern->getOffset(), value.size()); - else - this->m_provider->read(bitfieldFieldPattern->getOffset(), value.data(), value.size()); - - return new ASTNodeIntegerLiteral(hex::extract(bitfieldFieldPattern->getBitOffset() + (bitfieldFieldPattern->getBitSize() - 1), bitfieldFieldPattern->getBitOffset(), value)); - } else - this->getConsole().abortEvaluation("tried to use non-integer value in numeric expression"); - } - - ASTNode* Evaluator::evaluateFunctionCall(ASTNodeFunctionCall *node) { - std::vector evaluatedParams; - ON_SCOPE_EXIT { - for (auto ¶m : evaluatedParams) - delete param; - }; - - for (auto ¶m : node->getParams()) { - if (auto numericExpression = dynamic_cast(param); numericExpression != nullptr) - evaluatedParams.push_back(this->evaluateMathematicalExpression(numericExpression)); - else if (auto typeOperatorExpression = dynamic_cast(param); typeOperatorExpression != nullptr) - evaluatedParams.push_back(this->evaluateTypeOperator(typeOperatorExpression)); - else if (auto stringLiteral = dynamic_cast(param); stringLiteral != nullptr) - evaluatedParams.push_back(stringLiteral->clone()); - } - - ContentRegistry::PatternLanguageFunctions::Function *function; - if (this->m_definedFunctions.contains(node->getFunctionName().data())) - function = &this->m_definedFunctions[node->getFunctionName().data()]; - else if (ContentRegistry::PatternLanguageFunctions::getEntries().contains(node->getFunctionName().data())) - function = &ContentRegistry::PatternLanguageFunctions::getEntries()[node->getFunctionName().data()]; - else - this->getConsole().abortEvaluation(hex::format("no function named '{0}' found", node->getFunctionName().data())); - - if (function->parameterCount == ContentRegistry::PatternLanguageFunctions::UnlimitedParameters) { - ; // Don't check parameter count - } - else if (function->parameterCount & ContentRegistry::PatternLanguageFunctions::LessParametersThan) { - if (evaluatedParams.size() >= (function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan)) - this->getConsole().abortEvaluation(hex::format("too many parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan)); - } else if (function->parameterCount & ContentRegistry::PatternLanguageFunctions::MoreParametersThan) { - if (evaluatedParams.size() <= (function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan)) - this->getConsole().abortEvaluation(hex::format("too few parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan)); - } else if (function->parameterCount != evaluatedParams.size()) { - this->getConsole().abortEvaluation(hex::format("invalid number of parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function->parameterCount)); - } - - return function->func(*this, evaluatedParams); - } - - ASTNodeIntegerLiteral* Evaluator::evaluateTypeOperator(ASTNodeTypeOperator *typeOperatorNode) { - if (auto rvalue = dynamic_cast(typeOperatorNode->getExpression()); rvalue != nullptr) { - auto pattern = this->patternFromName(rvalue->getPath()); - - switch (typeOperatorNode->getOperator()) { - case Token::Operator::AddressOf: - return new ASTNodeIntegerLiteral(static_cast(pattern->getOffset())); - case Token::Operator::SizeOf: - return new ASTNodeIntegerLiteral(static_cast(pattern->getSize())); - default: - this->getConsole().abortEvaluation("invalid type operator used. This is a bug!"); - } - } else { - this->getConsole().abortEvaluation("non-rvalue used in type operator"); - } - } - -#define FLOAT_BIT_OPERATION(name) \ - auto name(hex::floating_point auto left, auto right) { throw std::runtime_error(""); return 0; } \ - auto name(auto left, hex::floating_point auto right) { throw std::runtime_error(""); return 0; } \ - auto name(hex::floating_point auto left, hex::floating_point auto right) { throw std::runtime_error(""); return 0; } \ - auto name(hex::integral auto left, hex::integral auto right) - - namespace { - - FLOAT_BIT_OPERATION(shiftLeft) { - return left << right; - } - - FLOAT_BIT_OPERATION(shiftRight) { - return left >> right; - } - - FLOAT_BIT_OPERATION(bitAnd) { - return left & right; - } - - FLOAT_BIT_OPERATION(bitOr) { - return left | right; - } - - FLOAT_BIT_OPERATION(bitXor) { - return left ^ right; - } - - FLOAT_BIT_OPERATION(bitNot) { - return ~right; - } - - FLOAT_BIT_OPERATION(modulus) { - return left % right; - } - - } - - ASTNodeIntegerLiteral* Evaluator::evaluateOperator(ASTNodeIntegerLiteral *left, ASTNodeIntegerLiteral *right, Token::Operator op) { - try { - return std::visit([&](auto &&leftValue, auto &&rightValue) -> ASTNodeIntegerLiteral * { - switch (op) { - case Token::Operator::Plus: - return new ASTNodeIntegerLiteral(leftValue + rightValue); - case Token::Operator::Minus: - return new ASTNodeIntegerLiteral(leftValue - rightValue); - case Token::Operator::Star: - return new ASTNodeIntegerLiteral(leftValue * rightValue); - case Token::Operator::Slash: - if (rightValue == 0) - this->getConsole().abortEvaluation("Division by zero"); - return new ASTNodeIntegerLiteral(leftValue / rightValue); - case Token::Operator::Percent: - if (rightValue == 0) - this->getConsole().abortEvaluation("Division by zero"); - return new ASTNodeIntegerLiteral(modulus(leftValue, rightValue)); - case Token::Operator::ShiftLeft: - return new ASTNodeIntegerLiteral(shiftLeft(leftValue, rightValue)); - case Token::Operator::ShiftRight: - return new ASTNodeIntegerLiteral(shiftRight(leftValue, rightValue)); - case Token::Operator::BitAnd: - return new ASTNodeIntegerLiteral(bitAnd(leftValue, rightValue)); - case Token::Operator::BitXor: - return new ASTNodeIntegerLiteral(bitXor(leftValue, rightValue)); - case Token::Operator::BitOr: - return new ASTNodeIntegerLiteral(bitOr(leftValue, rightValue)); - case Token::Operator::BitNot: - return new ASTNodeIntegerLiteral(bitNot(leftValue, rightValue)); - case Token::Operator::BoolEquals: - return new ASTNodeIntegerLiteral(leftValue == rightValue); - case Token::Operator::BoolNotEquals: - return new ASTNodeIntegerLiteral(leftValue != rightValue); - case Token::Operator::BoolGreaterThan: - return new ASTNodeIntegerLiteral(leftValue > rightValue); - case Token::Operator::BoolLessThan: - return new ASTNodeIntegerLiteral(leftValue < rightValue); - case Token::Operator::BoolGreaterThanOrEquals: - return new ASTNodeIntegerLiteral(leftValue >= rightValue); - case Token::Operator::BoolLessThanOrEquals: - return new ASTNodeIntegerLiteral(leftValue <= rightValue); - case Token::Operator::BoolAnd: - return new ASTNodeIntegerLiteral(leftValue && rightValue); - case Token::Operator::BoolXor: - return new ASTNodeIntegerLiteral(leftValue && !rightValue || !leftValue && rightValue); - case Token::Operator::BoolOr: - return new ASTNodeIntegerLiteral(leftValue || rightValue); - case Token::Operator::BoolNot: - return new ASTNodeIntegerLiteral(!rightValue); - default: - this->getConsole().abortEvaluation("invalid operator used in mathematical expression"); - } - - }, left->getValue(), right->getValue()); - } catch (std::runtime_error &e) { - this->getConsole().abortEvaluation("bitwise operations on floating point numbers are forbidden"); - } - } - - ASTNodeIntegerLiteral* Evaluator::evaluateOperand(ASTNode *node) { - if (auto exprLiteral = dynamic_cast(node); exprLiteral != nullptr) - return exprLiteral; - else if (auto exprExpression = dynamic_cast(node); exprExpression != nullptr) - return evaluateMathematicalExpression(exprExpression); - else if (auto exprRvalue = dynamic_cast(node); exprRvalue != nullptr) - return evaluateRValue(exprRvalue); - else if (auto exprScopeResolution = dynamic_cast(node); exprScopeResolution != nullptr) - return evaluateScopeResolution(exprScopeResolution); - else if (auto exprTernary = dynamic_cast(node); exprTernary != nullptr) - return evaluateTernaryExpression(exprTernary); - else if (auto exprFunctionCall = dynamic_cast(node); exprFunctionCall != nullptr) { - auto returnValue = evaluateFunctionCall(exprFunctionCall); - - if (returnValue == nullptr) - this->getConsole().abortEvaluation("function returning void used in expression"); - else if (auto integerNode = dynamic_cast(returnValue); integerNode != nullptr) - return integerNode; - else - this->getConsole().abortEvaluation("function not returning a numeric value used in expression"); - } else if (auto typeOperator = dynamic_cast(node); typeOperator != nullptr) - return evaluateTypeOperator(typeOperator); - else - this->getConsole().abortEvaluation("invalid operand"); - } - - ASTNodeIntegerLiteral* Evaluator::evaluateTernaryExpression(ASTNodeTernaryExpression *node) { - switch (node->getOperator()) { - case Token::Operator::TernaryConditional: { - auto condition = this->evaluateOperand(node->getFirstOperand()); - ON_SCOPE_EXIT { delete condition; }; - - if (std::visit([](auto &&value){ return value != 0; }, condition->getValue())) - return this->evaluateOperand(node->getSecondOperand()); - else - return this->evaluateOperand(node->getThirdOperand()); - } - default: - this->getConsole().abortEvaluation("invalid operator used in ternary expression"); - } - } - - ASTNodeIntegerLiteral* Evaluator::evaluateMathematicalExpression(ASTNodeNumericExpression *node) { - auto leftInteger = this->evaluateOperand(node->getLeftOperand()); - auto rightInteger = this->evaluateOperand(node->getRightOperand()); - - return evaluateOperator(leftInteger, rightInteger, node->getOperator()); - } - - void Evaluator::createLocalVariable(const std::string &varName, PatternData *pattern) { - auto startOffset = this->m_currOffset; - ON_SCOPE_EXIT { this->m_currOffset = startOffset; }; - - auto endOfStack = this->m_localStack.size(); - - for (auto &variable : *this->m_localVariables.back()) { - if (variable->getVariableName() == varName) - this->getConsole().abortEvaluation(hex::format("redefinition of variable {}", varName)); - } - - this->m_localStack.resize(endOfStack + pattern->getSize()); - - pattern->setVariableName(std::string(varName)); - pattern->setOffset(endOfStack); + pattern->setVariableName(name); + pattern->setOffset(this->getStack().size()); pattern->setLocal(true); - this->m_localVariables.back()->push_back(pattern); - std::memset(this->m_localStack.data() + pattern->getOffset(), 0x00, pattern->getSize()); + this->getStack().emplace_back(); + variables.push_back(pattern); } - void Evaluator::setLocalVariableValue(const std::string &varName, const void *value, size_t size) { - PatternData *varPattern = nullptr; - for (auto &var : *this->m_localVariables.back()) { - if (var->getVariableName() == varName) - varPattern = var; + void Evaluator::setVariable(const std::string &name, const Token::Literal& value) { + PatternData *pattern = nullptr; + + auto &variables = *this->getScope(0).scope; + for (auto &variable : variables) { + if (variable->getVariableName() == name) { + pattern = variable; + break; + } } - std::memset(this->m_localStack.data() + varPattern->getOffset(), 0x00, varPattern->getSize()); - std::memcpy(this->m_localStack.data() + varPattern->getOffset(), value, std::min(varPattern->getSize(), size)); - } + if (pattern == nullptr) + LogConsole::abortEvaluation(hex::format("no variable with name '{}' found", name)); - void Evaluator::evaluateFunctionDefinition(ASTNodeFunctionDefinition *node) { - ContentRegistry::PatternLanguageFunctions::Function function = { - (u32)node->getParams().size(), - [paramNames = node->getParams(), body = node->getBody()](Evaluator& evaluator, std::vector ¶ms) -> ASTNode* { - // Create local variables from parameters - std::vector localVariables; - evaluator.m_localVariables.push_back(&localVariables); - - ON_SCOPE_EXIT { - u32 stackSizeToDrop = 0; - for (auto &localVar : *evaluator.m_localVariables.back()) { - stackSizeToDrop += localVar->getSize(); - delete localVar; - } - evaluator.m_localVariables.pop_back(); - evaluator.m_localStack.resize(evaluator.m_localStack.size() - stackSizeToDrop); - }; - - auto startOffset = evaluator.m_currOffset; - for (u32 i = 0; i < params.size(); i++) { - if (auto integerLiteralNode = dynamic_cast(params[i]); integerLiteralNode != nullptr) { - std::visit([&](auto &&value) { - using Type = std::remove_cvref_t; - - PatternData *pattern; - if constexpr (std::is_unsigned_v) - pattern = new PatternDataUnsigned(0, sizeof(value)); - else if constexpr (std::is_signed_v) - pattern = new PatternDataSigned(0, sizeof(value)); - else if constexpr (std::is_floating_point_v) - pattern = new PatternDataFloat(0, sizeof(value)); - else return; - - evaluator.createLocalVariable(paramNames[i], pattern); - evaluator.setLocalVariableValue(paramNames[i], &value, sizeof(value)); - }, integerLiteralNode->getValue()); - } else if (auto stringLiteralNode = dynamic_cast(params[i]); stringLiteralNode != nullptr) { - auto string = stringLiteralNode->getString(); - - evaluator.createLocalVariable(paramNames[i], new PatternDataString(0, string.length())); - evaluator.setLocalVariableValue(paramNames[i], string.data(), string.length()); - } else - evaluator.getConsole().abortEvaluation(hex::format("cannot create local variable {}, invalid type", paramNames[i])); - } - evaluator.m_currOffset = startOffset; - - return evaluator.evaluateFunctionBody(body).value_or(nullptr); - } - }; - - if (this->m_definedFunctions.contains(std::string(node->getName()))) - this->getConsole().abortEvaluation(hex::format("redefinition of function {}", node->getName())); - - this->m_definedFunctions.insert({ std::string(node->getName()), function }); - } - - std::optional Evaluator::evaluateFunctionBody(const std::vector &body) { - std::optional returnResult; - auto startOffset = this->m_currOffset; - - for (auto &statement : body) { - ON_SCOPE_EXIT { this->m_currOffset = startOffset; }; - - if (auto functionCallNode = dynamic_cast(statement); functionCallNode != nullptr) { - auto result = this->evaluateFunctionCall(functionCallNode); - delete result; - } else if (auto varDeclNode = dynamic_cast(statement); varDeclNode != nullptr) { - auto pattern = this->evaluateVariable(varDeclNode); - this->createLocalVariable(varDeclNode->getName(), pattern); - } else if (auto multiVarDeclNode = dynamic_cast(statement); multiVarDeclNode != nullptr) { - for (auto &delc : multiVarDeclNode->getVariables()) { - if (auto varDecl = dynamic_cast(delc); varDecl != nullptr) { - auto pattern = this->evaluateVariable(varDecl); - this->createLocalVariable(varDecl->getName(), pattern); - } else - this->getConsole().abortEvaluation("invalid multi-variable declaration"); - } - } else if (auto assignmentNode = dynamic_cast(statement); assignmentNode != nullptr) { - if (auto numericExpressionNode = dynamic_cast(assignmentNode->getRValue()); numericExpressionNode != nullptr) { - auto value = this->evaluateMathematicalExpression(numericExpressionNode); - ON_SCOPE_EXIT { delete value; }; - - std::visit([&](auto &&value) { - this->setLocalVariableValue(assignmentNode->getLValueName(), &value, sizeof(value)); - }, value->getValue()); - } else { - this->getConsole().abortEvaluation("invalid rvalue used in assignment"); - } - } else if (auto returnNode = dynamic_cast(statement); returnNode != nullptr) { - if (returnNode->getRValue() == nullptr) { - returnResult = nullptr; - } else if (auto numericExpressionNode = dynamic_cast(returnNode->getRValue()); numericExpressionNode != nullptr) { - returnResult = this->evaluateMathematicalExpression(numericExpressionNode); - } else { - this->getConsole().abortEvaluation("invalid rvalue used in return statement"); - } - } else if (auto conditionalNode = dynamic_cast(statement); conditionalNode != nullptr) { - if (auto numericExpressionNode = dynamic_cast(conditionalNode->getCondition()); numericExpressionNode != nullptr) { - auto condition = this->evaluateMathematicalExpression(numericExpressionNode); - - u32 localVariableStartCount = this->m_localVariables.back()->size(); - u32 localVariableStackStartSize = this->m_localStack.size(); - - if (std::visit([](auto &&value) { return value != 0; }, condition->getValue())) - returnResult = this->evaluateFunctionBody(conditionalNode->getTrueBody()); + Token::Literal castedLiteral = std::visit(overloaded { + [&](double &value) -> Token::Literal { + if (dynamic_cast(pattern)) + return u128(value); + else if (dynamic_cast(pattern)) + return s128(value); + else if (dynamic_cast(pattern)) + return value; else - returnResult = this->evaluateFunctionBody(conditionalNode->getFalseBody()); - - for (u32 i = localVariableStartCount; i < this->m_localVariables.back()->size(); i++) - delete (*this->m_localVariables.back())[i]; - this->m_localVariables.back()->resize(localVariableStartCount); - this->m_localStack.resize(localVariableStackStartSize); - - } else { - this->getConsole().abortEvaluation("invalid rvalue used in return statement"); - } - } else if (auto whileLoopNode = dynamic_cast(statement); whileLoopNode != nullptr) { - if (auto numericExpressionNode = dynamic_cast(whileLoopNode->getCondition()); numericExpressionNode != nullptr) { - auto condition = this->evaluateMathematicalExpression(numericExpressionNode); - - while (std::visit([](auto &&value) { return value != 0; }, condition->getValue())) { - u32 localVariableStartCount = this->m_localVariables.back()->size(); - u32 localVariableStackStartSize = this->m_localStack.size(); - - returnResult = this->evaluateFunctionBody(whileLoopNode->getBody()); - if (returnResult.has_value()) - break; - - for (u32 i = localVariableStartCount; i < this->m_localVariables.back()->size(); i++) - delete (*this->m_localVariables.back())[i]; - this->m_localVariables.back()->resize(localVariableStartCount); - this->m_localStack.resize(localVariableStackStartSize); - - condition = this->evaluateMathematicalExpression(numericExpressionNode); + LogConsole::abortEvaluation(hex::format("cannot cast type 'double' to type '{}'", pattern->getTypeName())); + }, + [&](const std::string &value) -> Token::Literal { + if (dynamic_cast(pattern)) + return value; + else + LogConsole::abortEvaluation(hex::format("cannot cast type 'string' to type '{}'", pattern->getTypeName())); + }, + [&](PatternData * const &value) -> Token::Literal { + if (value->getTypeName() == pattern->getTypeName()) + return value; + else + LogConsole::abortEvaluation(hex::format("cannot cast type '{}' to type '{}'", value->getTypeName(), pattern->getTypeName())); + }, + [&](auto &&value) -> Token::Literal { + if (dynamic_cast(pattern)) + return u128(value); + else if (dynamic_cast(pattern)) + return s128(value); + else if (dynamic_cast(pattern)) + return char(value); + else if (dynamic_cast(pattern)) + return bool(value); + else if (dynamic_cast(pattern)) + return double(value); + else + LogConsole::abortEvaluation(hex::format("cannot cast type 'string' to type '{}'", pattern->getTypeName())); } + }, value); - } else { - this->getConsole().abortEvaluation("invalid rvalue used in return statement"); - } - } - - if (returnResult.has_value()) - return returnResult.value(); - } - - - return { }; + this->getStack().back() = castedLiteral; } - PatternData* Evaluator::evaluateAttributes(ASTNode *currNode, PatternData *currPattern) { - auto attributableNode = dynamic_cast(currNode); - if (attributableNode == nullptr) - this->getConsole().abortEvaluation("attributes applied to invalid expression"); + std::optional> Evaluator::evaluate(const std::vector &ast) { + this->m_stack.clear(); + this->m_customFunctions.clear(); + this->m_scopes.clear(); - auto handleVariableAttributes = [this, &currPattern](auto attribute, auto value) { + for (auto &func : this->m_customFunctionDefinitions) + delete func; + this->m_customFunctionDefinitions.clear(); - if (attribute == "color" && value.has_value()) - currPattern->setColor(hex::changeEndianess(u32(strtoul(value->data(), nullptr, 16)) << 8, std::endian::big)); - else if (attribute == "name" && value.has_value()) - currPattern->setVariableName(value->data()); - else if (attribute == "comment" && value.has_value()) - currPattern->setComment(value->data()); - else if (attribute == "hidden" && value.has_value()) - currPattern->setHidden(true); - else - this->getConsole().abortEvaluation("unknown or invalid attribute"); - - }; - - auto &attributes = attributableNode->getAttributes(); - - if (attributes.empty()) - return currPattern; - - if (auto variableDeclNode = dynamic_cast(currNode); variableDeclNode != nullptr) { - for (auto &attribute : attributes) - handleVariableAttributes(attribute->getAttribute(), attribute->getValue()); - } else if (auto arrayDeclNode = dynamic_cast(currNode); arrayDeclNode != nullptr) { - for (auto &attribute : attributes) - handleVariableAttributes(attribute->getAttribute(), attribute->getValue()); - } else if (auto pointerDeclNode = dynamic_cast(currNode); pointerDeclNode != nullptr) { - for (auto &attribute : attributes) - handleVariableAttributes(attribute->getAttribute(), attribute->getValue()); - } else if (auto structNode = dynamic_cast(currNode); structNode != nullptr) { - this->getConsole().abortEvaluation("unknown or invalid attribute"); - } else if (auto unionNode = dynamic_cast(currNode); unionNode != nullptr) { - this->getConsole().abortEvaluation("unknown or invalid attribute"); - } else if (auto enumNode = dynamic_cast(currNode); enumNode != nullptr) { - this->getConsole().abortEvaluation("unknown or invalid attribute"); - } else if (auto bitfieldNode = dynamic_cast(currNode); bitfieldNode != nullptr) { - this->getConsole().abortEvaluation("unknown or invalid attribute"); - } else - this->getConsole().abortEvaluation("attributes applied to invalid expression"); - - return currPattern; - } - - PatternData* Evaluator::evaluateBuiltinType(ASTNodeBuiltinType *node) { - auto &type = node->getType(); - auto typeSize = Token::getTypeSize(type); - - PatternData *pattern; - - if (type == Token::ValueType::Character) - pattern = new PatternDataCharacter(this->m_currOffset); - else if (type == Token::ValueType::Character16) - pattern = new PatternDataCharacter16(this->m_currOffset); - else if (type == Token::ValueType::Boolean) - pattern = new PatternDataBoolean(this->m_currOffset); - else if (Token::isUnsigned(type)) - pattern = new PatternDataUnsigned(this->m_currOffset, typeSize); - else if (Token::isSigned(type)) - pattern = new PatternDataSigned(this->m_currOffset, typeSize); - else if (Token::isFloatingPoint(type)) - pattern = new PatternDataFloat(this->m_currOffset, typeSize); - else if (type == Token::ValueType::Padding) - pattern = new PatternDataPadding(this->m_currOffset, 1); - else - this->getConsole().abortEvaluation("invalid builtin type"); - - this->m_currOffset += typeSize; - - pattern->setTypeName(Token::getTypeName(type)); - pattern->setEndian(this->getCurrentEndian()); - - return pattern; - } - - void Evaluator::evaluateMember(ASTNode *node, std::vector &currMembers, bool increaseOffset) { - auto startOffset = this->m_currOffset; - - if (auto memberVariableNode = dynamic_cast(node); memberVariableNode != nullptr) - currMembers.push_back(this->evaluateVariable(memberVariableNode)); - else if (auto memberMultiVariableNode = dynamic_cast(node); memberMultiVariableNode != nullptr) { - for (auto decl : memberMultiVariableNode->getVariables()) { - if (auto variableDecl = dynamic_cast(decl); variableDecl != nullptr) - currMembers.push_back(this->evaluateVariable(variableDecl)); - else - this->getConsole().abortEvaluation("invalid multi-variable declaration"); - } - } - else if (auto memberArrayNode = dynamic_cast(node); memberArrayNode != nullptr) - currMembers.push_back(this->evaluateArray(memberArrayNode)); - else if (auto memberPointerNode = dynamic_cast(node); memberPointerNode != nullptr) - currMembers.push_back(this->evaluatePointer(memberPointerNode)); - else if (auto conditionalNode = dynamic_cast(node); conditionalNode != nullptr) { - auto condition = this->evaluateMathematicalExpression(static_cast(conditionalNode->getCondition())); - - if (std::visit([](auto &&value) { return value != 0; }, condition->getValue())) { - for (auto &statement : conditionalNode->getTrueBody()) { - this->evaluateMember(statement, currMembers, increaseOffset); - } - } else { - for (auto &statement : conditionalNode->getFalseBody()) { - this->evaluateMember(statement, currMembers, increaseOffset); - } - } - - delete condition; - } - else - this->getConsole().abortEvaluation("invalid struct member"); - - if (!increaseOffset) - this->m_currOffset = startOffset; - } - - PatternData* Evaluator::evaluateStruct(ASTNodeStruct *node) { - std::vector memberPatterns; - - auto structPattern = new PatternDataStruct(this->m_currOffset, 0); - structPattern->setParent(this->m_currMemberScope.back()); - - this->m_currMembers.push_back(&memberPatterns); - this->m_currMemberScope.push_back(structPattern); - ON_SCOPE_EXIT { - this->m_currMembers.pop_back(); - this->m_currMemberScope.pop_back(); - }; - - this->m_currRecursionDepth++; - if (this->m_currRecursionDepth > this->m_recursionLimit) - this->getConsole().abortEvaluation(hex::format("evaluation depth exceeds maximum of {0}. Use #pragma eval_depth to increase the maximum", this->m_recursionLimit)); - - auto startOffset = this->m_currOffset; - for (auto &member : node->getMembers()) { - this->evaluateMember(member, memberPatterns, true); - structPattern->setMembers(memberPatterns); - } - structPattern->setSize(this->m_currOffset - startOffset); - - this->m_currRecursionDepth--; - - return this->evaluateAttributes(node, structPattern); - } - - PatternData* Evaluator::evaluateUnion(ASTNodeUnion *node) { - std::vector memberPatterns; - - auto unionPattern = new PatternDataUnion(this->m_currOffset, 0); - unionPattern->setParent(this->m_currMemberScope.back()); - - this->m_currMembers.push_back(&memberPatterns); - this->m_currMemberScope.push_back(unionPattern); - ON_SCOPE_EXIT { - this->m_currMembers.pop_back(); - this->m_currMemberScope.pop_back(); - }; - - auto startOffset = this->m_currOffset; - - this->m_currRecursionDepth++; - if (this->m_currRecursionDepth > this->m_recursionLimit) - this->getConsole().abortEvaluation(hex::format("evaluation depth exceeds maximum of {0}. Use #pragma eval_depth to increase the maximum", this->m_recursionLimit)); - - for (auto &member : node->getMembers()) { - this->evaluateMember(member, memberPatterns, false); - unionPattern->setMembers(memberPatterns); - } - - this->m_currRecursionDepth--; - - size_t size = 0; - for (const auto &pattern : memberPatterns) - size = std::max(size, pattern->getSize()); - unionPattern->setSize(size); - - this->m_currOffset += size; - - return this->evaluateAttributes(node, unionPattern); - } - - PatternData* Evaluator::evaluateEnum(ASTNodeEnum *node) { - std::vector> entryPatterns; - - auto underlyingType = dynamic_cast(node->getUnderlyingType()); - if (underlyingType == nullptr) - this->getConsole().abortEvaluation("enum underlying type was not ASTNodeTypeDecl. This is a bug"); - - size_t size; - auto builtinUnderlyingType = dynamic_cast(underlyingType->getType()); - if (builtinUnderlyingType != nullptr) - size = Token::getTypeSize(builtinUnderlyingType->getType()); - else - this->getConsole().abortEvaluation("invalid enum underlying type"); - - auto startOffset = this->m_currOffset; - for (auto &[name, value] : node->getEntries()) { - auto expression = dynamic_cast(value); - if (expression == nullptr) - this->getConsole().abortEvaluation("invalid expression in enum value"); - - auto valueNode = evaluateMathematicalExpression(expression); - ON_SCOPE_EXIT { delete valueNode; }; - - entryPatterns.emplace_back(valueNode->getValue(), name); - } - - this->m_currOffset += size; - - auto enumPattern = new PatternDataEnum(startOffset, size); - enumPattern->setSize(size); - enumPattern->setEnumValues(entryPatterns); - - return this->evaluateAttributes(node, enumPattern); - } - - PatternData* Evaluator::evaluateBitfield(ASTNodeBitfield *node) { - std::vector entryPatterns; - - auto startOffset = this->m_currOffset; - size_t bits = 0; - for (auto &[name, value] : node->getEntries()) { - auto expression = dynamic_cast(value); - if (expression == nullptr) - this->getConsole().abortEvaluation("invalid expression in bitfield field size"); - - auto valueNode = evaluateMathematicalExpression(expression); - ON_SCOPE_EXIT { delete valueNode; }; - - auto fieldBits = std::visit([this] (auto &&value) { - using Type = std::remove_cvref_t; - if constexpr (std::is_floating_point_v) - this->getConsole().abortEvaluation("bitfield entry size must be an integer value"); - return static_cast(value); - }, valueNode->getValue()); - - if (fieldBits > 64 || fieldBits <= 0) - this->getConsole().abortEvaluation("bitfield entry must occupy between 1 and 64 bits"); - - auto fieldPattern = new PatternDataBitfieldField(startOffset, bits, fieldBits); - fieldPattern->setVariableName(name); - fieldPattern->setEndian(this->getCurrentEndian()); - entryPatterns.push_back(fieldPattern); - - bits += fieldBits; - } - - size_t size = (bits + 7) / 8; - this->m_currOffset += size; - - auto bitfieldPattern = new PatternDataBitfield(startOffset, size); - bitfieldPattern->setFields(entryPatterns); - - return this->evaluateAttributes(node, bitfieldPattern); - } - - PatternData* Evaluator::evaluateType(ASTNodeTypeDecl *node) { - auto type = node->getType(); - - if (type == nullptr) - type = this->m_types[node->getName().data()]; - - this->m_endianStack.push_back(node->getEndian().value_or(this->m_defaultDataEndian)); - - PatternData *pattern; - - if (auto builtinTypeNode = dynamic_cast(type); builtinTypeNode != nullptr) - return this->evaluateBuiltinType(builtinTypeNode); - else if (auto typeDeclNode = dynamic_cast(type); typeDeclNode != nullptr) - pattern = this->evaluateType(typeDeclNode); - else if (auto structNode = dynamic_cast(type); structNode != nullptr) - pattern = this->evaluateStruct(structNode); - else if (auto unionNode = dynamic_cast(type); unionNode != nullptr) - pattern = this->evaluateUnion(unionNode); - else if (auto enumNode = dynamic_cast(type); enumNode != nullptr) - pattern = this->evaluateEnum(enumNode); - else if (auto bitfieldNode = dynamic_cast(type); bitfieldNode != nullptr) - pattern = this->evaluateBitfield(bitfieldNode); - else - this->getConsole().abortEvaluation("type could not be evaluated"); - - if (!node->getName().empty()) - pattern->setTypeName(node->getName().data()); - - pattern->setEndian(this->getCurrentEndian()); - - this->m_endianStack.pop_back(); - - return pattern; - } - - PatternData* Evaluator::evaluateVariable(ASTNodeVariableDecl *node) { - - if (auto offset = dynamic_cast(node->getPlacementOffset()); offset != nullptr) { - auto valueNode = evaluateMathematicalExpression(offset); - ON_SCOPE_EXIT { delete valueNode; }; - - this->m_currOffset = std::visit([this] (auto &&value) { - using Type = std::remove_cvref_t; - if constexpr (std::is_floating_point_v) - this->getConsole().abortEvaluation("bitfield entry size must be an integer value"); - return static_cast(value); - }, valueNode->getValue()); - } - - if (this->m_currOffset < this->m_provider->getBaseAddress() || this->m_currOffset >= this->m_provider->getActualSize() + this->m_provider->getBaseAddress()) { - if (node->getPlacementOffset() != nullptr) - this->getConsole().abortEvaluation("variable placed out of range"); - else - return nullptr; - } - - PatternData *pattern; - if (auto typeDecl = dynamic_cast(node->getType()); typeDecl != nullptr) - pattern = this->evaluateType(typeDecl); - else if (auto builtinTypeDecl = dynamic_cast(node->getType()); builtinTypeDecl != nullptr) - pattern = this->evaluateBuiltinType(builtinTypeDecl); - else - this->getConsole().abortEvaluation("ASTNodeVariableDecl had an invalid type. This is a bug!"); - - pattern->setVariableName(node->getName().data()); - - return this->evaluateAttributes(node, pattern); - } - - PatternData* Evaluator::evaluateArray(ASTNodeArrayVariableDecl *node) { - // Evaluate placement of array - if (auto offset = dynamic_cast(node->getPlacementOffset()); offset != nullptr) { - auto valueNode = evaluateMathematicalExpression(offset); - ON_SCOPE_EXIT { delete valueNode; }; - - this->m_currOffset = std::visit([this] (auto &&value) { - using Type = std::remove_cvref_t; - if constexpr (std::is_floating_point_v) - this->getConsole().abortEvaluation("bitfield entry size must be an integer value"); - return static_cast(value); - }, valueNode->getValue()); - } - - // Check if placed in range of the data - if (this->m_currOffset < this->m_provider->getBaseAddress() || this->m_currOffset >= this->m_provider->getActualSize() + this->m_provider->getBaseAddress()) { - if (node->getPlacementOffset() != nullptr) - this->getConsole().abortEvaluation("variable placed out of range"); - else - return nullptr; - } - - - auto type = static_cast(node->getType())->getType(); - - if (dynamic_cast(type) != nullptr) - return this->evaluateStaticArray(node); - - auto attributes = dynamic_cast(type)->getAttributes(); - - bool isStaticType = std::any_of(attributes.begin(), attributes.end(), [](ASTNodeAttribute *attribute) { - return attribute->getAttribute() == "static" && !attribute->getValue().has_value(); - }); - - if (isStaticType) - return this->evaluateStaticArray(node); - else - return this->evaluateDynamicArray(node); - } - - PatternData* Evaluator::evaluateStaticArray(ASTNodeArrayVariableDecl *node) { - std::optional color; - - ssize_t arraySize = 0; - - auto startOffset = this->m_currOffset; - PatternData *templatePattern; - if (auto typeDecl = dynamic_cast(node->getType()); typeDecl != nullptr) - templatePattern = this->evaluateType(typeDecl); - else if (auto builtinTypeDecl = dynamic_cast(node->getType()); builtinTypeDecl != nullptr) - templatePattern = this->evaluateBuiltinType(builtinTypeDecl); - else - this->getConsole().abortEvaluation("ASTNodeVariableDecl had an invalid type. This is a bug!"); - - auto entrySize = this->m_currOffset - startOffset; - - ON_SCOPE_EXIT { delete templatePattern; }; - - auto sizeNode = node->getSize(); - if (auto numericExpression = dynamic_cast(sizeNode); numericExpression != nullptr) { - // Parse explicit size of array - auto valueNode = this->evaluateMathematicalExpression(numericExpression); - ON_SCOPE_EXIT { delete valueNode; }; - - arraySize = std::visit([this] (auto &&value) { - using Type = std::remove_cvref_t; - if constexpr (std::is_floating_point_v) - this->getConsole().abortEvaluation("bitfield entry size must be an integer value"); - return static_cast(value); - }, valueNode->getValue()); - } else if (auto whileLoopExpression = dynamic_cast(sizeNode); whileLoopExpression != nullptr) { - // Parse while loop based size of array - auto conditionNode = this->evaluateMathematicalExpression(static_cast(whileLoopExpression->getCondition())); - ON_SCOPE_EXIT { delete conditionNode; }; - - while (std::visit([](auto &&value) { return value != 0; }, conditionNode->getValue())) { - arraySize++; - - delete conditionNode; - conditionNode = this->evaluateMathematicalExpression(static_cast(whileLoopExpression->getCondition())); - } - } else { - // Parse unsized array - - if (auto typeDecl = dynamic_cast(node->getType()); typeDecl != nullptr) { - if (auto builtinType = dynamic_cast(typeDecl->getType()); builtinType != nullptr) { - std::vector bytes(Token::getTypeSize(builtinType->getType()), 0x00); - u64 offset = startOffset; - - do { - this->m_provider->read(offset, bytes.data(), bytes.size()); - offset += bytes.size(); - arraySize++; - } while (!std::all_of(bytes.begin(), bytes.end(), [](u8 byte){ return byte == 0x00; }) && offset < this->m_provider->getSize()); - } - } - } - - if (arraySize < 0) - this->getConsole().abortEvaluation("array size cannot be negative"); - - PatternData *pattern; - if (dynamic_cast(templatePattern) != nullptr) - pattern = new PatternDataString(startOffset, entrySize * arraySize, color.value_or(0)); - else if (dynamic_cast(templatePattern) != nullptr) - pattern = new PatternDataString16(startOffset, entrySize * arraySize, color.value_or(0)); - else if (dynamic_cast(templatePattern) != nullptr) - pattern = new PatternDataPadding(startOffset, entrySize * arraySize); - else { - auto arrayPattern = new PatternDataStaticArray(startOffset, entrySize * arraySize, color.value_or(0)); - arrayPattern->setTypeName(templatePattern->getTypeName()); - arrayPattern->setEntries(templatePattern->clone(), arraySize); - - pattern = arrayPattern; - } - - pattern->setVariableName(node->getName().data()); - pattern->setEndian(this->getCurrentEndian()); - - this->m_currOffset = startOffset + entrySize * arraySize; - - return this->evaluateAttributes(node, pattern); - - } - - PatternData* Evaluator::evaluateDynamicArray(ASTNodeArrayVariableDecl *node) { - auto startOffset = this->m_currOffset; - - std::vector entries; - std::optional color; - - auto addEntry = [this, node, &entries, &color](u64 index) { - PatternData *entry; - if (auto typeDecl = dynamic_cast(node->getType()); typeDecl != nullptr) - entry = this->evaluateType(typeDecl); - else - this->getConsole().abortEvaluation("ASTNodeVariableDecl had an invalid type. This is a bug!"); - - entry->setVariableName(hex::format("[{0}]", index)); - entry->setEndian(this->getCurrentEndian()); - - if (!color.has_value()) - color = entry->getColor(); - entry->setColor(color.value_or(0)); - - if (this->m_currOffset > this->m_provider->getActualSize() + this->m_provider->getBaseAddress()) { - delete entry; - return; - } - - entries.push_back(entry); - }; - - auto sizeNode = node->getSize(); - if (auto numericExpression = dynamic_cast(sizeNode); numericExpression != nullptr) { - // Parse explicit size of array - auto valueNode = this->evaluateMathematicalExpression(numericExpression); - ON_SCOPE_EXIT { delete valueNode; }; - - auto arraySize = std::visit([this] (auto &&value) { - using Type = std::remove_cvref_t; - if constexpr (std::is_floating_point_v) - this->getConsole().abortEvaluation("bitfield entry size must be an integer value"); - return static_cast(value); - }, valueNode->getValue()); - - if (arraySize < 0) - this->getConsole().abortEvaluation("array size cannot be negative"); - - for (u64 i = 0; i < arraySize; i++) { - addEntry(i); - } - - } else if (auto whileLoopExpression = dynamic_cast(sizeNode); whileLoopExpression != nullptr) { - // Parse while loop based size of array - auto conditionNode = this->evaluateMathematicalExpression(static_cast(whileLoopExpression->getCondition())); - ON_SCOPE_EXIT { delete conditionNode; }; - - u64 index = 0; - while (std::visit([](auto &&value) { return value != 0; }, conditionNode->getValue())) { - - addEntry(index); - index++; - - delete conditionNode; - conditionNode = this->evaluateMathematicalExpression(static_cast(whileLoopExpression->getCondition())); - } - } - - auto deleteEntries = SCOPE_GUARD { - for (auto &entry : entries) - delete entry; - }; - - if (node->getSize() == nullptr) - this->getConsole().abortEvaluation("no bounds provided for array"); - auto pattern = new PatternDataDynamicArray(startOffset, (this->m_currOffset - startOffset), color.value_or(0)); - - deleteEntries.release(); - - pattern->setEntries(entries); - pattern->setVariableName(node->getName().data()); - pattern->setEndian(this->getCurrentEndian()); - - return this->evaluateAttributes(node, pattern); - } - - PatternData* Evaluator::evaluatePointer(ASTNodePointerVariableDecl *node) { - s128 pointerOffset; - if (auto offset = dynamic_cast(node->getPlacementOffset()); offset != nullptr) { - auto valueNode = evaluateMathematicalExpression(offset); - ON_SCOPE_EXIT { delete valueNode; }; - - pointerOffset = std::visit([this] (auto &&value) { - using Type = std::remove_cvref_t; - if constexpr (std::is_floating_point_v) - this->getConsole().abortEvaluation("bitfield entry size must be an integer value"); - return static_cast(value); - }, valueNode->getValue()); - this->m_currOffset = pointerOffset; - } else { - pointerOffset = this->m_currOffset; - } - - if (this->m_currOffset < this->m_provider->getBaseAddress() || this->m_currOffset >= this->m_provider->getActualSize() + this->m_provider->getBaseAddress()) { - if (node->getPlacementOffset() != nullptr) - this->getConsole().abortEvaluation("variable placed out of range"); - else - return nullptr; - } - - PatternData *sizeType; - - auto underlyingType = dynamic_cast(node->getSizeType()); - if (underlyingType == nullptr) - this->getConsole().abortEvaluation("underlying type is not ASTNodeTypeDecl. This is a bug"); - - if (auto builtinTypeNode = dynamic_cast(underlyingType->getType()); builtinTypeNode != nullptr) { - sizeType = evaluateBuiltinType(builtinTypeNode); - } else - this->getConsole().abortEvaluation("pointer size is not a builtin type"); - - size_t pointerSize = sizeType->getSize(); - - u128 pointedAtOffset = 0; - this->m_provider->read(pointerOffset, &pointedAtOffset, pointerSize); - this->m_currOffset = hex::changeEndianess(pointedAtOffset, pointerSize, underlyingType->getEndian().value_or(this->m_defaultDataEndian)); - - delete sizeType; - - - if (this->m_currOffset > this->m_provider->getActualSize() + this->m_provider->getBaseAddress()) - this->getConsole().abortEvaluation("pointer points past the end of the data"); - - PatternData *pointedAt; - if (auto typeDecl = dynamic_cast(node->getType()); typeDecl != nullptr) - pointedAt = this->evaluateType(typeDecl); - else if (auto builtinTypeDecl = dynamic_cast(node->getType()); builtinTypeDecl != nullptr) - pointedAt = this->evaluateBuiltinType(builtinTypeDecl); - else - this->getConsole().abortEvaluation("ASTNodeVariableDecl had an invalid type. This is a bug!"); - - this->m_currOffset = pointerOffset + pointerSize; - - auto pattern = new PatternDataPointer(pointerOffset, pointerSize); - - pattern->setVariableName(node->getName().data()); - pattern->setEndian(this->getCurrentEndian()); - pattern->setPointedAtPattern(pointedAt); - - return this->evaluateAttributes(node, pattern); - } - - std::optional> Evaluator::evaluate(const std::vector &ast) { - - this->m_globalMembers.clear(); - this->m_types.clear(); - this->m_endianStack.clear(); - this->m_definedFunctions.clear(); - this->m_currOffset = 0; + std::vector patterns; try { - for (const auto& node : ast) { - if (auto typeDeclNode = dynamic_cast(node); typeDeclNode != nullptr) { - if (this->m_types[typeDeclNode->getName().data()] == nullptr) - this->m_types[typeDeclNode->getName().data()] = typeDeclNode->getType(); - } - } - - for (const auto& [name, node] : this->m_types) { - if (auto typeDeclNode = static_cast(node); typeDeclNode->getType() == nullptr) - this->getConsole().abortEvaluation(hex::format("unresolved type '{}'", name)); - } - - for (const auto& node : ast) { - this->m_currMembers.clear(); - this->m_currMemberScope.clear(); - this->m_currMemberScope.push_back(nullptr); - - this->m_endianStack.push_back(this->m_defaultDataEndian); - this->m_currRecursionDepth = 0; - - PatternData *pattern = nullptr; - - if (auto variableDeclNode = dynamic_cast(node); variableDeclNode != nullptr) { - pattern = this->evaluateVariable(variableDeclNode); - } else if (auto arrayDeclNode = dynamic_cast(node); arrayDeclNode != nullptr) { - pattern = this->evaluateArray(arrayDeclNode); - } else if (auto pointerDeclNode = dynamic_cast(node); pointerDeclNode != nullptr) { - pattern = this->evaluatePointer(pointerDeclNode); - } else if (auto typeDeclNode = dynamic_cast(node); typeDeclNode != nullptr) { - // Handled above - } else if (auto functionCallNode = dynamic_cast(node); functionCallNode != nullptr) { - auto result = this->evaluateFunctionCall(functionCallNode); - delete result; - } else if (auto functionDefNode = dynamic_cast(node); functionDefNode != nullptr) { - this->evaluateFunctionDefinition(functionDefNode); + pushScope(nullptr, patterns); + for (auto node : ast) { + if (dynamic_cast(node)) { + ;// Don't create patterns from type declarations + } else if (dynamic_cast(node)) { + delete node->evaluate(this); + } else if (dynamic_cast(node)) { + this->m_customFunctionDefinitions.push_back(node->evaluate(this)); + } else { + auto newPatterns = node->createPatterns(this); + patterns.insert(patterns.end(), newPatterns.begin(), newPatterns.end()); } - if (pattern != nullptr) - this->m_globalMembers.push_back(pattern); - - this->m_endianStack.clear(); } - } catch (LogConsole::EvaluateError &e) { - this->getConsole().log(LogConsole::Level::Error, e); + popScope(); + } catch (const LogConsole::EvaluateError &error) { + this->m_console.log(LogConsole::Level::Error, error.second); - return { }; + if (error.first != 0) + this->m_console.setHardError(error); + + for (auto &pattern : patterns) + delete pattern; + patterns.clear(); + + return std::nullopt; } - return this->m_globalMembers; + return patterns; } } \ No newline at end of file diff --git a/plugins/libimhex/source/pattern_language/lexer.cpp b/plugins/libimhex/source/pattern_language/lexer.cpp index 1d06fabe4..8529a5050 100644 --- a/plugins/libimhex/source/pattern_language/lexer.cpp +++ b/plugins/libimhex/source/pattern_language/lexer.cpp @@ -28,9 +28,9 @@ namespace hex::pl { return string.find_first_not_of("0123456789ABCDEFabcdef.xUL"); } - std::optional parseIntegerLiteral(const std::string &string) { + std::optional parseIntegerLiteral(const std::string &string) { Token::ValueType type = Token::ValueType::Any; - Token::IntegerLiteral result; + Token::Literal result; u8 base; @@ -38,20 +38,8 @@ namespace hex::pl { auto numberData = std::string_view(string).substr(0, endPos); if (numberData.ends_with('U')) { - type = Token::ValueType::Unsigned32Bit; - numberData.remove_suffix(1); - } else if (numberData.ends_with("UL")) { - type = Token::ValueType::Unsigned64Bit; - numberData.remove_suffix(2); - } else if (numberData.ends_with("ULL")) { type = Token::ValueType::Unsigned128Bit; - numberData.remove_suffix(3); - } else if (numberData.ends_with("L")) { - type = Token::ValueType::Signed64Bit; numberData.remove_suffix(1); - } else if (numberData.ends_with("LL")) { - type = Token::ValueType::Signed128Bit; - numberData.remove_suffix(2); } else if (!numberData.starts_with("0x") && !numberData.starts_with("0b")) { if (numberData.ends_with('F')) { type = Token::ValueType::Float; @@ -98,7 +86,7 @@ namespace hex::pl { } else return { }; if (type == Token::ValueType::Any) - type = Token::ValueType::Signed32Bit; + type = Token::ValueType::Signed128Bit; if (numberData.length() == 0) @@ -119,10 +107,6 @@ namespace hex::pl { } switch (type) { - case Token::ValueType::Unsigned32Bit: return { u32(integer) }; - case Token::ValueType::Signed32Bit: return { s32(integer) }; - case Token::ValueType::Unsigned64Bit: return { u64(integer) }; - case Token::ValueType::Signed64Bit: return { s64(integer) }; case Token::ValueType::Unsigned128Bit: return { u128(integer) }; case Token::ValueType::Signed128Bit: return { s128(integer) }; default: return { }; @@ -379,7 +363,7 @@ namespace hex::pl { auto [c, charSize] = character.value(); - tokens.emplace_back(VALUE_TOKEN(Integer, c)); + tokens.emplace_back(VALUE_TOKEN(Integer, Token::Literal(c))); offset += charSize; } else if (c == '\"') { auto string = getStringLiteral(code.substr(offset)); @@ -389,7 +373,7 @@ namespace hex::pl { auto [s, stringSize] = string.value(); - tokens.emplace_back(VALUE_TOKEN(String, s)); + tokens.emplace_back(VALUE_TOKEN(String, Token::Literal(s))); offset += stringSize; } else if (std::isalpha(c)) { std::string identifier = matchTillInvalid(&code[offset], [](char c) -> bool { return std::isalnum(c) || c == '_'; }); @@ -415,11 +399,13 @@ namespace hex::pl { else if (identifier == "else") tokens.emplace_back(TOKEN(Keyword, Else)); else if (identifier == "false") - tokens.emplace_back(VALUE_TOKEN(Integer, bool(0))); + tokens.emplace_back(VALUE_TOKEN(Integer, Token::Literal(false))); else if (identifier == "true") - tokens.emplace_back(VALUE_TOKEN(Integer, bool(1))); + tokens.emplace_back(VALUE_TOKEN(Integer, Token::Literal(true))); else if (identifier == "parent") tokens.emplace_back(TOKEN(Keyword, Parent)); + else if (identifier == "this") + tokens.emplace_back(TOKEN(Keyword, This)); else if (identifier == "while") tokens.emplace_back(TOKEN(Keyword, While)); else if (identifier == "fn") @@ -460,13 +446,15 @@ namespace hex::pl { tokens.emplace_back(TOKEN(ValueType, Character16)); else if (identifier == "bool") tokens.emplace_back(TOKEN(ValueType, Boolean)); + else if (identifier == "str") + tokens.emplace_back(TOKEN(ValueType, String)); else if (identifier == "padding") tokens.emplace_back(TOKEN(ValueType, Padding)); // If it's not a keyword and a builtin type, it has to be an identifier else - tokens.emplace_back(VALUE_TOKEN(Identifier, identifier)); + tokens.emplace_back(VALUE_TOKEN(Identifier, Token::Identifier(identifier))); offset += identifier.length(); } else if (std::isdigit(c)) { @@ -476,7 +464,7 @@ namespace hex::pl { throwLexerError("invalid integer literal", lineNumber); - tokens.emplace_back(VALUE_TOKEN(Integer, integer.value())); + tokens.emplace_back(VALUE_TOKEN(Integer, Token::Literal(integer.value()))); offset += getIntegerLiteralLength(&code[offset]); } else throwLexerError("unknown token", lineNumber); diff --git a/plugins/libimhex/source/pattern_language/parser.cpp b/plugins/libimhex/source/pattern_language/parser.cpp index b40f401c6..a0851c0fb 100644 --- a/plugins/libimhex/source/pattern_language/parser.cpp +++ b/plugins/libimhex/source/pattern_language/parser.cpp @@ -1,13 +1,9 @@ #include -#include - #include #define MATCHES(x) (begin() && x) -#define TO_NUMERIC_EXPRESSION(node) new ASTNodeNumericExpression((node), new ASTNodeIntegerLiteral(s32(0)), Token::Operator::Plus) - // Definition syntax: // [A] : Either A or no token // [A|B] : Either A, B or no token @@ -33,10 +29,7 @@ namespace hex::pl { }; while (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) { - if (MATCHES(sequence(STRING))) - params.push_back(parseStringLiteral()); - else - params.push_back(parseMathematicalExpression()); + params.push_back(parseMathematicalExpression()); if (MATCHES(sequence(SEPARATOR_COMMA, SEPARATOR_ROUNDBRACKETCLOSE))) throwParseError("unexpected ',' at end of function parameter list", -1); @@ -49,18 +42,18 @@ namespace hex::pl { paramCleanup.release(); - return new ASTNodeFunctionCall(functionName, params); + return create(new ASTNodeFunctionCall(functionName, params)); } ASTNode* Parser::parseStringLiteral() { - return new ASTNodeStringLiteral(getValue(-1)); + return create(new ASTNodeLiteral(getValue(-1))); } std::string Parser::parseNamespaceResolution() { std::string name; while (true) { - name += getValue(-1); + name += getValue(-1).get(); if (MATCHES(sequence(OPERATOR_SCOPERESOLUTION, IDENTIFIER))) { name += "::"; @@ -77,14 +70,14 @@ namespace hex::pl { std::string typeName; while (true) { - typeName += getValue(-1); + typeName += getValue(-1).get(); if (MATCHES(sequence(OPERATOR_SCOPERESOLUTION, IDENTIFIER))) { if (peek(OPERATOR_SCOPERESOLUTION, 0) && peek(IDENTIFIER, 1)) { typeName += "::"; continue; } else { - return new ASTNodeScopeResolution({ typeName, getValue(-1) }); + return create(new ASTNodeScopeResolution({ typeName, getValue(-1).get() })); } } else @@ -97,9 +90,11 @@ namespace hex::pl { // ASTNode* Parser::parseRValue(ASTNodeRValue::Path &path) { if (peek(IDENTIFIER, -1)) - path.push_back(getValue(-1)); + path.push_back(getValue(-1).get()); else if (peek(KEYWORD_PARENT, -1)) path.emplace_back("parent"); + else if (peek(KEYWORD_THIS, -1)) + path.emplace_back("this"); if (MATCHES(sequence(SEPARATOR_SQUAREBRACKETOPEN))) { path.push_back(parseMathematicalExpression()); @@ -113,13 +108,13 @@ namespace hex::pl { else throwParseError("expected member name or 'parent' keyword", -1); } else - return new ASTNodeRValue(path); + return create(new ASTNodeRValue(path)); } // ASTNode* Parser::parseFactor() { if (MATCHES(sequence(INTEGER))) - return TO_NUMERIC_EXPRESSION(new ASTNodeIntegerLiteral(getValue(-1))); + return new ASTNodeLiteral(getValue(-1)); else if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETOPEN))) { auto node = this->parseMathematicalExpression(); if (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) { @@ -135,34 +130,54 @@ namespace hex::pl { if (isFunction) { - return TO_NUMERIC_EXPRESSION(this->parseFunctionCall()); + return this->parseFunctionCall(); } else if (peek(OPERATOR_SCOPERESOLUTION, 0)) { - return TO_NUMERIC_EXPRESSION(this->parseScopeResolution()); + return this->parseScopeResolution(); } else { ASTNodeRValue::Path path; - return TO_NUMERIC_EXPRESSION(this->parseRValue(path)); + return this->parseRValue(path); } - } else if (MATCHES(oneOf(KEYWORD_PARENT))) { + } else if (MATCHES(oneOf(KEYWORD_PARENT, KEYWORD_THIS))) { ASTNodeRValue::Path path; - return TO_NUMERIC_EXPRESSION(this->parseRValue(path)); + return this->parseRValue(path); } else if (MATCHES(sequence(OPERATOR_DOLLAR))) { - return TO_NUMERIC_EXPRESSION(new ASTNodeRValue({ "$" })); + return new ASTNodeRValue({ "$" }); } else if (MATCHES(oneOf(OPERATOR_ADDRESSOF, OPERATOR_SIZEOF) && sequence(SEPARATOR_ROUNDBRACKETOPEN))) { auto op = getValue(-2); - if (!MATCHES(oneOf(IDENTIFIER, KEYWORD_PARENT))) { + if (!MATCHES(oneOf(IDENTIFIER, KEYWORD_PARENT, KEYWORD_THIS))) { throwParseError("expected rvalue identifier"); } ASTNodeRValue::Path path; - auto node = new ASTNodeTypeOperator(op, this->parseRValue(path)); + auto node = create(new ASTNodeTypeOperator(op, this->parseRValue(path))); if (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) { delete node; throwParseError("expected closing parenthesis"); } - return TO_NUMERIC_EXPRESSION(node); + return node; } else - throwParseError("expected integer or parenthesis"); + throwParseError("expected value or parenthesis"); + } + + ASTNode* Parser::parseCastExpression() { + if (peek(KEYWORD_BE) || peek(KEYWORD_LE) || peek(VALUETYPE_ANY)) { + auto type = parseType(); + auto builtinType = dynamic_cast(type->getType()); + + if (builtinType == nullptr) + throwParseError("invalid type used for pointer size", -1); + + if (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETOPEN))) + throwParseError("expected '(' before cast expression", -1); + + auto node = parseFactor(); + + if (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) + throwParseError("expected ')' after cast expression", -1); + + return new ASTNodeCast(node, type); + } else return parseFactor(); } // <+|-|!|~> (parseFactor) @@ -170,10 +185,12 @@ namespace hex::pl { if (MATCHES(oneOf(OPERATOR_PLUS, OPERATOR_MINUS, OPERATOR_BOOLNOT, OPERATOR_BITNOT))) { auto op = getValue(-1); - return new ASTNodeNumericExpression(new ASTNodeIntegerLiteral(0), this->parseFactor(), op); + return create(new ASTNodeMathematicalExpression(new ASTNodeLiteral(0), this->parseCastExpression(), op)); + } else if (MATCHES(sequence(STRING))) { + return this->parseStringLiteral(); } - return this->parseFactor(); + return this->parseCastExpression(); } // (parseUnaryExpression) <*|/|%> (parseUnaryExpression) @@ -184,7 +201,7 @@ namespace hex::pl { while (MATCHES(oneOf(OPERATOR_STAR, OPERATOR_SLASH, OPERATOR_PERCENT))) { auto op = getValue(-1); - node = new ASTNodeNumericExpression(node, this->parseUnaryExpression(), op); + node = create(new ASTNodeMathematicalExpression(node, this->parseUnaryExpression(), op)); } nodeCleanup.release(); @@ -200,7 +217,7 @@ namespace hex::pl { while (MATCHES(variant(OPERATOR_PLUS, OPERATOR_MINUS))) { auto op = getValue(-1); - node = new ASTNodeNumericExpression(node, this->parseMultiplicativeExpression(), op); + node = create(new ASTNodeMathematicalExpression(node, this->parseMultiplicativeExpression(), op)); } nodeCleanup.release(); @@ -216,7 +233,7 @@ namespace hex::pl { while (MATCHES(variant(OPERATOR_SHIFTLEFT, OPERATOR_SHIFTRIGHT))) { auto op = getValue(-1); - node = new ASTNodeNumericExpression(node, this->parseAdditiveExpression(), op); + node = create(new ASTNodeMathematicalExpression(node, this->parseAdditiveExpression(), op)); } nodeCleanup.release(); @@ -232,7 +249,7 @@ namespace hex::pl { while (MATCHES(sequence(OPERATOR_BOOLGREATERTHAN) || sequence(OPERATOR_BOOLLESSTHAN) || sequence(OPERATOR_BOOLGREATERTHANOREQUALS) || sequence(OPERATOR_BOOLLESSTHANOREQUALS))) { auto op = getValue(-1); - node = new ASTNodeNumericExpression(node, this->parseShiftExpression(), op); + node = create(new ASTNodeMathematicalExpression(node, this->parseShiftExpression(), op)); } nodeCleanup.release(); @@ -248,7 +265,7 @@ namespace hex::pl { while (MATCHES(sequence(OPERATOR_BOOLEQUALS) || sequence(OPERATOR_BOOLNOTEQUALS))) { auto op = getValue(-1); - node = new ASTNodeNumericExpression(node, this->parseRelationExpression(), op); + node = create(new ASTNodeMathematicalExpression(node, this->parseRelationExpression(), op)); } nodeCleanup.release(); @@ -263,7 +280,7 @@ namespace hex::pl { auto nodeCleanup = SCOPE_GUARD { delete node; }; while (MATCHES(sequence(OPERATOR_BITAND))) { - node = new ASTNodeNumericExpression(node, this->parseEqualityExpression(), Token::Operator::BitAnd); + node = create(new ASTNodeMathematicalExpression(node, this->parseEqualityExpression(), Token::Operator::BitAnd)); } nodeCleanup.release(); @@ -278,7 +295,7 @@ namespace hex::pl { auto nodeCleanup = SCOPE_GUARD { delete node; }; while (MATCHES(sequence(OPERATOR_BITXOR))) { - node = new ASTNodeNumericExpression(node, this->parseBinaryAndExpression(), Token::Operator::BitXor); + node = create(new ASTNodeMathematicalExpression(node, this->parseBinaryAndExpression(), Token::Operator::BitXor)); } nodeCleanup.release(); @@ -293,7 +310,7 @@ namespace hex::pl { auto nodeCleanup = SCOPE_GUARD { delete node; }; while (MATCHES(sequence(OPERATOR_BITOR))) { - node = new ASTNodeNumericExpression(node, this->parseBinaryXorExpression(), Token::Operator::BitOr); + node = create(new ASTNodeMathematicalExpression(node, this->parseBinaryXorExpression(), Token::Operator::BitOr)); } nodeCleanup.release(); @@ -308,7 +325,7 @@ namespace hex::pl { auto nodeCleanup = SCOPE_GUARD { delete node; }; while (MATCHES(sequence(OPERATOR_BOOLAND))) { - node = new ASTNodeNumericExpression(node, this->parseBinaryOrExpression(), Token::Operator::BitOr); + node = create(new ASTNodeMathematicalExpression(node, this->parseBinaryOrExpression(), Token::Operator::BitOr)); } nodeCleanup.release(); @@ -323,7 +340,7 @@ namespace hex::pl { auto nodeCleanup = SCOPE_GUARD { delete node; }; while (MATCHES(sequence(OPERATOR_BOOLXOR))) { - node = new ASTNodeNumericExpression(node, this->parseBooleanAnd(), Token::Operator::BitOr); + node = create(new ASTNodeMathematicalExpression(node, this->parseBooleanAnd(), Token::Operator::BitOr)); } nodeCleanup.release(); @@ -338,7 +355,7 @@ namespace hex::pl { auto nodeCleanup = SCOPE_GUARD { delete node; }; while (MATCHES(sequence(OPERATOR_BOOLOR))) { - node = new ASTNodeNumericExpression(node, this->parseBooleanXor(), Token::Operator::BitOr); + node = create(new ASTNodeMathematicalExpression(node, this->parseBooleanXor(), Token::Operator::BitOr)); } nodeCleanup.release(); @@ -359,7 +376,7 @@ namespace hex::pl { throwParseError("expected ':' in ternary expression"); auto third = this->parseBooleanOr(); - node = TO_NUMERIC_EXPRESSION(new ASTNodeTernaryExpression(node, second, third, Token::Operator::TernaryConditional)); + node = create(new ASTNodeTernaryExpression(node, second, third, Token::Operator::TernaryConditional)); } nodeCleanup.release(); @@ -381,14 +398,19 @@ namespace hex::pl { if (!MATCHES(sequence(IDENTIFIER))) throwParseError("expected attribute expression"); - auto attribute = this->getValue(-1); + auto attribute = getValue(-1).get(); if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETOPEN, STRING, SEPARATOR_ROUNDBRACKETCLOSE))) { - auto value = this->getValue(-2); - currNode->addAttribute(new ASTNodeAttribute(attribute, value)); + auto value = getValue(-2); + auto string = std::get_if(&value); + + if (string == nullptr) + throwParseError("expected string attribute argument"); + + currNode->addAttribute(create(new ASTNodeAttribute(attribute, *string))); } else - currNode->addAttribute(new ASTNodeAttribute(attribute)); + currNode->addAttribute(create(new ASTNodeAttribute(attribute))); } while (MATCHES(sequence(SEPARATOR_COMMA))); @@ -398,16 +420,24 @@ namespace hex::pl { /* Functions */ - ASTNode* Parser::parseFunctionDefintion() { - const auto &functionName = getValue(-2); - std::vector params; + ASTNode* Parser::parseFunctionDefinition() { + const auto &functionName = getValue(-2).get(); + std::map params; // Parse parameter list - bool hasParams = MATCHES(sequence(IDENTIFIER)); + bool hasParams = !peek(SEPARATOR_ROUNDBRACKETCLOSE); + u32 unnamedParamCount = 0; while (hasParams) { - params.push_back(getValue(-1)); + auto type = parseType(true); - if (!MATCHES(sequence(SEPARATOR_COMMA, IDENTIFIER))) { + if (MATCHES(sequence(IDENTIFIER))) + params.emplace(getValue(-1).get(), type); + else { + params.emplace(std::to_string(unnamedParamCount), type); + unnamedParamCount++; + } + + if (!MATCHES(sequence(SEPARATOR_COMMA))) { if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) break; else @@ -435,7 +465,7 @@ namespace hex::pl { } bodyCleanup.release(); - return new ASTNodeFunctionDefinition(getNamespacePrefixedName(functionName), params, body); + return create(new ASTNodeFunctionDefinition(getNamespacePrefixedName(functionName), params, body)); } ASTNode* Parser::parseFunctionStatement() { @@ -487,18 +517,18 @@ namespace hex::pl { } ASTNode* Parser::parseFunctionVariableAssignment() { - const auto &lvalue = getValue(-2); + const auto &lvalue = getValue(-2).get(); auto rvalue = this->parseMathematicalExpression(); - return new ASTNodeAssignment(lvalue, rvalue); + return create(new ASTNodeAssignment(lvalue, rvalue)); } ASTNode* Parser::parseFunctionReturnStatement() { if (peek(SEPARATOR_ENDOFEXPRESSION)) - return new ASTNodeReturnStatement(nullptr); + return create(new ASTNodeReturnStatement(nullptr)); else - return new ASTNodeReturnStatement(this->parseMathematicalExpression()); + return create(new ASTNodeReturnStatement(this->parseMathematicalExpression())); } ASTNode* Parser::parseFunctionConditional() { @@ -532,7 +562,7 @@ namespace hex::pl { cleanup.release(); - return new ASTNodeConditionalStatement(condition, trueBody, falseBody); + return create(new ASTNodeConditionalStatement(condition, trueBody, falseBody)); } ASTNode* Parser::parseFunctionWhileLoop() { @@ -556,7 +586,7 @@ namespace hex::pl { cleanup.release(); - return new ASTNodeWhileStatement(condition, body); + return create(new ASTNodeWhileStatement(condition, body)); } /* Control flow */ @@ -593,7 +623,7 @@ namespace hex::pl { cleanup.release(); - return new ASTNodeConditionalStatement(condition, trueBody, falseBody); + return create(new ASTNodeConditionalStatement(condition, trueBody, falseBody)); } // while ((parseMathematicalExpression)) @@ -609,13 +639,13 @@ namespace hex::pl { cleanup.release(); - return new ASTNodeWhileStatement(condition, { }); + return create(new ASTNodeWhileStatement(condition, { })); } /* Type declarations */ - // [be|le] - ASTNodeTypeDecl* Parser::parseType() { + // [be|le] + ASTNodeTypeDecl* Parser::parseType(bool allowString) { std::optional endian; if (MATCHES(sequence(KEYWORD_LE))) @@ -627,24 +657,28 @@ namespace hex::pl { std::string typeName = parseNamespaceResolution(); if (this->m_types.contains(typeName)) - return new ASTNodeTypeDecl({ }, this->m_types[typeName]->clone(), endian); + return create(new ASTNodeTypeDecl({ }, this->m_types[typeName]->clone(), endian)); else if (this->m_types.contains(getNamespacePrefixedName(typeName))) - return new ASTNodeTypeDecl({ }, this->m_types[getNamespacePrefixedName(typeName)]->clone(), endian); + return create(new ASTNodeTypeDecl({ }, this->m_types[getNamespacePrefixedName(typeName)]->clone(), endian)); else throwParseError(hex::format("unknown type '{}'", typeName)); } else if (MATCHES(sequence(VALUETYPE_ANY))) { // Builtin type - return new ASTNodeTypeDecl({ }, new ASTNodeBuiltinType(getValue(-1)), endian); + auto type = getValue(-1); + if (!allowString && type == Token::ValueType::String) + throwParseError("cannot use 'str' in this context. Use a character array instead"); + + return create(new ASTNodeTypeDecl({ }, new ASTNodeBuiltinType(type), endian)); } else throwParseError("failed to parse type. Expected identifier or builtin type"); } // using Identifier = (parseType) ASTNode* Parser::parseUsingDeclaration() { - auto name = getValue(-2); + auto name = getValue(-2).get(); auto *type = dynamic_cast(parseType()); if (type == nullptr) throwParseError("invalid type used in variable declaration", -1); - return new ASTNodeTypeDecl(name, type, type->getEndian()); + return create(new ASTNodeTypeDecl(name, type, type->getEndian())); } // padding[(parseMathematicalExpression)] @@ -656,7 +690,7 @@ namespace hex::pl { throwParseError("expected closing ']' at end of array declaration", -1); } - return new ASTNodeArrayVariableDecl({ }, new ASTNodeTypeDecl({ }, new ASTNodeBuiltinType(Token::ValueType::Padding)), size);; + return create(new ASTNodeArrayVariableDecl({ }, new ASTNodeTypeDecl({ }, new ASTNodeBuiltinType(Token::ValueType::Padding)), size)); } // (parseType) Identifier @@ -667,19 +701,19 @@ namespace hex::pl { auto variableCleanup = SCOPE_GUARD { for (auto var : variables) delete var; }; do { - variables.push_back(new ASTNodeVariableDecl(getValue(-1), type->clone())); + variables.push_back(create(new ASTNodeVariableDecl(getValue(-1).get(), type->clone()))); } while (MATCHES(sequence(SEPARATOR_COMMA, IDENTIFIER))); variableCleanup.release(); - return new ASTNodeMultiVariableDecl(variables); + return create(new ASTNodeMultiVariableDecl(variables)); } else - return new ASTNodeVariableDecl(getValue(-1), type->clone()); + return create(new ASTNodeVariableDecl(getValue(-1).get(), type->clone())); } // (parseType) Identifier[(parseMathematicalExpression)] ASTNode* Parser::parseMemberArrayVariable(ASTNodeTypeDecl *type) { - auto name = getValue(-2); + auto name = getValue(-2).get(); ASTNode *size = nullptr; auto sizeCleanup = SCOPE_GUARD { delete size; }; @@ -696,12 +730,12 @@ namespace hex::pl { sizeCleanup.release(); - return new ASTNodeArrayVariableDecl(name, type->clone(), size); + return create(new ASTNodeArrayVariableDecl(name, type->clone(), size)); } // (parseType) *Identifier : (parseType) ASTNode* Parser::parseMemberPointerVariable(ASTNodeTypeDecl *type) { - auto name = getValue(-2); + auto name = getValue(-2).get(); auto sizeType = parseType(); @@ -712,7 +746,7 @@ namespace hex::pl { throwParseError("invalid type used for pointer size", -1); } - return new ASTNodePointerVariableDecl(name, type->clone(), sizeType); + return create(new ASTNodePointerVariableDecl(name, type->clone(), sizeType)); } // [(parsePadding)|(parseMemberVariable)|(parseMemberArrayVariable)|(parseMemberPointerVariable)] @@ -758,8 +792,8 @@ namespace hex::pl { // struct Identifier { <(parseMember)...> } ASTNode* Parser::parseStruct() { - const auto structNode = new ASTNodeStruct(); - const auto &typeName = getValue(-2); + const auto structNode = create(new ASTNodeStruct()); + const auto &typeName = getValue(-2).get(); auto structGuard = SCOPE_GUARD { delete structNode; }; while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { @@ -768,13 +802,13 @@ namespace hex::pl { structGuard.release(); - return new ASTNodeTypeDecl(typeName, structNode); + return create(new ASTNodeTypeDecl(typeName, structNode)); } // union Identifier { <(parseMember)...> } ASTNode* Parser::parseUnion() { - const auto unionNode = new ASTNodeUnion(); - const auto &typeName = getValue(-2); + const auto unionNode = create(new ASTNodeUnion()); + const auto &typeName = getValue(-2).get(); auto unionGuard = SCOPE_GUARD { delete unionNode; }; while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { @@ -783,17 +817,17 @@ namespace hex::pl { unionGuard.release(); - return new ASTNodeTypeDecl(typeName, unionNode); + return create(new ASTNodeTypeDecl(typeName, unionNode)); } // enum Identifier : (parseType) { <...> } ASTNode* Parser::parseEnum() { - auto typeName = getValue(-2); + auto typeName = getValue(-2).get(); auto underlyingType = parseType(); if (underlyingType->getEndian().has_value()) throwParseError("underlying type may not have an endian specification", -2); - const auto enumNode = new ASTNodeEnum(underlyingType); + const auto enumNode = create(new ASTNodeEnum(underlyingType)); auto enumGuard = SCOPE_GUARD { delete enumNode; }; if (!MATCHES(sequence(SEPARATOR_CURLYBRACKETOPEN))) @@ -802,7 +836,7 @@ namespace hex::pl { ASTNode *lastEntry = nullptr; while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { if (MATCHES(sequence(IDENTIFIER, OPERATOR_ASSIGNMENT))) { - auto name = getValue(-2); + auto name = getValue(-2).get(); auto value = parseMathematicalExpression(); enumNode->addEntry(name, value); @@ -810,11 +844,11 @@ namespace hex::pl { } else if (MATCHES(sequence(IDENTIFIER))) { ASTNode *valueExpr; - auto name = getValue(-1); + auto name = getValue(-1).get(); if (enumNode->getEntries().empty()) - valueExpr = lastEntry = TO_NUMERIC_EXPRESSION(new ASTNodeIntegerLiteral(u8(0))); + valueExpr = lastEntry = create(new ASTNodeLiteral(u128(0))); else - valueExpr = lastEntry = new ASTNodeNumericExpression(lastEntry->clone(), new ASTNodeIntegerLiteral(s32(1)), Token::Operator::Plus); + valueExpr = lastEntry = create(new ASTNodeMathematicalExpression(lastEntry->clone(), new ASTNodeLiteral(u128(1)), Token::Operator::Plus)); enumNode->addEntry(name, valueExpr); } @@ -833,19 +867,19 @@ namespace hex::pl { enumGuard.release(); - return new ASTNodeTypeDecl(typeName, enumNode); + return create(new ASTNodeTypeDecl(typeName, enumNode)); } // bitfield Identifier { } ASTNode* Parser::parseBitfield() { - std::string typeName = getValue(-2); + std::string typeName = getValue(-2).get(); - const auto bitfieldNode = new ASTNodeBitfield(); + const auto bitfieldNode = create(new ASTNodeBitfield()); auto enumGuard = SCOPE_GUARD { delete bitfieldNode; }; while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { if (MATCHES(sequence(IDENTIFIER, OPERATOR_INHERIT))) { - auto name = getValue(-2); + auto name = getValue(-2).get(); bitfieldNode->addEntry(name, parseMathematicalExpression()); } else if (MATCHES(sequence(SEPARATOR_ENDOFPROGRAM))) @@ -862,24 +896,24 @@ namespace hex::pl { enumGuard.release(); - return new ASTNodeTypeDecl(typeName, bitfieldNode); + return create(new ASTNodeTypeDecl(typeName, bitfieldNode)); } // (parseType) Identifier @ Integer ASTNode* Parser::parseVariablePlacement(ASTNodeTypeDecl *type) { - auto name = getValue(-1); + auto name = getValue(-1).get(); if (!MATCHES(sequence(OPERATOR_AT))) throwParseError("expected placement instruction", -1); auto placementOffset = parseMathematicalExpression(); - return new ASTNodeVariableDecl(name, type, placementOffset); + return create(new ASTNodeVariableDecl(name, type, placementOffset)); } // (parseType) Identifier[[(parseMathematicalExpression)]] @ Integer ASTNode* Parser::parseArrayVariablePlacement(ASTNodeTypeDecl *type) { - auto name = getValue(-2); + auto name = getValue(-2).get(); ASTNode *size = nullptr; auto sizeCleanup = SCOPE_GUARD { delete size; }; @@ -901,12 +935,12 @@ namespace hex::pl { sizeCleanup.release(); - return new ASTNodeArrayVariableDecl(name, type, size, placementOffset); + return create(new ASTNodeArrayVariableDecl(name, type, size, placementOffset)); } // (parseType) *Identifier : (parseType) @ Integer ASTNode* Parser::parsePointerVariablePlacement(ASTNodeTypeDecl *type) { - auto name = getValue(-2); + auto name = getValue(-2).get(); auto sizeType = parseType(); auto sizeCleanup = SCOPE_GUARD { delete sizeType; }; @@ -925,7 +959,7 @@ namespace hex::pl { sizeCleanup.release(); - return new ASTNodePointerVariableDecl(name, type, sizeType, placementOffset); + return create(new ASTNodePointerVariableDecl(name, type, sizeType, placementOffset)); } std::vector Parser::parseNamespace() { @@ -937,7 +971,7 @@ namespace hex::pl { this->m_currNamespace.push_back(this->m_currNamespace.back()); while (true) { - this->m_currNamespace.back().push_back(getValue(-1)); + this->m_currNamespace.back().push_back(getValue(-1).get()); if (MATCHES(sequence(OPERATOR_SCOPERESOLUTION, IDENTIFIER))) continue; @@ -974,7 +1008,7 @@ namespace hex::pl { // <(parseUsingDeclaration)|(parseVariablePlacement)|(parseStruct)> std::vector Parser::parseStatements() { - ASTNode *statement = nullptr; + ASTNode *statement; if (MATCHES(sequence(KEYWORD_USING, IDENTIFIER, OPERATOR_ASSIGNMENT))) statement = parseUsingDeclaration(); @@ -1003,7 +1037,7 @@ namespace hex::pl { else if (MATCHES(sequence(KEYWORD_BITFIELD, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN))) statement = parseBitfield(); else if (MATCHES(sequence(KEYWORD_FUNCTION, IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) - statement = parseFunctionDefintion(); + statement = parseFunctionDefinition(); else if (MATCHES(sequence(KEYWORD_NAMESPACE))) return parseNamespace(); else throwParseError("invalid sequence", 0); @@ -1018,7 +1052,7 @@ namespace hex::pl { while (MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION))); if (auto typeDecl = dynamic_cast(statement); typeDecl != nullptr) { - auto typeName = getNamespacePrefixedName(typeDecl->getName().data()); + auto typeName = getNamespacePrefixedName(typeDecl->getName()); if (this->m_types.contains(typeName)) throwParseError(hex::format("redefinition of type '{}'", typeName)); diff --git a/plugins/libimhex/source/pattern_language/pattern_language.cpp b/plugins/libimhex/source/pattern_language/pattern_language.cpp index b1c8f01c8..4f47c7533 100644 --- a/plugins/libimhex/source/pattern_language/pattern_language.cpp +++ b/plugins/libimhex/source/pattern_language/pattern_language.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -61,7 +62,6 @@ namespace hex::pl { delete this->m_lexer; delete this->m_parser; delete this->m_validator; - delete this->m_evaluator; } @@ -70,6 +70,10 @@ namespace hex::pl { this->m_evaluator->getConsole().clear(); this->m_evaluator->setProvider(provider); + for (auto &node : this->m_currAST) + delete node; + this->m_currAST.clear(); + auto preprocessedCode = this->m_preprocessor->preprocess(string); if (!preprocessedCode.has_value()) { this->m_currError = this->m_preprocessor->getError(); @@ -77,7 +81,7 @@ namespace hex::pl { } this->m_evaluator->setDefaultEndian(this->m_defaultEndian); - this->m_evaluator->setRecursionLimit(this->m_recursionLimit); + // this->m_evaluator->setRecursionLimit(this->m_recursionLimit); auto tokens = this->m_lexer->lex(preprocessedCode.value()); if (!tokens.has_value()) { @@ -91,22 +95,15 @@ namespace hex::pl { return { }; } - ON_SCOPE_EXIT { - for(auto &node : ast.value()) - delete node; - }; + this->m_currAST = ast.value(); - auto validatorResult = this->m_validator->validate(ast.value()); - if (!validatorResult) { - this->m_currError = this->m_validator->getError(); + auto patterns = this->m_evaluator->evaluate(ast.value()); + if (!patterns.has_value()) { + this->m_currError = this->m_evaluator->getConsole().getLastHardError(); return { }; } - auto patternData = this->m_evaluator->evaluate(ast.value()); - if (!patternData.has_value()) - return { }; - - return patternData.value(); + return patterns; } std::optional> PatternLanguage::executeFile(prv::Provider *provider, const std::string &path) { diff --git a/plugins/libimhex/source/pattern_language/preprocessor.cpp b/plugins/libimhex/source/pattern_language/preprocessor.cpp index 982217539..d40d29c01 100644 --- a/plugins/libimhex/source/pattern_language/preprocessor.cpp +++ b/plugins/libimhex/source/pattern_language/preprocessor.cpp @@ -25,8 +25,9 @@ namespace hex::pl { output.reserve(code.length()); try { + bool startOfLine = true; while (offset < code.length()) { - if (code[offset] == '#') { + if (code[offset] == '#' && startOfLine) { offset += 1; if (code.substr(offset, 7) == "include") { @@ -164,8 +165,11 @@ namespace hex::pl { throwPreprocessorError("unterminated comment", lineNumber - 1); } - if (code[offset] == '\n') + if (code[offset] == '\n') { lineNumber++; + startOfLine = true; + } else if (!std::isspace(code[offset])) + startOfLine = false; output += code[offset]; offset += 1; diff --git a/source/views/view_pattern_editor.cpp b/source/views/view_pattern_editor.cpp index 1347265d3..dd82a258f 100644 --- a/source/views/view_pattern_editor.cpp +++ b/source/views/view_pattern_editor.cpp @@ -23,21 +23,21 @@ namespace hex { static TextEditor::LanguageDefinition langDef; if (!initialized) { static const char* const keywords[] = { - "using", "struct", "union", "enum", "bitfield", "be", "le", "if", "else", "false", "true", "parent", "addressof", "sizeof", "$", "while", "fn", "return", "namespace" + "using", "struct", "union", "enum", "bitfield", "be", "le", "if", "else", "false", "true", "this", "parent", "addressof", "sizeof", "$", "while", "fn", "return", "namespace" }; for (auto& k : keywords) langDef.mKeywords.insert(k); - static std::pair builtInTypes[] = { - { "u8", 1 }, { "u16", 2 }, { "u32", 4 }, { "u64", 8 }, { "u128", 16 }, - { "s8", 1 }, { "s16", 2 }, { "s32", 4 }, { "s64", 8 }, { "s128", 16 }, - { "float", 4 }, { "double", 8 }, { "char", 1 }, { "char16", 2 }, { "bool", 1 }, { "padding", 1 } + static const char* const builtInTypes[] = { + "u8", "u16", "u32", "u64", "u128", + "s8", "s16", "s32", "s64", "s128", + "float", "double", "char", "char16", + "bool", "padding", "str" }; - for (const auto &[name, size] : builtInTypes) { + for (const auto name : builtInTypes) { TextEditor::Identifier id; - id.mDeclaration = std::to_string(size); - id.mDeclaration += size == 1 ? " byte" : " bytes"; + id.mDeclaration = "Built-in type"; langDef.mIdentifiers.insert(std::make_pair(std::string(name), id)); } diff --git a/tests/include/test_patterns/test_pattern_enums.hpp b/tests/include/test_patterns/test_pattern_enums.hpp index fdc79b75b..ed0524fa2 100644 --- a/tests/include/test_patterns/test_pattern_enums.hpp +++ b/tests/include/test_patterns/test_pattern_enums.hpp @@ -9,10 +9,10 @@ namespace hex::test { TestPatternEnums() : TestPattern("Enums"){ auto testEnum = create("TestEnum", "testEnum", 0x120, sizeof(u32)); testEnum->setEnumValues({ - { s32(0x0000), "A" }, - { s32(0x1234), "B" }, - { s32(0x1235), "C" }, - { s32(0x1236), "D" }, + { u128(0x0000), "A" }, + { s128(0x1234), "B" }, + { u128(0x1235), "C" }, + { u128(0x1236), "D" }, }); addPattern(testEnum); diff --git a/tests/include/test_patterns/test_pattern_literals.hpp b/tests/include/test_patterns/test_pattern_literals.hpp index a1365b01e..3a51fc47b 100644 --- a/tests/include/test_patterns/test_pattern_literals.hpp +++ b/tests/include/test_patterns/test_pattern_literals.hpp @@ -19,7 +19,7 @@ namespace hex::test { std::assert(255 == 0xFF, MSG); std::assert(0xAA == 0b10101010, MSG); std::assert(12345 != 67890, MSG); - std::assert(100ULL == 0x64ULL, MSG); + std::assert(100U == 0x64U, MSG); std::assert(-100 == -0x64, MSG); std::assert(3.14159F > 1.414D, MSG); std::assert('A' == 0x41, MSG); diff --git a/tests/include/test_patterns/test_pattern_math.hpp b/tests/include/test_patterns/test_pattern_math.hpp index b2bf5d4ea..0d277822a 100644 --- a/tests/include/test_patterns/test_pattern_math.hpp +++ b/tests/include/test_patterns/test_pattern_math.hpp @@ -33,7 +33,7 @@ namespace hex::test { std::assert(0xFF00FF | 0x00AA00 == 0xFFAAFF, "| operator error"); std::assert(0xFFFFFF & 0x00FF00 == 0x00FF00, "& operator error"); std::assert(0xFFFFFF ^ 0x00AA00 == 0xFF55FF, "^ operator error"); - std::assert(~0xFFFFFFFF == 0x00, "~ operator error"); + std::assert(~0x00 == 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, "~ operator error"); std::assert(0xAA >> 4 == 0x0A, ">> operator error"); std::assert(0xAA << 4 == 0xAA0, "<< operator error"); @@ -46,7 +46,7 @@ namespace hex::test { // Special operators std::assert($ == 0, "$ operator error"); - std::assert((10 == 20) ? 30 : 40 == 40, "?: operator error"); + std::assert(((10 == 20) ? 30 : 40) == 40, "?: operator error"); // Type operators struct TypeTest { u32 x, y, z; }; diff --git a/tests/include/test_patterns/test_pattern_padding.hpp b/tests/include/test_patterns/test_pattern_padding.hpp index 6f3b81681..ba07877a5 100644 --- a/tests/include/test_patterns/test_pattern_padding.hpp +++ b/tests/include/test_patterns/test_pattern_padding.hpp @@ -10,7 +10,7 @@ namespace hex::test { auto testStruct = create("TestStruct", "testStruct", 0x100, sizeof(s32) + 20 + sizeof(u8[0x10])); auto variable = create("s32", "variable", 0x100, sizeof(s32)); - auto padding = create("", "", 0x100 + sizeof(s32), 20); + auto padding = create("padding", "", 0x100 + sizeof(s32), 20); auto array = create("u8", "array", 0x100 + sizeof(s32) + 20, sizeof(u8[0x10])); array->setEntries(create("u8", "", 0x100 + sizeof(s32) + 20, sizeof(u8)), 0x10); diff --git a/tests/source/main.cpp b/tests/source/main.cpp index 4962912a7..e64640089 100644 --- a/tests/source/main.cpp +++ b/tests/source/main.cpp @@ -17,14 +17,14 @@ using namespace hex::test; void addFunctions() { hex::ContentRegistry::PatternLanguageFunctions::Namespace nsStd = { "std" }; - hex::ContentRegistry::PatternLanguageFunctions::add(nsStd, "assert", 2, [](auto &ctx, auto params) { - auto condition = AS_TYPE(hex::pl::ASTNodeIntegerLiteral, params[0])->getValue(); - auto message = AS_TYPE(hex::pl::ASTNodeStringLiteral, params[1])->getString(); + hex::ContentRegistry::PatternLanguageFunctions::add(nsStd, "assert", 2, [](Evaluator *ctx, auto params) -> Token::Literal { + auto condition = Token::literalToBoolean(params[0]); + auto message = Token::literalToString(params[1], false); - if (LITERAL_COMPARE(condition, condition == 0)) - ctx.getConsole().abortEvaluation(hex::format("assertion failed \"{0}\"", message.data())); + if (!condition) + LogConsole::abortEvaluation(hex::format("assertion failed \"{0}\"", message)); - return nullptr; + return { }; }); } @@ -67,7 +67,7 @@ int test(int argc, char **argv) { hex::log::fatal("Error during compilation!"); if (auto error = language.getError(); error.has_value()) - hex::log::info("Compile error: {}:{}", error->first, error->second); + hex::log::info("Compile error: {} : {}", error->first, error->second); else for (auto &[level, message] : language.getConsoleLog()) hex::log::info("Evaluate error: {}", message); @@ -93,10 +93,10 @@ int test(int argc, char **argv) { // Check if the produced patterns are the ones expected for (u32 i = 0; i < currTest->getPatterns().size(); i++) { - auto &left = *patterns->at(i); - auto &right = *currTest->getPatterns().at(i); + auto &evaluatedPattern = *patterns->at(i); + auto &controlPattern = *currTest->getPatterns().at(i); - if (left != right) { + if (evaluatedPattern != controlPattern) { hex::log::fatal("Pattern with name {}:{} didn't match template", patterns->at(i)->getTypeName(), patterns->at(i)->getVariableName()); return EXIT_FAILURE; }