diff --git a/DnsServerCore/Dns/Applications/DnsApplicationAssemblyLoadContext.cs b/DnsServerCore/Dns/Applications/DnsApplicationAssemblyLoadContext.cs index 42d8b4be..718632a6 100644 --- a/DnsServerCore/Dns/Applications/DnsApplicationAssemblyLoadContext.cs +++ b/DnsServerCore/Dns/Applications/DnsApplicationAssemblyLoadContext.cs @@ -1,6 +1,6 @@ /* Technitium DNS Server -Copyright (C) 2022 Shreyas Zare (shreyas@technitium.com) +Copyright (C) 2024 Shreyas Zare (shreyas@technitium.com) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -17,6 +17,7 @@ along with this program. If not, see . */ +using DnsServerCore.ApplicationCommon; using System; using System.Collections.Generic; using System.IO; @@ -30,7 +31,12 @@ namespace DnsServerCore.Dns { #region variables - readonly string _applicationFolder; + readonly static Type _dnsApplicationInterface = typeof(IDnsApplication); + + readonly IDnsServer _dnsServer; + + readonly List _appAssemblies; + readonly AssemblyDependencyResolver _dependencyResolver; readonly Dictionary _loadedUnmanagedDlls = new Dictionary(); readonly List _unmanagedDllTempPaths = new List(); @@ -39,10 +45,10 @@ namespace DnsServerCore.Dns #region constructor - public DnsApplicationAssemblyLoadContext(string applicationFolder) + public DnsApplicationAssemblyLoadContext(IDnsServer dnsServer) : base(true) { - _applicationFolder = applicationFolder; + _dnsServer = dnsServer; Unloading += delegate (AssemblyLoadContext obj) { @@ -56,6 +62,95 @@ namespace DnsServerCore.Dns { } } }; + + _appAssemblies = new List(); + + IEnumerable loadedAssemblies = Default.Assemblies; + + foreach (string dllFile in Directory.GetFiles(_dnsServer.ApplicationFolder, "*.dll", SearchOption.TopDirectoryOnly)) + { + string dllFileName = Path.GetFileNameWithoutExtension(dllFile); + + bool isLoaded = false; + + foreach (Assembly loadedAssembly in loadedAssemblies) + { + if (!string.IsNullOrEmpty(loadedAssembly.Location)) + { + if (Path.GetFileNameWithoutExtension(loadedAssembly.Location).Equals(dllFileName, StringComparison.OrdinalIgnoreCase)) + { + isLoaded = true; + break; + } + } + else + { + AssemblyName assemblyName = loadedAssembly.GetName(); + + if ((assemblyName.Name != null) && assemblyName.Name.Equals(dllFileName, StringComparison.OrdinalIgnoreCase)) + { + isLoaded = true; + break; + } + } + } + + if (isLoaded) + continue; + + try + { + Assembly appAssembly; + string pdbFile = Path.Combine(_dnsServer.ApplicationFolder, Path.GetFileNameWithoutExtension(dllFile) + ".pdb"); + + if (File.Exists(pdbFile)) + { + using (FileStream dllStream = new FileStream(dllFile, FileMode.Open, FileAccess.Read)) + { + using (FileStream pdbStream = new FileStream(pdbFile, FileMode.Open, FileAccess.Read)) + { + appAssembly = LoadFromStream(dllStream, pdbStream); + } + } + } + else + { + using (FileStream dllStream = new FileStream(dllFile, FileMode.Open, FileAccess.Read)) + { + appAssembly = LoadFromStream(dllStream); + } + } + + if (_dependencyResolver is null) + { + bool isMainAssembly = false; + + foreach (Type classType in appAssembly.ExportedTypes) + { + foreach (Type interfaceType in classType.GetInterfaces()) + { + if (interfaceType == _dnsApplicationInterface) + { + isMainAssembly = true; + break; + } + } + + if (isMainAssembly) + break; + } + + if (isMainAssembly) + _dependencyResolver = new AssemblyDependencyResolver(dllFile); + } + + _appAssemblies.Add(appAssembly); + } + catch (Exception ex) + { + _dnsServer.WriteLog(ex); + } + } } #endregion @@ -71,59 +166,69 @@ namespace DnsServerCore.Dns { string unmanagedDllPath = null; - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + if (_dependencyResolver is not null) { - string runtime = "win-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower(); - string[] prefixes = new string[] { "" }; - string[] extensions = new string[] { ".dll" }; - - unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtime, prefixes, extensions); - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - bool isAlpine = false; - - try - { - string osReleaseFile = "/etc/os-release"; - - if (File.Exists(osReleaseFile)) - isAlpine = File.ReadAllText(osReleaseFile).Contains("alpine", StringComparison.OrdinalIgnoreCase); - } - catch - { } - - string runtimeAlpine = "alpine-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower(); - string runtimeLinux = "linux-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower(); - string[] prefixes = new string[] { "", "lib" }; - string[] extensions = new string[] { ".so", ".so.1" }; - - if (isAlpine) - { - unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeAlpine, prefixes, extensions); - if (unmanagedDllPath is null) - unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeLinux, prefixes, extensions); - } - else - { - unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeLinux, prefixes, extensions); - } - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - string runtime = "osx-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower(); - string[] prefixes = new string[] { "", "lib" }; - string[] extensions = new string[] { ".dylib" }; - - unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtime, prefixes, extensions); + string resolvedPath = _dependencyResolver.ResolveUnmanagedDllToPath(unmanagedDllName); + if (!string.IsNullOrEmpty(resolvedPath) && File.Exists(resolvedPath)) + unmanagedDllPath = resolvedPath; } if (unmanagedDllPath is null) - return IntPtr.Zero; + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + string runtime = "win-" + RuntimeInformation.ProcessArchitecture.ToString().ToLowerInvariant(); + string[] prefixes = new string[] { "" }; + string[] extensions = new string[] { ".dll" }; + + unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtime, prefixes, extensions); + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + bool isAlpine = false; + + try + { + string osReleaseFile = "/etc/os-release"; + + if (File.Exists(osReleaseFile)) + isAlpine = File.ReadAllText(osReleaseFile).Contains("alpine", StringComparison.OrdinalIgnoreCase); + } + catch + { } + + string runtimeAlpine = "linux-musl-" + RuntimeInformation.ProcessArchitecture.ToString().ToLowerInvariant(); + string runtimeLinux = "linux-" + RuntimeInformation.ProcessArchitecture.ToString().ToLowerInvariant(); + string[] prefixes = new string[] { "", "lib" }; + string[] extensions = new string[] { ".so", ".so.1" }; + + if (isAlpine) + { + unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeAlpine, prefixes, extensions); + if (unmanagedDllPath is null) + unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeLinux, prefixes, extensions); + } + else + { + unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeLinux, prefixes, extensions); + } + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + string runtime = "osx-" + RuntimeInformation.ProcessArchitecture.ToString().ToLowerInvariant(); + string[] prefixes = new string[] { "", "lib" }; + string[] extensions = new string[] { ".dylib" }; + + unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtime, prefixes, extensions); + } + + if (unmanagedDllPath is null) + return IntPtr.Zero; + } lock (_loadedUnmanagedDlls) { - if (!_loadedUnmanagedDlls.TryGetValue(unmanagedDllPath.ToLower(), out IntPtr value)) + if (!_loadedUnmanagedDlls.TryGetValue(unmanagedDllPath.ToLowerInvariant(), out IntPtr value)) { //load the unmanaged DLL //copy unmanaged dll into temp file for loading to allow uninstalling/updating app at runtime. @@ -140,7 +245,7 @@ namespace DnsServerCore.Dns _unmanagedDllTempPaths.Add(tempPath); value = LoadUnmanagedDllFromPath(tempPath); - _loadedUnmanagedDlls.Add(unmanagedDllPath.ToLower(), value); + _loadedUnmanagedDlls.Add(unmanagedDllPath.ToLowerInvariant(), value); } return value; @@ -157,7 +262,7 @@ namespace DnsServerCore.Dns { foreach (string extension in extensions) { - string path = Path.Combine(_applicationFolder, "runtimes", runtime, "native", prefix + unmanagedDllName + extension); + string path = Path.Combine(_dnsServer.ApplicationFolder, "runtimes", runtime, "native", prefix + unmanagedDllName + extension); if (File.Exists(path)) return path; } @@ -167,5 +272,12 @@ namespace DnsServerCore.Dns } #endregion + + #region properties + + public IReadOnlyList AppAssemblies + { get { return _appAssemblies; } } + + #endregion } }