From 7ba844763cd16bd5bba7173a02c54b6b92e100ea Mon Sep 17 00:00:00 2001 From: Shreyas Zare Date: Sat, 14 Sep 2024 16:34:43 +0530 Subject: [PATCH] AuthManager: updated brute force protection code to use network instead of IP address so as to protect from ipv6 source. Code refactoring done. --- DnsServerCore/Auth/AuthManager.cs | 86 ++++++++++++++++++------------- 1 file changed, 51 insertions(+), 35 deletions(-) diff --git a/DnsServerCore/Auth/AuthManager.cs b/DnsServerCore/Auth/AuthManager.cs index 55ffad2d..6215c017 100644 --- a/DnsServerCore/Auth/AuthManager.cs +++ b/DnsServerCore/Auth/AuthManager.cs @@ -22,9 +22,11 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Net; +using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; +using TechnitiumLibrary.Net; namespace DnsServerCore.Auth { @@ -39,11 +41,11 @@ namespace DnsServerCore.Auth readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(1, 10); - readonly ConcurrentDictionary _failedLoginAttempts = new ConcurrentDictionary(1, 10); + readonly ConcurrentDictionary _failedLoginAttemptNetworks = new ConcurrentDictionary(1, 10); const int MAX_LOGIN_ATTEMPTS = 5; - readonly ConcurrentDictionary _blockedAddresses = new ConcurrentDictionary(1, 10); - const int BLOCK_ADDRESS_INTERVAL = 5 * 60 * 1000; + readonly ConcurrentDictionary _blockedNetworks = new ConcurrentDictionary(1, 10); + const int BLOCK_NETWORK_INTERVAL = 5 * 60 * 1000; readonly string _configFolder; readonly LogManager _log; @@ -324,7 +326,7 @@ namespace DnsServerCore.Auth for (int i = 0; i < count; i++) { Group group = new Group(bR); - _groups.TryAdd(group.Name.ToLower(), group); + _groups.TryAdd(group.Name.ToLowerInvariant(), group); } } @@ -398,35 +400,50 @@ namespace DnsServerCore.Auth session.WriteTo(bW); } - private void FailedLoginAttempt(IPAddress address) + private static IPAddress GetClientNetwork(IPAddress address) { - _failedLoginAttempts.AddOrUpdate(address, 1, delegate (IPAddress key, int attempts) + switch (address.AddressFamily) + { + case AddressFamily.InterNetwork: + return address.GetNetworkAddress(32); + + case AddressFamily.InterNetworkV6: + return address.GetNetworkAddress(64); + + default: + throw new InvalidOperationException(); + } + } + + private void MarkFailedLoginAttempt(IPAddress network) + { + _failedLoginAttemptNetworks.AddOrUpdate(network, 1, delegate (IPAddress key, int attempts) { return attempts + 1; }); } - private bool LoginAttemptsExceedLimit(IPAddress address, int limit) + private bool HasLoginAttemptExceedLimit(IPAddress network, int limit) { - if (!_failedLoginAttempts.TryGetValue(address, out int attempts)) + if (!_failedLoginAttemptNetworks.TryGetValue(network, out int attempts)) return false; return attempts >= limit; } - private void ResetFailedLoginAttempt(IPAddress address) + private void ResetFailedLoginAttempts(IPAddress network) { - _failedLoginAttempts.TryRemove(address, out _); + _failedLoginAttemptNetworks.TryRemove(network, out _); } - private void BlockAddress(IPAddress address, int interval) + private void BlockNetwork(IPAddress network, int interval) { - _blockedAddresses.TryAdd(address, DateTime.UtcNow.AddMilliseconds(interval)); + _blockedNetworks.TryAdd(network, DateTime.UtcNow.AddMilliseconds(interval)); } - private bool IsAddressBlocked(IPAddress address) + private bool IsNetworkBlocked(IPAddress network) { - if (!_blockedAddresses.TryGetValue(address, out DateTime expiry)) + if (!_blockedNetworks.TryGetValue(network, out DateTime expiry)) return false; if (expiry > DateTime.UtcNow) @@ -435,16 +452,16 @@ namespace DnsServerCore.Auth } else { - UnblockAddress(address); - ResetFailedLoginAttempt(address); + UnblockNetwork(network); + ResetFailedLoginAttempts(network); return false; } } - private void UnblockAddress(IPAddress address) + private void UnblockNetwork(IPAddress network) { - _blockedAddresses.TryRemove(address, out _); + _blockedNetworks.TryRemove(network, out _); } #endregion @@ -453,7 +470,7 @@ namespace DnsServerCore.Auth public User GetUser(string username) { - if (_users.TryGetValue(username.ToLower(), out User user)) + if (_users.TryGetValue(username.ToLowerInvariant(), out User user)) return user; return null; @@ -464,7 +481,7 @@ namespace DnsServerCore.Auth if (_users.Count >= byte.MaxValue) throw new DnsWebServiceException("Cannot create more than 255 users."); - username = username.ToLower(); + username = username.ToLowerInvariant(); User user = new User(displayName, username, password, iterations); @@ -496,7 +513,7 @@ namespace DnsServerCore.Auth public bool DeleteUser(string username) { - if (_users.TryRemove(username.ToLower(), out User deletedUser)) + if (_users.TryRemove(username.ToLowerInvariant(), out User deletedUser)) { //delete all sessions foreach (UserSession session in GetSessions(deletedUser)) @@ -517,7 +534,7 @@ namespace DnsServerCore.Auth public Group GetGroup(string name) { - if (_groups.TryGetValue(name.ToLower(), out Group group)) + if (_groups.TryGetValue(name.ToLowerInvariant(), out Group group)) return group; return null; @@ -557,7 +574,7 @@ namespace DnsServerCore.Auth Group group = new Group(name, description); - if (_groups.TryAdd(name.ToLower(), group)) + if (_groups.TryAdd(name.ToLowerInvariant(), group)) return group; throw new DnsWebServiceException("Group already exists: " + name); @@ -574,13 +591,13 @@ namespace DnsServerCore.Auth string oldGroupName = group.Name; group.Name = newGroupName; - if (!_groups.TryAdd(group.Name.ToLower(), group)) + if (!_groups.TryAdd(group.Name.ToLowerInvariant(), group)) { group.Name = oldGroupName; //revert throw new DnsWebServiceException("Group already exists: " + newGroupName); } - _groups.TryRemove(oldGroupName.ToLower(), out _); + _groups.TryRemove(oldGroupName.ToLowerInvariant(), out _); //update users foreach (KeyValuePair user in _users) @@ -589,7 +606,7 @@ namespace DnsServerCore.Auth public bool DeleteGroup(string name) { - name = name.ToLower(); + name = name.ToLowerInvariant(); switch (name) { @@ -643,8 +660,10 @@ namespace DnsServerCore.Auth public async Task CreateSessionAsync(UserSessionType type, string tokenName, string username, string password, IPAddress remoteAddress, string userAgent) { - if (IsAddressBlocked(remoteAddress)) - throw new DnsWebServiceException("Max limit of " + MAX_LOGIN_ATTEMPTS + " attempts exceeded. Access blocked for " + (BLOCK_ADDRESS_INTERVAL / 1000) + " seconds."); + IPAddress network = GetClientNetwork(remoteAddress); + + if (IsNetworkBlocked(network)) + throw new DnsWebServiceException("Max limit of " + MAX_LOGIN_ATTEMPTS + " attempts exceeded. Access blocked for " + (BLOCK_NETWORK_INTERVAL / 1000) + " seconds."); User user = GetUser(username); @@ -652,10 +671,10 @@ namespace DnsServerCore.Auth { if (password != "admin") { - FailedLoginAttempt(remoteAddress); + MarkFailedLoginAttempt(network); - if (LoginAttemptsExceedLimit(remoteAddress, MAX_LOGIN_ATTEMPTS)) - BlockAddress(remoteAddress, BLOCK_ADDRESS_INTERVAL); + if (HasLoginAttemptExceedLimit(network, MAX_LOGIN_ATTEMPTS)) + BlockNetwork(network, BLOCK_NETWORK_INTERVAL); await Task.Delay(1000); } @@ -663,7 +682,7 @@ namespace DnsServerCore.Auth throw new DnsWebServiceException("Invalid username or password for user: " + username); } - ResetFailedLoginAttempt(remoteAddress); + ResetFailedLoginAttempts(network); if (user.Disabled) throw new DnsWebServiceException("User account is disabled. Please contact your administrator."); @@ -680,9 +699,6 @@ namespace DnsServerCore.Auth public UserSession CreateApiToken(string tokenName, string username, IPAddress remoteAddress, string userAgent) { - if (IsAddressBlocked(remoteAddress)) - throw new DnsWebServiceException("Max limit of " + MAX_LOGIN_ATTEMPTS + " attempts exceeded. Access blocked for " + (BLOCK_ADDRESS_INTERVAL / 1000) + " seconds."); - User user = GetUser(username); if (user is null) throw new DnsWebServiceException("No such user exists: " + username);