1
0
mirror of synced 2024-12-01 02:37:18 +01:00

patterns: Added inheritance for structs

This commit is contained in:
WerWolv 2021-09-24 00:47:34 +02:00
parent 6713f65040
commit 2edd6cd6c4
2 changed files with 46 additions and 4 deletions

View File

@ -9,6 +9,7 @@
#include <map> #include <map>
#include <variant> #include <variant>
#include <vector> #include <vector>
#include <ranges>
#include <hex/pattern_language/ast_node_base.hpp> #include <hex/pattern_language/ast_node_base.hpp>
@ -315,7 +316,7 @@ namespace hex::pl {
else if (Token::isFloatingPoint(this->m_type)) else if (Token::isFloatingPoint(this->m_type))
pattern = new PatternDataFloat(offset, size); pattern = new PatternDataFloat(offset, size);
else if (this->m_type == Token::ValueType::Boolean) else if (this->m_type == Token::ValueType::Boolean)
pattern = new PatternDataBoolean(offset, size); pattern = new PatternDataBoolean(offset);
else if (this->m_type == Token::ValueType::Character) else if (this->m_type == Token::ValueType::Character)
pattern = new PatternDataCharacter(offset); pattern = new PatternDataCharacter(offset);
else if (this->m_type == Token::ValueType::Character16) else if (this->m_type == Token::ValueType::Character16)
@ -1015,11 +1016,15 @@ namespace hex::pl {
ASTNodeStruct(const ASTNodeStruct &other) : ASTNode(other), Attributable(other) { ASTNodeStruct(const ASTNodeStruct &other) : ASTNode(other), Attributable(other) {
for (const auto &otherMember : other.getMembers()) for (const auto &otherMember : other.getMembers())
this->m_members.push_back(otherMember->clone()); this->m_members.push_back(otherMember->clone());
for (const auto &otherInheritance : other.getInheritance())
this->m_inheritance.push_back(otherInheritance->clone());
} }
~ASTNodeStruct() override { ~ASTNodeStruct() override {
for (auto &member : this->m_members) for (auto &member : this->m_members)
delete member; delete member;
for (auto &inheritance : this->m_inheritance)
delete inheritance;
} }
[[nodiscard]] ASTNode* clone() const override { [[nodiscard]] ASTNode* clone() const override {
@ -1033,11 +1038,26 @@ namespace hex::pl {
std::vector<PatternData*> memberPatterns; std::vector<PatternData*> memberPatterns;
evaluator->pushScope(pattern, memberPatterns); evaluator->pushScope(pattern, memberPatterns);
for (auto inheritance : this->m_inheritance) {
auto inheritancePatterns = inheritance->createPatterns(evaluator).front();
ON_SCOPE_EXIT {
delete inheritancePatterns;
};
if (auto structPattern = dynamic_cast<PatternDataStruct*>(inheritancePatterns)) {
for (auto member : structPattern->getMembers()) {
memberPatterns.push_back(member->clone());
}
}
}
for (auto member : this->m_members) { for (auto member : this->m_members) {
for (auto &memberPattern : member->createPatterns(evaluator)) { for (auto &memberPattern : member->createPatterns(evaluator)) {
memberPatterns.push_back(memberPattern); memberPatterns.push_back(memberPattern);
} }
} }
evaluator->popScope(); evaluator->popScope();
pattern->setMembers(memberPatterns); pattern->setMembers(memberPatterns);
@ -1049,8 +1069,12 @@ namespace hex::pl {
[[nodiscard]] const std::vector<ASTNode*>& getMembers() const { return this->m_members; } [[nodiscard]] const std::vector<ASTNode*>& getMembers() const { return this->m_members; }
void addMember(ASTNode *node) { this->m_members.push_back(node); } void addMember(ASTNode *node) { this->m_members.push_back(node); }
[[nodiscard]] const std::vector<ASTNode*>& getInheritance() const { return this->m_inheritance; }
void addInheritance(ASTNode *node) { this->m_inheritance.push_back(node); }
private: private:
std::vector<ASTNode*> m_members; std::vector<ASTNode*> m_members;
std::vector<ASTNode*> m_inheritance;
}; };
class ASTNodeUnion : public ASTNode, public Attributable { class ASTNodeUnion : public ASTNode, public Attributable {
@ -1342,7 +1366,7 @@ namespace hex::pl {
continue; continue;
} else { } else {
bool found = false; bool found = false;
for (const auto &variable : searchScope) { for (const auto &variable : searchScope | std::views::reverse) {
if (variable->getVariableName() == name) { if (variable->getVariableName() == name) {
auto newPattern = variable->clone(); auto newPattern = variable->clone();
delete currPattern; delete currPattern;

View File

@ -797,9 +797,27 @@ namespace hex::pl {
// struct Identifier { <(parseMember)...> } // struct Identifier { <(parseMember)...> }
ASTNode* Parser::parseStruct() { ASTNode* Parser::parseStruct() {
const auto structNode = create(new ASTNodeStruct()); const auto structNode = create(new ASTNodeStruct());
const auto &typeName = getValue<Token::Identifier>(-2).get(); const auto &typeName = getValue<Token::Identifier>(-1).get();
auto structGuard = SCOPE_GUARD { delete structNode; }; auto structGuard = SCOPE_GUARD { delete structNode; };
if (MATCHES(sequence(OPERATOR_INHERIT, IDENTIFIER))) {
// Inheritance
do {
auto inheritedTypeName = getValue<Token::Identifier>(-1).get();
if (!this->m_types.contains(inheritedTypeName))
throwParseError(hex::format("cannot inherit from unknown type '{}'", inheritedTypeName), -1);
structNode->addInheritance(this->m_types[inheritedTypeName]->clone());
} while (MATCHES(sequence(SEPARATOR_COMMA, IDENTIFIER)));
} else if (MATCHES(sequence(OPERATOR_INHERIT, VALUETYPE_ANY))) {
throwParseError("cannot inherit from builtin type");
}
if (!MATCHES(sequence(SEPARATOR_CURLYBRACKETOPEN)))
throwParseError("expected '{' after struct definition", -1);
while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) {
structNode->addMember(parseMember()); structNode->addMember(parseMember());
} }
@ -1032,7 +1050,7 @@ namespace hex::pl {
} }
else if (peek(KEYWORD_BE) || peek(KEYWORD_LE) || peek(VALUETYPE_ANY)) else if (peek(KEYWORD_BE) || peek(KEYWORD_LE) || peek(VALUETYPE_ANY))
statement = parsePlacement(); statement = parsePlacement();
else if (MATCHES(sequence(KEYWORD_STRUCT, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN))) else if (MATCHES(sequence(KEYWORD_STRUCT, IDENTIFIER)))
statement = parseStruct(); statement = parseStruct();
else if (MATCHES(sequence(KEYWORD_UNION, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN))) else if (MATCHES(sequence(KEYWORD_UNION, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN)))
statement = parseUnion(); statement = parseUnion();