diff --git a/plugins/libimhex/include/hex/pattern_language/ast_node.hpp b/plugins/libimhex/include/hex/pattern_language/ast_node.hpp index df6134548..4f5a1dc3e 100644 --- a/plugins/libimhex/include/hex/pattern_language/ast_node.hpp +++ b/plugins/libimhex/include/hex/pattern_language/ast_node.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include @@ -315,7 +316,7 @@ namespace hex::pl { else if (Token::isFloatingPoint(this->m_type)) pattern = new PatternDataFloat(offset, size); else if (this->m_type == Token::ValueType::Boolean) - pattern = new PatternDataBoolean(offset, size); + pattern = new PatternDataBoolean(offset); else if (this->m_type == Token::ValueType::Character) pattern = new PatternDataCharacter(offset); else if (this->m_type == Token::ValueType::Character16) @@ -1015,11 +1016,15 @@ namespace hex::pl { ASTNodeStruct(const ASTNodeStruct &other) : ASTNode(other), Attributable(other) { for (const auto &otherMember : other.getMembers()) this->m_members.push_back(otherMember->clone()); + for (const auto &otherInheritance : other.getInheritance()) + this->m_inheritance.push_back(otherInheritance->clone()); } ~ASTNodeStruct() override { for (auto &member : this->m_members) delete member; + for (auto &inheritance : this->m_inheritance) + delete inheritance; } [[nodiscard]] ASTNode* clone() const override { @@ -1033,11 +1038,26 @@ namespace hex::pl { std::vector memberPatterns; evaluator->pushScope(pattern, memberPatterns); + + for (auto inheritance : this->m_inheritance) { + auto inheritancePatterns = inheritance->createPatterns(evaluator).front(); + ON_SCOPE_EXIT { + delete inheritancePatterns; + }; + + if (auto structPattern = dynamic_cast(inheritancePatterns)) { + for (auto member : structPattern->getMembers()) { + memberPatterns.push_back(member->clone()); + } + } + } + for (auto member : this->m_members) { for (auto &memberPattern : member->createPatterns(evaluator)) { memberPatterns.push_back(memberPattern); } } + evaluator->popScope(); pattern->setMembers(memberPatterns); @@ -1049,8 +1069,12 @@ namespace hex::pl { [[nodiscard]] const std::vector& getMembers() const { return this->m_members; } void addMember(ASTNode *node) { this->m_members.push_back(node); } + [[nodiscard]] const std::vector& getInheritance() const { return this->m_inheritance; } + void addInheritance(ASTNode *node) { this->m_inheritance.push_back(node); } + private: std::vector m_members; + std::vector m_inheritance; }; class ASTNodeUnion : public ASTNode, public Attributable { @@ -1342,7 +1366,7 @@ namespace hex::pl { continue; } else { bool found = false; - for (const auto &variable : searchScope) { + for (const auto &variable : searchScope | std::views::reverse) { if (variable->getVariableName() == name) { auto newPattern = variable->clone(); delete currPattern; diff --git a/plugins/libimhex/source/pattern_language/parser.cpp b/plugins/libimhex/source/pattern_language/parser.cpp index 19025bc36..c55c0cb58 100644 --- a/plugins/libimhex/source/pattern_language/parser.cpp +++ b/plugins/libimhex/source/pattern_language/parser.cpp @@ -797,9 +797,27 @@ namespace hex::pl { // struct Identifier { <(parseMember)...> } ASTNode* Parser::parseStruct() { const auto structNode = create(new ASTNodeStruct()); - const auto &typeName = getValue(-2).get(); + const auto &typeName = getValue(-1).get(); auto structGuard = SCOPE_GUARD { delete structNode; }; + if (MATCHES(sequence(OPERATOR_INHERIT, IDENTIFIER))) { + // Inheritance + + do { + auto inheritedTypeName = getValue(-1).get(); + if (!this->m_types.contains(inheritedTypeName)) + throwParseError(hex::format("cannot inherit from unknown type '{}'", inheritedTypeName), -1); + + structNode->addInheritance(this->m_types[inheritedTypeName]->clone()); + } while (MATCHES(sequence(SEPARATOR_COMMA, IDENTIFIER))); + + } else if (MATCHES(sequence(OPERATOR_INHERIT, VALUETYPE_ANY))) { + throwParseError("cannot inherit from builtin type"); + } + + if (!MATCHES(sequence(SEPARATOR_CURLYBRACKETOPEN))) + throwParseError("expected '{' after struct definition", -1); + while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { structNode->addMember(parseMember()); } @@ -1032,7 +1050,7 @@ namespace hex::pl { } else if (peek(KEYWORD_BE) || peek(KEYWORD_LE) || peek(VALUETYPE_ANY)) statement = parsePlacement(); - else if (MATCHES(sequence(KEYWORD_STRUCT, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN))) + else if (MATCHES(sequence(KEYWORD_STRUCT, IDENTIFIER))) statement = parseStruct(); else if (MATCHES(sequence(KEYWORD_UNION, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN))) statement = parseUnion();