Added function calling as well as a few builtin functions
This commit is contained in:
parent
b47736b595
commit
bef20f7808
@ -135,6 +135,7 @@ add_executable(imhex ${application_type}
|
||||
source/lang/parser.cpp
|
||||
source/lang/validator.cpp
|
||||
source/lang/evaluator.cpp
|
||||
source/lang/builtin_functions.cpp
|
||||
|
||||
source/providers/file_provider.cpp
|
||||
|
||||
|
@ -442,4 +442,39 @@ namespace hex::lang {
|
||||
std::vector<ASTNode*> m_trueBody, m_falseBody;
|
||||
};
|
||||
|
||||
class ASTNodeFunctionCall : public ASTNode {
|
||||
public:
|
||||
explicit ASTNodeFunctionCall(std::string_view functionName, std::vector<ASTNode*> params)
|
||||
: ASTNode(), m_functionName(functionName), m_params(std::move(params)) { }
|
||||
|
||||
~ASTNodeFunctionCall() override {
|
||||
for (auto ¶m : this->m_params)
|
||||
delete param;
|
||||
}
|
||||
|
||||
ASTNodeFunctionCall(const ASTNodeFunctionCall &other) : ASTNode(other) {
|
||||
this->m_functionName = other.m_functionName;
|
||||
|
||||
for (auto ¶m : other.m_params)
|
||||
this->m_params.push_back(param->clone());
|
||||
}
|
||||
|
||||
ASTNode* clone() const override {
|
||||
return new ASTNodeFunctionCall(*this);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string_view getFunctionName() {
|
||||
return this->m_functionName;
|
||||
}
|
||||
|
||||
[[nodiscard]] const std::vector<ASTNode*>& getParams() const {
|
||||
return this->m_params;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string m_functionName;
|
||||
std::vector<ASTNode*> m_params;
|
||||
};
|
||||
|
||||
|
||||
}
|
@ -21,6 +21,17 @@ namespace hex::lang {
|
||||
|
||||
const std::pair<u32, std::string>& getError() { return this->m_error; }
|
||||
|
||||
|
||||
struct Function {
|
||||
constexpr static u32 UnlimitedParameters = 0xFFFF'FFFF;
|
||||
constexpr static u32 MoreParametersThan = 0x8000'0000;
|
||||
constexpr static u32 LessParametersThan = 0x4000'0000;
|
||||
constexpr static u32 NoParameters = 0x0000'0000;
|
||||
|
||||
u32 parameterCount;
|
||||
std::function<ASTNodeIntegerLiteral*(std::vector<ASTNodeIntegerLiteral*>)> func;
|
||||
};
|
||||
|
||||
private:
|
||||
std::map<std::string, ASTNode*> m_types;
|
||||
prv::Provider* &m_provider;
|
||||
@ -28,6 +39,7 @@ namespace hex::lang {
|
||||
u64 m_currOffset = 0;
|
||||
std::optional<std::endian> m_currEndian;
|
||||
std::vector<std::vector<PatternData*>*> m_currMembers;
|
||||
std::map<std::string, Function> m_functions;
|
||||
|
||||
std::pair<u32, std::string> m_error;
|
||||
|
||||
@ -41,8 +53,16 @@ namespace hex::lang {
|
||||
return this->m_currEndian.value_or(this->m_defaultDataEndian);
|
||||
}
|
||||
|
||||
void addFunction(std::string_view name, u32 parameterCount, std::function<ASTNodeIntegerLiteral*(std::vector<ASTNodeIntegerLiteral*>)> func) {
|
||||
if (this->m_functions.contains(name.data()))
|
||||
throwEvaluateError(hex::format("redefinition of function '%s'", name.data()), 1);
|
||||
|
||||
this->m_functions[name.data()] = { parameterCount, func };
|
||||
}
|
||||
|
||||
ASTNodeIntegerLiteral* evaluateScopeResolution(ASTNodeScopeResolution *node);
|
||||
ASTNodeIntegerLiteral* evaluateRValue(ASTNodeRValue *node);
|
||||
ASTNodeIntegerLiteral* evaluateFunctionCall(ASTNodeFunctionCall *node);
|
||||
ASTNodeIntegerLiteral* evaluateOperator(ASTNodeIntegerLiteral *left, ASTNodeIntegerLiteral *right, Token::Operator op);
|
||||
ASTNodeIntegerLiteral* evaluateOperand(ASTNode *node);
|
||||
ASTNodeIntegerLiteral* evaluateTernaryExpression(ASTNodeTernaryExpression *node);
|
||||
@ -59,6 +79,14 @@ namespace hex::lang {
|
||||
PatternData* evaluateArray(ASTNodeArrayVariableDecl *node);
|
||||
PatternData* evaluatePointer(ASTNodePointerVariableDecl *node);
|
||||
|
||||
|
||||
#define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* name(std::vector<ASTNodeIntegerLiteral*> params)
|
||||
|
||||
BUILTIN_FUNCTION(findSequence);
|
||||
BUILTIN_FUNCTION(readUnsigned);
|
||||
BUILTIN_FUNCTION(readSigned);
|
||||
|
||||
#undef BUILTIN_FUNCTION
|
||||
};
|
||||
|
||||
}
|
@ -53,6 +53,7 @@ namespace hex::lang {
|
||||
return this->m_curr[index].type;
|
||||
}
|
||||
|
||||
ASTNode* parseFunctionCall();
|
||||
ASTNode* parseScopeResolution(std::vector<std::string> &path);
|
||||
ASTNode* parseRValue(std::vector<std::string> &path);
|
||||
ASTNode* parseFactor();
|
||||
|
92
source/lang/builtin_functions.cpp
Normal file
92
source/lang/builtin_functions.cpp
Normal file
@ -0,0 +1,92 @@
|
||||
#include "lang/evaluator.hpp"
|
||||
|
||||
namespace hex::lang {
|
||||
|
||||
#define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* Evaluator::name(std::vector<ASTNodeIntegerLiteral*> params)
|
||||
|
||||
#define LITERAL_COMPARE(literal, cond) std::visit([&, this](auto &&literal) { return (cond) != 0; }, literal)
|
||||
|
||||
BUILTIN_FUNCTION(findSequence) {
|
||||
auto& occurrenceIndex = params[0]->getValue();
|
||||
std::vector<u8> sequence;
|
||||
for (u32 i = 1; i < params.size(); i++) {
|
||||
sequence.push_back(std::visit([](auto &&value) -> u8 {
|
||||
if (value <= 0xFF)
|
||||
return value;
|
||||
else
|
||||
throwEvaluateError("sequence bytes need to fit into 1 byte", 1);
|
||||
}, params[i]->getValue()));
|
||||
}
|
||||
|
||||
std::vector<u8> bytes(sequence.size(), 0x00);
|
||||
u32 occurrences = 0;
|
||||
for (u64 offset = 0; offset < this->m_provider->getSize() - sequence.size(); offset++) {
|
||||
this->m_provider->read(offset, bytes.data(), bytes.size());
|
||||
|
||||
if (bytes == sequence) {
|
||||
if (LITERAL_COMPARE(occurrenceIndex, occurrenceIndex < occurrences)) {
|
||||
occurrences++;
|
||||
continue;
|
||||
}
|
||||
|
||||
return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, offset });
|
||||
}
|
||||
}
|
||||
|
||||
throwEvaluateError("failed to find sequence", 1);
|
||||
}
|
||||
|
||||
BUILTIN_FUNCTION(readUnsigned) {
|
||||
auto address = params[0]->getValue();
|
||||
auto size = params[1]->getValue();
|
||||
|
||||
if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize()))
|
||||
throwEvaluateError("address out of range", 1);
|
||||
|
||||
return std::visit([this](auto &&address, auto &&size) {
|
||||
if (size <= 0 || size > 16)
|
||||
throwEvaluateError("invalid read size", 1);
|
||||
|
||||
u8 value[(u8)size];
|
||||
this->m_provider->read(address, value, size);
|
||||
|
||||
switch ((u8)size) {
|
||||
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast<u8*>(value), 1, this->getCurrentEndian()) });
|
||||
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast<u16*>(value), 2, this->getCurrentEndian()) });
|
||||
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, hex::changeEndianess(*reinterpret_cast<u32*>(value), 4, this->getCurrentEndian()) });
|
||||
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, hex::changeEndianess(*reinterpret_cast<u64*>(value), 8, this->getCurrentEndian()) });
|
||||
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, hex::changeEndianess(*reinterpret_cast<u128*>(value), 16, this->getCurrentEndian()) });
|
||||
default: throwEvaluateError("invalid rvalue size", 1);
|
||||
}
|
||||
}, address, size);
|
||||
}
|
||||
|
||||
BUILTIN_FUNCTION(readSigned) {
|
||||
auto address = params[0]->getValue();
|
||||
auto size = params[1]->getValue();
|
||||
|
||||
if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize()))
|
||||
throwEvaluateError("address out of range", 1);
|
||||
|
||||
return std::visit([this](auto &&address, auto &&size) {
|
||||
if (size <= 0 || size > 16)
|
||||
throwEvaluateError("invalid read size", 1);
|
||||
|
||||
u8 value[(u8)size];
|
||||
this->m_provider->read(address, value, size);
|
||||
|
||||
switch ((u8)size) {
|
||||
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast<s8*>(value), 1, this->getCurrentEndian()) });
|
||||
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast<s16*>(value), 2, this->getCurrentEndian()) });
|
||||
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast<s32*>(value), 4, this->getCurrentEndian()) });
|
||||
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast<s64*>(value), 8, this->getCurrentEndian()) });
|
||||
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast<s128*>(value), 16, this->getCurrentEndian()) });
|
||||
default: throwEvaluateError("invalid rvalue size", 1);
|
||||
}
|
||||
}, address, size);
|
||||
}
|
||||
|
||||
|
||||
#undef BUILTIN_FUNCTION
|
||||
|
||||
}
|
@ -12,6 +12,18 @@ namespace hex::lang {
|
||||
|
||||
Evaluator::Evaluator(prv::Provider* &provider, std::endian defaultDataEndian)
|
||||
: m_provider(provider), m_defaultDataEndian(defaultDataEndian) {
|
||||
|
||||
this->addFunction("findSequence", Function::MoreParametersThan | 1, [this](auto params) {
|
||||
return this->findSequence(params);
|
||||
});
|
||||
|
||||
this->addFunction("readUnsigned", 2, [this](auto params) {
|
||||
return this->readUnsigned(params);
|
||||
});
|
||||
|
||||
this->addFunction("readSigned", 2, [this](auto params) {
|
||||
return this->readSigned(params);
|
||||
});
|
||||
}
|
||||
|
||||
ASTNodeIntegerLiteral* Evaluator::evaluateScopeResolution(ASTNodeScopeResolution *node) {
|
||||
@ -84,8 +96,6 @@ namespace hex::lang {
|
||||
u8 value[enumPattern->getSize()];
|
||||
this->m_provider->read(enumPattern->getOffset(), value, enumPattern->getSize());
|
||||
|
||||
|
||||
|
||||
switch (enumPattern->getSize()) {
|
||||
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast<u8*>(value), 1, this->getCurrentEndian()) });
|
||||
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast<u16*>(value), 2, this->getCurrentEndian()) });
|
||||
@ -98,6 +108,37 @@ namespace hex::lang {
|
||||
throwEvaluateError("tried to use non-integer value in numeric expression", node->getLineNumber());
|
||||
}
|
||||
|
||||
ASTNodeIntegerLiteral* Evaluator::evaluateFunctionCall(ASTNodeFunctionCall *node) {
|
||||
std::vector<ASTNodeIntegerLiteral*> evaluatedParams;
|
||||
ScopeExit paramCleanup([&] {
|
||||
for (auto ¶m : evaluatedParams)
|
||||
delete param;
|
||||
});
|
||||
|
||||
for (auto ¶m : node->getParams())
|
||||
evaluatedParams.push_back(this->evaluateMathematicalExpression(static_cast<ASTNodeNumericExpression*>(param)));
|
||||
|
||||
if (!this->m_functions.contains(node->getFunctionName().data()))
|
||||
throwEvaluateError(hex::format("no function named '%s' found", node->getFunctionName().data()), node->getLineNumber());
|
||||
|
||||
auto &function = this->m_functions[node->getFunctionName().data()];
|
||||
|
||||
if (function.parameterCount == Function::UnlimitedParameters) {
|
||||
; // Don't check parameter count
|
||||
}
|
||||
else if (function.parameterCount & Function::LessParametersThan) {
|
||||
if (evaluatedParams.size() >= (function.parameterCount & ~Function::LessParametersThan))
|
||||
throwEvaluateError(hex::format("too many parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount & ~Function::LessParametersThan), node->getLineNumber());
|
||||
} else if (function.parameterCount & Function::MoreParametersThan) {
|
||||
if (evaluatedParams.size() <= (function.parameterCount & ~Function::MoreParametersThan))
|
||||
throwEvaluateError(hex::format("too few parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount & ~Function::MoreParametersThan), node->getLineNumber());
|
||||
} else if (function.parameterCount != evaluatedParams.size()) {
|
||||
throwEvaluateError(hex::format("invalid number of parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount), node->getLineNumber());
|
||||
}
|
||||
|
||||
return function.func(evaluatedParams);
|
||||
}
|
||||
|
||||
#define FLOAT_BIT_OPERATION(name) \
|
||||
auto name(std::floating_point auto left, auto right) { throw std::runtime_error(""); return 0; } \
|
||||
auto name(auto left, std::floating_point auto right) { throw std::runtime_error(""); return 0; } \
|
||||
@ -220,6 +261,8 @@ namespace hex::lang {
|
||||
return evaluateScopeResolution(exprScopeResolution);
|
||||
else if (auto exprTernary = dynamic_cast<ASTNodeTernaryExpression*>(node); exprTernary != nullptr)
|
||||
return evaluateTernaryExpression(exprTernary);
|
||||
else if (auto exprFunctionCall = dynamic_cast<ASTNodeFunctionCall*>(node); exprFunctionCall != nullptr)
|
||||
return evaluateFunctionCall(exprFunctionCall);
|
||||
else
|
||||
throwEvaluateError("invalid operand", node->getLineNumber());
|
||||
}
|
||||
|
@ -54,12 +54,14 @@ namespace hex::lang {
|
||||
} else if (numberData.ends_with("LL")) {
|
||||
type = Token::ValueType::Signed128Bit;
|
||||
numberData.remove_suffix(2);
|
||||
} else if (numberData.ends_with('F')) {
|
||||
type = Token::ValueType::Float;
|
||||
numberData.remove_suffix(1);
|
||||
} else if (numberData.ends_with('D')) {
|
||||
type = Token::ValueType::Double;
|
||||
numberData.remove_suffix(1);
|
||||
} else if (!numberData.starts_with("0x") && !numberData.starts_with("0b")) {
|
||||
if (numberData.ends_with('F')) {
|
||||
type = Token::ValueType::Float;
|
||||
numberData.remove_suffix(1);
|
||||
} else if (numberData.ends_with('D')) {
|
||||
type = Token::ValueType::Double;
|
||||
numberData.remove_suffix(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (numberData.starts_with("0x")) {
|
||||
|
@ -18,6 +18,32 @@ namespace hex::lang {
|
||||
|
||||
/* Mathematical expressions */
|
||||
|
||||
// Identifier([(parseMathematicalExpression)|<(parseMathematicalExpression),...>(parseMathematicalExpression)]
|
||||
ASTNode* Parser::parseFunctionCall() {
|
||||
auto functionName = getValue<std::string>(-2);
|
||||
std::vector<ASTNode*> params;
|
||||
ScopeExit paramCleanup([&]{
|
||||
for (auto ¶m : params)
|
||||
delete param;
|
||||
});
|
||||
|
||||
while (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) {
|
||||
params.push_back(parseMathematicalExpression());
|
||||
|
||||
if (MATCHES(sequence(SEPARATOR_COMMA, SEPARATOR_ROUNDBRACKETCLOSE)))
|
||||
throwParseError("unexpected ',' at end of function parameter list", -1);
|
||||
else if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE)))
|
||||
break;
|
||||
else if (!MATCHES(sequence(SEPARATOR_COMMA)))
|
||||
throwParseError("missing ',' between parameters", -1);
|
||||
|
||||
}
|
||||
|
||||
paramCleanup.release();
|
||||
|
||||
return TO_NUMERIC_EXPRESSION(new ASTNodeFunctionCall(functionName, params));
|
||||
}
|
||||
|
||||
// Identifier::<Identifier[::]...>
|
||||
ASTNode* Parser::parseScopeResolution(std::vector<std::string> &path) {
|
||||
if (peek(IDENTIFIER, -1))
|
||||
@ -59,12 +85,12 @@ namespace hex::lang {
|
||||
std::vector<std::string> path;
|
||||
this->m_curr--;
|
||||
return this->parseScopeResolution(path);
|
||||
}
|
||||
else if (MATCHES(sequence(IDENTIFIER))) {
|
||||
} else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) {
|
||||
return this->parseFunctionCall();
|
||||
} else if (MATCHES(sequence(IDENTIFIER))) {
|
||||
std::vector<std::string> path;
|
||||
return this->parseRValue(path);
|
||||
}
|
||||
else
|
||||
} else
|
||||
throwParseError("expected integer or parenthesis");
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user