From c9fae32ddfee0cfaf2209922453ba1e7ab23a816 Mon Sep 17 00:00:00 2001 From: WerWolv Date: Sun, 20 Jun 2021 23:46:13 +0200 Subject: [PATCH] patterns: Added function if statements, improved returns --- .../libimhex/include/hex/lang/evaluator.hpp | 1 + plugins/libimhex/include/hex/lang/parser.hpp | 6 +- plugins/libimhex/source/lang/evaluator.cpp | 104 +++++++++++------- plugins/libimhex/source/lang/parser.cpp | 89 +++++++++++---- 4 files changed, 136 insertions(+), 64 deletions(-) diff --git a/plugins/libimhex/include/hex/lang/evaluator.hpp b/plugins/libimhex/include/hex/lang/evaluator.hpp index a355c7d0c..f2240afb1 100644 --- a/plugins/libimhex/include/hex/lang/evaluator.hpp +++ b/plugins/libimhex/include/hex/lang/evaluator.hpp @@ -68,6 +68,7 @@ namespace hex::lang { ASTNodeIntegerLiteral* evaluateTernaryExpression(ASTNodeTernaryExpression *node); ASTNodeIntegerLiteral* evaluateMathematicalExpression(ASTNodeNumericExpression *node); void evaluateFunctionDefinition(ASTNodeFunctionDefinition *node); + std::optional evaluateFunctionBody(const std::vector &body); PatternData* findPattern(std::vector currMembers, const ASTNodeRValue::Path &path); PatternData* evaluateAttributes(ASTNode *currNode, PatternData *currPattern); diff --git a/plugins/libimhex/include/hex/lang/parser.hpp b/plugins/libimhex/include/hex/lang/parser.hpp index d48678aba..d05342581 100644 --- a/plugins/libimhex/include/hex/lang/parser.hpp +++ b/plugins/libimhex/include/hex/lang/parser.hpp @@ -73,8 +73,10 @@ namespace hex::lang { void parseAttribute(Attributable *currNode); ASTNode* parseFunctionDefintion(); - ASTNode* parseVariableAssignment(); - ASTNode* parseReturnStatement(); + ASTNode* parseFunctionStatement(); + ASTNode* parseFunctionVariableAssignment(); + ASTNode* parseFunctionReturnStatement(); + ASTNode* parseFunctionConditional(); ASTNode* parseConditional(); ASTNode* parseWhileStatement(); ASTNode* parseType(s32 startIndex); diff --git a/plugins/libimhex/source/lang/evaluator.cpp b/plugins/libimhex/source/lang/evaluator.cpp index 1ada296f3..45f3a950c 100644 --- a/plugins/libimhex/source/lang/evaluator.cpp +++ b/plugins/libimhex/source/lang/evaluator.cpp @@ -480,47 +480,7 @@ namespace hex::lang { } evaluator.m_currOffset = startOffset; - for (auto &statement : body) { - ON_SCOPE_EXIT { evaluator.m_currOffset = startOffset; }; - - if (auto functionCallNode = dynamic_cast(statement); functionCallNode != nullptr) { - auto result = evaluator.evaluateFunctionCall(functionCallNode); - delete result; - } else if (auto varDeclNode = dynamic_cast(statement); varDeclNode != nullptr) { - auto pattern = evaluator.evaluateVariable(varDeclNode); - evaluator.createLocalVariable(varDeclNode->getName(), pattern); - } else if (auto assignmentNode = dynamic_cast(statement); assignmentNode != nullptr) { - if (auto numericExpressionNode = dynamic_cast(assignmentNode->getRValue()); numericExpressionNode != nullptr) { - auto value = evaluator.evaluateMathematicalExpression(numericExpressionNode); - ON_SCOPE_EXIT { delete value; }; - - std::visit([&](auto &&value) { - evaluator.setLocalVariableValue(assignmentNode->getLValueName(), &value, sizeof(value)); - }, value->getValue()); - } else { - evaluator.getConsole().abortEvaluation("invalid rvalue used in assignment"); - } - } else if (auto assignmentNode = dynamic_cast(statement); assignmentNode != nullptr) { - if (auto numericExpressionNode = dynamic_cast(assignmentNode->getRValue()); numericExpressionNode != nullptr) { - auto value = evaluator.evaluateMathematicalExpression(numericExpressionNode); - ON_SCOPE_EXIT { delete value; }; - - std::visit([&](auto &&value) { - evaluator.setLocalVariableValue(assignmentNode->getLValueName(), &value, sizeof(value)); - }, value->getValue()); - } else { - evaluator.getConsole().abortEvaluation("invalid rvalue used in assignment"); - } - } else if (auto returnNode = dynamic_cast(statement); returnNode != nullptr) { - if (auto numericExpressionNode = dynamic_cast(returnNode->getRValue()); numericExpressionNode != nullptr) { - return evaluator.evaluateMathematicalExpression(numericExpressionNode); - } else { - evaluator.getConsole().abortEvaluation("invalid rvalue used in return statement"); - } - } - } - - return nullptr; + return evaluator.evaluateFunctionBody(body).value_or(nullptr); } }; @@ -530,6 +490,68 @@ namespace hex::lang { this->m_definedFunctions.insert({ std::string(node->getName()), function }); } + std::optional Evaluator::evaluateFunctionBody(const std::vector &body) { + std::optional returnResult; + auto startOffset = this->m_currOffset; + + for (auto &statement : body) { + ON_SCOPE_EXIT { this->m_currOffset = startOffset; }; + + if (auto functionCallNode = dynamic_cast(statement); functionCallNode != nullptr) { + auto result = this->evaluateFunctionCall(functionCallNode); + delete result; + } else if (auto varDeclNode = dynamic_cast(statement); varDeclNode != nullptr) { + auto pattern = this->evaluateVariable(varDeclNode); + this->createLocalVariable(varDeclNode->getName(), pattern); + } else if (auto assignmentNode = dynamic_cast(statement); assignmentNode != nullptr) { + if (auto numericExpressionNode = dynamic_cast(assignmentNode->getRValue()); numericExpressionNode != nullptr) { + auto value = this->evaluateMathematicalExpression(numericExpressionNode); + ON_SCOPE_EXIT { delete value; }; + + std::visit([&](auto &&value) { + this->setLocalVariableValue(assignmentNode->getLValueName(), &value, sizeof(value)); + }, value->getValue()); + } else { + this->getConsole().abortEvaluation("invalid rvalue used in assignment"); + } + } else if (auto returnNode = dynamic_cast(statement); returnNode != nullptr) { + if (returnNode->getRValue() == nullptr) { + returnResult = nullptr; + } else if (auto numericExpressionNode = dynamic_cast(returnNode->getRValue()); numericExpressionNode != nullptr) { + returnResult = this->evaluateMathematicalExpression(numericExpressionNode); + } else { + this->getConsole().abortEvaluation("invalid rvalue used in return statement"); + } + } else if (auto conditionalNode = dynamic_cast(statement); conditionalNode != nullptr) { + if (auto numericExpressionNode = dynamic_cast(conditionalNode->getCondition()); numericExpressionNode != nullptr) { + auto condition = this->evaluateMathematicalExpression(numericExpressionNode); + + u32 localVariableStartCount = this->m_localVariables.back()->size(); + u32 localVariableStackStartSize = this->m_localStack.size(); + + if (std::visit([](auto &&value) { return value != 0; }, condition->getValue())) + returnResult = this->evaluateFunctionBody(conditionalNode->getTrueBody()); + else + returnResult = this->evaluateFunctionBody(conditionalNode->getFalseBody()); + + for (u32 i = localVariableStartCount; i < this->m_localVariables.size(); i++) + delete (*this->m_localVariables.back())[i]; + this->m_localVariables.back()->resize(localVariableStartCount); + this->m_localStack.resize(localVariableStackStartSize); + + } else { + this->getConsole().abortEvaluation("invalid rvalue used in return statement"); + } + } + + if (returnResult.has_value()) + return returnResult.value(); + } + + + return { }; + } + PatternData* Evaluator::evaluateAttributes(ASTNode *currNode, PatternData *currPattern) { auto attributableNode = dynamic_cast(currNode); if (attributableNode == nullptr) diff --git a/plugins/libimhex/source/lang/parser.cpp b/plugins/libimhex/source/lang/parser.cpp index 8f2cab371..7be7919bb 100644 --- a/plugins/libimhex/source/lang/parser.cpp +++ b/plugins/libimhex/source/lang/parser.cpp @@ -388,31 +388,41 @@ namespace hex::lang { }; while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { - ASTNode *statement; - if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) - statement = parseFunctionCall(); - else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(IDENTIFIER, SEPARATOR_SQUAREBRACKETOPEN) && sequence(SEPARATOR_SQUAREBRACKETOPEN))) - statement = parseMemberArrayVariable(); - else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(IDENTIFIER))) - statement = parseMemberVariable(); - else if (MATCHES(sequence(IDENTIFIER, OPERATOR_ASSIGNMENT))) - statement = parseVariableAssignment(); - else if (MATCHES(sequence(KEYWORD_RETURN))) - statement = parseReturnStatement(); - else - throwParseError("invalid sequence", 0); - - body.push_back(statement); - - if (!MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION))) - throwParseError("missing ';' at end of expression", -1); + body.push_back(this->parseFunctionStatement()); } bodyCleanup.release(); return new ASTNodeFunctionDefinition(functionName, params, body); } - ASTNode* Parser::parseVariableAssignment() { + ASTNode* Parser::parseFunctionStatement() { + bool needsSemicolon = true; + ASTNode *statement; + + if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) + statement = parseFunctionCall(); + else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(IDENTIFIER))) + statement = parseMemberVariable(); + else if (MATCHES(sequence(IDENTIFIER, OPERATOR_ASSIGNMENT))) + statement = parseFunctionVariableAssignment(); + else if (MATCHES(sequence(KEYWORD_RETURN))) + statement = parseFunctionReturnStatement(); + else if (MATCHES(sequence(KEYWORD_IF, SEPARATOR_ROUNDBRACKETOPEN))) { + statement = parseFunctionConditional(); + needsSemicolon = false; + } + else + throwParseError("invalid sequence", 0); + + if (needsSemicolon && !MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION))) { + delete statement; + throwParseError("missing ';' at end of expression", -1); + } + + return statement; + } + + ASTNode* Parser::parseFunctionVariableAssignment() { const auto &lvalue = getValue(-2); auto rvalue = this->parseMathematicalExpression(); @@ -420,8 +430,45 @@ namespace hex::lang { return new ASTNodeAssignment(lvalue, rvalue); } - ASTNode* Parser::parseReturnStatement() { - return new ASTNodeReturnStatement(this->parseMathematicalExpression()); + ASTNode* Parser::parseFunctionReturnStatement() { + if (peek(SEPARATOR_ENDOFEXPRESSION)) + return new ASTNodeReturnStatement(nullptr); + else + return new ASTNodeReturnStatement(this->parseMathematicalExpression()); + } + + ASTNode* Parser::parseFunctionConditional() { + auto condition = parseMathematicalExpression(); + std::vector trueBody, falseBody; + + auto cleanup = SCOPE_GUARD { + delete condition; + for (auto &statement : trueBody) + delete statement; + for (auto &statement : falseBody) + delete statement; + }; + + if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE, SEPARATOR_CURLYBRACKETOPEN))) { + while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { + trueBody.push_back(parseFunctionStatement()); + } + } else if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) { + trueBody.push_back(parseFunctionStatement()); + } else + throwParseError("expected body of conditional statement"); + + if (MATCHES(sequence(KEYWORD_ELSE, SEPARATOR_CURLYBRACKETOPEN))) { + while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { + falseBody.push_back(parseFunctionStatement()); + } + } else if (MATCHES(sequence(KEYWORD_ELSE))) { + falseBody.push_back(parseFunctionStatement()); + } + + cleanup.release(); + + return new ASTNodeConditionalStatement(condition, trueBody, falseBody); } /* Control flow */