diff --git a/include/lang/ast_node.hpp b/include/lang/ast_node.hpp index afecffb63..54a53c0f7 100644 --- a/include/lang/ast_node.hpp +++ b/include/lang/ast_node.hpp @@ -479,5 +479,27 @@ namespace hex::lang { std::vector m_params; }; + class ASTNodeStringLiteral : public ASTNode { + public: + explicit ASTNodeStringLiteral(std::string_view string) : ASTNode(), m_string(string) { } + + ~ASTNodeStringLiteral() override { } + + ASTNodeStringLiteral(const ASTNodeStringLiteral &other) : ASTNode(other) { + this->m_string = other.m_string; + } + + ASTNode* clone() const override { + return new ASTNodeStringLiteral(*this); + } + + [[nodiscard]] std::string_view getString() { + return this->m_string; + } + + private: + std::string m_string; + }; + } \ No newline at end of file diff --git a/include/lang/evaluator.hpp b/include/lang/evaluator.hpp index 81bbda059..d6b14f02f 100644 --- a/include/lang/evaluator.hpp +++ b/include/lang/evaluator.hpp @@ -3,6 +3,7 @@ #include #include "providers/provider.hpp" +#include "helpers/utils.hpp" #include "lang/pattern_data.hpp" #include "ast_node.hpp" @@ -29,7 +30,7 @@ namespace hex::lang { constexpr static u32 NoParameters = 0x0000'0000; u32 parameterCount; - std::function)> func; + std::function)> func; }; private: @@ -54,7 +55,7 @@ namespace hex::lang { return this->m_endianStack.back(); } - void addFunction(std::string_view name, u32 parameterCount, std::function)> func) { + void addFunction(std::string_view name, u32 parameterCount, std::function)> func) { if (this->m_functions.contains(name.data())) throwEvaluateError(hex::format("redefinition of function '%s'", name.data()), 1); @@ -81,7 +82,17 @@ namespace hex::lang { PatternData* evaluatePointer(ASTNodePointerVariableDecl *node); - #define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* name(std::vector params) + template + T* asType(ASTNode *param) { + if (auto evaluatedParam = dynamic_cast(param); evaluatedParam != nullptr) + return evaluatedParam; + else + throwEvaluateError("function got wrong type of parameter", 1); + } + + + + #define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* TOKEN_CONCAT(builtin_, name)(std::vector params) BUILTIN_FUNCTION(findSequence); BUILTIN_FUNCTION(readUnsigned); diff --git a/include/lang/parser.hpp b/include/lang/parser.hpp index f494ba956..fbf9e31ae 100644 --- a/include/lang/parser.hpp +++ b/include/lang/parser.hpp @@ -54,6 +54,7 @@ namespace hex::lang { } ASTNode* parseFunctionCall(); + ASTNode* parseStringLiteral(); ASTNode* parseScopeResolution(std::vector &path); ASTNode* parseRValue(std::vector &path); ASTNode* parseFactor(); diff --git a/source/lang/builtin_functions.cpp b/source/lang/builtin_functions.cpp index 78a2e0ddf..69071f51c 100644 --- a/source/lang/builtin_functions.cpp +++ b/source/lang/builtin_functions.cpp @@ -2,12 +2,12 @@ namespace hex::lang { - #define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* Evaluator::name(std::vector params) + #define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* Evaluator::TOKEN_CONCAT(builtin_, name)(std::vector params) #define LITERAL_COMPARE(literal, cond) std::visit([&, this](auto &&literal) { return (cond) != 0; }, literal) BUILTIN_FUNCTION(findSequence) { - auto& occurrenceIndex = params[0]->getValue(); + auto& occurrenceIndex = asType(params[0])->getValue(); std::vector sequence; for (u32 i = 1; i < params.size(); i++) { sequence.push_back(std::visit([](auto &&value) -> u8 { @@ -15,7 +15,7 @@ namespace hex::lang { return value; else throwEvaluateError("sequence bytes need to fit into 1 byte", 1); - }, params[i]->getValue())); + }, asType(params[i])->getValue())); } std::vector bytes(sequence.size(), 0x00); @@ -37,8 +37,8 @@ namespace hex::lang { } BUILTIN_FUNCTION(readUnsigned) { - auto address = params[0]->getValue(); - auto size = params[1]->getValue(); + auto address = asType(params[0])->getValue(); + auto size = asType(params[1])->getValue(); if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize())) throwEvaluateError("address out of range", 1); @@ -62,8 +62,8 @@ namespace hex::lang { } BUILTIN_FUNCTION(readSigned) { - auto address = params[0]->getValue(); - auto size = params[1]->getValue(); + auto address = asType(params[0])->getValue(); + auto size = asType(params[1])->getValue(); if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize())) throwEvaluateError("address out of range", 1); @@ -76,13 +76,13 @@ namespace hex::lang { this->m_provider->read(address, value, size); switch ((u8)size) { - case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast(value), 1, this->getCurrentEndian()) }); - case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast(value), 2, this->getCurrentEndian()) }); - case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast(value), 4, this->getCurrentEndian()) }); - case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast(value), 8, this->getCurrentEndian()) }); - case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast(value), 16, this->getCurrentEndian()) }); - default: throwEvaluateError("invalid read size", 1); - } + case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast(value), 1, this->getCurrentEndian()) }); + case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast(value), 2, this->getCurrentEndian()) }); + case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast(value), 4, this->getCurrentEndian()) }); + case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast(value), 8, this->getCurrentEndian()) }); + case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast(value), 16, this->getCurrentEndian()) }); + default: throwEvaluateError("invalid read size", 1); + } }, address, size); } diff --git a/source/lang/evaluator.cpp b/source/lang/evaluator.cpp index 184e34905..f98cd9513 100644 --- a/source/lang/evaluator.cpp +++ b/source/lang/evaluator.cpp @@ -14,15 +14,15 @@ namespace hex::lang { : m_provider(provider), m_defaultDataEndian(defaultDataEndian) { this->addFunction("findSequence", Function::MoreParametersThan | 1, [this](auto params) { - return this->findSequence(params); + return this->builtin_findSequence(params); }); this->addFunction("readUnsigned", 2, [this](auto params) { - return this->readUnsigned(params); + return this->builtin_readUnsigned(params); }); this->addFunction("readSigned", 2, [this](auto params) { - return this->readSigned(params); + return this->builtin_readSigned(params); }); } @@ -126,14 +126,18 @@ namespace hex::lang { } ASTNodeIntegerLiteral* Evaluator::evaluateFunctionCall(ASTNodeFunctionCall *node) { - std::vector evaluatedParams; + std::vector evaluatedParams; ScopeExit paramCleanup([&] { for (auto ¶m : evaluatedParams) delete param; }); - for (auto ¶m : node->getParams()) - evaluatedParams.push_back(this->evaluateMathematicalExpression(static_cast(param))); + for (auto ¶m : node->getParams()) { + if (auto numericExpression = dynamic_cast(param); numericExpression != nullptr) + evaluatedParams.push_back(this->evaluateMathematicalExpression(numericExpression)); + else if (auto stringLiteral = dynamic_cast(param); stringLiteral != nullptr) + evaluatedParams.push_back(stringLiteral->clone()); + } if (!this->m_functions.contains(node->getFunctionName().data())) throwEvaluateError(hex::format("no function named '%s' found", node->getFunctionName().data()), node->getLineNumber()); @@ -278,8 +282,14 @@ namespace hex::lang { return evaluateScopeResolution(exprScopeResolution); else if (auto exprTernary = dynamic_cast(node); exprTernary != nullptr) return evaluateTernaryExpression(exprTernary); - else if (auto exprFunctionCall = dynamic_cast(node); exprFunctionCall != nullptr) - return evaluateFunctionCall(exprFunctionCall); + else if (auto exprFunctionCall = dynamic_cast(node); exprFunctionCall != nullptr) { + auto returnValue = evaluateFunctionCall(exprFunctionCall); + + if (returnValue == nullptr) + throwEvaluateError("function returning void used in expression", node->getLineNumber()); + else + return returnValue; + } else throwEvaluateError("invalid operand", node->getLineNumber()); } @@ -675,6 +685,9 @@ namespace hex::lang { this->m_globalMembers.push_back(this->evaluatePointer(pointerDeclNode)); } else if (auto typeDeclNode = dynamic_cast(node); typeDeclNode != nullptr) { this->m_types[typeDeclNode->getName().data()] = typeDeclNode->getType(); + } else if (auto functionCallNode = dynamic_cast(node); functionCallNode != nullptr) { + auto result = this->evaluateFunctionCall(functionCallNode); + delete result; } this->m_endianStack.pop_back(); diff --git a/source/lang/parser.cpp b/source/lang/parser.cpp index ee5045c7f..ea2fca299 100644 --- a/source/lang/parser.cpp +++ b/source/lang/parser.cpp @@ -28,7 +28,10 @@ namespace hex::lang { }); while (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) { - params.push_back(parseMathematicalExpression()); + if (MATCHES(sequence(STRING))) + params.push_back(parseStringLiteral()); + else + params.push_back(parseMathematicalExpression()); if (MATCHES(sequence(SEPARATOR_COMMA, SEPARATOR_ROUNDBRACKETCLOSE))) throwParseError("unexpected ',' at end of function parameter list", -1); @@ -41,7 +44,11 @@ namespace hex::lang { paramCleanup.release(); - return TO_NUMERIC_EXPRESSION(new ASTNodeFunctionCall(functionName, params)); + return new ASTNodeFunctionCall(functionName, params); + } + + ASTNode* Parser::parseStringLiteral() { + return new ASTNodeStringLiteral(getValue(-1)); } // Identifier:: @@ -86,7 +93,7 @@ namespace hex::lang { this->m_curr--; return this->parseScopeResolution(path); } else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) { - return this->parseFunctionCall(); + return TO_NUMERIC_EXPRESSION(this->parseFunctionCall()); } else if (MATCHES(sequence(IDENTIFIER))) { std::vector path; return this->parseRValue(path); @@ -591,6 +598,8 @@ namespace hex::lang { statement = parseEnum(); else if (MATCHES(sequence(KEYWORD_BITFIELD, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN))) statement = parseBitfield(); + else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) + statement = parseFunctionCall(); else throwParseError("invalid sequence", 0); if (!MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION)))