streamline section-searching code

This commit is contained in:
beerpiss 2024-06-30 02:48:34 +07:00
parent c28815b544
commit a55c2bdd55

View File

@ -1,7 +1,6 @@
mod external; mod external;
use std::{ use std::{
ffi::CStr,
io::{self, Read}, io::{self, Read},
mem::{self}, mem::{self},
path::Path, path::Path,
@ -289,24 +288,22 @@ pub fn self_update(module: &LibraryHandle) -> Result<bool, SelfUpdateError> {
let section_header_offset = (&(*nt_header).OptionalHeader as *const _ as *const u8) let section_header_offset = (&(*nt_header).OptionalHeader as *const _ as *const u8)
.byte_add((*nt_header).FileHeader.SizeOfOptionalHeader as usize) .byte_add((*nt_header).FileHeader.SizeOfOptionalHeader as usize)
as *const IMAGE_SECTION_HEADER; as *const IMAGE_SECTION_HEADER;
let section_header = (0..number_of_sections)
.find_map(|i| {
let header = *section_header_offset
.byte_add(mem::size_of::<IMAGE_SECTION_HEADER>() * i as usize);
for i in 0..number_of_sections { if &header.Name == b".rtext\0\0" {
let section_header = *section_header_offset.byte_add(40 * i as usize); Some(header)
let section_name = CStr::from_bytes_until_nul(&section_header.Name) } else {
.unwrap() None
.to_str()
.unwrap();
if section_name != ".rtext" {
continue;
} }
})
let src_addr = module .ok_or(SelfUpdateError::NoUpdaterCodeSection)?;
let section_addr = module
.handle() .handle()
.byte_add(section_header.VirtualAddress as usize) .byte_add(section_header.VirtualAddress as usize) as *mut u8;
as *mut u8;
let section_size = *section_header.Misc.VirtualSize() as usize; let section_size = *section_header.Misc.VirtualSize() as usize;
let dst_addr = VirtualAlloc( let dst_addr = VirtualAlloc(
ptr::null_mut(), ptr::null_mut(),
section_size, section_size,
@ -320,13 +317,9 @@ pub fn self_update(module: &LibraryHandle) -> Result<bool, SelfUpdateError> {
debug!( debug!(
"Copying updater code section from {:p} to {:p}", "Copying updater code section from {:p} to {:p}",
src_addr, dst_addr section_addr, dst_addr
); );
std::ptr::copy_nonoverlapping(src_addr, dst_addr, section_size); std::ptr::copy_nonoverlapping(section_addr, dst_addr, section_size);
let updater_start_address = (replace_with_new_library as PROC)
.byte_add(dst_addr as usize)
.byte_sub(src_addr as usize);
debug!("Making updater code executable"); debug!("Making updater code executable");
let mut old_protect = 0u32; let mut old_protect = 0u32;
@ -353,20 +346,19 @@ pub fn self_update(module: &LibraryHandle) -> Result<bool, SelfUpdateError> {
debug!("Allocated heap for updater code at {heap:p}"); debug!("Allocated heap for updater code at {heap:p}");
(*heap).module = module.handle(); (*heap).module = module.handle();
let old = U16CString::from_str_truncate(module_filename); for (i, c) in module_filename.encode_utf16().enumerate() {
let new = U16CString::from_str_truncate(new_module_filename); (*heap).old[i] = c;
std::ptr::copy_nonoverlapping( }
old.as_ptr(), for (i, c) in new_module_filename.encode_utf16().enumerate() {
(*heap).old.as_mut_ptr(), (*heap).new[i] = c;
old.as_slice().len(), }
);
std::ptr::copy_nonoverlapping( let updater_start_address = (replace_with_new_library as PROC)
new.as_ptr(), .byte_add(dst_addr as usize)
(*heap).new.as_mut_ptr(), .byte_sub(section_addr as usize);
new.as_slice().len(),
);
debug!("Executing updater code at {updater_start_address:p}"); debug!("Executing updater code at {updater_start_address:p}");
let handle = CreateThread( let handle = CreateThread(
ptr::null_mut(), ptr::null_mut(),
0, 0,
@ -386,10 +378,7 @@ pub fn self_update(module: &LibraryHandle) -> Result<bool, SelfUpdateError> {
}); });
} }
return Ok(true); Ok(true)
}
Err(SelfUpdateError::NoUpdaterCodeSection)
} }
} }