diff --git a/lib/libimhex/include/hex/pattern_language/ast/ast_node_type_decl.hpp b/lib/libimhex/include/hex/pattern_language/ast/ast_node_type_decl.hpp index 656460601..9a7c4a486 100644 --- a/lib/libimhex/include/hex/pattern_language/ast/ast_node_type_decl.hpp +++ b/lib/libimhex/include/hex/pattern_language/ast/ast_node_type_decl.hpp @@ -8,6 +8,8 @@ namespace hex::pl { class ASTNodeTypeDecl : public ASTNode, public Attributable { public: + ASTNodeTypeDecl(std::string name) : m_forwardDeclared(true), m_name(name) { } + ASTNodeTypeDecl(std::string name, std::shared_ptr type, std::optional endian = std::nullopt) : ASTNode(), m_name(std::move(name)), m_type(std::move(type)), m_endian(endian) { } @@ -69,7 +71,18 @@ namespace hex::pl { Attributable::addAttribute(std::move(attribute)); } + [[nodiscard]] + bool isForwardDeclared() const { + return this->m_forwardDeclared; + } + + void setType(std::shared_ptr type) { + this->m_forwardDeclared = false; + this->m_type = type; + } + private: + bool m_forwardDeclared = false; std::string m_name; std::shared_ptr m_type; std::optional m_endian; diff --git a/lib/libimhex/include/hex/pattern_language/parser.hpp b/lib/libimhex/include/hex/pattern_language/parser.hpp index 16119bd90..e4366cec2 100644 --- a/lib/libimhex/include/hex/pattern_language/parser.hpp +++ b/lib/libimhex/include/hex/pattern_language/parser.hpp @@ -34,7 +34,7 @@ namespace hex::pl { TokenIter m_curr; TokenIter m_originalPosition, m_partOriginalPosition; - std::unordered_map> m_types; + std::unordered_map> m_types; std::vector m_matchedOptionals; std::vector> m_currNamespace; @@ -121,6 +121,7 @@ namespace hex::pl { std::shared_ptr parseUnion(); std::shared_ptr parseEnum(); std::shared_ptr parseBitfield(); + void parseForwardDeclaration(); std::unique_ptr parseVariablePlacement(const std::shared_ptr &type); std::unique_ptr parseArrayVariablePlacement(const std::shared_ptr &type); std::unique_ptr parsePointerVariablePlacement(const std::shared_ptr &type); diff --git a/lib/libimhex/source/pattern_language/parser.cpp b/lib/libimhex/source/pattern_language/parser.cpp index d477351b5..99413737e 100644 --- a/lib/libimhex/source/pattern_language/parser.cpp +++ b/lib/libimhex/source/pattern_language/parser.cpp @@ -706,10 +706,7 @@ namespace hex::pl { // using Identifier = (parseType) std::shared_ptr Parser::parseUsingDeclaration() { - auto name = parseNamespaceResolution(); - - if (!MATCHES(sequence(OPERATOR_ASSIGNMENT))) - throwParserError("expected '=' after type name of using declaration"); + auto name = getNamespacePrefixedName(getValue(-2).get()); auto type = parseType(); @@ -976,6 +973,16 @@ namespace hex::pl { return typeDecl; } + // using Identifier; + void Parser::parseForwardDeclaration() { + std::string typeName = getNamespacePrefixedName(getValue(-1).get()); + + if (this->m_types.contains(typeName)) + return; + + this->m_types.insert({ typeName, create(new ASTNodeTypeDecl(typeName) )}); + } + // (parseType) Identifier @ Integer std::unique_ptr Parser::parseVariablePlacement(const std::shared_ptr &type) { bool inVariable = false; @@ -1088,8 +1095,10 @@ namespace hex::pl { std::vector> Parser::parseStatements() { std::shared_ptr statement; - if (MATCHES(sequence(KEYWORD_USING, IDENTIFIER))) + if (MATCHES(sequence(KEYWORD_USING, IDENTIFIER, OPERATOR_ASSIGNMENT))) statement = parseUsingDeclaration(); + else if (MATCHES(sequence(KEYWORD_USING, IDENTIFIER))) + parseForwardDeclaration(); else if (peek(IDENTIFIER)) { auto originalPos = this->m_curr; this->m_curr++; @@ -1118,7 +1127,7 @@ namespace hex::pl { return parseNamespace(); else throwParserError("invalid sequence", 0); - if (MATCHES(sequence(SEPARATOR_SQUAREBRACKETOPEN, SEPARATOR_SQUAREBRACKETOPEN))) + if (statement && MATCHES(sequence(SEPARATOR_SQUAREBRACKETOPEN, SEPARATOR_SQUAREBRACKETOPEN))) parseAttribute(dynamic_cast(statement.get())); if (!MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION))) @@ -1128,19 +1137,28 @@ namespace hex::pl { while (MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION))) ; + if (!statement) + return { }; + return hex::moveToVector(std::move(statement)); } std::shared_ptr Parser::addType(const std::string &name, std::unique_ptr &&node, std::optional endian) { auto typeName = getNamespacePrefixedName(name); - if (this->m_types.contains(typeName)) - throwParserError(hex::format("redefinition of type '{}'", typeName)); + if (this->m_types.contains(typeName) && this->m_types.at(typeName)->isForwardDeclared()) { + this->m_types.at(typeName)->setType(std::move(node)); - std::shared_ptr typeDecl = create(new ASTNodeTypeDecl(typeName, std::move(node), endian)); - this->m_types.insert({ typeName, typeDecl }); + return this->m_types.at(typeName); + } else { + if (this->m_types.contains(typeName)) + throwParserError(hex::format("redefinition of type '{}'", typeName)); - return typeDecl; + std::shared_ptr typeDecl = create(new ASTNodeTypeDecl(typeName, std::move(node), endian)); + this->m_types.insert({ typeName, typeDecl }); + + return typeDecl; + } } // <(parseNamespace)...> EndOfProgram diff --git a/lib/libimhex/source/pattern_language/validator.cpp b/lib/libimhex/source/pattern_language/validator.cpp index 3a07eb8bd..0f6d71d94 100644 --- a/lib/libimhex/source/pattern_language/validator.cpp +++ b/lib/libimhex/source/pattern_language/validator.cpp @@ -33,7 +33,8 @@ namespace hex::pl { if (!types.insert(typeDeclNode->getName().data()).second) throwValidatorError(hex::format("redefinition of type '{0}'", typeDeclNode->getName().data()), typeDeclNode->getLineNumber()); - this->validate(hex::moveToVector>(typeDeclNode->getType()->clone())); + if (!typeDeclNode->isForwardDeclared()) + this->validate(hex::moveToVector>(typeDeclNode->getType()->clone())); } else if (auto structNode = dynamic_cast(node.get()); structNode != nullptr) { this->validate(structNode->getMembers()); } else if (auto unionNode = dynamic_cast(node.get()); unionNode != nullptr) {