diff --git a/src/updater/mod.rs b/src/updater/mod.rs index 0a8981c..c15c3f9 100644 --- a/src/updater/mod.rs +++ b/src/updater/mod.rs @@ -1,7 +1,6 @@ mod external; use std::{ - ffi::CStr, io::{self, Read}, mem::{self}, path::Path, @@ -289,107 +288,97 @@ pub fn self_update(module: &LibraryHandle) -> Result { let section_header_offset = (&(*nt_header).OptionalHeader as *const _ as *const u8) .byte_add((*nt_header).FileHeader.SizeOfOptionalHeader as usize) as *const IMAGE_SECTION_HEADER; + let section_header = (0..number_of_sections) + .find_map(|i| { + let header = *section_header_offset + .byte_add(mem::size_of::() * i as usize); - for i in 0..number_of_sections { - let section_header = *section_header_offset.byte_add(40 * i as usize); - let section_name = CStr::from_bytes_until_nul(§ion_header.Name) - .unwrap() - .to_str() - .unwrap(); + if &header.Name == b".rtext\0\0" { + Some(header) + } else { + None + } + }) + .ok_or(SelfUpdateError::NoUpdaterCodeSection)?; + let section_addr = module + .handle() + .byte_add(section_header.VirtualAddress as usize) as *mut u8; + let section_size = *section_header.Misc.VirtualSize() as usize; + let dst_addr = VirtualAlloc( + ptr::null_mut(), + section_size, + MEM_COMMIT | MEM_RESERVE, + PAGE_READWRITE, + ) as *mut u8; - if section_name != ".rtext" { - continue; - } - - let src_addr = module - .handle() - .byte_add(section_header.VirtualAddress as usize) - as *mut u8; - let section_size = *section_header.Misc.VirtualSize() as usize; - - let dst_addr = VirtualAlloc( - ptr::null_mut(), - section_size, - MEM_COMMIT | MEM_RESERVE, - PAGE_READWRITE, - ) as *mut u8; - - if dst_addr.is_null() { - return Err(SelfUpdateError::FailedToAllocateMemory); - } - - debug!( - "Copying updater code section from {:p} to {:p}", - src_addr, dst_addr - ); - std::ptr::copy_nonoverlapping(src_addr, dst_addr, section_size); - - let updater_start_address = (replace_with_new_library as PROC) - .byte_add(dst_addr as usize) - .byte_sub(src_addr as usize); - - debug!("Making updater code executable"); - let mut old_protect = 0u32; - let result = VirtualProtect( - dst_addr as *mut _, - section_size, - PAGE_EXECUTE_READ, - &mut old_protect, - ); - - if result == 0 { - return Err(SelfUpdateError::FailedVirtualProtect { - errno: GetLastError(), - }); - } - - let process_heap = GetProcessHeap(); - let heap = HeapAlloc( - process_heap, - HEAP_ZERO_MEMORY, - mem::size_of::(), - ) as *mut ReplaceArgs; - - debug!("Allocated heap for updater code at {heap:p}"); - - (*heap).module = module.handle(); - let old = U16CString::from_str_truncate(module_filename); - let new = U16CString::from_str_truncate(new_module_filename); - std::ptr::copy_nonoverlapping( - old.as_ptr(), - (*heap).old.as_mut_ptr(), - old.as_slice().len(), - ); - std::ptr::copy_nonoverlapping( - new.as_ptr(), - (*heap).new.as_mut_ptr(), - new.as_slice().len(), - ); - - debug!("Executing updater code at {updater_start_address:p}"); - let handle = CreateThread( - ptr::null_mut(), - 0, - Some(std::mem::transmute::< - PROC, - unsafe extern "system" fn(*mut winapi::ctypes::c_void) -> u32, - >(updater_start_address)), - heap as *mut _, - 0, - ptr::null_mut(), - ); - - if handle.is_null() { - error!("Could not execute updater code: {}", GetLastError()); - return Err(SelfUpdateError::FailedCreateThread { - errno: GetLastError(), - }); - } - - return Ok(true); + if dst_addr.is_null() { + return Err(SelfUpdateError::FailedToAllocateMemory); } - Err(SelfUpdateError::NoUpdaterCodeSection) + debug!( + "Copying updater code section from {:p} to {:p}", + section_addr, dst_addr + ); + std::ptr::copy_nonoverlapping(section_addr, dst_addr, section_size); + + debug!("Making updater code executable"); + let mut old_protect = 0u32; + let result = VirtualProtect( + dst_addr as *mut _, + section_size, + PAGE_EXECUTE_READ, + &mut old_protect, + ); + + if result == 0 { + return Err(SelfUpdateError::FailedVirtualProtect { + errno: GetLastError(), + }); + } + + let process_heap = GetProcessHeap(); + let heap = HeapAlloc( + process_heap, + HEAP_ZERO_MEMORY, + mem::size_of::(), + ) as *mut ReplaceArgs; + + debug!("Allocated heap for updater code at {heap:p}"); + + (*heap).module = module.handle(); + for (i, c) in module_filename.encode_utf16().enumerate() { + (*heap).old[i] = c; + } + for (i, c) in new_module_filename.encode_utf16().enumerate() { + (*heap).new[i] = c; + } + + let updater_start_address = (replace_with_new_library as PROC) + .byte_add(dst_addr as usize) + .byte_sub(section_addr as usize); + + debug!("Executing updater code at {updater_start_address:p}"); + + let handle = CreateThread( + ptr::null_mut(), + 0, + Some(std::mem::transmute::< + PROC, + unsafe extern "system" fn(*mut winapi::ctypes::c_void) -> u32, + >(updater_start_address)), + heap as *mut _, + 0, + ptr::null_mut(), + ); + + if handle.is_null() { + error!("Could not execute updater code: {}", GetLastError()); + return Err(SelfUpdateError::FailedCreateThread { + errno: GetLastError(), + }); + } + + Ok(true) } }