From d769041916df4efb48fd212683dc9842decc10f1 Mon Sep 17 00:00:00 2001
From: Chris Robinson <chris.kcat@gmail.com>
Date: Fri, 2 Jun 2023 19:12:48 -0700
Subject: Start the WASAPI COM thread when initializing the backend

COM doesn't make this easy. We want to be able to get device change
notifications without an open device, but we need an IMMDeviceEnumerator object
to register the notification client, which requires COM to be initialized. COM
must then stay initialized while we have the IMMDeviceEnumerator object, which
we can't assume for the calling thread so it has to be done in the COM thread.

Consequently, the COM thread must stay alive and can't quit while the DLL is
loaded if we want to get those notifications without an open device, and as
there's no reliable way to make the thread quit during DLL unload, the DLL must
stay pinned until process exit.
---
 alc/backends/wasapi.cpp | 163 +++++++++++-------------------------------------
 1 file changed, 38 insertions(+), 125 deletions(-)

(limited to 'alc/backends/wasapi.cpp')

diff --git a/alc/backends/wasapi.cpp b/alc/backends/wasapi.cpp
index a554d1b4..333c0dc7 100644
--- a/alc/backends/wasapi.cpp
+++ b/alc/backends/wasapi.cpp
@@ -222,7 +222,6 @@ struct DeviceHandle
 using EventRegistrationToken = Windows::Foundation::EventRegistrationToken;
 #else
 using DeviceHandle           = ComPtr<IMMDevice>;
-using EventRegistrationToken = void*;
 #endif
 
 #if defined(ALSOFT_UWP)
@@ -231,7 +230,6 @@ struct DeviceHelper final : public IActivateAudioInterfaceCompletionHandler
 struct DeviceHelper final : private IMMNotificationClient
 #endif
 {
-public:
     DeviceHelper()
     {
 #if defined(ALSOFT_UWP)
@@ -259,13 +257,6 @@ public:
                         msg);
                 }
             });
-#else
-        HRESULT hr{CoCreateInstance(CLSID_MMDeviceEnumerator, nullptr, CLSCTX_INPROC_SERVER,
-            IID_IMMDeviceEnumerator, al::out_ptr(mEnumerator))};
-        if(SUCCEEDED(hr))
-            mEnumerator->RegisterEndpointNotificationCallback(this);
-        else
-            WARN("Failed to create IMMDeviceEnumerator instance: 0x%08lx\n", hr);
 #endif
     }
     ~DeviceHelper()
@@ -287,13 +278,7 @@ public:
     /** -------------------------- IUnkonwn ----------------------------- */
     std::atomic<ULONG> mRefCount{1};
     STDMETHODIMP_(ULONG) AddRef() noexcept override { return mRefCount.fetch_add(1u) + 1u; }
-
-    STDMETHODIMP_(ULONG) Release() noexcept override
-    {
-        auto ret = mRefCount.fetch_sub(1u) - 1u;
-        if(!ret) delete this;
-        return ret;
-    }
+    STDMETHODIMP_(ULONG) Release() noexcept override { return mRefCount.fetch_sub(1u) - 1u; }
 
     STDMETHODIMP QueryInterface(const IID& IId, void **UnknownPtrPtr) noexcept override
     {
@@ -394,7 +379,22 @@ public:
 #endif
 
     /** -------------------------- DeviceHelper ----------------------------- */
-    HRESULT OpenDevice(LPCWSTR devid, EDataFlow flow, DeviceHandle& device) 
+    HRESULT init()
+    {
+#if !defined(ALSOFT_UWP)
+        HRESULT hr{CoCreateInstance(CLSID_MMDeviceEnumerator, nullptr, CLSCTX_INPROC_SERVER,
+            IID_IMMDeviceEnumerator, al::out_ptr(mEnumerator))};
+        if(SUCCEEDED(hr))
+            mEnumerator->RegisterEndpointNotificationCallback(this);
+        else
+            WARN("Failed to create IMMDeviceEnumerator instance: 0x%08lx\n", hr);
+        return hr;
+#else
+        return S_OK;
+#endif
+    }
+
+    HRESULT openDevice(LPCWSTR devid, EDataFlow flow, DeviceHandle& device)
     {
 #if !defined(ALSOFT_UWP)
         HRESULT hr{E_POINTER};
@@ -425,7 +425,7 @@ public:
 #endif
     }
 
-    HRESULT ActivateAudioClient(_In_ DeviceHandle& device, void **ppv)
+    static HRESULT ActivateAudioClient(_In_ DeviceHandle& device, void **ppv)
     {
 #if !defined(ALSOFT_UWP)
         HRESULT hr{device->Activate(__uuidof(IAudioClient3), CLSCTX_INPROC_SERVER, nullptr, ppv)};
@@ -570,7 +570,6 @@ public:
             WARN("Unexpected PROPVARIANT type: 0x%04x\n", pvprop->vt);
             guid = UnknownGuid;
         }
-
 #else
         auto devInfo     = device.value;
         std::string name = wstr_to_utf8(devInfo->Name->Data());
@@ -798,7 +797,7 @@ struct WasapiProxy {
 
     virtual HRESULT resetProxy() = 0;
     virtual HRESULT startProxy() = 0;
-    virtual void  stopProxy() = 0;
+    virtual void stopProxy() = 0;
 
     struct Msg {
         MsgType mType;
@@ -808,14 +807,11 @@ struct WasapiProxy {
 
         explicit operator bool() const noexcept { return mType != MsgType::QuitThread; }
     };
-    static std::thread sThread;
     static std::deque<Msg> mMsgQueue;
     static std::mutex mMsgQueueLock;
     static std::condition_variable mMsgQueueCond;
-    static std::mutex sThreadLock;
-    static size_t sInitCount;
 
-    static ComPtr<DeviceHelper> sDeviceHelper;
+    static std::optional<DeviceHelper> sDeviceHelper;
 
     std::future<HRESULT> pushMessage(MsgType type, const char *param=nullptr)
     {
@@ -851,45 +847,11 @@ struct WasapiProxy {
     }
 
     static int messageHandler(std::promise<HRESULT> *promise);
-
-    static HRESULT InitThread()
-    {
-        std::lock_guard<std::mutex> _{sThreadLock};
-        HRESULT res{S_OK};
-        if(!sThread.joinable())
-        {
-            std::promise<HRESULT> promise;
-            auto future = promise.get_future();
-
-            sThread = std::thread{&WasapiProxy::messageHandler, &promise};
-            res = future.get();
-            if(FAILED(res))
-            {
-                sThread.join();
-                return res;
-            }
-        }
-        ++sInitCount;
-        return res;
-    }
-
-    static void DeinitThread()
-    {
-        std::lock_guard<std::mutex> _{sThreadLock};
-        if(!--sInitCount && sThread.joinable())
-        {
-            pushMessageStatic(MsgType::QuitThread);
-            sThread.join();
-        }
-    }
 };
-std::thread WasapiProxy::sThread;
 std::deque<WasapiProxy::Msg> WasapiProxy::mMsgQueue;
 std::mutex WasapiProxy::mMsgQueueLock;
 std::condition_variable WasapiProxy::mMsgQueueCond;
-std::mutex WasapiProxy::sThreadLock;
-ComPtr<DeviceHelper> WasapiProxy::sDeviceHelper;
-size_t WasapiProxy::sInitCount{0};
+std::optional<DeviceHelper> WasapiProxy::sDeviceHelper;
 
 int WasapiProxy::messageHandler(std::promise<HRESULT> *promise)
 {
@@ -902,8 +864,12 @@ int WasapiProxy::messageHandler(std::promise<HRESULT> *promise)
         promise->set_value(hr);
         return 0;
     }
-    promise->set_value(S_OK);
+
+    hr = sDeviceHelper.emplace().init();
+    promise->set_value(hr);
     promise = nullptr;
+    if(FAILED(hr))
+        goto skip_loop;
 
     TRACE("Starting message loop\n");
     while(Msg msg{popMessage()})
@@ -956,6 +922,9 @@ int WasapiProxy::messageHandler(std::promise<HRESULT> *promise)
         msg.mPromise.set_value(E_FAIL);
     }
     TRACE("Message loop finished\n");
+
+skip_loop:
+    sDeviceHelper.reset();
     CoUninitialize();
 
     return 0;
@@ -1005,10 +974,7 @@ struct WasapiPlayback final : public BackendBase, WasapiProxy {
 WasapiPlayback::~WasapiPlayback()
 {
     if(SUCCEEDED(mOpenStatus))
-    {
         pushMessage(MsgType::CloseDevice).wait();
-        DeinitThread();
-    }
     mOpenStatus = E_FAIL;
 
     if(mNotifyEvent != nullptr)
@@ -1121,13 +1087,6 @@ void WasapiPlayback::open(const char *name)
             "Failed to create notify events"};
     }
 
-    HRESULT hr{InitThread()};
-    if(FAILED(hr))
-    {
-        throw al::backend_exception{al::backend_error::DeviceError,
-            "Failed to init COM thread: 0x%08lx", hr};
-    }
-
     if(name)
     {
         if(PlaybackDevices.empty())
@@ -1142,11 +1101,8 @@ void WasapiPlayback::open(const char *name)
 
     mOpenStatus = pushMessage(MsgType::OpenDevice, name).get();
     if(FAILED(mOpenStatus))
-    {
-        DeinitThread();
         throw al::backend_exception{al::backend_error::DeviceError, "Device init failed: 0x%08lx",
             mOpenStatus};
-    }
 }
 
 HRESULT WasapiPlayback::openProxy(const char *name)
@@ -1173,14 +1129,14 @@ HRESULT WasapiPlayback::openProxy(const char *name)
         devid = iter->devid.c_str();
     }
 
-    HRESULT hr{sDeviceHelper->OpenDevice(devid, eRender, mMMDev)};
-    if (FAILED(hr))
+    HRESULT hr{sDeviceHelper->openDevice(devid, eRender, mMMDev)};
+    if(FAILED(hr))
     {
         WARN("Failed to open device \"%s\"\n", name ? name : "(default)");
         return hr;
     }
     mClient = nullptr;
-    if (name)
+    if(name)
         mDevice->DeviceName = std::string{DevNameHead} + name;
     else
         mDevice->DeviceName = DevNameHead + DeviceHelper::get_device_name_and_guid(mMMDev).first;
@@ -1657,10 +1613,7 @@ struct WasapiCapture final : public BackendBase, WasapiProxy {
 WasapiCapture::~WasapiCapture()
 {
     if(SUCCEEDED(mOpenStatus))
-    {
         pushMessage(MsgType::CloseDevice).wait();
-        DeinitThread();
-    }
     mOpenStatus = E_FAIL;
 
     if(mNotifyEvent != nullptr)
@@ -1775,13 +1728,6 @@ void WasapiCapture::open(const char *name)
             "Failed to create notify events"};
     }
 
-    HRESULT hr{InitThread()};
-    if(FAILED(hr))
-    {
-        throw al::backend_exception{al::backend_error::DeviceError,
-            "Failed to init COM thread: 0x%08lx", hr};
-    }
-
     if(name)
     {
         if(CaptureDevices.empty())
@@ -1796,13 +1742,10 @@ void WasapiCapture::open(const char *name)
 
     mOpenStatus = pushMessage(MsgType::OpenDevice, name).get();
     if(FAILED(mOpenStatus))
-    {
-        DeinitThread();
         throw al::backend_exception{al::backend_error::DeviceError, "Device init failed: 0x%08lx",
             mOpenStatus};
-    }
 
-    hr = pushMessage(MsgType::ResetDevice).get();
+    HRESULT hr{pushMessage(MsgType::ResetDevice).get()};
     if(FAILED(hr))
     {
         if(hr == E_OUTOFMEMORY)
@@ -1835,7 +1778,7 @@ HRESULT WasapiCapture::openProxy(const char *name)
         devid = iter->devid.c_str();
     }
 
-    HRESULT hr{sDeviceHelper->OpenDevice(devid, eCapture, mMMDev)};
+    HRESULT hr{sDeviceHelper->openDevice(devid, eCapture, mMMDev)};
     if (FAILED(hr))
     {
         WARN("Failed to open device \"%s\"\n", name ? name : "(default)");
@@ -2210,34 +2153,13 @@ uint WasapiCapture::availableSamples()
 bool WasapiBackendFactory::init()
 {
     static HRESULT InitResult{E_FAIL};
-
     if(FAILED(InitResult)) try
     {
-        auto res = std::async(std::launch::async, []() -> HRESULT
-        {
-            HRESULT hr{CoInitializeEx(nullptr, COINIT_MULTITHREADED)};
-            if(FAILED(hr))
-            {
-                WARN("Failed to initialize COM: 0x%08lx\n", hr);
-                return hr;
-            }
-#if !defined(ALSOFT_UWP)
-            ComPtr<IMMDeviceEnumerator> enumerator;
-            hr = CoCreateInstance(CLSID_MMDeviceEnumerator, nullptr, CLSCTX_INPROC_SERVER,
-                IID_IMMDeviceEnumerator, al::out_ptr(enumerator));
-            if(FAILED(hr))
-                WARN("Failed to create IMMDeviceEnumerator instance: 0x%08lx\n", hr);
-            enumerator = nullptr;
-#endif
-            if(SUCCEEDED(hr))
-                WasapiProxy::sDeviceHelper.reset(new DeviceHelper{});
-
-            CoUninitialize();
-
-            return hr;
-        });
+        std::promise<HRESULT> promise;
+        auto future = promise.get_future();
 
-        InitResult = res.get();
+        std::thread{&WasapiProxy::messageHandler, &promise}.detach();
+        InitResult = future.get();
     }
     catch(...) {
     }
@@ -2250,16 +2172,7 @@ bool WasapiBackendFactory::querySupport(BackendType type)
 
 std::string WasapiBackendFactory::probe(BackendType type)
 {
-    struct ProxyControl {
-        HRESULT mResult{};
-        ProxyControl() { mResult = WasapiProxy::InitThread(); }
-        ~ProxyControl() { if(SUCCEEDED(mResult)) WasapiProxy::DeinitThread(); }
-    };
-    ProxyControl proxy;
-
     std::string outnames;
-    if(FAILED(proxy.mResult))
-        return outnames;
 
     switch(type)
     {
-- 
cgit v1.2.3