Skip to content

Instantly share code, notes, and snippets.

@j1lecks
Created February 17, 2025 10:20
Show Gist options
  • Select an option

  • Save j1lecks/ebd7ae9800c0d153dc0294b4ae483ed3 to your computer and use it in GitHub Desktop.

Select an option

Save j1lecks/ebd7ae9800c0d153dc0294b4ae483ed3 to your computer and use it in GitHub Desktop.
Simulating connection pool unfairness
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"
[packages]
ipykernel = "*"
autopep8 = "*"
matplotlib = "*"
numpy = "*"
pandas = "*"
scikit-learn = "*"
[dev-packages]
[requires]
python_version = "3.13"
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
import matplotlib.pyplot as plt
# Simulation Parameters
simulation_time = 300 # Total simulation time in seconds
num_clients = 100 # Number of clients each with their own connection pools
num_db_nodes = 5 # Number of database replicas
dns_ttl = 15 # DNS TTL in seconds
ttl_jitter_range = 1 # ±d seconds of randomness in TTL expiry per client
connection_lifetime_jitter_range = 1 # ±d seconds of randomness in connection lifetime
num_trials = 10 # Number of trials per connection lifetime value
# Focused search in the range 30-40s
focused_connection_lifetime_values = np.arange(2, 301, 1) # Test lifetimes from 30s to 39s
# Lists to store performance metrics
std_dev_results_focused = []
max_min_ratio_results_focused = []
for lifetime in focused_connection_lifetime_values:
std_dev_trials = []
max_min_ratio_trials = []
for _ in range(num_trials):
# Time Steps
time_steps = np.arange(0, simulation_time, 1)
# Initialize request distribution tracking
db_load_distribution = {i: np.zeros_like(time_steps) for i in range(num_db_nodes)}
# Assignments and connection lifetimes
client_dns_assignments = np.arange(num_clients) % num_db_nodes # Round-robin sequence
client_ttl_expiries = np.random.randint(dns_ttl - ttl_jitter_range, dns_ttl + ttl_jitter_range + 1, num_clients)
active_connections = np.random.randint(lifetime - connection_lifetime_jitter_range,
lifetime + connection_lifetime_jitter_range + 1, num_clients)
connection_dns_assignments = client_dns_assignments.copy() # Each connection retains its assigned address
# Simulation Loop
for t_idx, t in enumerate(time_steps):
for c in range(num_clients):
if client_ttl_expiries[c] == 0:
client_dns_assignments[c] = (client_dns_assignments[c] + 1) % num_db_nodes
client_ttl_expiries[c] = np.random.randint(dns_ttl - ttl_jitter_range, dns_ttl + ttl_jitter_range + 1)
if active_connections[c] == 0:
connection_dns_assignments[c] = client_dns_assignments[c] # Assign new connection to current DNS address
active_connections[c] = np.random.randint(lifetime - connection_lifetime_jitter_range,
lifetime + connection_lifetime_jitter_range + 1)
for c in range(num_clients):
db_load_distribution[connection_dns_assignments[c]][t] += 1
client_ttl_expiries -= 1
active_connections -= 1
# Convert to DataFrame for analysis
df_sim = pd.DataFrame(db_load_distribution, index=time_steps)
# Compute performance metrics
std_dev_trials.append(df_sim.std(axis=1).mean())
max_min_ratio_trials.append((df_sim.max(axis=1) / df_sim.min(axis=1)).mean())
# Average results over trials
std_dev_results_focused.append(np.mean(std_dev_trials))
max_min_ratio_results_focused.append(np.mean(max_min_ratio_trials))
# Prepare data for regression analysis
X_focused = focused_connection_lifetime_values.reshape(-1, 1)
y_std_focused = np.array(std_dev_results_focused)
y_max_min_focused = np.array(max_min_ratio_results_focused)
# Fit quadratic regression models
reg_std_poly_focused = make_pipeline(PolynomialFeatures(degree=2), LinearRegression()).fit(X_focused, y_std_focused)
reg_max_min_poly_focused = make_pipeline(PolynomialFeatures(degree=2), LinearRegression()).fit(X_focused, y_max_min_focused)
# Generate predictions over a finer grid
fine_lifetime_values_focused = np.linspace(min(focused_connection_lifetime_values), max(focused_connection_lifetime_values), 100).reshape(-1, 1)
std_dev_predictions_focused = reg_std_poly_focused.predict(fine_lifetime_values_focused)
max_min_predictions_focused = reg_max_min_poly_focused.predict(fine_lifetime_values_focused)
# Find the optimal connection lifetime by identifying the minimum in the quadratic fit
optimal_lifetime_std_quad_focused = fine_lifetime_values_focused[np.argmin(std_dev_predictions_focused)][0]
optimal_lifetime_max_min_quad_focused = fine_lifetime_values_focused[np.argmin(max_min_predictions_focused)][0]
# Display results
print("Optimal connection lifetime minimizing standard deviation (focused quadratic):", optimal_lifetime_std_quad_focused)
print("Optimal connection lifetime minimizing max/min load ratio (focused quadratic):", optimal_lifetime_max_min_quad_focused)
# Plot results
plt.figure(figsize=(10, 5))
plt.scatter(focused_connection_lifetime_values, std_dev_results_focused, label="Standard Deviation", marker="o", color="blue")
plt.plot(fine_lifetime_values_focused, std_dev_predictions_focused, label="Quadratic Fit (StdDev)", linestyle="--", color="blue")
plt.scatter(focused_connection_lifetime_values, max_min_ratio_results_focused, label="Max/Min Load Ratio", marker="s", color="red")
plt.plot(fine_lifetime_values_focused, max_min_predictions_focused, label="Quadratic Fit (Max/Min)", linestyle="--", color="red")
plt.axvline(optimal_lifetime_std_quad_focused, linestyle="--", color="blue", label=f"Optimal (StdDev): {optimal_lifetime_std_quad_focused:.2f}s")
plt.axvline(optimal_lifetime_max_min_quad_focused, linestyle="--", color="red", label=f"Optimal (Max/Min): {optimal_lifetime_max_min_quad_focused:.2f}s")
plt.xlabel("Connection Lifetime (seconds)")
plt.ylabel("Performance Metric")
plt.title("Optimizing Connection Lifetime for Load Balancing (Focused 30-40s Range)")
plt.legend()
plt.show()
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Simulation Parameters
simulation_time = 86400 # Total simulation time in seconds
traffic_cycle_time = 86400 # 24 hours in seconds
num_clients = 100 # Number of clients each with their own connection pools
num_db_nodes = 5 # Number of database replicas
dns_ttl = 15 # DNS TTL in seconds
ttl_jitter_range = 1 # ±d seconds of randomness in TTL expiry per client
connection_lifetime = 33 # Max connection lifetime in seconds (below DNS TTL)
connection_lifetime_jitter_range = 1 # ±d seconds of randomness in connection lifetime
# Traffic Load Parameters
baseline_rps = 1500 # Minimum RPS
peak_rps = 2500 # Maximum RPS
def apply_jitter(base_value, jitter_range):
"""Applies jitter to a base value within the given range."""
return base_value + (np.random.randint(-jitter_range, jitter_range + 1) if jitter_range > 0 else 0)
def assign_dns_round_robin(client_index, num_db_nodes):
"""Assigns DNS using round-robin strategy."""
return (client_index + 1) % num_db_nodes
def assign_dns_random(client_index, num_db_nodes):
"""Assigns DNS using random selection."""
return np.random.randint(0, num_db_nodes)
# Set the DNS assignment strategy (choose one)
assign_dns = assign_dns_round_robin # Change to assign_dns_random if needed
# Time Steps
time_steps = np.arange(0, simulation_time, 1)
time_scaled = (time_steps / simulation_time) * (2 * np.pi) # Normalize to full cycle
# Generate cyclic traffic variation
traffic_variation = (np.sin(time_scaled) + 1) / 2 # Normalized sine wave from 0 to 1
traffic_rps = baseline_rps + traffic_variation * (peak_rps - baseline_rps)
# Initialize DNS Assignments
client_dns_assignments = np.arange(num_clients) % num_db_nodes # Round-robin sequence
client_ttl_expiries = np.array([apply_jitter(dns_ttl, ttl_jitter_range) for _ in range(num_clients)])
active_connections = np.array([apply_jitter(connection_lifetime, connection_lifetime_jitter_range) for _ in range(num_clients)])
connection_dns_assignments = client_dns_assignments.copy() # Each connection retains its assigned address
# Tracking request distribution over time
db_load_distribution = {i: np.zeros_like(time_steps) for i in range(num_db_nodes)}
# Simulation Loop
for t_idx, t in enumerate(time_steps):
# Adjust request rate based on cyclic traffic pattern
current_rps = traffic_rps[t_idx]
for c in range(num_clients):
# If client's TTL expired, update DNS assignment
if client_ttl_expiries[c] == 0:
client_dns_assignments[c] = assign_dns(client_dns_assignments[c], num_db_nodes)
client_ttl_expiries[c] = apply_jitter(dns_ttl, ttl_jitter_range)
# If connection lifetime expires, refresh using the current DNS assignment
if active_connections[c] == 0:
connection_dns_assignments[c] = client_dns_assignments[c] # Assign new connection to current DNS address
active_connections[c] = apply_jitter(connection_lifetime, connection_lifetime_jitter_range)
# Distribute requests based on updated DNS assignments
for c in range(num_clients):
assigned_db = connection_dns_assignments[c] # Use the DNS assignment stored in the connection
db_load_distribution[assigned_db][t] += current_rps / num_clients
# Countdown TTLs for all clients
client_ttl_expiries -= 1
# Countdown active connection lifetimes
active_connections -= 1
# Convert to DataFrame for visualization
df_correct_dns_ttl = pd.DataFrame(db_load_distribution, index=time_steps)
df_correct_dns_ttl.columns = [f"DB Node {i}" for i in range(num_db_nodes)]
# Compute standard deviation of load distribution
overall_std_dev = df_correct_dns_ttl.std(axis=1).mean()
# Compute max/min node load ratio
overall_max_min_ratio = (df_correct_dns_ttl.max(axis=1) / df_correct_dns_ttl.min(axis=1)).mean()
# Display overall performance metrics
print("Overall Standard Deviation of Load Distribution:", overall_std_dev)
print("Overall Max/Min Load Ratio:", overall_max_min_ratio)
# Plot updated load distribution over time
plt.figure(figsize=(10, 5))
for i in range(num_db_nodes):
plt.plot(time_steps, df_correct_dns_ttl[f"DB Node {i}"], label=f"DB Node {i}")
plt.xlabel("Time (seconds)")
plt.ylabel("Requests per Second")
plt.title("DNS Load Balancing with Corrected TTL & Connection Expiry Behavior")
plt.legend()
plt.show()
# # Plot standard deviation over time
# plt.figure(figsize=(10, 5))
# plt.plot(time_steps, df_correct_dns_ttl.std(axis=1), label='Standard Deviation')
# plt.xlabel("Time (seconds)")
# plt.ylabel("Standard Deviation")
# plt.title("Load Distribution Standard Deviation Over Time")
# plt.legend()
# plt.show()
# # Plot max/min load ratio over time
# plt.figure(figsize=(10, 5))
# plt.plot(time_steps, df_correct_dns_ttl.max(axis=1) / df_correct_dns_ttl.min(axis=1), label='Max/Min Load Ratio')
# plt.xlabel("Time (seconds)")
# plt.ylabel("Load Ratio")
# plt.title("Max/Min Load Ratio Over Time")
# plt.legend()
# plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment