diff --git a/crates/gpui/src/platform/windows/directx_renderer.rs b/crates/gpui/src/platform/windows/directx_renderer.rs index ac285b79ac..585b1dab1c 100644 --- a/crates/gpui/src/platform/windows/directx_renderer.rs +++ b/crates/gpui/src/platform/windows/directx_renderer.rs @@ -4,15 +4,16 @@ use ::util::ResultExt; use anyhow::{Context, Result}; use windows::{ Win32::{ - Foundation::{HMODULE, HWND}, + Foundation::{FreeLibrary, HMODULE, HWND}, Graphics::{ Direct3D::*, Direct3D11::*, DirectComposition::*, Dxgi::{Common::*, *}, }, + System::LibraryLoader::LoadLibraryA, }, - core::Interface, + core::{Interface, PCSTR}, }; use crate::{ @@ -1618,17 +1619,32 @@ pub(crate) mod shader_resources { } } +fn with_dll_library(dll_name: PCSTR, f: F) -> Result +where + F: FnOnce(HMODULE) -> Result, +{ + let library = unsafe { + LoadLibraryA(dll_name).with_context(|| format!("Loading dll: {}", dll_name.display()))? + }; + let result = f(library); + unsafe { + FreeLibrary(library) + .with_context(|| format!("Freeing dll: {}", dll_name.display())) + .log_err(); + } + result +} + mod nvidia { use std::{ ffi::CStr, os::raw::{c_char, c_int, c_uint}, }; - use anyhow::{Context, Result}; - use windows::{ - Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, - core::s, - }; + use anyhow::Result; + use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s}; + + use crate::platform::windows::directx_renderer::with_dll_library; // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180 const NVAPI_SHORT_STRING_MAX: usize = 64; @@ -1645,13 +1661,12 @@ mod nvidia { ) -> c_int; pub(super) fn get_driver_version() -> Result { - unsafe { - // Try to load the NVIDIA driver DLL - #[cfg(target_pointer_width = "64")] - let nvidia_dll = LoadLibraryA(s!("nvapi64.dll")).context("Can't load nvapi64.dll")?; - #[cfg(target_pointer_width = "32")] - let nvidia_dll = LoadLibraryA(s!("nvapi.dll")).context("Can't load nvapi.dll")?; + #[cfg(target_pointer_width = "64")] + let nvidia_dll_name = s!("nvapi64.dll"); + #[cfg(target_pointer_width = "32")] + let nvidia_dll_name = s!("nvapi.dll"); + with_dll_library(nvidia_dll_name, |nvidia_dll| unsafe { let nvapi_query_addr = GetProcAddress(nvidia_dll, s!("nvapi_QueryInterface")) .ok_or_else(|| anyhow::anyhow!("Failed to get nvapi_QueryInterface address"))?; let nvapi_query: extern "C" fn(u32) -> *mut () = std::mem::transmute(nvapi_query_addr); @@ -1686,18 +1701,17 @@ mod nvidia { minor, branch_string.to_string_lossy() )) - } + }) } } mod amd { use std::os::raw::{c_char, c_int, c_void}; - use anyhow::{Context, Result}; - use windows::{ - Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, - core::s, - }; + use anyhow::Result; + use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s}; + + use crate::platform::windows::directx_renderer::with_dll_library; // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145 const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12); @@ -1731,14 +1745,12 @@ mod amd { type agsDeInitialize_t = unsafe extern "C" fn(context: *mut AGSContext) -> c_int; pub(super) fn get_driver_version() -> Result { - unsafe { - #[cfg(target_pointer_width = "64")] - let amd_dll = - LoadLibraryA(s!("amd_ags_x64.dll")).context("Failed to load AMD AGS library")?; - #[cfg(target_pointer_width = "32")] - let amd_dll = - LoadLibraryA(s!("amd_ags_x86.dll")).context("Failed to load AMD AGS library")?; + #[cfg(target_pointer_width = "64")] + let amd_dll_name = s!("amd_ags_x64.dll"); + #[cfg(target_pointer_width = "32")] + let amd_dll_name = s!("amd_ags_x86.dll"); + with_dll_library(amd_dll_name, |amd_dll| unsafe { let ags_initialize_addr = GetProcAddress(amd_dll, s!("agsInitialize")) .ok_or_else(|| anyhow::anyhow!("Failed to get agsInitialize address"))?; let ags_deinitialize_addr = GetProcAddress(amd_dll, s!("agsDeInitialize")) @@ -1784,7 +1796,7 @@ mod amd { ags_deinitialize(context); Ok(format!("{} ({})", software_version, driver_version)) - } + }) } }