-
Notifications
You must be signed in to change notification settings - Fork 1
/
dnsserver
executable file
·205 lines (173 loc) · 7.76 KB
/
dnsserver
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#!/usr/bin/env python3
import json
import threading
import time
from argparse import Namespace, ArgumentParser
from dataclasses import dataclass
from urllib import request
from dnslib.dns import RR
from dnslib.server import DNSServer, DNSLogger, DNSError
from geo import find_best_replica
from replicas import REPLICAS
from utils import get_local_ip
import sys
CDN_NAME = "cs5700cdn.example.com"
TTL = 45
class Resolver:
"""
This class represents a DNS server.
"""
CLIENT_IPS: set = set()
CLIENT_REPLICA_MAP: dict = dict()
@staticmethod
def resolve(request, handler) -> str:
"""
Resolves queries for cs5700cdn.example.com.
:param request: the query
:param handler: the http handler
:return:
"""
if request.q.qname == CDN_NAME: # if the name is known
client = handler.client_address[0] # grab the client's IP
Resolver.CLIENT_IPS.add(client) # add it to the known clients
response = request.reply()
if client in Resolver.CLIENT_REPLICA_MAP.keys(): # if the client is known
# use active measurements
replica_ip = Resolver.CLIENT_REPLICA_MAP[client]['replica']
else:
replica_ip = find_best_replica(client) # use GeoIP database
response.add_answer(
*RR.fromZone(f"{CDN_NAME}. {TTL} A {replica_ip}")) # build the DNS response to the query
return response
else: # raise an error if the name is anything other than cs5700cdn.example.com
raise DNSError(f"Not authoritative for {request.q.qname}")
@dataclass()
class DNSProxy:
"""
This class represents a proxy for the DNS server.
"""
server: DNSServer
resolver: Resolver
port: int
def run(self) -> None:
"""
Starts the DNS server.
"""
try:
self.server.start_thread() # thread responsible for keeping the DNS server running
print("DNS server running...")
measurements_threads = threading.Thread(
target=self._get_all_measurements
) # threads responsible for getting the active measurements
measurements_threads.start()
measurements_threads.join()
except KeyboardInterrupt:
self.stop()
def stop(self) -> None:
"""
Shuts down the DNS server.
"""
self.server.stop()
print("\nDNS server stopped.")
def _get_all_measurements(self) -> None:
"""
Maintains a queue/list of active measurements by spawning threads to request active measurements from each http
replica server; as long as there's data in the queue/list, assigns the best http replica server to each client.
"""
while True:
if len(self.resolver.CLIENT_IPS) == 0: # if there are no known clients
time.sleep(3) # wait 3 seconds and retry
continue
measurement_queue = list() # contains the active measurements
measurement_thread_queue = list() # contains the jobs for the threads
lock = threading.Lock()
for replica_ip in REPLICAS.keys(): # request active measurements from each http replica server
measurement_thread = threading.Thread(
target=self._request_measurement, args=(
replica_ip, measurement_queue, lock)
) # create jobs for the threads responsible for the active measurements
measurement_thread_queue.append(
measurement_thread) # add jobs to the queue
try:
measurement_thread.start()
except ConnectionRefusedError: # the http replica servers are not ready to send yet
continue
for thread in measurement_thread_queue: # wait for all threads to get the jobs done
thread.join()
if len(measurement_queue) > 0: # if there are active measurements
# calculate the best http replica server for each client
self._set_best_replicas(measurement_queue)
# wait 12 seconds until requesting the next round of active measurements
time.sleep(12)
def _request_measurement(self, replica_ip: str, measurement_queue: list, lock: threading.Lock) -> json:
"""
Make a POST request to the http replica servers with the client IP addresses and, subsequently, makes a GET
request to the same servers for active measurements and adds them to a queue/list.
:param replica_ip: the http replica server's IP address
:param measurement_queue: an empty list that will contain the active measurements
:param lock: the threads lock to avoid a race condition
:return: a json object containing the active measurements of RTT and CPU usage
"""
req = request.Request(
"http://" + replica_ip + ":" + str(self.port) + "/measure"
) # format request to send
req.add_header("Content-Type", "application/json; charset=utf-8")
# list containing the client's IP addresses
clients_list = list(self.resolver.CLIENT_IPS)
body = json.dumps(clients_list) # build the POST request body
body_bytes = body.encode('utf-8') # pack the POST request
res = request.urlopen(req, body_bytes) # send the POST request
data = res.read() # read the GET request
json_data = json.loads(data.decode('utf-8')) # unpack the GET request
# extract the http replica server's active measurements
json_data["replica"] = replica_ip
with lock: # add active measurements to queue avoiding race conditions
measurement_queue.append(json_data)
@staticmethod
def _set_best_replicas(replica_measurements: list) -> None:
"""
Assigns the best http replica server to each client.
:param replica_measurements: a list containing the active measurements for each http replica server
"""
client_replica_map = {}
for replica in replica_measurements:
for client_key in replica['rtts']:
# if rtt is the fallback max size, don't consider it for the map, and skip.
if int(replica['rtts'][client_key]) == sys.maxsize:
continue
best_replica = { # map http replica server to RTT
'replica': replica['replica'],
'rtt': int(replica['rtts'][client_key])
}
client_replica = client_replica_map\
.setdefault(client_key, best_replica) # insert client's IP address as key in client_replica_map
# assign best http replica server
if int(replica['rtts'][client_key]) < int(client_replica['rtt']):
client_replica_map[client_key] = best_replica
# update the DNS server's dictionary
Resolver.CLIENT_REPLICA_MAP = client_replica_map
def parse_args() -> Namespace:
"""
Parses the command line arguments.
:return: the command line arguments
"""
parser = ArgumentParser()
parser.add_argument(
"-p", type=int, help="the port number the server will bind to")
parser.add_argument(
"-n", type=str, help="the CDN-specific name that the server translates to an IP.")
args = parser.parse_args()
return args
def main():
global CDN_NAME
args = parse_args()
port = args.p
CDN_NAME = args.n
resolver = Resolver()
logger = DNSLogger(prefix=False)
server = DNSServer(resolver, logger=logger,
port=port, address=get_local_ip())
proxy = DNSProxy(server, resolver, port)
proxy.run()
if __name__ == "__main__":
main()