diff --git a/DnsServerCore/Dns/Applications/DnsApplicationAssemblyLoadContext.cs b/DnsServerCore/Dns/Applications/DnsApplicationAssemblyLoadContext.cs
index 42d8b4be..718632a6 100644
--- a/DnsServerCore/Dns/Applications/DnsApplicationAssemblyLoadContext.cs
+++ b/DnsServerCore/Dns/Applications/DnsApplicationAssemblyLoadContext.cs
@@ -1,6 +1,6 @@
/*
Technitium DNS Server
-Copyright (C) 2022 Shreyas Zare (shreyas@technitium.com)
+Copyright (C) 2024 Shreyas Zare (shreyas@technitium.com)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
@@ -17,6 +17,7 @@ along with this program. If not, see .
*/
+using DnsServerCore.ApplicationCommon;
using System;
using System.Collections.Generic;
using System.IO;
@@ -30,7 +31,12 @@ namespace DnsServerCore.Dns
{
#region variables
- readonly string _applicationFolder;
+ readonly static Type _dnsApplicationInterface = typeof(IDnsApplication);
+
+ readonly IDnsServer _dnsServer;
+
+ readonly List _appAssemblies;
+ readonly AssemblyDependencyResolver _dependencyResolver;
readonly Dictionary _loadedUnmanagedDlls = new Dictionary();
readonly List _unmanagedDllTempPaths = new List();
@@ -39,10 +45,10 @@ namespace DnsServerCore.Dns
#region constructor
- public DnsApplicationAssemblyLoadContext(string applicationFolder)
+ public DnsApplicationAssemblyLoadContext(IDnsServer dnsServer)
: base(true)
{
- _applicationFolder = applicationFolder;
+ _dnsServer = dnsServer;
Unloading += delegate (AssemblyLoadContext obj)
{
@@ -56,6 +62,95 @@ namespace DnsServerCore.Dns
{ }
}
};
+
+ _appAssemblies = new List();
+
+ IEnumerable loadedAssemblies = Default.Assemblies;
+
+ foreach (string dllFile in Directory.GetFiles(_dnsServer.ApplicationFolder, "*.dll", SearchOption.TopDirectoryOnly))
+ {
+ string dllFileName = Path.GetFileNameWithoutExtension(dllFile);
+
+ bool isLoaded = false;
+
+ foreach (Assembly loadedAssembly in loadedAssemblies)
+ {
+ if (!string.IsNullOrEmpty(loadedAssembly.Location))
+ {
+ if (Path.GetFileNameWithoutExtension(loadedAssembly.Location).Equals(dllFileName, StringComparison.OrdinalIgnoreCase))
+ {
+ isLoaded = true;
+ break;
+ }
+ }
+ else
+ {
+ AssemblyName assemblyName = loadedAssembly.GetName();
+
+ if ((assemblyName.Name != null) && assemblyName.Name.Equals(dllFileName, StringComparison.OrdinalIgnoreCase))
+ {
+ isLoaded = true;
+ break;
+ }
+ }
+ }
+
+ if (isLoaded)
+ continue;
+
+ try
+ {
+ Assembly appAssembly;
+ string pdbFile = Path.Combine(_dnsServer.ApplicationFolder, Path.GetFileNameWithoutExtension(dllFile) + ".pdb");
+
+ if (File.Exists(pdbFile))
+ {
+ using (FileStream dllStream = new FileStream(dllFile, FileMode.Open, FileAccess.Read))
+ {
+ using (FileStream pdbStream = new FileStream(pdbFile, FileMode.Open, FileAccess.Read))
+ {
+ appAssembly = LoadFromStream(dllStream, pdbStream);
+ }
+ }
+ }
+ else
+ {
+ using (FileStream dllStream = new FileStream(dllFile, FileMode.Open, FileAccess.Read))
+ {
+ appAssembly = LoadFromStream(dllStream);
+ }
+ }
+
+ if (_dependencyResolver is null)
+ {
+ bool isMainAssembly = false;
+
+ foreach (Type classType in appAssembly.ExportedTypes)
+ {
+ foreach (Type interfaceType in classType.GetInterfaces())
+ {
+ if (interfaceType == _dnsApplicationInterface)
+ {
+ isMainAssembly = true;
+ break;
+ }
+ }
+
+ if (isMainAssembly)
+ break;
+ }
+
+ if (isMainAssembly)
+ _dependencyResolver = new AssemblyDependencyResolver(dllFile);
+ }
+
+ _appAssemblies.Add(appAssembly);
+ }
+ catch (Exception ex)
+ {
+ _dnsServer.WriteLog(ex);
+ }
+ }
}
#endregion
@@ -71,59 +166,69 @@ namespace DnsServerCore.Dns
{
string unmanagedDllPath = null;
- if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ if (_dependencyResolver is not null)
{
- string runtime = "win-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower();
- string[] prefixes = new string[] { "" };
- string[] extensions = new string[] { ".dll" };
-
- unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtime, prefixes, extensions);
- }
- else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
- {
- bool isAlpine = false;
-
- try
- {
- string osReleaseFile = "/etc/os-release";
-
- if (File.Exists(osReleaseFile))
- isAlpine = File.ReadAllText(osReleaseFile).Contains("alpine", StringComparison.OrdinalIgnoreCase);
- }
- catch
- { }
-
- string runtimeAlpine = "alpine-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower();
- string runtimeLinux = "linux-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower();
- string[] prefixes = new string[] { "", "lib" };
- string[] extensions = new string[] { ".so", ".so.1" };
-
- if (isAlpine)
- {
- unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeAlpine, prefixes, extensions);
- if (unmanagedDllPath is null)
- unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeLinux, prefixes, extensions);
- }
- else
- {
- unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeLinux, prefixes, extensions);
- }
- }
- else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
- {
- string runtime = "osx-" + RuntimeInformation.ProcessArchitecture.ToString().ToLower();
- string[] prefixes = new string[] { "", "lib" };
- string[] extensions = new string[] { ".dylib" };
-
- unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtime, prefixes, extensions);
+ string resolvedPath = _dependencyResolver.ResolveUnmanagedDllToPath(unmanagedDllName);
+ if (!string.IsNullOrEmpty(resolvedPath) && File.Exists(resolvedPath))
+ unmanagedDllPath = resolvedPath;
}
if (unmanagedDllPath is null)
- return IntPtr.Zero;
+ {
+ if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ {
+ string runtime = "win-" + RuntimeInformation.ProcessArchitecture.ToString().ToLowerInvariant();
+ string[] prefixes = new string[] { "" };
+ string[] extensions = new string[] { ".dll" };
+
+ unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtime, prefixes, extensions);
+ }
+ else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
+ {
+ bool isAlpine = false;
+
+ try
+ {
+ string osReleaseFile = "/etc/os-release";
+
+ if (File.Exists(osReleaseFile))
+ isAlpine = File.ReadAllText(osReleaseFile).Contains("alpine", StringComparison.OrdinalIgnoreCase);
+ }
+ catch
+ { }
+
+ string runtimeAlpine = "linux-musl-" + RuntimeInformation.ProcessArchitecture.ToString().ToLowerInvariant();
+ string runtimeLinux = "linux-" + RuntimeInformation.ProcessArchitecture.ToString().ToLowerInvariant();
+ string[] prefixes = new string[] { "", "lib" };
+ string[] extensions = new string[] { ".so", ".so.1" };
+
+ if (isAlpine)
+ {
+ unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeAlpine, prefixes, extensions);
+ if (unmanagedDllPath is null)
+ unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeLinux, prefixes, extensions);
+ }
+ else
+ {
+ unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtimeLinux, prefixes, extensions);
+ }
+ }
+ else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
+ {
+ string runtime = "osx-" + RuntimeInformation.ProcessArchitecture.ToString().ToLowerInvariant();
+ string[] prefixes = new string[] { "", "lib" };
+ string[] extensions = new string[] { ".dylib" };
+
+ unmanagedDllPath = FindUnmanagedDllPath(unmanagedDllName, runtime, prefixes, extensions);
+ }
+
+ if (unmanagedDllPath is null)
+ return IntPtr.Zero;
+ }
lock (_loadedUnmanagedDlls)
{
- if (!_loadedUnmanagedDlls.TryGetValue(unmanagedDllPath.ToLower(), out IntPtr value))
+ if (!_loadedUnmanagedDlls.TryGetValue(unmanagedDllPath.ToLowerInvariant(), out IntPtr value))
{
//load the unmanaged DLL
//copy unmanaged dll into temp file for loading to allow uninstalling/updating app at runtime.
@@ -140,7 +245,7 @@ namespace DnsServerCore.Dns
_unmanagedDllTempPaths.Add(tempPath);
value = LoadUnmanagedDllFromPath(tempPath);
- _loadedUnmanagedDlls.Add(unmanagedDllPath.ToLower(), value);
+ _loadedUnmanagedDlls.Add(unmanagedDllPath.ToLowerInvariant(), value);
}
return value;
@@ -157,7 +262,7 @@ namespace DnsServerCore.Dns
{
foreach (string extension in extensions)
{
- string path = Path.Combine(_applicationFolder, "runtimes", runtime, "native", prefix + unmanagedDllName + extension);
+ string path = Path.Combine(_dnsServer.ApplicationFolder, "runtimes", runtime, "native", prefix + unmanagedDllName + extension);
if (File.Exists(path))
return path;
}
@@ -167,5 +272,12 @@ namespace DnsServerCore.Dns
}
#endregion
+
+ #region properties
+
+ public IReadOnlyList AppAssemblies
+ { get { return _appAssemblies; } }
+
+ #endregion
}
}