From 21f8fb40902609292b8e400b6b814fad243a5bc4 Mon Sep 17 00:00:00 2001 From: WerWolv Date: Thu, 17 Jun 2021 23:13:58 +0200 Subject: [PATCH] patterns: Added while statement for array sizing --- .../source/content/lang_builtin_functions.cpp | 38 +++--- .../libimhex/include/hex/helpers/utils.hpp | 3 + .../libimhex/include/hex/lang/ast_node.hpp | 24 ++++ plugins/libimhex/include/hex/lang/parser.hpp | 1 + plugins/libimhex/include/hex/lang/token.hpp | 4 +- plugins/libimhex/source/lang/evaluator.cpp | 124 +++++++++++------- plugins/libimhex/source/lang/lexer.cpp | 6 +- plugins/libimhex/source/lang/parser.cpp | 30 ++++- source/views/view_pattern.cpp | 2 +- 9 files changed, 156 insertions(+), 76 deletions(-) diff --git a/plugins/builtin/source/content/lang_builtin_functions.cpp b/plugins/builtin/source/content/lang_builtin_functions.cpp index d2fcecb3f..1f5981ad7 100644 --- a/plugins/builtin/source/content/lang_builtin_functions.cpp +++ b/plugins/builtin/source/content/lang_builtin_functions.cpp @@ -126,23 +126,27 @@ namespace hex::plugin::builtin { std::string message; for (auto& param : params) { if (auto integerLiteral = dynamic_cast(param); integerLiteral != nullptr) { - switch (integerLiteral->getType()) { - case Token::ValueType::Character: message += std::get(integerLiteral->getValue()); break; - case Token::ValueType::Unsigned8Bit: message += std::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Signed8Bit: message += std::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Unsigned16Bit: message += std::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Signed16Bit: message += std::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Unsigned32Bit: message += std::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Signed32Bit: message += std::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Unsigned64Bit: message += std::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Signed64Bit: message += std::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Unsigned128Bit: message += hex::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Signed128Bit: message += hex::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Float: message += std::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Double: message += std::to_string(std::get(integerLiteral->getValue())); break; - case Token::ValueType::Boolean: message += std::get(integerLiteral->getValue()) ? "true" : "false"; break; - case Token::ValueType::CustomType: message += "< Custom Type >"; break; - } + std::visit([&](auto &&value) { + switch (integerLiteral->getType()) { + case lang::Token::ValueType::Character: message += (char)value; break; + case lang::Token::ValueType::Boolean: message += value == 0 ? "false" : "true"; break; + case lang::Token::ValueType::Unsigned8Bit: + case lang::Token::ValueType::Unsigned16Bit: + case lang::Token::ValueType::Unsigned32Bit: + case lang::Token::ValueType::Unsigned64Bit: + case lang::Token::ValueType::Unsigned128Bit: + message += std::to_string(static_cast(value)); + break; + case lang::Token::ValueType::Signed8Bit: + case lang::Token::ValueType::Signed16Bit: + case lang::Token::ValueType::Signed32Bit: + case lang::Token::ValueType::Signed64Bit: + case lang::Token::ValueType::Signed128Bit: + message += std::to_string(static_cast(value)); + break; + default: message += "< Custom Type >"; + } + }, integerLiteral->getValue()); } else if (auto stringLiteral = dynamic_cast(param); stringLiteral != nullptr) message += stringLiteral->getString(); diff --git a/plugins/libimhex/include/hex/helpers/utils.hpp b/plugins/libimhex/include/hex/helpers/utils.hpp index 5bc606f3d..ec151b348 100644 --- a/plugins/libimhex/include/hex/helpers/utils.hpp +++ b/plugins/libimhex/include/hex/helpers/utils.hpp @@ -108,6 +108,9 @@ namespace hex { template struct always_false : std::false_type {}; + template struct overloaded : Ts... { using Ts::operator()...; }; + template overloaded(Ts...) -> overloaded; + template constexpr T changeEndianess(T value, std::endian endian) { if (endian == std::endian::native) diff --git a/plugins/libimhex/include/hex/lang/ast_node.hpp b/plugins/libimhex/include/hex/lang/ast_node.hpp index 79f62e381..d3fa72e5c 100644 --- a/plugins/libimhex/include/hex/lang/ast_node.hpp +++ b/plugins/libimhex/include/hex/lang/ast_node.hpp @@ -482,6 +482,30 @@ namespace hex::lang { std::vector m_trueBody, m_falseBody; }; + class ASTNodeWhileStatement : public ASTNode { + public: + explicit ASTNodeWhileStatement(ASTNode *condition) : ASTNode(), m_condition(condition) { } + + ~ASTNodeWhileStatement() override { + delete this->m_condition; + } + + ASTNodeWhileStatement(const ASTNodeWhileStatement &other) : ASTNode(other) { + this->m_condition = other.m_condition->clone(); + } + + [[nodiscard]] ASTNode* clone() const override { + return new ASTNodeWhileStatement(*this); + } + + [[nodiscard]] ASTNode* getCondition() { + return this->m_condition; + } + + private: + ASTNode *m_condition; + }; + class ASTNodeFunctionCall : public ASTNode { public: explicit ASTNodeFunctionCall(std::string_view functionName, std::vector params) diff --git a/plugins/libimhex/include/hex/lang/parser.hpp b/plugins/libimhex/include/hex/lang/parser.hpp index b330523cc..3d0678375 100644 --- a/plugins/libimhex/include/hex/lang/parser.hpp +++ b/plugins/libimhex/include/hex/lang/parser.hpp @@ -73,6 +73,7 @@ namespace hex::lang { void parseAttribute(Attributable *currNode); ASTNode* parseConditional(); + ASTNode* parseWhileStatement(); ASTNode* parseType(s32 startIndex); ASTNode* parseUsingDeclaration(); ASTNode* parsePadding(); diff --git a/plugins/libimhex/include/hex/lang/token.hpp b/plugins/libimhex/include/hex/lang/token.hpp index 2680e5176..4fbcb524f 100644 --- a/plugins/libimhex/include/hex/lang/token.hpp +++ b/plugins/libimhex/include/hex/lang/token.hpp @@ -32,7 +32,8 @@ namespace hex::lang { BigEndian, If, Else, - Parent + Parent, + While }; enum class Operator { @@ -225,6 +226,7 @@ namespace hex::lang { #define KEYWORD_IF COMPONENT(Keyword, If) #define KEYWORD_ELSE COMPONENT(Keyword, Else) #define KEYWORD_PARENT COMPONENT(Keyword, Parent) +#define KEYWORD_WHILE COMPONENT(Keyword, While) #define INTEGER hex::lang::Token::Type::Integer, hex::lang::Token::IntegerLiteral(hex::lang::Token::ValueType::Any, u64(0)) #define IDENTIFIER hex::lang::Token::Type::Identifier, "" diff --git a/plugins/libimhex/source/lang/evaluator.cpp b/plugins/libimhex/source/lang/evaluator.cpp index 50d22b299..ab79d0e3c 100644 --- a/plugins/libimhex/source/lang/evaluator.cpp +++ b/plugins/libimhex/source/lang/evaluator.cpp @@ -141,9 +141,6 @@ namespace hex::lang { } ASTNodeIntegerLiteral* Evaluator::evaluateRValue(ASTNodeRValue *node) { - if (this->m_currMembers.empty() && this->m_globalMembers.empty()) - this->getConsole().abortEvaluation("no variables available"); - if (node->getPath().size() == 1) { if (auto part = std::get_if(&node->getPath()[0]); part != nullptr && *part == "$") return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, this->m_currOffset }); @@ -201,7 +198,7 @@ namespace hex::lang { for (auto ¶m : node->getParams()) { if (auto numericExpression = dynamic_cast(param); numericExpression != nullptr) evaluatedParams.push_back(this->evaluateMathematicalExpression(numericExpression)); - else if (auto typeOperatorExpression = dynamic_cast(param)) + else if (auto typeOperatorExpression = dynamic_cast(param); typeOperatorExpression != nullptr) evaluatedParams.push_back(this->evaluateTypeOperator(typeOperatorExpression)); else if (auto stringLiteral = dynamic_cast(param); stringLiteral != nullptr) evaluatedParams.push_back(stringLiteral->clone()); @@ -486,6 +483,8 @@ namespace hex::lang { pattern = new PatternDataSigned(this->m_currOffset, typeSize); else if (Token::isFloatingPoint(type)) pattern = new PatternDataFloat(this->m_currOffset, typeSize); + else if (type == Token::ValueType::Padding) + pattern = new PatternDataPadding(this->m_currOffset, 1); else this->getConsole().abortEvaluation("invalid builtin type"); @@ -734,6 +733,7 @@ namespace hex::lang { PatternData* Evaluator::evaluateArray(ASTNodeArrayVariableDecl *node) { + // Evaluate placement of array if (auto offset = dynamic_cast(node->getPlacementOffset()); offset != nullptr) { auto valueNode = evaluateMathematicalExpression(offset); ON_SCOPE_EXIT { delete valueNode; }; @@ -745,6 +745,7 @@ namespace hex::lang { }, valueNode->getValue()); } + // Check if placed in range of the data if (this->m_currOffset < this->m_provider->getBaseAddress() || this->m_currOffset >= this->m_provider->getActualSize() + this->m_provider->getBaseAddress()) { if (node->getPlacementOffset() != nullptr) this->getConsole().abortEvaluation("variable placed out of range"); @@ -754,59 +755,19 @@ namespace hex::lang { auto startOffset = this->m_currOffset; - ASTNodeIntegerLiteral *valueNode; - u64 arraySize = 0; - - if (node->getSize() != nullptr) { - if (auto sizeNumericExpression = dynamic_cast(node->getSize()); sizeNumericExpression != nullptr) - valueNode = evaluateMathematicalExpression(sizeNumericExpression); - else - this->getConsole().abortEvaluation("array size not a numeric expression"); - - ON_SCOPE_EXIT { delete valueNode; }; - - arraySize = std::visit([this, node, type = valueNode->getType()] (auto &&value) { - if (Token::isFloatingPoint(type)) - this->getConsole().abortEvaluation("array size must be an integer value"); - return static_cast(value); - }, valueNode->getValue()); - - if (auto typeDecl = dynamic_cast(node->getType()); typeDecl != nullptr) { - if (auto builtinType = dynamic_cast(typeDecl->getType()); builtinType != nullptr) { - if (builtinType->getType() == Token::ValueType::Padding) { - this->m_currOffset += arraySize; - return new PatternDataPadding(startOffset, arraySize); - } - } - } - } else { - if (auto typeDecl = dynamic_cast(node->getType()); typeDecl != nullptr) { - if (auto builtinType = dynamic_cast(typeDecl->getType()); builtinType != nullptr) { - std::vector bytes(Token::getTypeSize(builtinType->getType()), 0x00); - u64 offset = startOffset; - - do { - this->m_provider->read(offset, bytes.data(), bytes.size()); - offset += bytes.size(); - arraySize++; - } while (!std::all_of(bytes.begin(), bytes.end(), [](u8 byte){ return byte == 0x00; }) && offset < this->m_provider->getSize()); - } - } - } - std::vector entries; std::optional color; - for (s128 i = 0; i < arraySize; i++) { + + auto addEntry = [this, node, &entries, &color](u64 index) { PatternData *entry; if (auto typeDecl = dynamic_cast(node->getType()); typeDecl != nullptr) entry = this->evaluateType(typeDecl); - else if (auto builtinTypeDecl = dynamic_cast(node->getType()); builtinTypeDecl != nullptr) { + else if (auto builtinTypeDecl = dynamic_cast(node->getType()); builtinTypeDecl != nullptr) entry = this->evaluateBuiltinType(builtinTypeDecl); - } else this->getConsole().abortEvaluation("ASTNodeVariableDecl had an invalid type. This is a bug!"); - entry->setVariableName(hex::format("[{0}]", (u64)i)); + entry->setVariableName(hex::format("[{0}]", index)); entry->setEndian(this->getCurrentEndian()); if (!color.has_value()) @@ -815,17 +776,76 @@ namespace hex::lang { if (this->m_currOffset > this->m_provider->getActualSize() + this->m_provider->getBaseAddress()) { delete entry; - break; + return; } entries.push_back(entry); + }; + auto sizeNode = node->getSize(); + if (auto numericExpression = dynamic_cast(sizeNode); numericExpression != nullptr) { + // Parse explicit size of array + auto valueNode = this->evaluateMathematicalExpression(numericExpression); + ON_SCOPE_EXIT { delete valueNode; }; + + auto arraySize = std::visit([this, node, type = valueNode->getType()] (auto &&value) { + if (Token::isFloatingPoint(type)) + this->getConsole().abortEvaluation("array size must be an integer value"); + return static_cast(value); + }, valueNode->getValue()); + + for (u64 i = 0; i < arraySize; i++) { + addEntry(i); + } + + } else if (auto whileLoopExpression = dynamic_cast(sizeNode); whileLoopExpression != nullptr) { + // Parse while loop based size of array + auto conditionNode = this->evaluateMathematicalExpression(static_cast(whileLoopExpression->getCondition())); + ON_SCOPE_EXIT { delete conditionNode; }; + + u64 index = 0; + while (std::visit([](auto &&value) { return value != 0; }, conditionNode->getValue())) { + + addEntry(index); + index++; + + delete conditionNode; + conditionNode = this->evaluateMathematicalExpression(static_cast(whileLoopExpression->getCondition())); + } + } else { + // Parse unsized array + + if (auto typeDecl = dynamic_cast(node->getType()); typeDecl != nullptr) { + if (auto builtinType = dynamic_cast(typeDecl->getType()); builtinType != nullptr) { + std::vector bytes(Token::getTypeSize(builtinType->getType()), 0x00); + u64 offset = startOffset; + + u64 index = 0; + do { + this->m_provider->read(offset, bytes.data(), bytes.size()); + offset += bytes.size(); + addEntry(index); + index++; + } while (!std::all_of(bytes.begin(), bytes.end(), [](u8 byte){ return byte == 0x00; }) && offset < this->m_provider->getSize()); + } + } + } + + auto deleteEntries = SCOPE_GUARD { + for (auto &entry : entries) + delete entry; + }; + + if (auto typeDecl = dynamic_cast(node->getType()); typeDecl != nullptr) { + if (auto builtinType = dynamic_cast(typeDecl->getType()); builtinType != nullptr) { + if (builtinType->getType() == Token::ValueType::Padding) + return new PatternDataPadding(startOffset, this->m_currOffset - startOffset); + } } PatternData *pattern; - if (entries.empty()) { + if (entries.empty()) pattern = new PatternDataPadding(startOffset, 0); - } else if (dynamic_cast(entries[0]) != nullptr) pattern = new PatternDataString(startOffset, (this->m_currOffset - startOffset), color.value_or(0)); else if (dynamic_cast(entries[0]) != nullptr) @@ -836,6 +856,8 @@ namespace hex::lang { auto arrayPattern = new PatternDataArray(startOffset, (this->m_currOffset - startOffset), color.value_or(0)); arrayPattern->setEntries(entries); + deleteEntries.release(); + pattern = arrayPattern; } diff --git a/plugins/libimhex/source/lang/lexer.cpp b/plugins/libimhex/source/lang/lexer.cpp index 1210b7b3d..7611dde5e 100644 --- a/plugins/libimhex/source/lang/lexer.cpp +++ b/plugins/libimhex/source/lang/lexer.cpp @@ -223,7 +223,7 @@ namespace hex::lang { if (string.empty()) return { }; - if (!string[0] != '\'') + if (string[0] != '\'') return { }; @@ -234,7 +234,7 @@ namespace hex::lang { auto &[c, charSize] = character.value(); - if (string.length() >= charSize || string[charSize] != '\'') + if (string.length() >= charSize + 2 && string[charSize + 1] != '\'') return { }; return {{ c, charSize + 2 }}; @@ -422,6 +422,8 @@ namespace hex::lang { tokens.emplace_back(VALUE_TOKEN(Integer, Token::IntegerLiteral(Token::ValueType::Boolean, s32(1)))); else if (identifier == "parent") tokens.emplace_back(TOKEN(Keyword, Parent)); + else if (identifier == "while") + tokens.emplace_back(TOKEN(Keyword, While)); // Check for built-in types else if (identifier == "u8") diff --git a/plugins/libimhex/source/lang/parser.cpp b/plugins/libimhex/source/lang/parser.cpp index c2bc224de..58eedbcbe 100644 --- a/plugins/libimhex/source/lang/parser.cpp +++ b/plugins/libimhex/source/lang/parser.cpp @@ -84,7 +84,7 @@ namespace hex::lang { else throwParseError("expected member name or 'parent' keyword", -1); } else - return new ASTNodeRValue(path); + return TO_NUMERIC_EXPRESSION(new ASTNodeRValue(path)); } // @@ -108,7 +108,7 @@ namespace hex::lang { ASTNodeRValue::Path path; return TO_NUMERIC_EXPRESSION(this->parseRValue(path)); } else if (MATCHES(sequence(OPERATOR_DOLLAR))) { - return new ASTNodeRValue({ "$" }); + return TO_NUMERIC_EXPRESSION(new ASTNodeRValue({ "$" })); } else if (MATCHES(oneOf(OPERATOR_ADDRESSOF, OPERATOR_SIZEOF) && sequence(SEPARATOR_ROUNDBRACKETOPEN))) { auto op = getValue(-2); @@ -395,6 +395,22 @@ namespace hex::lang { return new ASTNodeConditionalStatement(condition, trueBody, falseBody); } + // while ((parseMathematicalExpression)) + ASTNode* Parser::parseWhileStatement() { + auto condition = parseMathematicalExpression(); + + auto cleanup = SCOPE_GUARD { + delete condition; + }; + + if (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) + throwParseError("expected closing ')' after while head"); + + cleanup.release(); + + return new ASTNodeWhileStatement(condition); + } + /* Type declarations */ // [be|le] @@ -459,7 +475,10 @@ namespace hex::lang { auto sizeCleanup = SCOPE_GUARD { delete size; }; if (!MATCHES(sequence(SEPARATOR_SQUAREBRACKETCLOSE))) { - size = parseMathematicalExpression(); + if (MATCHES(sequence(KEYWORD_WHILE, SEPARATOR_ROUNDBRACKETOPEN))) + size = parseWhileStatement(); + else + size = parseMathematicalExpression(); if (!MATCHES(sequence(SEPARATOR_SQUAREBRACKETCLOSE))) throwParseError("expected closing ']' at end of array declaration", -1); @@ -662,7 +681,10 @@ namespace hex::lang { auto sizeCleanup = SCOPE_GUARD { delete size; }; if (!MATCHES(sequence(SEPARATOR_SQUAREBRACKETCLOSE))) { - size = parseMathematicalExpression(); + if (MATCHES(sequence(KEYWORD_WHILE, SEPARATOR_ROUNDBRACKETOPEN))) + size = parseWhileStatement(); + else + size = parseMathematicalExpression(); if (!MATCHES(sequence(SEPARATOR_SQUAREBRACKETCLOSE))) throwParseError("expected closing ']' at end of array declaration", -1); diff --git a/source/views/view_pattern.cpp b/source/views/view_pattern.cpp index cf29d8a74..498c7d6e0 100644 --- a/source/views/view_pattern.cpp +++ b/source/views/view_pattern.cpp @@ -15,7 +15,7 @@ namespace hex { static TextEditor::LanguageDefinition langDef; if (!initialized) { static const char* const keywords[] = { - "using", "struct", "union", "enum", "bitfield", "be", "le", "if", "else", "false", "true", "parent", "addressof", "sizeof", "$" + "using", "struct", "union", "enum", "bitfield", "be", "le", "if", "else", "false", "true", "parent", "addressof", "sizeof", "$", "while" }; for (auto& k : keywords) langDef.mKeywords.insert(k);