1
0
mirror of synced 2024-11-12 10:10:53 +01:00

patterns: Added support for declaring custom functions

This commit is contained in:
WerWolv 2021-06-20 21:22:31 +02:00
parent ac53b4bcab
commit 7f0bdc95da
13 changed files with 471 additions and 189 deletions

View File

@ -40,7 +40,7 @@ namespace hex::plugin::builtin {
continue;
}
return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, offset });
return new ASTNodeIntegerLiteral(offset);
}
}
@ -63,11 +63,11 @@ namespace hex::plugin::builtin {
SharedData::currentProvider->read(address, value, size);
switch ((u8)size) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, *reinterpret_cast<u8*>(value) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, *reinterpret_cast<u16*>(value) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, *reinterpret_cast<u32*>(value) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, *reinterpret_cast<u64*>(value) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, *reinterpret_cast<u128*>(value) });
case 1: return new ASTNodeIntegerLiteral(*reinterpret_cast<u8*>(value));
case 2: return new ASTNodeIntegerLiteral(*reinterpret_cast<u16*>(value));
case 4: return new ASTNodeIntegerLiteral(*reinterpret_cast<u32*>(value));
case 8: return new ASTNodeIntegerLiteral(*reinterpret_cast<u64*>(value));
case 16: return new ASTNodeIntegerLiteral(*reinterpret_cast<u128*>(value));
default: ctx.getConsole().abortEvaluation("invalid read size");
}
}, address, size);
@ -89,11 +89,11 @@ namespace hex::plugin::builtin {
SharedData::currentProvider->read(address, value, size);
switch ((u8)size) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, *reinterpret_cast<s8*>(value) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, *reinterpret_cast<s16*>(value) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, *reinterpret_cast<s32*>(value) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, *reinterpret_cast<s64*>(value) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, *reinterpret_cast<s128*>(value) });
case 1: return new ASTNodeIntegerLiteral(*reinterpret_cast<s8*>(value));
case 2: return new ASTNodeIntegerLiteral(*reinterpret_cast<s16*>(value));
case 4: return new ASTNodeIntegerLiteral(*reinterpret_cast<s32*>(value));
case 8: return new ASTNodeIntegerLiteral(*reinterpret_cast<s64*>(value));
case 16: return new ASTNodeIntegerLiteral(*reinterpret_cast<s128*>(value));
default: ctx.getConsole().abortEvaluation("invalid read size");
}
}, address, size);
@ -127,25 +127,19 @@ namespace hex::plugin::builtin {
for (auto& param : params) {
if (auto integerLiteral = dynamic_cast<ASTNodeIntegerLiteral*>(param); integerLiteral != nullptr) {
std::visit([&](auto &&value) {
switch (integerLiteral->getType()) {
case lang::Token::ValueType::Character: message += (char)value; break;
case lang::Token::ValueType::Boolean: message += value == 0 ? "false" : "true"; break;
case lang::Token::ValueType::Unsigned8Bit:
case lang::Token::ValueType::Unsigned16Bit:
case lang::Token::ValueType::Unsigned32Bit:
case lang::Token::ValueType::Unsigned64Bit:
case lang::Token::ValueType::Unsigned128Bit:
message += std::to_string(static_cast<u64>(value));
break;
case lang::Token::ValueType::Signed8Bit:
case lang::Token::ValueType::Signed16Bit:
case lang::Token::ValueType::Signed32Bit:
case lang::Token::ValueType::Signed64Bit:
case lang::Token::ValueType::Signed128Bit:
message += std::to_string(static_cast<s64>(value));
break;
default: message += "< Custom Type >";
}
using Type = std::remove_cvref_t<decltype(value)>;
if constexpr (std::is_same_v<Type, char>)
message += (char)value;
else if constexpr (std::is_same_v<Type, bool>)
message += value == 0 ? "false" : "true";
else if constexpr (std::is_unsigned_v<Type>)
message += std::to_string(static_cast<u64>(value));
else if constexpr (std::is_signed_v<Type>)
message += std::to_string(static_cast<s64>(value));
else if constexpr (std::is_floating_point_v<Type>)
message += std::to_string(value);
else
message += "< Custom Type >";
}, integerLiteral->getValue());
}
else if (auto stringLiteral = dynamic_cast<ASTNodeStringLiteral*>(param); stringLiteral != nullptr)
@ -167,12 +161,12 @@ namespace hex::plugin::builtin {
return remainder != 0 ? u64(value) + (u64(alignment) - remainder) : u64(value);
}, alignment, value);
return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, u64(result) });
return new ASTNodeIntegerLiteral(u64(result));
});
/* dataSize() */
ContentRegistry::PatternLanguageFunctions::add("dataSize", ContentRegistry::PatternLanguageFunctions::NoParameters, [](auto &ctx, auto params) -> ASTNode* {
return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, u64(SharedData::currentProvider->getActualSize()) });
return new ASTNodeIntegerLiteral(u64(SharedData::currentProvider->getActualSize()));
});
}

View File

@ -87,10 +87,10 @@ namespace hex {
struct Function {
u32 parameterCount;
std::function<hex::lang::ASTNode*(hex::lang::Evaluator&, std::vector<hex::lang::ASTNode*>)> func;
std::function<hex::lang::ASTNode*(hex::lang::Evaluator&, std::vector<hex::lang::ASTNode*>&)> func;
};
static void add(std::string_view name, u32 parameterCount, const std::function<hex::lang::ASTNode*(hex::lang::Evaluator&, std::vector<hex::lang::ASTNode*>)> &func);
static void add(std::string_view name, u32 parameterCount, const std::function<hex::lang::ASTNode*(hex::lang::Evaluator&, std::vector<hex::lang::ASTNode*>&)> &func);
static std::map<std::string, ContentRegistry::PatternLanguageFunctions::Function>& getEntries();
};

View File

@ -56,11 +56,7 @@ namespace hex::lang {
}
[[nodiscard]] const auto& getValue() const {
return this->m_literal.second;
}
[[nodiscard]] Token::ValueType getType() const {
return this->m_literal.first;
return this->m_literal;
}
private:
@ -622,4 +618,105 @@ namespace hex::lang {
Token::Operator m_op;
ASTNode *m_expression;
};
class ASTNodeFunctionDefinition : public ASTNode {
public:
ASTNodeFunctionDefinition(std::string name, std::vector<std::string> params, std::vector<ASTNode*> body)
: m_name(std::move(name)), m_params(std::move(params)), m_body(std::move(body)) {
}
ASTNodeFunctionDefinition(const ASTNodeFunctionDefinition &other) : ASTNode(other) {
this->m_name = other.m_name;
this->m_params = other.m_params;
for (auto statement : other.m_body) {
this->m_body.push_back(statement->clone());
}
}
[[nodiscard]] ASTNode* clone() const override {
return new ASTNodeFunctionDefinition(*this);
}
~ASTNodeFunctionDefinition() override {
for (auto statement : this->m_body)
delete statement;
}
[[nodiscard]] std::string_view getName() const {
return this->m_name;
}
[[nodiscard]] const auto& getParams() const {
return this->m_params;
}
[[nodiscard]] const auto& getBody() const {
return this->m_body;
}
private:
std::string m_name;
std::vector<std::string> m_params;
std::vector<ASTNode*> m_body;
};
class ASTNodeAssignment : public ASTNode {
public:
ASTNodeAssignment(std::string lvalueName, ASTNode *rvalue) : m_lvalueName(std::move(lvalueName)), m_rvalue(rvalue) {
}
ASTNodeAssignment(const ASTNodeAssignment &other) : ASTNode(other) {
this->m_lvalueName = other.m_lvalueName;
this->m_rvalue = other.m_rvalue->clone();
}
[[nodiscard]] ASTNode* clone() const override {
return new ASTNodeAssignment(*this);
}
~ASTNodeAssignment() override {
delete this->m_rvalue;
}
[[nodiscard]] std::string_view getLValueName() const {
return this->m_lvalueName;
}
[[nodiscard]] ASTNode* getRValue() const {
return this->m_rvalue;
}
private:
std::string m_lvalueName;
ASTNode *m_rvalue;
};
class ASTNodeReturnStatement : public ASTNode {
public:
ASTNodeReturnStatement(ASTNode *rvalue) : m_rvalue(rvalue) {
}
ASTNodeReturnStatement(const ASTNodeReturnStatement &other) : ASTNode(other) {
this->m_rvalue = other.m_rvalue->clone();
}
[[nodiscard]] ASTNode* clone() const override {
return new ASTNodeReturnStatement(*this);
}
~ASTNodeReturnStatement() override {
delete this->m_rvalue;
}
[[nodiscard]] ASTNode* getRValue() const {
return this->m_rvalue;
}
private:
ASTNode *m_rvalue;
};
}

View File

@ -9,6 +9,7 @@
#include <hex/lang/log_console.hpp>
#include <bit>
#include <span>
#include <string>
#include <unordered_map>
#include <vector>
@ -46,12 +47,17 @@ namespace hex::lang {
std::vector<std::endian> m_endianStack;
std::vector<PatternData*> m_globalMembers;
std::vector<std::vector<PatternData*>*> m_currMembers;
std::vector<std::vector<PatternData*>*> m_localVariables;
std::vector<PatternData*> m_currMemberScope;
std::vector<u8> m_localStack;
std::map<std::string, ContentRegistry::PatternLanguageFunctions::Function> m_definedFunctions;
LogConsole m_console;
u32 m_recursionLimit;
u32 m_currRecursionDepth;
void createLocalVariable(std::string_view varName, PatternData *pattern);
void setLocalVariableValue(std::string_view varName, const void *value, size_t size);
ASTNodeIntegerLiteral* evaluateScopeResolution(ASTNodeScopeResolution *node);
ASTNodeIntegerLiteral* evaluateRValue(ASTNodeRValue *node);
@ -61,6 +67,7 @@ namespace hex::lang {
ASTNodeIntegerLiteral* evaluateOperand(ASTNode *node);
ASTNodeIntegerLiteral* evaluateTernaryExpression(ASTNodeTernaryExpression *node);
ASTNodeIntegerLiteral* evaluateMathematicalExpression(ASTNodeNumericExpression *node);
void evaluateFunctionDefinition(ASTNodeFunctionDefinition *node);
PatternData* findPattern(std::vector<PatternData*> currMembers, const ASTNodeRValue::Path &path);
PatternData* evaluateAttributes(ASTNode *currNode, PatternData *currPattern);

View File

@ -72,6 +72,9 @@ namespace hex::lang {
ASTNode* parseMathematicalExpression();
void parseAttribute(Attributable *currNode);
ASTNode* parseFunctionDefintion();
ASTNode* parseVariableAssignment();
ASTNode* parseReturnStatement();
ASTNode* parseConditional();
ASTNode* parseWhileStatement();
ASTNode* parseType(s32 startIndex);

View File

@ -167,10 +167,18 @@ namespace hex::lang {
this->m_hidden = hidden;
}
bool isHidden() const {
[[nodiscard]] bool isHidden() const {
return this->m_hidden;
}
void setLocal(bool local) {
this->m_local = local;
}
[[nodiscard]] bool isLocal() const {
return this->m_local;
}
protected:
void createDefaultEntry(std::string_view value) const {
ImGui::TableNextRow();
@ -217,6 +225,7 @@ namespace hex::lang {
std::string m_typeName;
PatternData *m_parent;
bool m_local = false;
};
class PatternDataPadding : public PatternData {
@ -886,7 +895,7 @@ namespace hex::lang {
}
return false;
}, entryValueLiteral.second);
}, entryValueLiteral);
if (matches)
break;
}

View File

@ -33,7 +33,9 @@ namespace hex::lang {
If,
Else,
Parent,
While
While,
Function,
Return
};
enum class Operator {
@ -107,8 +109,7 @@ namespace hex::lang {
EndOfProgram
};
using Integers = std::variant<u8, s8, u16, s16, u32, s32, u64, s64, u128, s128, float, double>;
using IntegerLiteral = std::pair<ValueType, Integers>;
using IntegerLiteral = std::variant<char, bool, u8, s8, u16, s16, u32, s32, u64, s64, u128, s128, float, double>;
using ValueTypes = std::variant<Keyword, std::string, Operator, IntegerLiteral, ValueType, Separator>;
Token(Type type, auto value, u32 lineNumber) : type(type), value(value), lineNumber(lineNumber) {
@ -131,28 +132,6 @@ namespace hex::lang {
return static_cast<u32>(type) >> 4;
}
[[nodiscard]] constexpr static inline IntegerLiteral castTo(ValueType type, const Integers &literal) {
return std::visit([type](auto &&value) {
switch (type) {
case ValueType::Signed8Bit: return IntegerLiteral(type, static_cast<s8>(value));
case ValueType::Signed16Bit: return IntegerLiteral(type, static_cast<s16>(value));
case ValueType::Signed32Bit: return IntegerLiteral(type, static_cast<s32>(value));
case ValueType::Signed64Bit: return IntegerLiteral(type, static_cast<s64>(value));
case ValueType::Signed128Bit: return IntegerLiteral(type, static_cast<s128>(value));
case ValueType::Unsigned8Bit: return IntegerLiteral(type, static_cast<u8>(value));
case ValueType::Unsigned16Bit: return IntegerLiteral(type, static_cast<u16>(value));
case ValueType::Unsigned32Bit: return IntegerLiteral(type, static_cast<u32>(value));
case ValueType::Unsigned64Bit: return IntegerLiteral(type, static_cast<u64>(value));
case ValueType::Unsigned128Bit: return IntegerLiteral(type, static_cast<u128>(value));
case ValueType::Float: return IntegerLiteral(type, static_cast<float>(value));
case ValueType::Double: return IntegerLiteral(type, static_cast<double>(value));
case ValueType::Character: return IntegerLiteral(type, static_cast<char>(value));
case ValueType::Character16: return IntegerLiteral(type, static_cast<char16_t>(value));
default: __builtin_unreachable();
}
}, literal);
}
[[nodiscard]] constexpr static auto getTypeName(const lang::Token::ValueType type) {
switch (type) {
case ValueType::Signed8Bit: return "s8";
@ -227,8 +206,10 @@ namespace hex::lang {
#define KEYWORD_ELSE COMPONENT(Keyword, Else)
#define KEYWORD_PARENT COMPONENT(Keyword, Parent)
#define KEYWORD_WHILE COMPONENT(Keyword, While)
#define KEYWORD_FUNCTION COMPONENT(Keyword, Function)
#define KEYWORD_RETURN COMPONENT(Keyword, Return)
#define INTEGER hex::lang::Token::Type::Integer, hex::lang::Token::IntegerLiteral(hex::lang::Token::ValueType::Any, u64(0))
#define INTEGER hex::lang::Token::Type::Integer, hex::lang::Token::IntegerLiteral(u64(0))
#define IDENTIFIER hex::lang::Token::Type::Identifier, ""
#define STRING hex::lang::Token::Type::String, ""

View File

@ -146,7 +146,7 @@ namespace hex {
/* Pattern Language Functions */
void ContentRegistry::PatternLanguageFunctions::add(std::string_view name, u32 parameterCount, const std::function<hex::lang::ASTNode*(hex::lang::Evaluator&, std::vector<hex::lang::ASTNode*>)> &func) {
void ContentRegistry::PatternLanguageFunctions::add(std::string_view name, u32 parameterCount, const std::function<hex::lang::ASTNode*(hex::lang::Evaluator&, std::vector<hex::lang::ASTNode*>&)> &func) {
getEntries()[name.data()] = Function{ parameterCount, func };
}

View File

@ -80,8 +80,10 @@ namespace hex::lang {
if (currPattern != nullptr) {
if (auto arrayPattern = dynamic_cast<PatternDataArray*>(currPattern); arrayPattern != nullptr) {
if (Token::isFloatingPoint(arrayIndexNode->getType()))
this->getConsole().abortEvaluation("cannot use float to index into array");
std::visit([this](auto &&arrayIndex) {
if (std::is_floating_point_v<decltype(arrayIndex)>)
this->getConsole().abortEvaluation("cannot use float to index into array");
}, arrayIndexNode->getValue());
std::visit([&](auto &&arrayIndex){
if (arrayIndex >= 0 && arrayIndex < arrayPattern->getEntries().size())
@ -110,9 +112,13 @@ namespace hex::lang {
PatternData *currPattern = nullptr;
// Local member access
if (this->m_currMembers.size() > 1)
// Local variable access
currPattern = this->findPattern(*this->m_localVariables.back(), path);
// If no local variable was found try local structure members
if (this->m_currMembers.size() > 1) {
currPattern = this->findPattern(*this->m_currMembers[this->m_currMembers.size() - 2], path);
}
// If no local member was found, try globally
if (currPattern == nullptr) {
@ -143,45 +149,55 @@ namespace hex::lang {
ASTNodeIntegerLiteral* Evaluator::evaluateRValue(ASTNodeRValue *node) {
if (node->getPath().size() == 1) {
if (auto part = std::get_if<std::string>(&node->getPath()[0]); part != nullptr && *part == "$")
return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, this->m_currOffset });
return new ASTNodeIntegerLiteral(this->m_currOffset);
}
auto currPattern = this->patternFromName(node->getPath());
if (auto unsignedPattern = dynamic_cast<PatternDataUnsigned*>(currPattern); unsignedPattern != nullptr) {
u8 value[unsignedPattern->getSize()];
this->m_provider->read(unsignedPattern->getOffset(), value, unsignedPattern->getSize());
if (currPattern->isLocal())
std::memcpy(value, this->m_localStack.data() + unsignedPattern->getOffset(), unsignedPattern->getSize());
else
this->m_provider->read(unsignedPattern->getOffset(), value, unsignedPattern->getSize());
switch (unsignedPattern->getSize()) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast<u8*>(value), 1, unsignedPattern->getEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast<u16*>(value), 2, unsignedPattern->getEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, hex::changeEndianess(*reinterpret_cast<u32*>(value), 4, unsignedPattern->getEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, hex::changeEndianess(*reinterpret_cast<u64*>(value), 8, unsignedPattern->getEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, hex::changeEndianess(*reinterpret_cast<u128*>(value), 16, unsignedPattern->getEndian()) });
case 1: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<u8*>(value), 1, unsignedPattern->getEndian()));
case 2: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<u16*>(value), 2, unsignedPattern->getEndian()));
case 4: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<u32*>(value), 4, unsignedPattern->getEndian()));
case 8: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<u64*>(value), 8, unsignedPattern->getEndian()));
case 16: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<u128*>(value), 16, unsignedPattern->getEndian()));
default: this->getConsole().abortEvaluation("invalid rvalue size");
}
} else if (auto signedPattern = dynamic_cast<PatternDataSigned*>(currPattern); signedPattern != nullptr) {
u8 value[signedPattern->getSize()];
this->m_provider->read(signedPattern->getOffset(), value, signedPattern->getSize());
if (currPattern->isLocal())
std::memcpy(value, this->m_localStack.data() + signedPattern->getOffset(), signedPattern->getSize());
else
this->m_provider->read(signedPattern->getOffset(), value, signedPattern->getSize());
switch (signedPattern->getSize()) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast<s8*>(value), 1, signedPattern->getEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast<s16*>(value), 2, signedPattern->getEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast<s32*>(value), 4, signedPattern->getEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast<s64*>(value), 8, signedPattern->getEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast<s128*>(value), 16, signedPattern->getEndian()) });
case 1: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<s8*>(value), 1, signedPattern->getEndian()));
case 2: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<s16*>(value), 2, signedPattern->getEndian()));
case 4: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<s32*>(value), 4, signedPattern->getEndian()));
case 8: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<s64*>(value), 8, signedPattern->getEndian()));
case 16: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<s128*>(value), 16, signedPattern->getEndian()));
default: this->getConsole().abortEvaluation("invalid rvalue size");
}
} else if (auto enumPattern = dynamic_cast<PatternDataEnum*>(currPattern); enumPattern != nullptr) {
u8 value[enumPattern->getSize()];
this->m_provider->read(enumPattern->getOffset(), value, enumPattern->getSize());
if (currPattern->isLocal())
std::memcpy(value, this->m_localStack.data() + enumPattern->getOffset(), enumPattern->getSize());
else
this->m_provider->read(enumPattern->getOffset(), value, enumPattern->getSize());
switch (enumPattern->getSize()) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast<u8*>(value), 1, enumPattern->getEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast<u16*>(value), 2, enumPattern->getEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, hex::changeEndianess(*reinterpret_cast<u32*>(value), 4, enumPattern->getEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, hex::changeEndianess(*reinterpret_cast<u64*>(value), 8, enumPattern->getEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, hex::changeEndianess(*reinterpret_cast<u128*>(value), 16, enumPattern->getEndian()) });
case 1: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<u8*>(value), 1, enumPattern->getEndian()));
case 2: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<u16*>(value), 2, enumPattern->getEndian()));
case 4: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<u32*>(value), 4, enumPattern->getEndian()));
case 8: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<u64*>(value), 8, enumPattern->getEndian()));
case 16: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast<u128*>(value), 16, enumPattern->getEndian()));
default: this->getConsole().abortEvaluation("invalid rvalue size");
}
} else
@ -204,25 +220,28 @@ namespace hex::lang {
evaluatedParams.push_back(stringLiteral->clone());
}
if (!ContentRegistry::PatternLanguageFunctions::getEntries().contains(node->getFunctionName().data()))
ContentRegistry::PatternLanguageFunctions::Function *function;
if (this->m_definedFunctions.contains(node->getFunctionName().data()))
function = &this->m_definedFunctions[node->getFunctionName().data()];
else if (ContentRegistry::PatternLanguageFunctions::getEntries().contains(node->getFunctionName().data()))
function = &ContentRegistry::PatternLanguageFunctions::getEntries()[node->getFunctionName().data()];
else
this->getConsole().abortEvaluation(hex::format("no function named '{0}' found", node->getFunctionName().data()));
auto &function = ContentRegistry::PatternLanguageFunctions::getEntries()[node->getFunctionName().data()];
if (function.parameterCount == ContentRegistry::PatternLanguageFunctions::UnlimitedParameters) {
if (function->parameterCount == ContentRegistry::PatternLanguageFunctions::UnlimitedParameters) {
; // Don't check parameter count
}
else if (function.parameterCount & ContentRegistry::PatternLanguageFunctions::LessParametersThan) {
if (evaluatedParams.size() >= (function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan))
this->getConsole().abortEvaluation(hex::format("too many parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan));
} else if (function.parameterCount & ContentRegistry::PatternLanguageFunctions::MoreParametersThan) {
if (evaluatedParams.size() <= (function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan))
this->getConsole().abortEvaluation(hex::format("too few parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan));
} else if (function.parameterCount != evaluatedParams.size()) {
this->getConsole().abortEvaluation(hex::format("invalid number of parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function.parameterCount));
else if (function->parameterCount & ContentRegistry::PatternLanguageFunctions::LessParametersThan) {
if (evaluatedParams.size() >= (function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan))
this->getConsole().abortEvaluation(hex::format("too many parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan));
} else if (function->parameterCount & ContentRegistry::PatternLanguageFunctions::MoreParametersThan) {
if (evaluatedParams.size() <= (function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan))
this->getConsole().abortEvaluation(hex::format("too few parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan));
} else if (function->parameterCount != evaluatedParams.size()) {
this->getConsole().abortEvaluation(hex::format("invalid number of parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function->parameterCount));
}
return function.func(*this, evaluatedParams);
return function->func(*this, evaluatedParams);
}
ASTNodeIntegerLiteral* Evaluator::evaluateTypeOperator(ASTNodeTypeOperator *typeOperatorNode) {
@ -231,9 +250,9 @@ namespace hex::lang {
switch (typeOperatorNode->getOperator()) {
case Token::Operator::AddressOf:
return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, static_cast<u64>(pattern->getOffset()) });
return new ASTNodeIntegerLiteral(static_cast<u64>(pattern->getOffset()));
case Token::Operator::SizeOf:
return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, static_cast<u64>(pattern->getSize()) });
return new ASTNodeIntegerLiteral(static_cast<u64>(pattern->getSize()));
default:
this->getConsole().abortEvaluation("invalid type operator used. This is a bug!");
}
@ -281,85 +300,55 @@ namespace hex::lang {
}
ASTNodeIntegerLiteral* Evaluator::evaluateOperator(ASTNodeIntegerLiteral *left, ASTNodeIntegerLiteral *right, Token::Operator op) {
auto newType = [&] {
#define CHECK_TYPE(type) if (left->getType() == (type) || right->getType() == (type)) return (type)
#define DEFAULT_TYPE(type) return (type)
if (left->getType() == Token::ValueType::Any && right->getType() != Token::ValueType::Any)
return right->getType();
if (left->getType() != Token::ValueType::Any && right->getType() == Token::ValueType::Any)
return left->getType();
CHECK_TYPE(Token::ValueType::Double);
CHECK_TYPE(Token::ValueType::Float);
CHECK_TYPE(Token::ValueType::Unsigned128Bit);
CHECK_TYPE(Token::ValueType::Signed128Bit);
CHECK_TYPE(Token::ValueType::Unsigned64Bit);
CHECK_TYPE(Token::ValueType::Signed64Bit);
CHECK_TYPE(Token::ValueType::Unsigned32Bit);
CHECK_TYPE(Token::ValueType::Signed32Bit);
CHECK_TYPE(Token::ValueType::Unsigned16Bit);
CHECK_TYPE(Token::ValueType::Signed16Bit);
CHECK_TYPE(Token::ValueType::Unsigned8Bit);
CHECK_TYPE(Token::ValueType::Signed8Bit);
CHECK_TYPE(Token::ValueType::Character);
CHECK_TYPE(Token::ValueType::Character16);
CHECK_TYPE(Token::ValueType::Boolean);
DEFAULT_TYPE(Token::ValueType::Signed32Bit);
#undef CHECK_TYPE
#undef DEFAULT_TYPE
}();
try {
return std::visit([&](auto &&leftValue, auto &&rightValue) -> ASTNodeIntegerLiteral * {
switch (op) {
case Token::Operator::Plus:
return new ASTNodeIntegerLiteral({ newType, leftValue + rightValue });
return new ASTNodeIntegerLiteral(leftValue + rightValue);
case Token::Operator::Minus:
return new ASTNodeIntegerLiteral({ newType, leftValue - rightValue });
return new ASTNodeIntegerLiteral(leftValue - rightValue);
case Token::Operator::Star:
return new ASTNodeIntegerLiteral({ newType, leftValue * rightValue });
return new ASTNodeIntegerLiteral(leftValue * rightValue);
case Token::Operator::Slash:
if (rightValue == 0)
this->getConsole().abortEvaluation("Division by zero");
return new ASTNodeIntegerLiteral({ newType, leftValue / rightValue });
return new ASTNodeIntegerLiteral(leftValue / rightValue);
case Token::Operator::Percent:
if (rightValue == 0)
this->getConsole().abortEvaluation("Division by zero");
return new ASTNodeIntegerLiteral({ newType, modulus(leftValue, rightValue) });
return new ASTNodeIntegerLiteral(modulus(leftValue, rightValue));
case Token::Operator::ShiftLeft:
return new ASTNodeIntegerLiteral({ newType, shiftLeft(leftValue, rightValue) });
return new ASTNodeIntegerLiteral(shiftLeft(leftValue, rightValue));
case Token::Operator::ShiftRight:
return new ASTNodeIntegerLiteral({ newType, shiftRight(leftValue, rightValue) });
return new ASTNodeIntegerLiteral(shiftRight(leftValue, rightValue));
case Token::Operator::BitAnd:
return new ASTNodeIntegerLiteral({ newType, bitAnd(leftValue, rightValue) });
return new ASTNodeIntegerLiteral(bitAnd(leftValue, rightValue));
case Token::Operator::BitXor:
return new ASTNodeIntegerLiteral({ newType, bitXor(leftValue, rightValue) });
return new ASTNodeIntegerLiteral(bitXor(leftValue, rightValue));
case Token::Operator::BitOr:
return new ASTNodeIntegerLiteral({ newType, bitOr(leftValue, rightValue) });
return new ASTNodeIntegerLiteral(bitOr(leftValue, rightValue));
case Token::Operator::BitNot:
return new ASTNodeIntegerLiteral({ newType, bitNot(leftValue, rightValue) });
return new ASTNodeIntegerLiteral(bitNot(leftValue, rightValue));
case Token::Operator::BoolEquals:
return new ASTNodeIntegerLiteral({ newType, leftValue == rightValue });
return new ASTNodeIntegerLiteral(leftValue == rightValue);
case Token::Operator::BoolNotEquals:
return new ASTNodeIntegerLiteral({ newType, leftValue != rightValue });
return new ASTNodeIntegerLiteral(leftValue != rightValue);
case Token::Operator::BoolGreaterThan:
return new ASTNodeIntegerLiteral({ newType, leftValue > rightValue });
return new ASTNodeIntegerLiteral(leftValue > rightValue);
case Token::Operator::BoolLessThan:
return new ASTNodeIntegerLiteral({ newType, leftValue < rightValue });
return new ASTNodeIntegerLiteral(leftValue < rightValue);
case Token::Operator::BoolGreaterThanOrEquals:
return new ASTNodeIntegerLiteral({ newType, leftValue >= rightValue });
return new ASTNodeIntegerLiteral(leftValue >= rightValue);
case Token::Operator::BoolLessThanOrEquals:
return new ASTNodeIntegerLiteral({ newType, leftValue <= rightValue });
return new ASTNodeIntegerLiteral(leftValue <= rightValue);
case Token::Operator::BoolAnd:
return new ASTNodeIntegerLiteral({ newType, leftValue && rightValue });
return new ASTNodeIntegerLiteral(leftValue && rightValue);
case Token::Operator::BoolXor:
return new ASTNodeIntegerLiteral({ newType, leftValue && !rightValue || !leftValue && rightValue });
return new ASTNodeIntegerLiteral(leftValue && !rightValue || !leftValue && rightValue);
case Token::Operator::BoolOr:
return new ASTNodeIntegerLiteral({ newType, leftValue || rightValue });
return new ASTNodeIntegerLiteral(leftValue || rightValue);
case Token::Operator::BoolNot:
return new ASTNodeIntegerLiteral({ newType, !rightValue });
return new ASTNodeIntegerLiteral(!rightValue);
default:
this->getConsole().abortEvaluation("invalid operator used in mathematical expression");
}
@ -419,6 +408,128 @@ namespace hex::lang {
return evaluateOperator(leftInteger, rightInteger, node->getOperator());
}
void Evaluator::createLocalVariable(std::string_view varName, PatternData *pattern) {
auto startOffset = this->m_currOffset;
ON_SCOPE_EXIT { this->m_currOffset = startOffset; };
auto endOfStack = this->m_localStack.size();
for (auto &variable : *this->m_localVariables.back()) {
if (variable->getVariableName() == varName)
this->getConsole().abortEvaluation(hex::format("redefinition of variable {}", varName));
}
this->m_localStack.resize(endOfStack + pattern->getSize());
pattern->setVariableName(std::string(varName));
pattern->setOffset(endOfStack);
pattern->setLocal(true);
this->m_localVariables.back()->push_back(pattern);
std::memset(this->m_localStack.data() + pattern->getOffset(), 0x00, pattern->getSize());
}
void Evaluator::setLocalVariableValue(std::string_view varName, const void *value, size_t size) {
PatternData *varPattern = nullptr;
for (auto &var : *this->m_localVariables.back()) {
if (var->getVariableName() == varName)
varPattern = var;
}
std::memset(this->m_localStack.data() + varPattern->getOffset(), 0x00, varPattern->getSize());
std::memcpy(this->m_localStack.data() + varPattern->getOffset(), value, std::min(varPattern->getSize(), size));
}
void Evaluator::evaluateFunctionDefinition(ASTNodeFunctionDefinition *node) {
ContentRegistry::PatternLanguageFunctions::Function function = {
(u32)node->getParams().size(),
[paramNames = node->getParams(), body = node->getBody()](Evaluator& evaluator, std::vector<ASTNode*> &params) -> ASTNode* {
// Create local variables from parameters
std::vector<PatternData*> localVariables;
evaluator.m_localVariables.push_back(&localVariables);
ON_SCOPE_EXIT {
u32 stackSizeToDrop = 0;
for (auto &localVar : *evaluator.m_localVariables.back()) {
stackSizeToDrop += localVar->getSize();
delete localVar;
}
evaluator.m_localVariables.pop_back();
evaluator.m_localStack.resize(evaluator.m_localStack.size() - stackSizeToDrop);
};
auto startOffset = evaluator.m_currOffset;
for (u32 i = 0; i < params.size(); i++) {
if (auto integerLiteralNode = dynamic_cast<ASTNodeIntegerLiteral*>(params[i]); integerLiteralNode != nullptr) {
std::visit([&](auto &&value) {
using Type = std::remove_cvref_t<decltype(value)>;
PatternData *pattern;
if constexpr (std::is_unsigned_v<Type>)
pattern = new PatternDataUnsigned(0, sizeof(value));
else if constexpr (std::is_signed_v<Type>)
pattern = new PatternDataSigned(0, sizeof(value));
else if constexpr (std::is_floating_point_v<Type>)
pattern = new PatternDataFloat(0, sizeof(value));
else return;
evaluator.createLocalVariable(paramNames[i], pattern);
evaluator.setLocalVariableValue(paramNames[i], &value, sizeof(value));
}, integerLiteralNode->getValue());
}
}
evaluator.m_currOffset = startOffset;
for (auto &statement : body) {
ON_SCOPE_EXIT { evaluator.m_currOffset = startOffset; };
if (auto functionCallNode = dynamic_cast<ASTNodeFunctionCall*>(statement); functionCallNode != nullptr) {
auto result = evaluator.evaluateFunctionCall(functionCallNode);
delete result;
} else if (auto varDeclNode = dynamic_cast<ASTNodeVariableDecl*>(statement); varDeclNode != nullptr) {
auto pattern = evaluator.evaluateVariable(varDeclNode);
evaluator.createLocalVariable(varDeclNode->getName(), pattern);
} else if (auto assignmentNode = dynamic_cast<ASTNodeAssignment*>(statement); assignmentNode != nullptr) {
if (auto numericExpressionNode = dynamic_cast<ASTNodeNumericExpression*>(assignmentNode->getRValue()); numericExpressionNode != nullptr) {
auto value = evaluator.evaluateMathematicalExpression(numericExpressionNode);
ON_SCOPE_EXIT { delete value; };
std::visit([&](auto &&value) {
evaluator.setLocalVariableValue(assignmentNode->getLValueName(), &value, sizeof(value));
}, value->getValue());
} else {
evaluator.getConsole().abortEvaluation("invalid rvalue used in assignment");
}
} else if (auto assignmentNode = dynamic_cast<ASTNodeAssignment*>(statement); assignmentNode != nullptr) {
if (auto numericExpressionNode = dynamic_cast<ASTNodeNumericExpression*>(assignmentNode->getRValue()); numericExpressionNode != nullptr) {
auto value = evaluator.evaluateMathematicalExpression(numericExpressionNode);
ON_SCOPE_EXIT { delete value; };
std::visit([&](auto &&value) {
evaluator.setLocalVariableValue(assignmentNode->getLValueName(), &value, sizeof(value));
}, value->getValue());
} else {
evaluator.getConsole().abortEvaluation("invalid rvalue used in assignment");
}
} else if (auto returnNode = dynamic_cast<ASTNodeReturnStatement*>(statement); returnNode != nullptr) {
if (auto numericExpressionNode = dynamic_cast<ASTNodeNumericExpression*>(returnNode->getRValue()); numericExpressionNode != nullptr) {
return evaluator.evaluateMathematicalExpression(numericExpressionNode);
} else {
evaluator.getConsole().abortEvaluation("invalid rvalue used in return statement");
}
}
}
return nullptr;
}
};
if (this->m_definedFunctions.contains(std::string(node->getName())))
this->getConsole().abortEvaluation(hex::format("redefinition of function {}", node->getName()));
this->m_definedFunctions.insert({ std::string(node->getName()), function });
}
PatternData* Evaluator::evaluateAttributes(ASTNode *currNode, PatternData *currPattern) {
auto attributableNode = dynamic_cast<Attributable*>(currNode);
if (attributableNode == nullptr)
@ -617,7 +728,7 @@ namespace hex::lang {
auto valueNode = evaluateMathematicalExpression(expression);
ON_SCOPE_EXIT { delete valueNode; };
entryPatterns.emplace_back(Token::castTo(builtinUnderlyingType->getType(), valueNode->getValue()), name);
entryPatterns.emplace_back(valueNode->getValue(), name);
}
this->m_currOffset += size;
@ -642,8 +753,9 @@ namespace hex::lang {
auto valueNode = evaluateMathematicalExpression(expression);
ON_SCOPE_EXIT { delete valueNode; };
auto fieldBits = std::visit([this, node, type = valueNode->getType()] (auto &&value) {
if (Token::isFloatingPoint(type))
auto fieldBits = std::visit([this] (auto &&value) {
using Type = std::remove_cvref_t<decltype(value)>;
if constexpr (std::is_floating_point_v<Type>)
this->getConsole().abortEvaluation("bitfield entry size must be an integer value");
return static_cast<s128>(value);
}, valueNode->getValue());
@ -706,9 +818,10 @@ namespace hex::lang {
auto valueNode = evaluateMathematicalExpression(offset);
ON_SCOPE_EXIT { delete valueNode; };
this->m_currOffset = std::visit([this, node, type = valueNode->getType()] (auto &&value) {
if (Token::isFloatingPoint(type))
this->getConsole().abortEvaluation("placement offset must be an integer value");
this->m_currOffset = std::visit([this] (auto &&value) {
using Type = std::remove_cvref_t<decltype(value)>;
if constexpr (std::is_floating_point_v<Type>)
this->getConsole().abortEvaluation("bitfield entry size must be an integer value");
return static_cast<u64>(value);
}, valueNode->getValue());
}
@ -740,9 +853,10 @@ namespace hex::lang {
auto valueNode = evaluateMathematicalExpression(offset);
ON_SCOPE_EXIT { delete valueNode; };
this->m_currOffset = std::visit([this, node, type = valueNode->getType()] (auto &&value) {
if (Token::isFloatingPoint(type))
this->getConsole().abortEvaluation("placement offset must be an integer value");
this->m_currOffset = std::visit([this] (auto &&value) {
using Type = std::remove_cvref_t<decltype(value)>;
if constexpr (std::is_floating_point_v<Type>)
this->getConsole().abortEvaluation("bitfield entry size must be an integer value");
return static_cast<u64>(value);
}, valueNode->getValue());
}
@ -790,9 +904,10 @@ namespace hex::lang {
auto valueNode = this->evaluateMathematicalExpression(numericExpression);
ON_SCOPE_EXIT { delete valueNode; };
auto arraySize = std::visit([this, node, type = valueNode->getType()] (auto &&value) {
if (Token::isFloatingPoint(type))
this->getConsole().abortEvaluation("array size must be an integer value");
auto arraySize = std::visit([this] (auto &&value) {
using Type = std::remove_cvref_t<decltype(value)>;
if constexpr (std::is_floating_point_v<Type>)
this->getConsole().abortEvaluation("bitfield entry size must be an integer value");
return static_cast<u64>(value);
}, valueNode->getValue());
@ -874,9 +989,10 @@ namespace hex::lang {
auto valueNode = evaluateMathematicalExpression(offset);
ON_SCOPE_EXIT { delete valueNode; };
pointerOffset = std::visit([this, node, type = valueNode->getType()] (auto &&value) {
if (Token::isFloatingPoint(type))
this->getConsole().abortEvaluation("pointer offset must be an integer value");
pointerOffset = std::visit([this] (auto &&value) {
using Type = std::remove_cvref_t<decltype(value)>;
if constexpr (std::is_floating_point_v<Type>)
this->getConsole().abortEvaluation("bitfield entry size must be an integer value");
return static_cast<s128>(value);
}, valueNode->getValue());
this->m_currOffset = pointerOffset;
@ -938,6 +1054,7 @@ namespace hex::lang {
this->m_globalMembers.clear();
this->m_types.clear();
this->m_endianStack.clear();
this->m_definedFunctions.clear();
this->m_currOffset = 0;
try {
@ -974,6 +1091,8 @@ namespace hex::lang {
} else if (auto functionCallNode = dynamic_cast<ASTNodeFunctionCall*>(node); functionCallNode != nullptr) {
auto result = this->evaluateFunctionCall(functionCallNode);
delete result;
} else if (auto functionDefNode = dynamic_cast<ASTNodeFunctionDefinition*>(node); functionDefNode != nullptr) {
this->evaluateFunctionDefinition(functionDefNode);
}
if (pattern != nullptr)

View File

@ -121,20 +121,20 @@ namespace hex::lang {
}
switch (type) {
case Token::ValueType::Unsigned32Bit: return {{ type, u32(integer) }};
case Token::ValueType::Signed32Bit: return {{ type, s32(integer) }};
case Token::ValueType::Unsigned64Bit: return {{ type, u64(integer) }};
case Token::ValueType::Signed64Bit: return {{ type, s64(integer) }};
case Token::ValueType::Unsigned128Bit: return {{ type, u128(integer) }};
case Token::ValueType::Signed128Bit: return {{ type, s128(integer) }};
case Token::ValueType::Unsigned32Bit: return { u32(integer) };
case Token::ValueType::Signed32Bit: return { s32(integer) };
case Token::ValueType::Unsigned64Bit: return { u64(integer) };
case Token::ValueType::Signed64Bit: return { s64(integer) };
case Token::ValueType::Unsigned128Bit: return { u128(integer) };
case Token::ValueType::Signed128Bit: return { s128(integer) };
default: return { };
}
} else if (Token::isFloatingPoint(type)) {
double floatingPoint = strtod(numberData.data(), nullptr);
switch (type) {
case Token::ValueType::Float: return {{ type, float(floatingPoint) }};
case Token::ValueType::Double: return {{ type, double(floatingPoint) }};
case Token::ValueType::Float: return { float(floatingPoint) };
case Token::ValueType::Double: return { double(floatingPoint) };
default: return { };
}
}
@ -381,7 +381,7 @@ namespace hex::lang {
auto [c, charSize] = character.value();
tokens.emplace_back(VALUE_TOKEN(Integer, Token::IntegerLiteral(Token::ValueType::Character, c) ));
tokens.emplace_back(VALUE_TOKEN(Integer, c));
offset += charSize;
} else if (c == '\"') {
auto string = getStringLiteral(code.substr(offset));
@ -417,13 +417,17 @@ namespace hex::lang {
else if (identifier == "else")
tokens.emplace_back(TOKEN(Keyword, Else));
else if (identifier == "false")
tokens.emplace_back(VALUE_TOKEN(Integer, Token::IntegerLiteral(Token::ValueType::Boolean, s32(0))));
tokens.emplace_back(VALUE_TOKEN(Integer, bool(0)));
else if (identifier == "true")
tokens.emplace_back(VALUE_TOKEN(Integer, Token::IntegerLiteral(Token::ValueType::Boolean, s32(1))));
tokens.emplace_back(VALUE_TOKEN(Integer, bool(1)));
else if (identifier == "parent")
tokens.emplace_back(TOKEN(Keyword, Parent));
else if (identifier == "while")
tokens.emplace_back(TOKEN(Keyword, While));
else if (identifier == "fn")
tokens.emplace_back(TOKEN(Keyword, Function));
else if (identifier == "return")
tokens.emplace_back(TOKEN(Keyword, Return));
// Check for built-in types
else if (identifier == "u8")

View File

@ -5,7 +5,7 @@
#define MATCHES(x) (begin() && x)
#define TO_NUMERIC_EXPRESSION(node) new ASTNodeNumericExpression((node), new ASTNodeIntegerLiteral({ Token::ValueType::Any, s32(0) }), Token::Operator::Plus)
#define TO_NUMERIC_EXPRESSION(node) new ASTNodeNumericExpression((node), new ASTNodeIntegerLiteral(s32(0)), Token::Operator::Plus)
// Definition syntax:
// [A] : Either A or no token
@ -132,7 +132,7 @@ namespace hex::lang {
if (MATCHES(oneOf(OPERATOR_PLUS, OPERATOR_MINUS, OPERATOR_BOOLNOT, OPERATOR_BITNOT))) {
auto op = getValue<Token::Operator>(-1);
return new ASTNodeNumericExpression(new ASTNodeIntegerLiteral({ Token::ValueType::Any, 0 }), this->parseFactor(), op);
return new ASTNodeNumericExpression(new ASTNodeIntegerLiteral(0), this->parseFactor(), op);
}
return this->parseFactor();
@ -358,6 +358,72 @@ namespace hex::lang {
throwParseError("unfinished attribute. Expected ']]'");
}
/* Functions */
ASTNode* Parser::parseFunctionDefintion() {
const auto &functionName = getValue<std::string>(-2);
std::vector<std::string> params;
// Parse parameter list
while (MATCHES(sequence(IDENTIFIER))) {
params.push_back(getValue<std::string>(-1));
if (!MATCHES(sequence(SEPARATOR_COMMA))) {
if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE)))
break;
else
throwParseError("expected closing ')' after parameter list");
}
}
if (!MATCHES(sequence(SEPARATOR_CURLYBRACKETOPEN)))
throwParseError("expected opening '{' after function definition");
// Parse function body
std::vector<ASTNode*> body;
auto bodyCleanup = SCOPE_GUARD {
for (auto &node : body)
delete node;
};
while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) {
ASTNode *statement;
if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN)))
statement = parseFunctionCall();
else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(IDENTIFIER, SEPARATOR_SQUAREBRACKETOPEN) && sequence<Not>(SEPARATOR_SQUAREBRACKETOPEN)))
statement = parseMemberArrayVariable();
else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(IDENTIFIER)))
statement = parseMemberVariable();
else if (MATCHES(sequence(IDENTIFIER, OPERATOR_ASSIGNMENT)))
statement = parseVariableAssignment();
else if (MATCHES(sequence(KEYWORD_RETURN)))
statement = parseReturnStatement();
else
throwParseError("invalid sequence", 0);
body.push_back(statement);
if (!MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION)))
throwParseError("missing ';' at end of expression", -1);
}
bodyCleanup.release();
return new ASTNodeFunctionDefinition(functionName, params, body);
}
ASTNode* Parser::parseVariableAssignment() {
const auto &lvalue = getValue<std::string>(-2);
auto rvalue = this->parseMathematicalExpression();
return new ASTNodeAssignment(lvalue, rvalue);
}
ASTNode* Parser::parseReturnStatement() {
return new ASTNodeReturnStatement(this->parseMathematicalExpression());
}
/* Control flow */
// if ((parseMathematicalExpression)) { (parseMember) }
@ -606,9 +672,9 @@ namespace hex::lang {
ASTNode *valueExpr;
auto name = getValue<std::string>(-1);
if (enumNode->getEntries().empty())
valueExpr = lastEntry = TO_NUMERIC_EXPRESSION(new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, u8(0) }));
valueExpr = lastEntry = TO_NUMERIC_EXPRESSION(new ASTNodeIntegerLiteral(u8(0)));
else
valueExpr = new ASTNodeNumericExpression(lastEntry->clone(), new ASTNodeIntegerLiteral({ Token::ValueType::Any, s32(1) }), Token::Operator::Plus);
valueExpr = new ASTNodeNumericExpression(lastEntry->clone(), new ASTNodeIntegerLiteral(s32(1)), Token::Operator::Plus);
enumNode->addEntry(name, valueExpr);
}
@ -743,6 +809,8 @@ namespace hex::lang {
statement = parseBitfield();
else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN)))
statement = parseFunctionCall();
else if (MATCHES(sequence(KEYWORD_FUNCTION, IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN)))
statement = parseFunctionDefintion();
else throwParseError("invalid sequence", 0);
if (MATCHES(sequence(SEPARATOR_SQUAREBRACKETOPEN, SEPARATOR_SQUAREBRACKETOPEN)))

View File

@ -15,7 +15,7 @@ namespace hex {
static TextEditor::LanguageDefinition langDef;
if (!initialized) {
static const char* const keywords[] = {
"using", "struct", "union", "enum", "bitfield", "be", "le", "if", "else", "false", "true", "parent", "addressof", "sizeof", "$", "while"
"using", "struct", "union", "enum", "bitfield", "be", "le", "if", "else", "false", "true", "parent", "addressof", "sizeof", "$", "while", "fn", "return"
};
for (auto& k : keywords)
langDef.mKeywords.insert(k);

View File

@ -563,7 +563,7 @@ namespace hex {
void Window::initGLFW() {
glfwSetErrorCallback([](int error, const char* desc) {
fprintf(stderr, "Glfw Error %d: %s\n", error, desc);
log::error("GLFW Error [{}] : {}", error, desc);
});
if (!glfwInit())