From 756d231ff90a63b216385fcd9dd013cb8aee5b64 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9o=20Lam?= <leo@leolam.fr>
Date: Mon, 25 Nov 2019 20:49:48 +0100
Subject: [PATCH] file_sys: Add support for BPS patches

The BPS format allows distributing patches that are smaller and that do
not contain copyrighted content if data is relocated
(unlike non-trivial IPS patches).

This is essential for games such as MM3D that have three barely
different code revisions. Supporting all three versions would
demand an unreasonable amount of work; with BPS patches only one
version has to be supported.
---
 src/core/file_sys/ncch_container.cpp |  32 ++--
 src/core/file_sys/patch.cpp          | 213 +++++++++++++++++++++++++++
 src/core/file_sys/patch.h            |   2 +
 3 files changed, 236 insertions(+), 11 deletions(-)

diff --git a/src/core/file_sys/ncch_container.cpp b/src/core/file_sys/ncch_container.cpp
index ec6bb3827..f0687fa9e 100644
--- a/src/core/file_sys/ncch_container.cpp
+++ b/src/core/file_sys/ncch_container.cpp
@@ -508,21 +508,31 @@ Loader::ResultStatus NCCHContainer::LoadSectionExeFS(const char* name, std::vect
 }
 
 Loader::ResultStatus NCCHContainer::ApplyCodePatch(std::vector<u8>& code) const {
-    const std::string override_ips = filepath + ".exefsdir/code.ips";
+    struct PatchLocation {
+        std::string path;
+        bool (*patch_fn)(const std::vector<u8>& patch, std::vector<u8>& code);
+    };
+    const std::array<PatchLocation, 2> patch_paths{{
+        {filepath + ".exefsdir/code.ips", Patch::ApplyIpsPatch},
+        {filepath + ".exefsdir/code.bps", Patch::ApplyBpsPatch},
+    }};
 
-    FileUtil::IOFile ips_file{override_ips, "rb"};
-    if (!ips_file)
-        return Loader::ResultStatus::ErrorNotUsed;
+    for (const PatchLocation& info : patch_paths) {
+        FileUtil::IOFile file{info.path, "rb"};
+        if (!file)
+            continue;
 
-    std::vector<u8> ips(ips_file.GetSize());
-    if (ips_file.ReadBytes(ips.data(), ips.size()) != ips.size())
-        return Loader::ResultStatus::Error;
+        std::vector<u8> patch(file.GetSize());
+        if (file.ReadBytes(patch.data(), patch.size()) != patch.size())
+            return Loader::ResultStatus::Error;
 
-    LOG_INFO(Service_FS, "File {} patching code.bin", override_ips);
-    if (!Patch::ApplyIpsPatch(ips, code))
-        return Loader::ResultStatus::Error;
+        LOG_INFO(Service_FS, "File {} patching code.bin", info.path);
+        if (!info.patch_fn(patch, code))
+            return Loader::ResultStatus::Error;
 
-    return Loader::ResultStatus::Success;
+        return Loader::ResultStatus::Success;
+    }
+    return Loader::ResultStatus::ErrorNotUsed;
 }
 
 Loader::ResultStatus NCCHContainer::LoadOverrideExeFSSection(const char* name,
diff --git a/src/core/file_sys/patch.cpp b/src/core/file_sys/patch.cpp
index 4bd7d7246..c93767165 100644
--- a/src/core/file_sys/patch.cpp
+++ b/src/core/file_sys/patch.cpp
@@ -2,8 +2,13 @@
 // Licensed under GPLv2 or any later version
 // Refer to the license.txt file included.
 
+#include <array>
 #include <cstring>
+#include <optional>
 #include <string>
+#include <string_view>
+#include <type_traits>
+#include <boost/crc.hpp>
 #include "common/logging/log.h"
 #include "core/file_sys/patch.h"
 
@@ -52,4 +57,212 @@ bool ApplyIpsPatch(const std::vector<u8>& ips, std::vector<u8>& buffer) {
     return true;
 }
 
+namespace Bps {
+
+// The BPS format uses variable length encoding for all integers.
+// Realistically uint32s are more than enough for code patching.
+using Number = u32;
+
+constexpr std::size_t FooterSize = 12;
+
+// The BPS format uses CRC32 checksums.
+static u32 crc32(const u8* data, std::size_t size) {
+    boost::crc_32_type result;
+    result.process_bytes(data, size);
+    return result.checksum();
+}
+
+// Utility class to make keeping track of offsets and bound checks less error prone.
+template <typename T>
+class Stream {
+public:
+    Stream(T* ptr, std::size_t size) : m_ptr{ptr}, m_size{size} {}
+
+    bool Read(void* buffer, std::size_t length) {
+        if (m_offset + length > m_size)
+            return false;
+        std::memcpy(buffer, m_ptr + m_offset, length);
+        m_offset += length;
+        return true;
+    }
+
+    template <typename OtherType>
+    bool CopyFrom(Stream<OtherType>& other, std::size_t length) {
+        if (m_offset + length > m_size)
+            return false;
+        if (!other.Read(m_ptr + m_offset, length))
+            return false;
+        m_offset += length;
+        return true;
+    }
+
+    template <typename ValueType>
+    std::optional<ValueType> Read() {
+        static_assert(std::is_pod_v<ValueType>);
+        ValueType val{};
+        if (!Read(&val, sizeof(val)))
+            return std::nullopt;
+        return val;
+    }
+
+    Number ReadNumber() {
+        Number data = 0, shift = 1;
+        std::optional<u8> x;
+        while ((x = Read<u8>())) {
+            data += (*x & 0x7f) * shift;
+            if (*x & 0x80)
+                break;
+            shift <<= 7;
+            data += shift;
+        }
+        return data;
+    }
+
+    auto data() const {
+        return m_ptr;
+    }
+
+    std::size_t size() const {
+        return m_size;
+    }
+
+    std::size_t Tell() const {
+        return m_offset;
+    }
+
+    bool Seek(size_t offset) {
+        if (offset > m_size)
+            return false;
+        m_offset = offset;
+        return true;
+    }
+
+private:
+    T* m_ptr = nullptr;
+    std::size_t m_size = 0;
+    std::size_t m_offset = 0;
+};
+
+class PatchApplier {
+public:
+    PatchApplier(Stream<const u8> source, Stream<u8> target, Stream<const u8> patch)
+        : m_source{source}, m_target{target}, m_patch{patch} {}
+
+    bool Apply() {
+        const auto magic = *m_patch.Read<std::array<char, 4>>();
+        if (std::string_view(magic.data(), magic.size()) != "BPS1") {
+            LOG_ERROR(Service_FS, "Invalid BPS magic");
+            return false;
+        }
+
+        const Bps::Number source_size = m_patch.ReadNumber();
+        const Bps::Number target_size = m_patch.ReadNumber();
+        const Bps::Number metadata_size = m_patch.ReadNumber();
+        if (source_size > m_source.size() || target_size > m_target.size() || metadata_size != 0) {
+            LOG_ERROR(Service_FS, "Invalid sizes");
+            return false;
+        }
+
+        const std::size_t command_start_offset = m_patch.Tell();
+        const std::size_t command_end_offset = m_patch.size() - FooterSize;
+        m_patch.Seek(command_end_offset);
+        const u32 source_crc32 = *m_patch.Read<u32>();
+        const u32 target_crc32 = *m_patch.Read<u32>();
+        m_patch.Seek(command_start_offset);
+
+        if (crc32(m_source.data(), source_size) != source_crc32) {
+            LOG_ERROR(Service_FS, "Unexpected source hash");
+            return false;
+        }
+
+        // Process all patch commands.
+        std::memset(m_target.data(), 0, m_target.size());
+        while (m_patch.Tell() < command_end_offset) {
+            if (!HandleCommand())
+                return false;
+        }
+
+        if (crc32(m_target.data(), target_size) != target_crc32) {
+            LOG_ERROR(Service_FS, "Unexpected target hash");
+            return false;
+        }
+
+        return true;
+    }
+
+private:
+    bool HandleCommand() {
+        const std::size_t offset = m_patch.Tell();
+        const Number data = m_patch.ReadNumber();
+        const Number command = data & 3;
+        const Number length = (data >> 2) + 1;
+
+        const bool ok = [&] {
+            switch (command) {
+            case 0:
+                return SourceRead(length);
+            case 1:
+                return TargetRead(length);
+            case 2:
+                return SourceCopy(length);
+            case 3:
+                return TargetCopy(length);
+            default:
+                return false;
+            }
+        }();
+        if (!ok)
+            LOG_ERROR(Service_FS, "Failed to process command {} at 0x{:x}", command, offset);
+        return ok;
+    }
+
+    bool SourceRead(Number length) {
+        return m_source.Seek(m_target.Tell()) && m_target.CopyFrom(m_source, length);
+    }
+
+    bool TargetRead(Number length) {
+        return m_target.CopyFrom(m_patch, length);
+    }
+
+    bool SourceCopy(Number length) {
+        const Number data = m_patch.ReadNumber();
+        m_source_relative_offset += (data & 1 ? -1 : +1) * int(data >> 1);
+        if (!m_source.Seek(m_source_relative_offset) || !m_target.CopyFrom(m_source, length))
+            return false;
+        m_source_relative_offset += length;
+        return true;
+    }
+
+    bool TargetCopy(Number length) {
+        const Number data = m_patch.ReadNumber();
+        m_target_relative_offset += (data & 1 ? -1 : +1) * int(data >> 1);
+        if (m_target.Tell() + length > m_target.size())
+            return false;
+        if (m_target_relative_offset + length > m_target.size())
+            return false;
+        // Byte by byte copy.
+        for (size_t i = 0; i < length; ++i)
+            m_target.data()[m_target.Tell() + i] = m_target.data()[m_target_relative_offset++];
+        m_target.Seek(m_target.Tell() + length);
+        return true;
+    }
+
+    std::size_t m_source_relative_offset = 0;
+    std::size_t m_target_relative_offset = 0;
+    Stream<const u8> m_source;
+    Stream<u8> m_target;
+    Stream<const u8> m_patch;
+};
+
+} // namespace Bps
+
+bool ApplyBpsPatch(const std::vector<u8>& patch, std::vector<u8>& buffer) {
+    const std::vector<u8> source = buffer;
+    Bps::Stream source_stream{source.data(), source.size()};
+    Bps::Stream target_stream{buffer.data(), buffer.size()};
+    Bps::Stream patch_stream{patch.data(), patch.size()};
+    Bps::PatchApplier applier{source_stream, target_stream, patch_stream};
+    return applier.Apply();
+}
+
 } // namespace FileSys::Patch
diff --git a/src/core/file_sys/patch.h b/src/core/file_sys/patch.h
index b77ca2bcc..9a8118475 100644
--- a/src/core/file_sys/patch.h
+++ b/src/core/file_sys/patch.h
@@ -12,4 +12,6 @@ namespace FileSys::Patch {
 
 bool ApplyIpsPatch(const std::vector<u8>& patch, std::vector<u8>& buffer);
 
+bool ApplyBpsPatch(const std::vector<u8>& patch, std::vector<u8>& buffer);
+
 } // namespace FileSys::Patch