# Copyright (C) Internet Systems Consortium, Inc. ("ISC") # # SPDX-License-Identifier: MPL-2.0 # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, you can obtain one at https://mozilla.org/MPL/2.0/. # # See the COPYRIGHT file distributed with this work for additional # information regarding copyright ownership. ############################################################################ # # This tool acts as a TCP/UDP proxy and delays all incoming packets by 500 # milliseconds. # # We use it to check pipelining - a client sents 8 questions over a # pipelined connection - that require asking a normal (examplea) and a # slow-responding (exampleb) servers: # a.examplea # a.exampleb # b.examplea # b.exampleb # c.examplea # c.exampleb # d.examplea # d.exampleb # # If pipelining works properly the answers will be returned out of order # with all answers from examplea returned first, and then all answers # from exampleb. # ############################################################################ from __future__ import print_function import datetime import os import select import signal import socket import sys import time import threading import struct DELAY = 0.5 THREADS = [] def log(msg): print(datetime.datetime.now().strftime("%d-%b-%Y %H:%M:%S.%f ") + msg) def sigterm(*_): log("SIGTERM received, shutting down") for thread in THREADS: thread.close() thread.join() os.remove("ans.pid") sys.exit(0) class TCPDelayer(threading.Thread): """For a given TCP connection conn we open a connection to (ip, port), and then we delay each incoming packet by DELAY by putting it in a queue. In the pipelined test TCP should not be used, but it's here for completnes. """ def __init__(self, conn, ip, port): threading.Thread.__init__(self) self.conn = conn self.cconn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.cconn.connect((ip, port)) self.queue = [] self.running = True def close(self): self.running = False def run(self): while self.running: curr_timeout = 0.5 try: curr_timeout = self.queue[0][0] - time.time() except StopIteration: pass if curr_timeout > 0: if curr_timeout == 0: curr_timeout = 0.5 rfds, _, _ = select.select( [self.conn, self.cconn], [], [], curr_timeout ) if self.conn in rfds: data = self.conn.recv(65535) if not data: return self.queue.append((time.time() + DELAY, data)) if self.cconn in rfds: data = self.cconn.recv(65535) if not data == 0: return self.conn.send(data) try: while self.queue[0][0] - time.time() < 0: _, data = self.queue.pop(0) self.cconn.send(data) except StopIteration: pass class UDPDelayer(threading.Thread): """Every incoming UDP packet is put in a queue for DELAY time, then it's sent to (ip, port). We remember the query id to send the response we get to a proper source, responses are not delayed. """ def __init__(self, usock, ip, port): threading.Thread.__init__(self) self.sock = usock self.csock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.dst = (ip, port) self.queue = [] self.qid_mapping = {} self.running = True def close(self): self.running = False def run(self): while self.running: curr_timeout = 0.5 if self.queue: curr_timeout = self.queue[0][0] - time.time() if curr_timeout >= 0: if curr_timeout == 0: curr_timeout = 0.5 rfds, _, _ = select.select( [self.sock, self.csock], [], [], curr_timeout ) if self.sock in rfds: data, addr = self.sock.recvfrom(65535) if not data: return self.queue.append((time.time() + DELAY, data)) qid = struct.unpack(">H", data[:2])[0] log("Received a query from %s, queryid %d" % (str(addr), qid)) self.qid_mapping[qid] = addr if self.csock in rfds: data, addr = self.csock.recvfrom(65535) if not data: return qid = struct.unpack(">H", data[:2])[0] dst = self.qid_mapping.get(qid) if dst is not None: self.sock.sendto(data, dst) log( "Received a response from %s, queryid %d, sending to %s" % (str(addr), qid, str(dst)) ) while self.queue and self.queue[0][0] - time.time() < 0: _, data = self.queue.pop(0) qid = struct.unpack(">H", data[:2])[0] log("Sending a query to %s, queryid %d" % (str(self.dst), qid)) self.csock.sendto(data, self.dst) def main(): signal.signal(signal.SIGTERM, sigterm) signal.signal(signal.SIGINT, sigterm) with open("ans.pid", "w") as pidfile: print(os.getpid(), file=pidfile) listenip = "10.53.0.5" serverip = "10.53.0.2" try: port = int(os.environ["PORT"]) except KeyError: port = 5300 log("Listening on %s:%d" % (listenip, port)) usock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) usock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) usock.bind((listenip, port)) thread = UDPDelayer(usock, serverip, port) thread.start() THREADS.append(thread) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((listenip, port)) sock.listen(1) sock.settimeout(1) while True: try: (clientsock, _) = sock.accept() log("Accepted connection from %s" % clientsock) thread = TCPDelayer(clientsock, serverip, port) thread.start() THREADS.append(thread) except socket.timeout: pass if __name__ == "__main__": main()