From 0186f2f456fbb653a718d3319c9ea56c66343006 Mon Sep 17 00:00:00 2001 From: WerWolv Date: Sun, 10 Mar 2024 22:05:26 +0100 Subject: [PATCH] feat: Added support for adding custom providers through C# --- .../include/hex/api/content_registry.hpp | 2 +- .../dotnet/AssemblyLoader/Program.cs | 86 +++++++++++++++---- .../include/loaders/dotnet/dotnet_loader.hpp | 3 +- .../source/loaders/dotnet/dotnet_loader.cpp | 27 ++++-- .../source/plugin_script_loader.cpp | 7 +- .../source/script_api/v1/mem.cpp | 61 +++++++++++++ .../templates/CSharp/ImHexLibrary/Memory.cs | 56 ++++++++++++ .../templates/CSharp/ImHexScript/Program.cs | 11 ++- 8 files changed, 221 insertions(+), 32 deletions(-) diff --git a/lib/libimhex/include/hex/api/content_registry.hpp b/lib/libimhex/include/hex/api/content_registry.hpp index 3ebf79c3a..46a808f0d 100644 --- a/lib/libimhex/include/hex/api/content_registry.hpp +++ b/lib/libimhex/include/hex/api/content_registry.hpp @@ -931,7 +931,7 @@ namespace hex { void addProviderName(const UnlocalizedString &unlocalizedName); - using ProviderCreationFunction = std::unique_ptr(*)(); + using ProviderCreationFunction = std::function()>; void add(const std::string &typeName, ProviderCreationFunction creationFunction); const std::vector& getEntries(); diff --git a/plugins/script_loader/dotnet/AssemblyLoader/Program.cs b/plugins/script_loader/dotnet/AssemblyLoader/Program.cs index 2ebcd8a23..a317b0e6f 100644 --- a/plugins/script_loader/dotnet/AssemblyLoader/Program.cs +++ b/plugins/script_loader/dotnet/AssemblyLoader/Program.cs @@ -1,5 +1,6 @@ using System.Reflection; using System.Runtime.InteropServices; +using System.Runtime.InteropServices.ComTypes; using System.Runtime.Loader; namespace ImHex @@ -12,7 +13,7 @@ namespace ImHex { try { - return ExecuteScript(Marshal.PtrToStringUTF8(arg, argLength)) ? 0 : 1; + return ExecuteScript(Marshal.PtrToStringUTF8(arg, argLength)); } catch (Exception e) { @@ -21,61 +22,108 @@ namespace ImHex } } - - private static bool ExecuteScript(string path) + private static List loadedPlugins = new(); + private static int ExecuteScript(string args) { + // Parse input in the form of "execType||path" + var splitArgs = args.Split("||"); + var type = splitArgs[0]; + var methodName = splitArgs[1]; + var path = splitArgs[2]; + + // Get the parent folder of the passed path string? basePath = Path.GetDirectoryName(path); if (basePath == null) { Console.WriteLine("[.NET Script] Failed to get base path"); - return false; + return 1; } + // Create a new assembly context AssemblyLoadContext? context = new("ScriptDomain_" + basePath, true); + int result = 0; try { + if (type is "LOAD") + { + if (loadedPlugins.Contains(path)) + { + return 0; + } + + // Check if the plugin is already loaded + loadedPlugins.Add(path); + } + + // Load all assemblies in the parent folder foreach (var file in Directory.GetFiles(basePath, "*.dll")) { + // Skip main Assembly + if (file.EndsWith("Main.dll")) + { + continue; + } + context.LoadFromStream(new MemoryStream(File.ReadAllBytes(file))); } + // Load the script assembly var assembly = context.LoadFromStream(new MemoryStream(File.ReadAllBytes(path))); + // Find a class named "Script" var entryPointType = assembly.GetType("Script"); if (entryPointType == null) { Console.WriteLine("[.NET Script] Failed to find Script type"); - return false; + return 1; } - var entryPointMethod = entryPointType.GetMethod("Main", BindingFlags.Static | BindingFlags.Public); - if (entryPointMethod == null) + if (type is "EXEC" or "LOAD") { - Console.WriteLine("[.NET Script] Failed to find ScriptMain method"); - return false; - } + // Load the function + var method = entryPointType.GetMethod(methodName, BindingFlags.Static | BindingFlags.Public); + if (method == null) + { + return 2; + } - entryPointMethod.Invoke(null, null); + // Execute it + method.Invoke(null, null); + } + else if (type == "CHECK") + { + return entryPointType.GetMethod(methodName, BindingFlags.Static | BindingFlags.Public) != null ? 0 : 1; + } + else + { + return 1; + } } catch (Exception e) { Console.WriteLine("[.NET Script] Exception in AssemblyLoader: " + e.ToString()); - return false; + return 3; } finally { - context.Unload(); - context = null; - - for (int i = 0; i < 10; i++) + if (type != "LOAD") { - GC.Collect(); - GC.WaitForPendingFinalizers(); + // Unload all assemblies associated with this script + context.Unload(); + context = null; + + // Run the garbage collector multiple times to make sure that the + // assemblies are unloaded for sure + for (int i = 0; i < 10; i++) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + } } } - return true; + return result; } } diff --git a/plugins/script_loader/include/loaders/dotnet/dotnet_loader.hpp b/plugins/script_loader/include/loaders/dotnet/dotnet_loader.hpp index 9bad92c93..b3159916b 100644 --- a/plugins/script_loader/include/loaders/dotnet/dotnet_loader.hpp +++ b/plugins/script_loader/include/loaders/dotnet/dotnet_loader.hpp @@ -17,7 +17,8 @@ namespace hex::script::loader { bool loadAll() override; private: - std::function m_loadAssembly; + std::function m_runMethod; + std::function m_methodExists; std::fs::path::string_type m_assemblyLoaderPathString; }; diff --git a/plugins/script_loader/source/loaders/dotnet/dotnet_loader.cpp b/plugins/script_loader/source/loaders/dotnet/dotnet_loader.cpp index c06c82f0f..93cb60e32 100644 --- a/plugins/script_loader/source/loaders/dotnet/dotnet_loader.cpp +++ b/plugins/script_loader/source/loaders/dotnet/dotnet_loader.cpp @@ -179,8 +179,19 @@ namespace hex::script::loader { continue; } - m_loadAssembly = [entryPoint](const std::fs::path &path) -> bool { - auto string = wolv::util::toUTF8String(path); + m_runMethod = [entryPoint](const std::string &methodName, bool keepLoaded, const std::fs::path &path) -> int { + auto pathString = wolv::util::toUTF8String(path); + + auto string = hex::format("{}||{}||{}", keepLoaded ? "LOAD" : "EXEC", methodName, pathString); + auto result = entryPoint(string.data(), string.size()); + + return result; + }; + + m_methodExists = [entryPoint](const std::string &methodName, const std::fs::path &path) -> bool { + auto pathString = wolv::util::toUTF8String(path); + + auto string = hex::format("CHECK||{}||{}", methodName, pathString); auto result = entryPoint(string.data(), string.size()); return result == 0; @@ -211,9 +222,15 @@ namespace hex::script::loader { if (!std::fs::exists(scriptPath)) continue; - this->addScript(entry.path().stem().string(), [this, scriptPath] { - hex::unused(m_loadAssembly(scriptPath)); - }); + if (m_methodExists("Main", scriptPath)) { + this->addScript(entry.path().stem().string(), [this, scriptPath] { + hex::unused(m_runMethod("Main", false, scriptPath)); + }); + } + + if (m_methodExists("OnLoad", scriptPath)) { + hex::unused(m_runMethod("OnLoad", true, scriptPath)); + } } } diff --git a/plugins/script_loader/source/plugin_script_loader.cpp b/plugins/script_loader/source/plugin_script_loader.cpp index cfdbe3bf2..addb4f729 100644 --- a/plugins/script_loader/source/plugin_script_loader.cpp +++ b/plugins/script_loader/source/plugin_script_loader.cpp @@ -69,10 +69,10 @@ namespace { } void addScriptsMenu() { + static std::vector scripts; static TaskHolder runnerTask, updaterTask; hex::ContentRegistry::Interface::addMenuItemSubMenu({ "hex.builtin.menu.extras" }, 5000, [] { static bool menuJustOpened = true; - static std::vector scripts; if (ImGui::BeginMenu("hex.script_loader.menu.run_script"_lang)) { if (menuJustOpened) { @@ -107,6 +107,10 @@ namespace { }, [] { return !runnerTask.isRunning(); }); + + updaterTask = TaskManager::createBackgroundTask("Updating Scripts...", [] (auto&) { + scripts = loadAllScripts(); + }); } } @@ -119,5 +123,4 @@ IMHEX_PLUGIN_SETUP("Script Loader", "WerWolv", "Script Loader plugin") { if (initializeAllLoaders()) { addScriptsMenu(); } - } diff --git a/plugins/script_loader/source/script_api/v1/mem.cpp b/plugins/script_loader/source/script_api/v1/mem.cpp index c808acc84..b7d1aa414 100644 --- a/plugins/script_loader/source/script_api/v1/mem.cpp +++ b/plugins/script_loader/source/script_api/v1/mem.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -41,4 +42,64 @@ SCRIPT_API(bool getSelection, u64 *start, u64 *end) { *end = selection->getEndAddress(); return true; +} + +class ScriptDataProvider : public hex::prv::Provider { +public: + using ReadFunction = void(*)(u64, void*, u64); + using WriteFunction = void(*)(u64, const void*, u64); + using GetSizeFunction = u64(*)(); + using GetNameFunction = std::string(*)(); + + bool open() override { return true; } + void close() override { } + + [[nodiscard]] bool isAvailable() const override { return true; } + [[nodiscard]] bool isReadable() const override { return true; } + [[nodiscard]] bool isWritable() const override { return true; } + [[nodiscard]] bool isResizable() const override { return true; } + [[nodiscard]] bool isSavable() const override { return true; } + [[nodiscard]] bool isDumpable() const override { return true; } + + void readRaw(u64 offset, void *buffer, size_t size) override { + m_readFunction(offset, buffer, size); + } + + void writeRaw(u64 offset, const void *buffer, size_t size) override { + m_writeFunction(offset, const_cast(buffer), size); + } + + void setFunctions(ReadFunction readFunc, WriteFunction writeFunc, GetSizeFunction getSizeFunc) { + m_readFunction = readFunc; + m_writeFunction = writeFunc; + m_getSizeFunction = getSizeFunc; + } + + [[nodiscard]] u64 getActualSize() const override { return m_getSizeFunction(); } + + void setTypeName(std::string typeName) { m_typeName = std::move(typeName);} + void setName(std::string name) { m_name = std::move(name);} + [[nodiscard]] std::string getTypeName() const override { return m_typeName; } + [[nodiscard]] std::string getName() const override { return m_name; } + +private: + ReadFunction m_readFunction = nullptr; + WriteFunction m_writeFunction = nullptr; + GetSizeFunction m_getSizeFunction = nullptr; + GetNameFunction m_getNameFunction = nullptr; + + std::string m_typeName, m_name; +}; + +SCRIPT_API(void registerProvider, const char *typeName, const char *name, ScriptDataProvider::ReadFunction readFunc, ScriptDataProvider::WriteFunction writeFunc, ScriptDataProvider::GetSizeFunction getSizeFunc) { + auto typeNameString = std::string(typeName); + auto nameString = std::string(name); + hex::ContentRegistry::Provider::impl::add(typeNameString, [typeNameString, nameString, readFunc, writeFunc, getSizeFunc] -> std::unique_ptr { + auto provider = std::make_unique(); + provider->setTypeName(typeNameString); + provider->setName(nameString); + provider->setFunctions(readFunc, writeFunc, getSizeFunc); + return provider; + }); + hex::ContentRegistry::Provider::impl::addProviderName(typeNameString); } \ No newline at end of file diff --git a/plugins/script_loader/templates/CSharp/ImHexLibrary/Memory.cs b/plugins/script_loader/templates/CSharp/ImHexLibrary/Memory.cs index 5f54788c4..d0460e4bd 100644 --- a/plugins/script_loader/templates/CSharp/ImHexLibrary/Memory.cs +++ b/plugins/script_loader/templates/CSharp/ImHexLibrary/Memory.cs @@ -5,8 +5,42 @@ using System.Runtime.InteropServices; namespace ImHex { + public interface IProvider + { + void readRaw(UInt64 address, IntPtr buffer, UInt64 size) + { + unsafe + { + Span data = new(buffer.ToPointer(), (int)size); + read(address, data); + } + } + + void writeRaw(UInt64 address, IntPtr buffer, UInt64 size) + { + unsafe + { + ReadOnlySpan data = new(buffer.ToPointer(), (int)size); + write(address, data); + } + } + + void read(UInt64 address, Span data); + void write(UInt64 address, ReadOnlySpan data); + + UInt64 getSize(); + + string getTypeName(); + string getName(); + } public class Memory { + private static List _registeredProviders = new(); + private static List _registeredProviderDelegates = new(); + + private delegate void DataAccessDelegate(UInt64 address, IntPtr buffer, UInt64 size); + private delegate UInt64 GetSizeDelegate(); + [DllImport(Library.Name)] private static extern void readMemoryV1(UInt64 address, UInt64 size, IntPtr buffer); @@ -15,6 +49,9 @@ namespace ImHex [DllImport(Library.Name)] private static extern bool getSelectionV1(IntPtr start, IntPtr end); + + [DllImport(Library.Name)] + private static extern int registerProviderV1([MarshalAs(UnmanagedType.LPStr)] string typeName, [MarshalAs(UnmanagedType.LPStr)] string name, IntPtr readFunction, IntPtr writeFunction, IntPtr getSizeFunction); public static byte[] Read(ulong address, ulong size) @@ -57,6 +94,25 @@ namespace ImHex return (start, end); } } + + public static int RegisterProvider() where T : IProvider, new() + { + _registeredProviders.Add(new T()); + + ref var provider = ref CollectionsMarshal.AsSpan(_registeredProviders)[^1]; + + _registeredProviderDelegates.Add(new DataAccessDelegate(provider.readRaw)); + _registeredProviderDelegates.Add(new DataAccessDelegate(provider.writeRaw)); + _registeredProviderDelegates.Add(new GetSizeDelegate(provider.getSize)); + + return registerProviderV1( + _registeredProviders[^1].getTypeName(), + _registeredProviders[^1].getName(), + Marshal.GetFunctionPointerForDelegate(_registeredProviderDelegates[^3]), + Marshal.GetFunctionPointerForDelegate(_registeredProviderDelegates[^2]), + Marshal.GetFunctionPointerForDelegate(_registeredProviderDelegates[^1]) + ); + } } } diff --git a/plugins/script_loader/templates/CSharp/ImHexScript/Program.cs b/plugins/script_loader/templates/CSharp/ImHexScript/Program.cs index 597dd1d1c..ba91efb56 100644 --- a/plugins/script_loader/templates/CSharp/ImHexScript/Program.cs +++ b/plugins/script_loader/templates/CSharp/ImHexScript/Program.cs @@ -1,10 +1,13 @@ using ImHex; -using System.Drawing; +class Script { + + public static void OnLoad() { + // This function is executed the first time the Plugin is loaded + } -class Script -{ public static void Main() { - UI.ShowMessageBox("Hello World!"); + // This function is executed when the plugin is selected in the "Run Script..." menu } + } \ No newline at end of file