1
0
mirror of synced 2024-11-28 09:30:51 +01:00

Added functions with string literals as parameter

This commit is contained in:
WerWolv 2021-01-09 21:47:11 +01:00
parent e28d6e7451
commit c5d023822d
6 changed files with 84 additions and 28 deletions

View File

@ -479,5 +479,27 @@ namespace hex::lang {
std::vector<ASTNode*> m_params;
};
class ASTNodeStringLiteral : public ASTNode {
public:
explicit ASTNodeStringLiteral(std::string_view string) : ASTNode(), m_string(string) { }
~ASTNodeStringLiteral() override { }
ASTNodeStringLiteral(const ASTNodeStringLiteral &other) : ASTNode(other) {
this->m_string = other.m_string;
}
ASTNode* clone() const override {
return new ASTNodeStringLiteral(*this);
}
[[nodiscard]] std::string_view getString() {
return this->m_string;
}
private:
std::string m_string;
};
}

View File

@ -3,6 +3,7 @@
#include <hex.hpp>
#include "providers/provider.hpp"
#include "helpers/utils.hpp"
#include "lang/pattern_data.hpp"
#include "ast_node.hpp"
@ -29,7 +30,7 @@ namespace hex::lang {
constexpr static u32 NoParameters = 0x0000'0000;
u32 parameterCount;
std::function<ASTNodeIntegerLiteral*(std::vector<ASTNodeIntegerLiteral*>)> func;
std::function<ASTNodeIntegerLiteral*(std::vector<ASTNode*>)> func;
};
private:
@ -54,7 +55,7 @@ namespace hex::lang {
return this->m_endianStack.back();
}
void addFunction(std::string_view name, u32 parameterCount, std::function<ASTNodeIntegerLiteral*(std::vector<ASTNodeIntegerLiteral*>)> func) {
void addFunction(std::string_view name, u32 parameterCount, std::function<ASTNodeIntegerLiteral*(std::vector<ASTNode*>)> func) {
if (this->m_functions.contains(name.data()))
throwEvaluateError(hex::format("redefinition of function '%s'", name.data()), 1);
@ -81,7 +82,17 @@ namespace hex::lang {
PatternData* evaluatePointer(ASTNodePointerVariableDecl *node);
#define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* name(std::vector<ASTNodeIntegerLiteral*> params)
template<typename T>
T* asType(ASTNode *param) {
if (auto evaluatedParam = dynamic_cast<T*>(param); evaluatedParam != nullptr)
return evaluatedParam;
else
throwEvaluateError("function got wrong type of parameter", 1);
}
#define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* TOKEN_CONCAT(builtin_, name)(std::vector<ASTNode*> params)
BUILTIN_FUNCTION(findSequence);
BUILTIN_FUNCTION(readUnsigned);

View File

@ -54,6 +54,7 @@ namespace hex::lang {
}
ASTNode* parseFunctionCall();
ASTNode* parseStringLiteral();
ASTNode* parseScopeResolution(std::vector<std::string> &path);
ASTNode* parseRValue(std::vector<std::string> &path);
ASTNode* parseFactor();

View File

@ -2,12 +2,12 @@
namespace hex::lang {
#define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* Evaluator::name(std::vector<ASTNodeIntegerLiteral*> params)
#define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* Evaluator::TOKEN_CONCAT(builtin_, name)(std::vector<ASTNode*> params)
#define LITERAL_COMPARE(literal, cond) std::visit([&, this](auto &&literal) { return (cond) != 0; }, literal)
BUILTIN_FUNCTION(findSequence) {
auto& occurrenceIndex = params[0]->getValue();
auto& occurrenceIndex = asType<ASTNodeIntegerLiteral>(params[0])->getValue();
std::vector<u8> sequence;
for (u32 i = 1; i < params.size(); i++) {
sequence.push_back(std::visit([](auto &&value) -> u8 {
@ -15,7 +15,7 @@ namespace hex::lang {
return value;
else
throwEvaluateError("sequence bytes need to fit into 1 byte", 1);
}, params[i]->getValue()));
}, asType<ASTNodeIntegerLiteral>(params[i])->getValue()));
}
std::vector<u8> bytes(sequence.size(), 0x00);
@ -37,8 +37,8 @@ namespace hex::lang {
}
BUILTIN_FUNCTION(readUnsigned) {
auto address = params[0]->getValue();
auto size = params[1]->getValue();
auto address = asType<ASTNodeIntegerLiteral>(params[0])->getValue();
auto size = asType<ASTNodeIntegerLiteral>(params[1])->getValue();
if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize()))
throwEvaluateError("address out of range", 1);
@ -62,8 +62,8 @@ namespace hex::lang {
}
BUILTIN_FUNCTION(readSigned) {
auto address = params[0]->getValue();
auto size = params[1]->getValue();
auto address = asType<ASTNodeIntegerLiteral>(params[0])->getValue();
auto size = asType<ASTNodeIntegerLiteral>(params[1])->getValue();
if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize()))
throwEvaluateError("address out of range", 1);
@ -76,13 +76,13 @@ namespace hex::lang {
this->m_provider->read(address, value, size);
switch ((u8)size) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast<s8*>(value), 1, this->getCurrentEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast<s16*>(value), 2, this->getCurrentEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast<s32*>(value), 4, this->getCurrentEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast<s64*>(value), 8, this->getCurrentEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast<s128*>(value), 16, this->getCurrentEndian()) });
default: throwEvaluateError("invalid read size", 1);
}
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast<s8*>(value), 1, this->getCurrentEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast<s16*>(value), 2, this->getCurrentEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast<s32*>(value), 4, this->getCurrentEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast<s64*>(value), 8, this->getCurrentEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast<s128*>(value), 16, this->getCurrentEndian()) });
default: throwEvaluateError("invalid read size", 1);
}
}, address, size);
}

View File

@ -14,15 +14,15 @@ namespace hex::lang {
: m_provider(provider), m_defaultDataEndian(defaultDataEndian) {
this->addFunction("findSequence", Function::MoreParametersThan | 1, [this](auto params) {
return this->findSequence(params);
return this->builtin_findSequence(params);
});
this->addFunction("readUnsigned", 2, [this](auto params) {
return this->readUnsigned(params);
return this->builtin_readUnsigned(params);
});
this->addFunction("readSigned", 2, [this](auto params) {
return this->readSigned(params);
return this->builtin_readSigned(params);
});
}
@ -126,14 +126,18 @@ namespace hex::lang {
}
ASTNodeIntegerLiteral* Evaluator::evaluateFunctionCall(ASTNodeFunctionCall *node) {
std::vector<ASTNodeIntegerLiteral*> evaluatedParams;
std::vector<ASTNode*> evaluatedParams;
ScopeExit paramCleanup([&] {
for (auto &param : evaluatedParams)
delete param;
});
for (auto &param : node->getParams())
evaluatedParams.push_back(this->evaluateMathematicalExpression(static_cast<ASTNodeNumericExpression*>(param)));
for (auto &param : node->getParams()) {
if (auto numericExpression = dynamic_cast<ASTNodeNumericExpression*>(param); numericExpression != nullptr)
evaluatedParams.push_back(this->evaluateMathematicalExpression(numericExpression));
else if (auto stringLiteral = dynamic_cast<ASTNodeStringLiteral*>(param); stringLiteral != nullptr)
evaluatedParams.push_back(stringLiteral->clone());
}
if (!this->m_functions.contains(node->getFunctionName().data()))
throwEvaluateError(hex::format("no function named '%s' found", node->getFunctionName().data()), node->getLineNumber());
@ -278,8 +282,14 @@ namespace hex::lang {
return evaluateScopeResolution(exprScopeResolution);
else if (auto exprTernary = dynamic_cast<ASTNodeTernaryExpression*>(node); exprTernary != nullptr)
return evaluateTernaryExpression(exprTernary);
else if (auto exprFunctionCall = dynamic_cast<ASTNodeFunctionCall*>(node); exprFunctionCall != nullptr)
return evaluateFunctionCall(exprFunctionCall);
else if (auto exprFunctionCall = dynamic_cast<ASTNodeFunctionCall*>(node); exprFunctionCall != nullptr) {
auto returnValue = evaluateFunctionCall(exprFunctionCall);
if (returnValue == nullptr)
throwEvaluateError("function returning void used in expression", node->getLineNumber());
else
return returnValue;
}
else
throwEvaluateError("invalid operand", node->getLineNumber());
}
@ -675,6 +685,9 @@ namespace hex::lang {
this->m_globalMembers.push_back(this->evaluatePointer(pointerDeclNode));
} else if (auto typeDeclNode = dynamic_cast<ASTNodeTypeDecl*>(node); typeDeclNode != nullptr) {
this->m_types[typeDeclNode->getName().data()] = typeDeclNode->getType();
} else if (auto functionCallNode = dynamic_cast<ASTNodeFunctionCall*>(node); functionCallNode != nullptr) {
auto result = this->evaluateFunctionCall(functionCallNode);
delete result;
}
this->m_endianStack.pop_back();

View File

@ -28,7 +28,10 @@ namespace hex::lang {
});
while (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) {
params.push_back(parseMathematicalExpression());
if (MATCHES(sequence(STRING)))
params.push_back(parseStringLiteral());
else
params.push_back(parseMathematicalExpression());
if (MATCHES(sequence(SEPARATOR_COMMA, SEPARATOR_ROUNDBRACKETCLOSE)))
throwParseError("unexpected ',' at end of function parameter list", -1);
@ -41,7 +44,11 @@ namespace hex::lang {
paramCleanup.release();
return TO_NUMERIC_EXPRESSION(new ASTNodeFunctionCall(functionName, params));
return new ASTNodeFunctionCall(functionName, params);
}
ASTNode* Parser::parseStringLiteral() {
return new ASTNodeStringLiteral(getValue<std::string>(-1));
}
// Identifier::<Identifier[::]...>
@ -86,7 +93,7 @@ namespace hex::lang {
this->m_curr--;
return this->parseScopeResolution(path);
} else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) {
return this->parseFunctionCall();
return TO_NUMERIC_EXPRESSION(this->parseFunctionCall());
} else if (MATCHES(sequence(IDENTIFIER))) {
std::vector<std::string> path;
return this->parseRValue(path);
@ -591,6 +598,8 @@ namespace hex::lang {
statement = parseEnum();
else if (MATCHES(sequence(KEYWORD_BITFIELD, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN)))
statement = parseBitfield();
else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN)))
statement = parseFunctionCall();
else throwParseError("invalid sequence", 0);
if (!MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION)))