pyopenssl icon indicating copy to clipboard operation
pyopenssl copied to clipboard

Issue with DTLS: Unable to Achieve Handshake Between Client and Serve

Open hamma96 opened this issue 1 year ago • 1 comments

I have been working for a week on creating a DTLS (Datagram Transport Layer Security) client-server setup, but I am consistently failing to achieve a successful handshake. Despite multiple attempts and configurations, the handshake process does not complete as expected.

`import socket import logging from OpenSSL import SSL from openssl_psk import patch_context import time import threading import hashlib

patch_context()

logging.basicConfig(level=logging.INFO)

def psk_client_callback(connection, hint): logging.info(f"[TLSClient] PSK client callback called with hint: {hint}") identity = b'client-identity' key = b'1a2b3c4d5e6f' logging.info(f"[TLSClient] Returning identity: {identity}, key: {key}") return (identity, key)

class TLSClient: def init(self, config): self.context = SSL.Context(SSL.DTLS_METHOD) self.context.set_cipher_list(b'PSK-AES256-CBC-SHA') self.context.set_psk_client_callback(psk_client_callback) self.context.set_options(SSL.OP_NO_RENEGOTIATION) self.context.set_info_callback(lambda conn, where, ret: print(f"[TLSClient] Info: where={where}, ret={ret}, state={conn.get_state_string()}")) self.client_socket = None self.config = config self.ssl_conn = None self.callback_running = False self._running = False

def log_handshake_progress(self, conn):
    state = conn.get_state_string()
    pending = conn.pending()
    cipher_name = conn.get_cipher_name()
    version = conn.get_protocol_version_name()
    logging.info(f"[TLSClient] Handshake state: {state}, Pending: {pending}, Cipher: {cipher_name}, Version: {version}")


def start_client(self):
    try:
        self._running = True
        self.client_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        #self.client_socket.setblocking(False)
        self.client_socket.connect(self.config['address'])
        
        self.context.set_timeout(30)
        self.ssl_conn = SSL.Connection(self.context, self.client_socket)            
        self.ssl_conn.set_connect_state()

        logging.info("[TLSClient] Starting DTLS handshake...")
        while self._running:
            try:   

                self.log_handshake_progress(self.ssl_conn)
                self.ssl_conn.do_handshake()
            
            except SSL.WantReadError:
                self.log_handshake_progress(self.ssl_conn)
                pass
            else:
                logging.info("[TLSClient] else handshake.")
                self._running = False
            
        self.log_handshake_progress(self.ssl_conn)
        logging.info("[TLSClient] DTLS handshake completed.")

        # Send a message to the server
        message = b"Hello from Client!"
        self.ssl_conn.send(message)
        logging.info(f"[TLSClient] Sent to server: {message}")

        # Receive a response from the server
        data = self.ssl_conn.recv(self.config['buffer_size'])
        logging.info(f"[TLSClient] Received from server: {data.decode()}")

        self.ssl_conn.shutdown()
        self.ssl_conn.close()

    except SSL.Error as e:
        logging.error(f"[TLSClient] SSL error: {e}")
    except Exception as e:
        logging.error(f"[TLSClient] Error: {e}")
    finally:
        self.callback_running = False  # Stop callback thread
        if self.client_socket:
            self.client_socket.close()
        logging.info("[TLSClient] Client stopped")

def psk_server_callback(connection, identity): logging.info(f"[TLSServer] PSK server callback called with identity: {identity}") if identity == b'client-identity': key = b'1a2b3c4d5e6f' logging.info(f"[TLSServer] Returning key: {key}") return key return None

class TLSServer: def init(self, config): self.context = SSL.Context(SSL.DTLS_METHOD) self.context.set_cipher_list(b'PSK-AES256-CBC-SHA') self.context.set_psk_server_callback(psk_server_callback) self.context.set_options(SSL.OP_NO_QUERY_MTU) self.context.set_info_callback(lambda conn, where, ret: print(f"[TLSServer] Info: where={where}, ret={ret}, state={conn.get_state_string()}")) # Setup cookie generation and verification self.context.set_cookie_generate_callback(self.generate_cookie) self.context.set_cookie_verify_callback(self.verify_cookie)

    self.server_socket = None
    self._running = False
    self.config = config
    self.ssl_conn = None

def generate_cookie(self, ssl):
        logging.info("[TLSServer] generate_cookie")
        return b"xyzzy"

def verify_cookie(self, ssl, cookie):
        logging.info("[TLSServer] verify_cookie")
        return cookie == b"xyzzy"

def log_handshake_progress(self, conn: SSL.Connection):
    state = conn.get_state_string()
    pending = conn.pending()
    cipher_name = conn.get_cipher_name()
    version = conn.get_protocol_version_name()
    logging.info(f"[TLSServer] Handshake state: {state}, Pending: {pending}, Cipher: {cipher_name}, Version: {version}")

def start_server(self):
    try:
        self._running = True
        self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        #self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        #self.server_socket.setblocking(False)
        self.server_socket.bind(self.config['address'])

        logging.info("[TLSServer] Server is running and waiting for connections...")
        s_handshaking = False
        self.context.set_timeout(30)
        s_listening = True

        import select
        while self._running:
            try:
                #ready_sockets, _, _ = select.select([self.server_socket], [], [])
                #sock = self.server_socket
                #for sock in ready_sockets:
                data, addr = self.server_socket.recvfrom(self.config['buffer_size'])
                ssl_conn = SSL.Connection(self.context, self.server_socket)          
                ssl_conn.set_accept_state()
                ssl_conn.set_tlsext_host_name(self.config['address'][0].encode())
                ssl_conn.set_ciphertext_mtu(1500)
                #self.invoke_client_callback(data, addr)
                self.log_handshake_progress(ssl_conn)
                if len(data) > 0 and data[0] == 22 and data[13] == 1:
                    logging.info("[TLSServer] Received ClientHello from client")
                    logging.info(f"[TLSServer] Received initial data from {addr}: {data}")
                    if s_listening:
                        try:
                            ssl_conn.DTLSv1_listen()
                            logging.info("[TLSServer] After DTLSv1_listen")
                        except SSL.WantReadError:
                            logging.info("[TLSServer] WantReadError during DTLSv1_listen")
                            continue
                        else:
                            s_listening = False
                            s_handshaking = True
                            logging.info("[TLSServer] s_listening=False")
                            ssl_conn.bio_write(data)
                        

                    logging.info(f"[TLSServer] Starting DTLS handshake with {addr}...")
                    while s_handshaking:
                        try:
                            self.log_handshake_progress(ssl_conn)
                            ssl_conn.do_handshake()
                            break
                        except SSL.WantReadError:
                            self.log_handshake_progress(ssl_conn)
                            self._running = False
                            s_handshaking = False
                            pass
                        except SSL.Error as e:
                            logging.error(f"[TLSServer] SSL error occurred during handshake: {e}")
                            self.log_handshake_progress(ssl_conn)
                            self._running = False
                            s_handshaking = False
                            break
                    self.log_handshake_progress(ssl_conn)
                    logging.info(f"[TLSServer] DTLS handshake with {addr} completed.")


            except SSL.Error as e:
                logging.error(f"[TLSServer] SSL error occurred: {e}")
            except Exception as e:
                logging.error(f"[TLSServer] An error occurred: {e}")

    except socket.error as e:
        logging.error(f"[TLSServer] Socket error: {e}")
    finally:
        self.cleanup()



def cleanup(self):
    self._running = False
    if self.server_socket:
        self.server_socket.close()
    logging.info("[TLSServer] Server cleaned up and stopped.")

if name == "main": server_config = { 'address': ('localhost', 4433), 'buffer_size': 4096 }

client_config = {
    'address': ('localhost', 4433),
    'buffer_size': 4096
}

server = TLSServer(server_config)
server_thread = threading.Thread(target=server.start_server)
server_thread.start()

time.sleep(1)

""" client = TLSClient(client_config)
client_thread = threading.Thread(target=client.start_client)
client_thread.start() """

time.sleep(120)

server._running = False
#client_thread.join()
server_thread.join()`

hamma96 avatar Jul 24 '24 10:07 hamma96

I have at this for alomst a week trying to understand why I canget this working with pyopenssl.

Here is the thing, server side code works when I use openSSL setting up the tunnel/negotiating via command line but using pyopenssl and others seems to not be working.

Its a weird mystery why DTLS 1.2 is not being talked about more and why its hard to find documentation on this protocol we need help

Gamechiefx avatar Jan 10 '25 06:01 Gamechiefx