windows: Fix the issue where ags.dll couldn’t be replaced during update (#35877)

Release Notes:

- N/A

---------

Co-authored-by: Kirill Bulatov <kirill@zed.dev>
This commit is contained in:
张小白 2025-08-08 22:42:20 +08:00 committed by GitHub
parent db901278f2
commit 2a310d78e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,15 +4,16 @@ use ::util::ResultExt;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use windows::{ use windows::{
Win32::{ Win32::{
Foundation::{HMODULE, HWND}, Foundation::{FreeLibrary, HMODULE, HWND},
Graphics::{ Graphics::{
Direct3D::*, Direct3D::*,
Direct3D11::*, Direct3D11::*,
DirectComposition::*, DirectComposition::*,
Dxgi::{Common::*, *}, Dxgi::{Common::*, *},
}, },
System::LibraryLoader::LoadLibraryA,
}, },
core::Interface, core::{Interface, PCSTR},
}; };
use crate::{ use crate::{
@ -1618,17 +1619,32 @@ pub(crate) mod shader_resources {
} }
} }
fn with_dll_library<R, F>(dll_name: PCSTR, f: F) -> Result<R>
where
F: FnOnce(HMODULE) -> Result<R>,
{
let library = unsafe {
LoadLibraryA(dll_name).with_context(|| format!("Loading dll: {}", dll_name.display()))?
};
let result = f(library);
unsafe {
FreeLibrary(library)
.with_context(|| format!("Freeing dll: {}", dll_name.display()))
.log_err();
}
result
}
mod nvidia { mod nvidia {
use std::{ use std::{
ffi::CStr, ffi::CStr,
os::raw::{c_char, c_int, c_uint}, os::raw::{c_char, c_int, c_uint},
}; };
use anyhow::{Context, Result}; use anyhow::Result;
use windows::{ use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s};
Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA},
core::s, use crate::platform::windows::directx_renderer::with_dll_library;
};
// https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180 // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180
const NVAPI_SHORT_STRING_MAX: usize = 64; const NVAPI_SHORT_STRING_MAX: usize = 64;
@ -1645,13 +1661,12 @@ mod nvidia {
) -> c_int; ) -> c_int;
pub(super) fn get_driver_version() -> Result<String> { pub(super) fn get_driver_version() -> Result<String> {
unsafe { #[cfg(target_pointer_width = "64")]
// Try to load the NVIDIA driver DLL let nvidia_dll_name = s!("nvapi64.dll");
#[cfg(target_pointer_width = "64")] #[cfg(target_pointer_width = "32")]
let nvidia_dll = LoadLibraryA(s!("nvapi64.dll")).context("Can't load nvapi64.dll")?; let nvidia_dll_name = s!("nvapi.dll");
#[cfg(target_pointer_width = "32")]
let nvidia_dll = LoadLibraryA(s!("nvapi.dll")).context("Can't load nvapi.dll")?;
with_dll_library(nvidia_dll_name, |nvidia_dll| unsafe {
let nvapi_query_addr = GetProcAddress(nvidia_dll, s!("nvapi_QueryInterface")) let nvapi_query_addr = GetProcAddress(nvidia_dll, s!("nvapi_QueryInterface"))
.ok_or_else(|| anyhow::anyhow!("Failed to get nvapi_QueryInterface address"))?; .ok_or_else(|| anyhow::anyhow!("Failed to get nvapi_QueryInterface address"))?;
let nvapi_query: extern "C" fn(u32) -> *mut () = std::mem::transmute(nvapi_query_addr); let nvapi_query: extern "C" fn(u32) -> *mut () = std::mem::transmute(nvapi_query_addr);
@ -1686,18 +1701,17 @@ mod nvidia {
minor, minor,
branch_string.to_string_lossy() branch_string.to_string_lossy()
)) ))
} })
} }
} }
mod amd { mod amd {
use std::os::raw::{c_char, c_int, c_void}; use std::os::raw::{c_char, c_int, c_void};
use anyhow::{Context, Result}; use anyhow::Result;
use windows::{ use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s};
Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA},
core::s, use crate::platform::windows::directx_renderer::with_dll_library;
};
// https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145 // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145
const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12); const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12);
@ -1731,14 +1745,12 @@ mod amd {
type agsDeInitialize_t = unsafe extern "C" fn(context: *mut AGSContext) -> c_int; type agsDeInitialize_t = unsafe extern "C" fn(context: *mut AGSContext) -> c_int;
pub(super) fn get_driver_version() -> Result<String> { pub(super) fn get_driver_version() -> Result<String> {
unsafe { #[cfg(target_pointer_width = "64")]
#[cfg(target_pointer_width = "64")] let amd_dll_name = s!("amd_ags_x64.dll");
let amd_dll = #[cfg(target_pointer_width = "32")]
LoadLibraryA(s!("amd_ags_x64.dll")).context("Failed to load AMD AGS library")?; let amd_dll_name = s!("amd_ags_x86.dll");
#[cfg(target_pointer_width = "32")]
let amd_dll =
LoadLibraryA(s!("amd_ags_x86.dll")).context("Failed to load AMD AGS library")?;
with_dll_library(amd_dll_name, |amd_dll| unsafe {
let ags_initialize_addr = GetProcAddress(amd_dll, s!("agsInitialize")) let ags_initialize_addr = GetProcAddress(amd_dll, s!("agsInitialize"))
.ok_or_else(|| anyhow::anyhow!("Failed to get agsInitialize address"))?; .ok_or_else(|| anyhow::anyhow!("Failed to get agsInitialize address"))?;
let ags_deinitialize_addr = GetProcAddress(amd_dll, s!("agsDeInitialize")) let ags_deinitialize_addr = GetProcAddress(amd_dll, s!("agsDeInitialize"))
@ -1784,7 +1796,7 @@ mod amd {
ags_deinitialize(context); ags_deinitialize(context);
Ok(format!("{} ({})", software_version, driver_version)) Ok(format!("{} ({})", software_version, driver_version))
} })
} }
} }