1
0
mirror of synced 2025-02-17 18:59:21 +01:00

Added function calling as well as a few builtin functions

This commit is contained in:
WerWolv 2021-01-07 15:37:37 +01:00
parent b47736b595
commit bef20f7808
8 changed files with 240 additions and 12 deletions

View File

@ -135,6 +135,7 @@ add_executable(imhex ${application_type}
source/lang/parser.cpp
source/lang/validator.cpp
source/lang/evaluator.cpp
source/lang/builtin_functions.cpp
source/providers/file_provider.cpp

View File

@ -442,4 +442,39 @@ namespace hex::lang {
std::vector<ASTNode*> m_trueBody, m_falseBody;
};
class ASTNodeFunctionCall : public ASTNode {
public:
explicit ASTNodeFunctionCall(std::string_view functionName, std::vector<ASTNode*> params)
: ASTNode(), m_functionName(functionName), m_params(std::move(params)) { }
~ASTNodeFunctionCall() override {
for (auto &param : this->m_params)
delete param;
}
ASTNodeFunctionCall(const ASTNodeFunctionCall &other) : ASTNode(other) {
this->m_functionName = other.m_functionName;
for (auto &param : other.m_params)
this->m_params.push_back(param->clone());
}
ASTNode* clone() const override {
return new ASTNodeFunctionCall(*this);
}
[[nodiscard]] std::string_view getFunctionName() {
return this->m_functionName;
}
[[nodiscard]] const std::vector<ASTNode*>& getParams() const {
return this->m_params;
}
private:
std::string m_functionName;
std::vector<ASTNode*> m_params;
};
}

View File

@ -21,6 +21,17 @@ namespace hex::lang {
const std::pair<u32, std::string>& getError() { return this->m_error; }
struct Function {
constexpr static u32 UnlimitedParameters = 0xFFFF'FFFF;
constexpr static u32 MoreParametersThan = 0x8000'0000;
constexpr static u32 LessParametersThan = 0x4000'0000;
constexpr static u32 NoParameters = 0x0000'0000;
u32 parameterCount;
std::function<ASTNodeIntegerLiteral*(std::vector<ASTNodeIntegerLiteral*>)> func;
};
private:
std::map<std::string, ASTNode*> m_types;
prv::Provider* &m_provider;
@ -28,6 +39,7 @@ namespace hex::lang {
u64 m_currOffset = 0;
std::optional<std::endian> m_currEndian;
std::vector<std::vector<PatternData*>*> m_currMembers;
std::map<std::string, Function> m_functions;
std::pair<u32, std::string> m_error;
@ -41,8 +53,16 @@ namespace hex::lang {
return this->m_currEndian.value_or(this->m_defaultDataEndian);
}
void addFunction(std::string_view name, u32 parameterCount, std::function<ASTNodeIntegerLiteral*(std::vector<ASTNodeIntegerLiteral*>)> func) {
if (this->m_functions.contains(name.data()))
throwEvaluateError(hex::format("redefinition of function '%s'", name.data()), 1);
this->m_functions[name.data()] = { parameterCount, func };
}
ASTNodeIntegerLiteral* evaluateScopeResolution(ASTNodeScopeResolution *node);
ASTNodeIntegerLiteral* evaluateRValue(ASTNodeRValue *node);
ASTNodeIntegerLiteral* evaluateFunctionCall(ASTNodeFunctionCall *node);
ASTNodeIntegerLiteral* evaluateOperator(ASTNodeIntegerLiteral *left, ASTNodeIntegerLiteral *right, Token::Operator op);
ASTNodeIntegerLiteral* evaluateOperand(ASTNode *node);
ASTNodeIntegerLiteral* evaluateTernaryExpression(ASTNodeTernaryExpression *node);
@ -59,6 +79,14 @@ namespace hex::lang {
PatternData* evaluateArray(ASTNodeArrayVariableDecl *node);
PatternData* evaluatePointer(ASTNodePointerVariableDecl *node);
#define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* name(std::vector<ASTNodeIntegerLiteral*> params)
BUILTIN_FUNCTION(findSequence);
BUILTIN_FUNCTION(readUnsigned);
BUILTIN_FUNCTION(readSigned);
#undef BUILTIN_FUNCTION
};
}

View File

@ -53,6 +53,7 @@ namespace hex::lang {
return this->m_curr[index].type;
}
ASTNode* parseFunctionCall();
ASTNode* parseScopeResolution(std::vector<std::string> &path);
ASTNode* parseRValue(std::vector<std::string> &path);
ASTNode* parseFactor();

View File

@ -0,0 +1,92 @@
#include "lang/evaluator.hpp"
namespace hex::lang {
#define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* Evaluator::name(std::vector<ASTNodeIntegerLiteral*> params)
#define LITERAL_COMPARE(literal, cond) std::visit([&, this](auto &&literal) { return (cond) != 0; }, literal)
BUILTIN_FUNCTION(findSequence) {
auto& occurrenceIndex = params[0]->getValue();
std::vector<u8> sequence;
for (u32 i = 1; i < params.size(); i++) {
sequence.push_back(std::visit([](auto &&value) -> u8 {
if (value <= 0xFF)
return value;
else
throwEvaluateError("sequence bytes need to fit into 1 byte", 1);
}, params[i]->getValue()));
}
std::vector<u8> bytes(sequence.size(), 0x00);
u32 occurrences = 0;
for (u64 offset = 0; offset < this->m_provider->getSize() - sequence.size(); offset++) {
this->m_provider->read(offset, bytes.data(), bytes.size());
if (bytes == sequence) {
if (LITERAL_COMPARE(occurrenceIndex, occurrenceIndex < occurrences)) {
occurrences++;
continue;
}
return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, offset });
}
}
throwEvaluateError("failed to find sequence", 1);
}
BUILTIN_FUNCTION(readUnsigned) {
auto address = params[0]->getValue();
auto size = params[1]->getValue();
if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize()))
throwEvaluateError("address out of range", 1);
return std::visit([this](auto &&address, auto &&size) {
if (size <= 0 || size > 16)
throwEvaluateError("invalid read size", 1);
u8 value[(u8)size];
this->m_provider->read(address, value, size);
switch ((u8)size) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast<u8*>(value), 1, this->getCurrentEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast<u16*>(value), 2, this->getCurrentEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, hex::changeEndianess(*reinterpret_cast<u32*>(value), 4, this->getCurrentEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, hex::changeEndianess(*reinterpret_cast<u64*>(value), 8, this->getCurrentEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, hex::changeEndianess(*reinterpret_cast<u128*>(value), 16, this->getCurrentEndian()) });
default: throwEvaluateError("invalid rvalue size", 1);
}
}, address, size);
}
BUILTIN_FUNCTION(readSigned) {
auto address = params[0]->getValue();
auto size = params[1]->getValue();
if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize()))
throwEvaluateError("address out of range", 1);
return std::visit([this](auto &&address, auto &&size) {
if (size <= 0 || size > 16)
throwEvaluateError("invalid read size", 1);
u8 value[(u8)size];
this->m_provider->read(address, value, size);
switch ((u8)size) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast<s8*>(value), 1, this->getCurrentEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast<s16*>(value), 2, this->getCurrentEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast<s32*>(value), 4, this->getCurrentEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast<s64*>(value), 8, this->getCurrentEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast<s128*>(value), 16, this->getCurrentEndian()) });
default: throwEvaluateError("invalid rvalue size", 1);
}
}, address, size);
}
#undef BUILTIN_FUNCTION
}

View File

@ -12,6 +12,18 @@ namespace hex::lang {
Evaluator::Evaluator(prv::Provider* &provider, std::endian defaultDataEndian)
: m_provider(provider), m_defaultDataEndian(defaultDataEndian) {
this->addFunction("findSequence", Function::MoreParametersThan | 1, [this](auto params) {
return this->findSequence(params);
});
this->addFunction("readUnsigned", 2, [this](auto params) {
return this->readUnsigned(params);
});
this->addFunction("readSigned", 2, [this](auto params) {
return this->readSigned(params);
});
}
ASTNodeIntegerLiteral* Evaluator::evaluateScopeResolution(ASTNodeScopeResolution *node) {
@ -84,8 +96,6 @@ namespace hex::lang {
u8 value[enumPattern->getSize()];
this->m_provider->read(enumPattern->getOffset(), value, enumPattern->getSize());
switch (enumPattern->getSize()) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast<u8*>(value), 1, this->getCurrentEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast<u16*>(value), 2, this->getCurrentEndian()) });
@ -98,6 +108,37 @@ namespace hex::lang {
throwEvaluateError("tried to use non-integer value in numeric expression", node->getLineNumber());
}
ASTNodeIntegerLiteral* Evaluator::evaluateFunctionCall(ASTNodeFunctionCall *node) {
std::vector<ASTNodeIntegerLiteral*> evaluatedParams;
ScopeExit paramCleanup([&] {
for (auto &param : evaluatedParams)
delete param;
});
for (auto &param : node->getParams())
evaluatedParams.push_back(this->evaluateMathematicalExpression(static_cast<ASTNodeNumericExpression*>(param)));
if (!this->m_functions.contains(node->getFunctionName().data()))
throwEvaluateError(hex::format("no function named '%s' found", node->getFunctionName().data()), node->getLineNumber());
auto &function = this->m_functions[node->getFunctionName().data()];
if (function.parameterCount == Function::UnlimitedParameters) {
; // Don't check parameter count
}
else if (function.parameterCount & Function::LessParametersThan) {
if (evaluatedParams.size() >= (function.parameterCount & ~Function::LessParametersThan))
throwEvaluateError(hex::format("too many parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount & ~Function::LessParametersThan), node->getLineNumber());
} else if (function.parameterCount & Function::MoreParametersThan) {
if (evaluatedParams.size() <= (function.parameterCount & ~Function::MoreParametersThan))
throwEvaluateError(hex::format("too few parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount & ~Function::MoreParametersThan), node->getLineNumber());
} else if (function.parameterCount != evaluatedParams.size()) {
throwEvaluateError(hex::format("invalid number of parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount), node->getLineNumber());
}
return function.func(evaluatedParams);
}
#define FLOAT_BIT_OPERATION(name) \
auto name(std::floating_point auto left, auto right) { throw std::runtime_error(""); return 0; } \
auto name(auto left, std::floating_point auto right) { throw std::runtime_error(""); return 0; } \
@ -220,6 +261,8 @@ namespace hex::lang {
return evaluateScopeResolution(exprScopeResolution);
else if (auto exprTernary = dynamic_cast<ASTNodeTernaryExpression*>(node); exprTernary != nullptr)
return evaluateTernaryExpression(exprTernary);
else if (auto exprFunctionCall = dynamic_cast<ASTNodeFunctionCall*>(node); exprFunctionCall != nullptr)
return evaluateFunctionCall(exprFunctionCall);
else
throwEvaluateError("invalid operand", node->getLineNumber());
}

View File

@ -54,12 +54,14 @@ namespace hex::lang {
} else if (numberData.ends_with("LL")) {
type = Token::ValueType::Signed128Bit;
numberData.remove_suffix(2);
} else if (numberData.ends_with('F')) {
type = Token::ValueType::Float;
numberData.remove_suffix(1);
} else if (numberData.ends_with('D')) {
type = Token::ValueType::Double;
numberData.remove_suffix(1);
} else if (!numberData.starts_with("0x") && !numberData.starts_with("0b")) {
if (numberData.ends_with('F')) {
type = Token::ValueType::Float;
numberData.remove_suffix(1);
} else if (numberData.ends_with('D')) {
type = Token::ValueType::Double;
numberData.remove_suffix(1);
}
}
if (numberData.starts_with("0x")) {

View File

@ -18,6 +18,32 @@ namespace hex::lang {
/* Mathematical expressions */
// Identifier([(parseMathematicalExpression)|<(parseMathematicalExpression),...>(parseMathematicalExpression)]
ASTNode* Parser::parseFunctionCall() {
auto functionName = getValue<std::string>(-2);
std::vector<ASTNode*> params;
ScopeExit paramCleanup([&]{
for (auto &param : params)
delete param;
});
while (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) {
params.push_back(parseMathematicalExpression());
if (MATCHES(sequence(SEPARATOR_COMMA, SEPARATOR_ROUNDBRACKETCLOSE)))
throwParseError("unexpected ',' at end of function parameter list", -1);
else if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE)))
break;
else if (!MATCHES(sequence(SEPARATOR_COMMA)))
throwParseError("missing ',' between parameters", -1);
}
paramCleanup.release();
return TO_NUMERIC_EXPRESSION(new ASTNodeFunctionCall(functionName, params));
}
// Identifier::<Identifier[::]...>
ASTNode* Parser::parseScopeResolution(std::vector<std::string> &path) {
if (peek(IDENTIFIER, -1))
@ -59,12 +85,12 @@ namespace hex::lang {
std::vector<std::string> path;
this->m_curr--;
return this->parseScopeResolution(path);
}
else if (MATCHES(sequence(IDENTIFIER))) {
} else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) {
return this->parseFunctionCall();
} else if (MATCHES(sequence(IDENTIFIER))) {
std::vector<std::string> path;
return this->parseRValue(path);
}
else
} else
throwParseError("expected integer or parenthesis");
}