1
0
mirror of synced 2025-02-17 18:59:21 +01:00

Added functions with string literals as parameter

This commit is contained in:
WerWolv 2021-01-09 21:47:11 +01:00
parent e28d6e7451
commit c5d023822d
6 changed files with 84 additions and 28 deletions

View File

@ -479,5 +479,27 @@ namespace hex::lang {
std::vector<ASTNode*> m_params; std::vector<ASTNode*> m_params;
}; };
class ASTNodeStringLiteral : public ASTNode {
public:
explicit ASTNodeStringLiteral(std::string_view string) : ASTNode(), m_string(string) { }
~ASTNodeStringLiteral() override { }
ASTNodeStringLiteral(const ASTNodeStringLiteral &other) : ASTNode(other) {
this->m_string = other.m_string;
}
ASTNode* clone() const override {
return new ASTNodeStringLiteral(*this);
}
[[nodiscard]] std::string_view getString() {
return this->m_string;
}
private:
std::string m_string;
};
} }

View File

@ -3,6 +3,7 @@
#include <hex.hpp> #include <hex.hpp>
#include "providers/provider.hpp" #include "providers/provider.hpp"
#include "helpers/utils.hpp"
#include "lang/pattern_data.hpp" #include "lang/pattern_data.hpp"
#include "ast_node.hpp" #include "ast_node.hpp"
@ -29,7 +30,7 @@ namespace hex::lang {
constexpr static u32 NoParameters = 0x0000'0000; constexpr static u32 NoParameters = 0x0000'0000;
u32 parameterCount; u32 parameterCount;
std::function<ASTNodeIntegerLiteral*(std::vector<ASTNodeIntegerLiteral*>)> func; std::function<ASTNodeIntegerLiteral*(std::vector<ASTNode*>)> func;
}; };
private: private:
@ -54,7 +55,7 @@ namespace hex::lang {
return this->m_endianStack.back(); return this->m_endianStack.back();
} }
void addFunction(std::string_view name, u32 parameterCount, std::function<ASTNodeIntegerLiteral*(std::vector<ASTNodeIntegerLiteral*>)> func) { void addFunction(std::string_view name, u32 parameterCount, std::function<ASTNodeIntegerLiteral*(std::vector<ASTNode*>)> func) {
if (this->m_functions.contains(name.data())) if (this->m_functions.contains(name.data()))
throwEvaluateError(hex::format("redefinition of function '%s'", name.data()), 1); throwEvaluateError(hex::format("redefinition of function '%s'", name.data()), 1);
@ -81,7 +82,17 @@ namespace hex::lang {
PatternData* evaluatePointer(ASTNodePointerVariableDecl *node); PatternData* evaluatePointer(ASTNodePointerVariableDecl *node);
#define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* name(std::vector<ASTNodeIntegerLiteral*> params) template<typename T>
T* asType(ASTNode *param) {
if (auto evaluatedParam = dynamic_cast<T*>(param); evaluatedParam != nullptr)
return evaluatedParam;
else
throwEvaluateError("function got wrong type of parameter", 1);
}
#define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* TOKEN_CONCAT(builtin_, name)(std::vector<ASTNode*> params)
BUILTIN_FUNCTION(findSequence); BUILTIN_FUNCTION(findSequence);
BUILTIN_FUNCTION(readUnsigned); BUILTIN_FUNCTION(readUnsigned);

View File

@ -54,6 +54,7 @@ namespace hex::lang {
} }
ASTNode* parseFunctionCall(); ASTNode* parseFunctionCall();
ASTNode* parseStringLiteral();
ASTNode* parseScopeResolution(std::vector<std::string> &path); ASTNode* parseScopeResolution(std::vector<std::string> &path);
ASTNode* parseRValue(std::vector<std::string> &path); ASTNode* parseRValue(std::vector<std::string> &path);
ASTNode* parseFactor(); ASTNode* parseFactor();

View File

@ -2,12 +2,12 @@
namespace hex::lang { namespace hex::lang {
#define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* Evaluator::name(std::vector<ASTNodeIntegerLiteral*> params) #define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* Evaluator::TOKEN_CONCAT(builtin_, name)(std::vector<ASTNode*> params)
#define LITERAL_COMPARE(literal, cond) std::visit([&, this](auto &&literal) { return (cond) != 0; }, literal) #define LITERAL_COMPARE(literal, cond) std::visit([&, this](auto &&literal) { return (cond) != 0; }, literal)
BUILTIN_FUNCTION(findSequence) { BUILTIN_FUNCTION(findSequence) {
auto& occurrenceIndex = params[0]->getValue(); auto& occurrenceIndex = asType<ASTNodeIntegerLiteral>(params[0])->getValue();
std::vector<u8> sequence; std::vector<u8> sequence;
for (u32 i = 1; i < params.size(); i++) { for (u32 i = 1; i < params.size(); i++) {
sequence.push_back(std::visit([](auto &&value) -> u8 { sequence.push_back(std::visit([](auto &&value) -> u8 {
@ -15,7 +15,7 @@ namespace hex::lang {
return value; return value;
else else
throwEvaluateError("sequence bytes need to fit into 1 byte", 1); throwEvaluateError("sequence bytes need to fit into 1 byte", 1);
}, params[i]->getValue())); }, asType<ASTNodeIntegerLiteral>(params[i])->getValue()));
} }
std::vector<u8> bytes(sequence.size(), 0x00); std::vector<u8> bytes(sequence.size(), 0x00);
@ -37,8 +37,8 @@ namespace hex::lang {
} }
BUILTIN_FUNCTION(readUnsigned) { BUILTIN_FUNCTION(readUnsigned) {
auto address = params[0]->getValue(); auto address = asType<ASTNodeIntegerLiteral>(params[0])->getValue();
auto size = params[1]->getValue(); auto size = asType<ASTNodeIntegerLiteral>(params[1])->getValue();
if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize())) if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize()))
throwEvaluateError("address out of range", 1); throwEvaluateError("address out of range", 1);
@ -62,8 +62,8 @@ namespace hex::lang {
} }
BUILTIN_FUNCTION(readSigned) { BUILTIN_FUNCTION(readSigned) {
auto address = params[0]->getValue(); auto address = asType<ASTNodeIntegerLiteral>(params[0])->getValue();
auto size = params[1]->getValue(); auto size = asType<ASTNodeIntegerLiteral>(params[1])->getValue();
if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize())) if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize()))
throwEvaluateError("address out of range", 1); throwEvaluateError("address out of range", 1);
@ -76,13 +76,13 @@ namespace hex::lang {
this->m_provider->read(address, value, size); this->m_provider->read(address, value, size);
switch ((u8)size) { switch ((u8)size) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast<s8*>(value), 1, this->getCurrentEndian()) }); case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast<s8*>(value), 1, this->getCurrentEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast<s16*>(value), 2, this->getCurrentEndian()) }); case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast<s16*>(value), 2, this->getCurrentEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast<s32*>(value), 4, this->getCurrentEndian()) }); case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast<s32*>(value), 4, this->getCurrentEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast<s64*>(value), 8, this->getCurrentEndian()) }); case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast<s64*>(value), 8, this->getCurrentEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast<s128*>(value), 16, this->getCurrentEndian()) }); case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast<s128*>(value), 16, this->getCurrentEndian()) });
default: throwEvaluateError("invalid read size", 1); default: throwEvaluateError("invalid read size", 1);
} }
}, address, size); }, address, size);
} }

View File

@ -14,15 +14,15 @@ namespace hex::lang {
: m_provider(provider), m_defaultDataEndian(defaultDataEndian) { : m_provider(provider), m_defaultDataEndian(defaultDataEndian) {
this->addFunction("findSequence", Function::MoreParametersThan | 1, [this](auto params) { this->addFunction("findSequence", Function::MoreParametersThan | 1, [this](auto params) {
return this->findSequence(params); return this->builtin_findSequence(params);
}); });
this->addFunction("readUnsigned", 2, [this](auto params) { this->addFunction("readUnsigned", 2, [this](auto params) {
return this->readUnsigned(params); return this->builtin_readUnsigned(params);
}); });
this->addFunction("readSigned", 2, [this](auto params) { this->addFunction("readSigned", 2, [this](auto params) {
return this->readSigned(params); return this->builtin_readSigned(params);
}); });
} }
@ -126,14 +126,18 @@ namespace hex::lang {
} }
ASTNodeIntegerLiteral* Evaluator::evaluateFunctionCall(ASTNodeFunctionCall *node) { ASTNodeIntegerLiteral* Evaluator::evaluateFunctionCall(ASTNodeFunctionCall *node) {
std::vector<ASTNodeIntegerLiteral*> evaluatedParams; std::vector<ASTNode*> evaluatedParams;
ScopeExit paramCleanup([&] { ScopeExit paramCleanup([&] {
for (auto &param : evaluatedParams) for (auto &param : evaluatedParams)
delete param; delete param;
}); });
for (auto &param : node->getParams()) for (auto &param : node->getParams()) {
evaluatedParams.push_back(this->evaluateMathematicalExpression(static_cast<ASTNodeNumericExpression*>(param))); if (auto numericExpression = dynamic_cast<ASTNodeNumericExpression*>(param); numericExpression != nullptr)
evaluatedParams.push_back(this->evaluateMathematicalExpression(numericExpression));
else if (auto stringLiteral = dynamic_cast<ASTNodeStringLiteral*>(param); stringLiteral != nullptr)
evaluatedParams.push_back(stringLiteral->clone());
}
if (!this->m_functions.contains(node->getFunctionName().data())) if (!this->m_functions.contains(node->getFunctionName().data()))
throwEvaluateError(hex::format("no function named '%s' found", node->getFunctionName().data()), node->getLineNumber()); throwEvaluateError(hex::format("no function named '%s' found", node->getFunctionName().data()), node->getLineNumber());
@ -278,8 +282,14 @@ namespace hex::lang {
return evaluateScopeResolution(exprScopeResolution); return evaluateScopeResolution(exprScopeResolution);
else if (auto exprTernary = dynamic_cast<ASTNodeTernaryExpression*>(node); exprTernary != nullptr) else if (auto exprTernary = dynamic_cast<ASTNodeTernaryExpression*>(node); exprTernary != nullptr)
return evaluateTernaryExpression(exprTernary); return evaluateTernaryExpression(exprTernary);
else if (auto exprFunctionCall = dynamic_cast<ASTNodeFunctionCall*>(node); exprFunctionCall != nullptr) else if (auto exprFunctionCall = dynamic_cast<ASTNodeFunctionCall*>(node); exprFunctionCall != nullptr) {
return evaluateFunctionCall(exprFunctionCall); auto returnValue = evaluateFunctionCall(exprFunctionCall);
if (returnValue == nullptr)
throwEvaluateError("function returning void used in expression", node->getLineNumber());
else
return returnValue;
}
else else
throwEvaluateError("invalid operand", node->getLineNumber()); throwEvaluateError("invalid operand", node->getLineNumber());
} }
@ -675,6 +685,9 @@ namespace hex::lang {
this->m_globalMembers.push_back(this->evaluatePointer(pointerDeclNode)); this->m_globalMembers.push_back(this->evaluatePointer(pointerDeclNode));
} else if (auto typeDeclNode = dynamic_cast<ASTNodeTypeDecl*>(node); typeDeclNode != nullptr) { } else if (auto typeDeclNode = dynamic_cast<ASTNodeTypeDecl*>(node); typeDeclNode != nullptr) {
this->m_types[typeDeclNode->getName().data()] = typeDeclNode->getType(); this->m_types[typeDeclNode->getName().data()] = typeDeclNode->getType();
} else if (auto functionCallNode = dynamic_cast<ASTNodeFunctionCall*>(node); functionCallNode != nullptr) {
auto result = this->evaluateFunctionCall(functionCallNode);
delete result;
} }
this->m_endianStack.pop_back(); this->m_endianStack.pop_back();

View File

@ -28,7 +28,10 @@ namespace hex::lang {
}); });
while (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) { while (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) {
params.push_back(parseMathematicalExpression()); if (MATCHES(sequence(STRING)))
params.push_back(parseStringLiteral());
else
params.push_back(parseMathematicalExpression());
if (MATCHES(sequence(SEPARATOR_COMMA, SEPARATOR_ROUNDBRACKETCLOSE))) if (MATCHES(sequence(SEPARATOR_COMMA, SEPARATOR_ROUNDBRACKETCLOSE)))
throwParseError("unexpected ',' at end of function parameter list", -1); throwParseError("unexpected ',' at end of function parameter list", -1);
@ -41,7 +44,11 @@ namespace hex::lang {
paramCleanup.release(); paramCleanup.release();
return TO_NUMERIC_EXPRESSION(new ASTNodeFunctionCall(functionName, params)); return new ASTNodeFunctionCall(functionName, params);
}
ASTNode* Parser::parseStringLiteral() {
return new ASTNodeStringLiteral(getValue<std::string>(-1));
} }
// Identifier::<Identifier[::]...> // Identifier::<Identifier[::]...>
@ -86,7 +93,7 @@ namespace hex::lang {
this->m_curr--; this->m_curr--;
return this->parseScopeResolution(path); return this->parseScopeResolution(path);
} else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) { } else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) {
return this->parseFunctionCall(); return TO_NUMERIC_EXPRESSION(this->parseFunctionCall());
} else if (MATCHES(sequence(IDENTIFIER))) { } else if (MATCHES(sequence(IDENTIFIER))) {
std::vector<std::string> path; std::vector<std::string> path;
return this->parseRValue(path); return this->parseRValue(path);
@ -591,6 +598,8 @@ namespace hex::lang {
statement = parseEnum(); statement = parseEnum();
else if (MATCHES(sequence(KEYWORD_BITFIELD, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN))) else if (MATCHES(sequence(KEYWORD_BITFIELD, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN)))
statement = parseBitfield(); statement = parseBitfield();
else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN)))
statement = parseFunctionCall();
else throwParseError("invalid sequence", 0); else throwParseError("invalid sequence", 0);
if (!MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION))) if (!MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION)))