Switch to crochet for hooking

This commit is contained in:
beerpiss 2024-03-23 16:54:58 +07:00
parent d2a276923c
commit cac66f0afe
2 changed files with 36 additions and 98 deletions

View File

@ -12,9 +12,9 @@ use std::{ptr, thread};
use ::log::{error, warn}; use ::log::{error, warn};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use url::Url; use url::Url;
use winapi::shared::minwindef::{BOOL, DWORD, HINSTANCE, LPVOID, TRUE, FALSE}; use winapi::shared::minwindef::{BOOL, DWORD, FALSE, HINSTANCE, LPVOID, TRUE};
use winapi::um::errhandlingapi::GetLastError; use winapi::um::errhandlingapi::GetLastError;
use winapi::um::handleapi::{DuplicateHandle, CloseHandle}; use winapi::um::handleapi::{CloseHandle, DuplicateHandle};
use winapi::um::processthreadsapi::{GetCurrentProcess, GetCurrentThread}; use winapi::um::processthreadsapi::{GetCurrentProcess, GetCurrentThread};
use winapi::um::synchapi::WaitForSingleObject; use winapi::um::synchapi::WaitForSingleObject;
use winapi::um::winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, SYNCHRONIZE}; use winapi::um::winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, SYNCHRONIZE};
@ -123,11 +123,14 @@ extern "system" fn DllMain(dll_module: HINSTANCE, call_reason: DWORD, reserved:
&mut cur_thread, &mut cur_thread,
SYNCHRONIZE, SYNCHRONIZE,
FALSE, FALSE,
0 0,
); );
if result == 0 { if result == 0 {
warn!("Failed to get current thread handle, error code: {}", GetLastError()); warn!(
"Failed to get current thread handle, error code: {}",
GetLastError()
);
} }
(ThreadHandle(cur_thread), result) (ThreadHandle(cur_thread), result)
@ -137,7 +140,7 @@ extern "system" fn DllMain(dll_module: HINSTANCE, call_reason: DWORD, reserved:
if result != 0 { if result != 0 {
unsafe { cur_thread.wait_and_close(100) }; unsafe { cur_thread.wait_and_close(100) };
} }
if let Err(err) = hook_init() { if let Err(err) = hook_init() {
error!("Failed to initialize hook: {:#}", err); error!("Failed to initialize hook: {:#}", err);
} }

View File

@ -1,5 +1,4 @@
use std::{ use std::{
ffi::CString,
fmt::Debug, fmt::Debug,
fs::File, fs::File,
io::Read, io::Read,
@ -10,18 +9,12 @@ use std::{
use ::log::{debug, error, info}; use ::log::{debug, error, info};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use log::warn; use log::warn;
use retour::static_detour;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use widestring::U16CString; use widestring::U16CString;
use winapi::{ use winapi::{
ctypes::c_void, ctypes::c_void,
shared::minwindef::{__some_function, BOOL, DWORD, FALSE, LPCVOID, LPDWORD, LPVOID, MAX_PATH}, shared::minwindef::{BOOL, DWORD, FALSE, LPCVOID, LPDWORD, LPVOID, MAX_PATH},
um::{ um::{errhandlingapi::GetLastError, winbase::GetPrivateProfileStringW, winhttp::HINTERNET},
errhandlingapi::GetLastError,
libloaderapi::{GetModuleHandleA, GetProcAddress},
winbase::GetPrivateProfileStringW,
winhttp::HINTERNET,
},
}; };
use crate::{ use crate::{
@ -42,14 +35,6 @@ use crate::{
pub static GAME_MAJOR_VERSION: AtomicU16 = AtomicU16::new(0); pub static GAME_MAJOR_VERSION: AtomicU16 = AtomicU16::new(0);
pub static PB_IMPORTED: AtomicBool = AtomicBool::new(true); pub static PB_IMPORTED: AtomicBool = AtomicBool::new(true);
type WinHttpWriteDataFunc = unsafe extern "system" fn(HINTERNET, LPCVOID, DWORD, LPDWORD) -> BOOL;
type WinHttpReadDataFunc = unsafe extern "system" fn(HINTERNET, LPVOID, DWORD, LPDWORD) -> BOOL;
static_detour! {
static DetourWriteData: unsafe extern "system" fn (HINTERNET, LPCVOID, DWORD, LPDWORD) -> BOOL;
static DetourReadData: unsafe extern "system" fn(HINTERNET, LPVOID, DWORD, LPDWORD) -> BOOL;
}
pub fn hook_init() -> Result<()> { pub fn hook_init() -> Result<()> {
if !CONFIGURATION.general.enable { if !CONFIGURATION.general.enable {
return Ok(()); return Ok(());
@ -113,12 +98,9 @@ pub fn hook_init() -> Result<()> {
let icf = decode_icf(&mut icf1_buf).map_err(|err| anyhow!("Reading ICF failed: {:#}", err))?; let icf = decode_icf(&mut icf1_buf).map_err(|err| anyhow!("Reading ICF failed: {:#}", err))?;
for entry in icf { for entry in icf {
match entry { if let IcfData::App(app) = entry {
IcfData::App(app) => { info!("Running on {} {}", app.id, app.version);
info!("Running on {} {}", app.id, app.version); GAME_MAJOR_VERSION.store(app.version.major, Ordering::Relaxed);
GAME_MAJOR_VERSION.store(app.version.major, Ordering::Relaxed);
}
_ => {}
} }
} }
@ -133,7 +115,7 @@ pub fn hook_init() -> Result<()> {
TachiResponse::Ok(resp) => { TachiResponse::Ok(resp) => {
if !resp.body.permissions.iter().any(|v| v == "submit_score") { if !resp.body.permissions.iter().any(|v| v == "submit_score") {
return Err(anyhow!( return Err(anyhow!(
"API key has insufficient permission. The permission submit_score must be set." "API key has insufficient permissions. The permission submit_score must be set."
)); ));
} }
@ -149,39 +131,13 @@ pub fn hook_init() -> Result<()> {
info!("Logged in to Tachi with userID {user_id}"); info!("Logged in to Tachi with userID {user_id}");
debug!("Acquring addresses");
let winhttpwritedata = unsafe {
let addr = get_proc_address("winhttp.dll", "WinHttpWriteData")
.map_err(|err| anyhow!("{:#}", err))?;
debug!("WinHttpWriteData: winhttp.dll!{:p}", addr);
std::mem::transmute::<_, WinHttpWriteDataFunc>(addr)
};
let winhttpreaddata = unsafe {
let addr = get_proc_address("winhttp.dll", "WinHttpReadData")
.map_err(|err| anyhow!("{:#}", err))?;
debug!("WinHttpReadData: winhttp.dll!{:p}", addr);
std::mem::transmute::<_, WinHttpReadDataFunc>(addr)
};
debug!("Initializing detours"); debug!("Initializing detours");
unsafe { crochet::enable!(winhttpwritedata_hook_wrapper)?;
debug!("Initializing WinHttpWriteData detour");
DetourWriteData
.initialize(winhttpwritedata, winhttpwritedata_hook_wrapper)?
.enable()?;
debug!("Initializing WinHttpReadData detour"); if CONFIGURATION.general.export_pbs || cfg!(debug_assertions) {
DetourReadData crochet::enable!(winhttpreaddata_hook_wrapper)?;
.initialize(winhttpreaddata, winhttpreaddata_hook_wrapper)? }
.enable()?;
};
info!("Hook successfully initialized"); info!("Hook successfully initialized");
@ -193,17 +149,18 @@ pub fn hook_release() -> Result<()> {
return Ok(()); return Ok(());
} }
if DetourWriteData.is_enabled() { if crochet::is_enabled!(winhttpreaddata_hook_wrapper) {
unsafe { DetourWriteData.disable()? }; crochet::disable!(winhttpreaddata_hook_wrapper)?;
} }
if DetourReadData.is_enabled() { if crochet::is_enabled!(winhttpwritedata_hook_wrapper) {
unsafe { DetourReadData.disable()? }; crochet::disable!(winhttpwritedata_hook_wrapper)?;
} }
Ok(()) Ok(())
} }
#[crochet::hook(compile_check, "winhttp.dll", "WinHttpReadData")]
fn winhttpreaddata_hook_wrapper( fn winhttpreaddata_hook_wrapper(
h_request: HINTERNET, h_request: HINTERNET,
lp_buffer: LPVOID, lp_buffer: LPVOID,
@ -212,14 +169,12 @@ fn winhttpreaddata_hook_wrapper(
) -> BOOL { ) -> BOOL {
debug!("hit winhttpreaddata"); debug!("hit winhttpreaddata");
let result = unsafe { let result = call_original!(
DetourReadData.call( h_request,
h_request, lp_buffer,
lp_buffer, dw_number_of_bytes_to_read,
dw_number_of_bytes_to_read, lpdw_number_of_bytes_read
lpdw_number_of_bytes_read, );
)
};
if result == FALSE { if result == FALSE {
let ec = unsafe { GetLastError() }; let ec = unsafe { GetLastError() };
@ -265,6 +220,7 @@ fn winhttpreaddata_hook_wrapper(
result result
} }
#[crochet::hook(compile_check, "winhttp.dll", "WinHttpWriteData")]
fn winhttpwritedata_hook_wrapper( fn winhttpwritedata_hook_wrapper(
h_request: HINTERNET, h_request: HINTERNET,
lp_buffer: LPCVOID, lp_buffer: LPCVOID,
@ -295,14 +251,12 @@ fn winhttpwritedata_hook_wrapper(
error!("{err:?}"); error!("{err:?}");
} }
unsafe { call_original!(
DetourWriteData.call( h_request,
h_request, lp_buffer,
lp_buffer, dw_number_of_bytes_to_write,
dw_number_of_bytes_to_write, lpdw_number_of_bytes_written
lpdw_number_of_bytes_written, )
)
}
} }
/// Common hook for WinHttpWriteData/WinHttpReadData. The flow is similar for both /// Common hook for WinHttpWriteData/WinHttpReadData. The flow is similar for both
@ -399,22 +353,3 @@ fn winhttprwdata_hook<'a, T: Debug + DeserializeOwned + ToTachiImport + 'static>
Ok(()) Ok(())
} }
fn get_proc_address(module: &str, function: &str) -> Result<*mut __some_function> {
let module_name = CString::new(module).unwrap();
let fun_name = CString::new(function).unwrap();
let module = unsafe { GetModuleHandleA(module_name.as_ptr()) };
if (module as *const c_void).is_null() {
let ec = unsafe { GetLastError() };
return Err(anyhow!("could not get module handle, error code {ec}"));
}
let addr = unsafe { GetProcAddress(module, fun_name.as_ptr()) };
if (addr as *const c_void).is_null() {
let ec = unsafe { GetLastError() };
return Err(anyhow!("could not get function address, error code {ec}"));
}
Ok(addr)
}