diff --git a/src/ipc-uapi-windows.h b/src/ipc-uapi-windows.h index 1aa08c4..4d362d0 100644 --- a/src/ipc-uapi-windows.h +++ b/src/ipc-uapi-windows.h @@ -10,6 +10,7 @@ #include #include #include +#include static FILE *userspace_interface_file(const char *iface) { @@ -113,6 +114,9 @@ static FILE *userspace_interface_file(const char *iface) return NULL; } +static bool have_cached_interfaces; +static struct hashtable cached_interfaces; + static bool userspace_has_wireguard_interface(const char *iface) { char fname[MAX_PATH]; @@ -120,10 +124,13 @@ static bool userspace_has_wireguard_interface(const char *iface) HANDLE find_handle; bool ret = false; + if (have_cached_interfaces) + return hashtable_find_entry(&cached_interfaces, iface) != NULL; + snprintf(fname, sizeof(fname), "ProtectedPrefix\\Administrators\\WireGuard\\%s", iface); find_handle = FindFirstFile("\\\\.\\pipe\\*", &find_data); if (find_handle == INVALID_HANDLE_VALUE) - return -GetLastError(); + return -EIO; do { if (!strcmp(fname, find_data.cFileName)) { ret = true; @@ -139,18 +146,25 @@ static int userspace_get_wireguard_interfaces(struct string_list *list) static const char prefix[] = "ProtectedPrefix\\Administrators\\WireGuard\\"; WIN32_FIND_DATA find_data; HANDLE find_handle; + char *iface; int ret = 0; find_handle = FindFirstFile("\\\\.\\pipe\\*", &find_data); if (find_handle == INVALID_HANDLE_VALUE) - return -GetLastError(); + return -EIO; do { if (strncmp(prefix, find_data.cFileName, strlen(prefix))) continue; - ret = string_list_add(list, find_data.cFileName + strlen(prefix)); + iface = find_data.cFileName + strlen(prefix); + ret = string_list_add(list, iface); if (ret < 0) goto out; + if (!hashtable_find_or_insert_entry(&cached_interfaces, iface)) { + ret = -errno; + goto out; + } } while (FindNextFile(find_handle, &find_data)); + have_cached_interfaces = true; out: FindClose(find_handle); diff --git a/src/ipc-windows.h b/src/ipc-windows.h index d7a889f..b38515d 100644 --- a/src/ipc-windows.h +++ b/src/ipc-windows.h @@ -13,12 +13,17 @@ #include #include #include +#include #define IPC_SUPPORTS_KERNEL_INTERFACE +static bool have_cached_kernel_interfaces; +static struct hashtable cached_kernel_interfaces; + static int kernel_get_wireguard_interfaces(struct string_list *list) { HDEVINFO dev_info = SetupDiGetClassDevsExW(&GUID_DEVCLASS_NET, NULL, NULL, DIGCF_PRESENT, NULL, NULL, NULL); + bool will_have_cached_kernel_interfaces = true; if (dev_info == INVALID_HANDLE_VALUE) { errno = EACCES; @@ -33,6 +38,7 @@ static int kernel_get_wireguard_interfaces(struct string_list *list) HKEY key; GUID instance_id; char *interface_name; + struct hashtable_entry *entry; if (!SetupDiEnumDeviceInfo(dev_info, i, &dev_info_data)) { if (GetLastError() == ERROR_NO_MORE_ITEMS) @@ -105,7 +111,25 @@ static int kernel_get_wireguard_interfaces(struct string_list *list) } string_list_add(list, interface_name); + + entry = hashtable_find_or_insert_entry(&cached_kernel_interfaces, interface_name); free(interface_name); + if (!entry) + goto cleanup_entry; + + if (SetupDiGetDeviceInstanceIdW(dev_info, &dev_info_data, NULL, 0, &buf_len) || GetLastError() != ERROR_INSUFFICIENT_BUFFER) + goto cleanup_entry; + entry->value = calloc(sizeof(WCHAR), buf_len); + if (!entry->value) + goto cleanup_entry; + if (!SetupDiGetDeviceInstanceIdW(dev_info, &dev_info_data, entry->value, buf_len, &buf_len)) { + free(entry->value); + entry->value = NULL; + goto cleanup_entry; + } + +cleanup_entry: + will_have_cached_kernel_interfaces |= entry != NULL && entry->value != NULL; cleanup_buf: free(buf); cleanup_key: @@ -113,15 +137,48 @@ static int kernel_get_wireguard_interfaces(struct string_list *list) skip:; } SetupDiDestroyDeviceInfoList(dev_info); + have_cached_kernel_interfaces = will_have_cached_kernel_interfaces; return 0; } static HANDLE kernel_interface_handle(const char *iface) { - HDEVINFO dev_info = SetupDiGetClassDevsExW(&GUID_DEVCLASS_NET, NULL, NULL, DIGCF_PRESENT, NULL, NULL, NULL); + HDEVINFO dev_info; WCHAR *interfaces = NULL; HANDLE handle; + if (have_cached_kernel_interfaces) { + struct hashtable_entry *entry = hashtable_find_entry(&cached_kernel_interfaces, iface); + if (entry) { + DWORD buf_len; + if (CM_Get_Device_Interface_List_SizeW( + &buf_len, (GUID *)&GUID_DEVINTERFACE_NET, (DEVINSTID_W)entry->value, + CM_GET_DEVICE_INTERFACE_LIST_PRESENT) != CR_SUCCESS) + goto err_hash; + interfaces = calloc(buf_len, sizeof(*interfaces)); + if (!interfaces) + goto err_hash; + if (CM_Get_Device_Interface_ListW( + (GUID *)&GUID_DEVINTERFACE_NET, (DEVINSTID_W)entry->value, interfaces, buf_len, + CM_GET_DEVICE_INTERFACE_LIST_PRESENT) != CR_SUCCESS || !interfaces[0]) { + free(interfaces); + interfaces = NULL; + goto err_hash; + } + handle = CreateFileW(interfaces, GENERIC_READ | GENERIC_WRITE, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, NULL, + OPEN_EXISTING, 0, NULL); + free(interfaces); + if (handle == INVALID_HANDLE_VALUE) + goto err_hash; + return handle; +err_hash: + errno = EACCES; + return NULL; + } + } + + dev_info = SetupDiGetClassDevsExW(&GUID_DEVCLASS_NET, NULL, NULL, DIGCF_PRESENT, NULL, NULL, NULL); if (dev_info == INVALID_HANDLE_VALUE) return NULL; diff --git a/src/wincompat/include/hashtable.h b/src/wincompat/include/hashtable.h new file mode 100644 index 0000000..efaa2e8 --- /dev/null +++ b/src/wincompat/include/hashtable.h @@ -0,0 +1,60 @@ +/* SPDX-License-Identifier: GPL-2.0 + * + * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved. + */ + +#ifndef _HASHTABLE_H +#define _HASHTABLE_H + +#include + +enum { HASHTABLE_ENTRY_BUCKETS_POW2 = 1 << 10 }; + +struct hashtable_entry { + char *key; + void *value; + struct hashtable_entry *next; +}; + +struct hashtable { + struct hashtable_entry *entry_buckets[HASHTABLE_ENTRY_BUCKETS_POW2]; +}; + +static unsigned int hashtable_bucket(const char *str) +{ + unsigned long hash = 5381; + char c; + while ((c = *str++)) + hash = ((hash << 5) + hash) ^ c; + return hash & (HASHTABLE_ENTRY_BUCKETS_POW2 - 1); +} + +static struct hashtable_entry *hashtable_find_entry(struct hashtable *hashtable, const char *key) +{ + struct hashtable_entry *entry; + for (entry = hashtable->entry_buckets[hashtable_bucket(key)]; entry; entry = entry->next) { + if (!strcmp(entry->key, key)) + return entry; + } + return NULL; +} + +static struct hashtable_entry *hashtable_find_or_insert_entry(struct hashtable *hashtable, const char *key) +{ + struct hashtable_entry **entry; + for (entry = &hashtable->entry_buckets[hashtable_bucket(key)]; *entry; entry = &(*entry)->next) { + if (!strcmp((*entry)->key, key)) + return *entry; + } + *entry = calloc(1, sizeof(**entry)); + if (!*entry) + return NULL; + (*entry)->key = strdup(key); + if (!(*entry)->key) { + free(*entry); + return NULL; + } + return *entry; +} + +#endif