1
0
mirror of synced 2024-12-01 02:37:18 +01:00

patterns: Added inheritance for structs

This commit is contained in:
WerWolv 2021-09-24 00:47:34 +02:00
parent 6713f65040
commit 2edd6cd6c4
2 changed files with 46 additions and 4 deletions

View File

@ -9,6 +9,7 @@
#include <map>
#include <variant>
#include <vector>
#include <ranges>
#include <hex/pattern_language/ast_node_base.hpp>
@ -315,7 +316,7 @@ namespace hex::pl {
else if (Token::isFloatingPoint(this->m_type))
pattern = new PatternDataFloat(offset, size);
else if (this->m_type == Token::ValueType::Boolean)
pattern = new PatternDataBoolean(offset, size);
pattern = new PatternDataBoolean(offset);
else if (this->m_type == Token::ValueType::Character)
pattern = new PatternDataCharacter(offset);
else if (this->m_type == Token::ValueType::Character16)
@ -1015,11 +1016,15 @@ namespace hex::pl {
ASTNodeStruct(const ASTNodeStruct &other) : ASTNode(other), Attributable(other) {
for (const auto &otherMember : other.getMembers())
this->m_members.push_back(otherMember->clone());
for (const auto &otherInheritance : other.getInheritance())
this->m_inheritance.push_back(otherInheritance->clone());
}
~ASTNodeStruct() override {
for (auto &member : this->m_members)
delete member;
for (auto &inheritance : this->m_inheritance)
delete inheritance;
}
[[nodiscard]] ASTNode* clone() const override {
@ -1033,11 +1038,26 @@ namespace hex::pl {
std::vector<PatternData*> memberPatterns;
evaluator->pushScope(pattern, memberPatterns);
for (auto inheritance : this->m_inheritance) {
auto inheritancePatterns = inheritance->createPatterns(evaluator).front();
ON_SCOPE_EXIT {
delete inheritancePatterns;
};
if (auto structPattern = dynamic_cast<PatternDataStruct*>(inheritancePatterns)) {
for (auto member : structPattern->getMembers()) {
memberPatterns.push_back(member->clone());
}
}
}
for (auto member : this->m_members) {
for (auto &memberPattern : member->createPatterns(evaluator)) {
memberPatterns.push_back(memberPattern);
}
}
evaluator->popScope();
pattern->setMembers(memberPatterns);
@ -1049,8 +1069,12 @@ namespace hex::pl {
[[nodiscard]] const std::vector<ASTNode*>& getMembers() const { return this->m_members; }
void addMember(ASTNode *node) { this->m_members.push_back(node); }
[[nodiscard]] const std::vector<ASTNode*>& getInheritance() const { return this->m_inheritance; }
void addInheritance(ASTNode *node) { this->m_inheritance.push_back(node); }
private:
std::vector<ASTNode*> m_members;
std::vector<ASTNode*> m_inheritance;
};
class ASTNodeUnion : public ASTNode, public Attributable {
@ -1342,7 +1366,7 @@ namespace hex::pl {
continue;
} else {
bool found = false;
for (const auto &variable : searchScope) {
for (const auto &variable : searchScope | std::views::reverse) {
if (variable->getVariableName() == name) {
auto newPattern = variable->clone();
delete currPattern;

View File

@ -797,9 +797,27 @@ namespace hex::pl {
// struct Identifier { <(parseMember)...> }
ASTNode* Parser::parseStruct() {
const auto structNode = create(new ASTNodeStruct());
const auto &typeName = getValue<Token::Identifier>(-2).get();
const auto &typeName = getValue<Token::Identifier>(-1).get();
auto structGuard = SCOPE_GUARD { delete structNode; };
if (MATCHES(sequence(OPERATOR_INHERIT, IDENTIFIER))) {
// Inheritance
do {
auto inheritedTypeName = getValue<Token::Identifier>(-1).get();
if (!this->m_types.contains(inheritedTypeName))
throwParseError(hex::format("cannot inherit from unknown type '{}'", inheritedTypeName), -1);
structNode->addInheritance(this->m_types[inheritedTypeName]->clone());
} while (MATCHES(sequence(SEPARATOR_COMMA, IDENTIFIER)));
} else if (MATCHES(sequence(OPERATOR_INHERIT, VALUETYPE_ANY))) {
throwParseError("cannot inherit from builtin type");
}
if (!MATCHES(sequence(SEPARATOR_CURLYBRACKETOPEN)))
throwParseError("expected '{' after struct definition", -1);
while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) {
structNode->addMember(parseMember());
}
@ -1032,7 +1050,7 @@ namespace hex::pl {
}
else if (peek(KEYWORD_BE) || peek(KEYWORD_LE) || peek(VALUETYPE_ANY))
statement = parsePlacement();
else if (MATCHES(sequence(KEYWORD_STRUCT, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN)))
else if (MATCHES(sequence(KEYWORD_STRUCT, IDENTIFIER)))
statement = parseStruct();
else if (MATCHES(sequence(KEYWORD_UNION, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN)))
statement = parseUnion();