nuttx icon indicating copy to clipboard operation
nuttx copied to clipboard

net/pkt: add SOCK_DGRAM support

Open zhhyu7 opened this issue 3 weeks ago • 1 comments

Summary

According to the definitions of PF_PACKET and SOCK_DGRAM, extend the current protocol stack pkt protocol to support SOCK_DGRAM mode.

Some third-party network libraries use AF_PACKET, SOCK_DGRAM type sockets to construct packets and send/receive data, This patch can add support for this.

Impact

The pkt socket adds support for sending and receiving SOCK_DGRAM.

Testing

sim:matter with Ping service implemented by pkt socket and SOCK_DGRAM.

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <signal.h>
#include <time.h>
#include <errno.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <sys/poll.h>
#include <arpa/inet.h>
#include <net/ethernet.h>
#include <net/if.h>
#include <netinet/ip.h>
#include <nuttx/net/icmp.h>
#include <netpacket/packet.h>
#include <net/if_arp.h>
#include <netinet/if_ether.h>
#include <sys/time.h>

#define MAX_PACKET_SIZE 65536
#define ICMP_ECHO_REQUEST 8
#define ICMP_ECHO_REPLY 0
#define DEFAULT_TIMEOUT 2000  // 2 seconds timeout
#define DEFAULT_COUNT 5       // send 5 packets

// global variable for signal handling
static volatile int running = 1;

// calculate ICMP checksum
unsigned short calculate_checksum(unsigned short *addr, int len) {
    unsigned long sum = 0;
    unsigned short answer = 0;
    
    while (len > 1) {
        sum += *addr++;
        len -= 2;
    }
    
    if (len == 1) {
        *(unsigned char *)&answer = *(unsigned char *)addr;
        sum += answer;
    }
    
    sum = (sum >> 16) + (sum & 0xFFFF);
    sum += (sum >> 16);
    answer = ~sum;
    
    return answer;
}

// get IP address of network interface
int get_ip_address(const char *ifname, struct in_addr *ip_addr) {
    int fd;
    struct ifreq ifr;
    
    fd = socket(AF_INET, SOCK_DGRAM, 0);
    if (fd < 0) {
        perror("socket");
        return -1;
    }
    
    strcpy(ifr.ifr_name, ifname);
    if (ioctl(fd, SIOCGIFADDR, &ifr) < 0) {
        perror("ioctl SIOCGIFADDR");
        close(fd);
        return -1;
    }
    
    *ip_addr = ((struct sockaddr_in *)&ifr.ifr_addr)->sin_addr;
    close(fd);
    return 0;
}

// construct ICMP request packet
int build_icmp_packet(unsigned char *packet, int packet_size, 
                      uint16_t id, uint16_t seq, struct in_addr src_ip, 
                      struct in_addr dst_ip, unsigned char *data, int data_len) {
    struct iphdr *ip;
    struct icmp_hdr_s *icmp;
    int total_len;
    
    // ICMP packet size = IP header + ICMP header + data
    total_len = sizeof(struct iphdr) + sizeof(struct icmp_hdr_s) + data_len;
    if (total_len > packet_size) {
        fprintf(stderr, "Packet too large\n");
        return -1;
    }
    
    // construct IP header
    ip = (struct iphdr *)packet;
    ip->version = 4;
    ip->ihl = 5;
    ip->tos = 0;
    ip->tot_len = htons(total_len);
    ip->id = htons(id);
    ip->frag_off = 0x40; // don't fragment
    ip->ttl = 64;
    ip->protocol = IPPROTO_ICMP;
    ip->check = 0;
    ip->saddr = src_ip.s_addr;
    ip->daddr = dst_ip.s_addr;
    ip->check = calculate_checksum((unsigned short *)ip, sizeof(struct iphdr));
    
    // construct ICMP header
    icmp = (struct icmp_hdr_s *)(packet + sizeof(struct iphdr));
    icmp->type = ICMP_ECHO_REQUEST;
    icmp->icode = 0;
    icmp->id = htons(id);
    icmp->seqno = htons(seq);
    icmp->icmpchksum = 0;
    
    // fill data (timestamp)
    if (data_len > 0) {
        memcpy(packet + sizeof(struct iphdr) + sizeof(struct icmp_hdr_s), data, data_len);
    }
    
    // calculate ICMP checksum
    icmp->icmpchksum = calculate_checksum((unsigned short *)icmp, 
                                        sizeof(struct icmp_hdr_s) + data_len);
    
    return total_len;
}

// parse received packet
int parse_packet(unsigned char *buffer, int len, 
                  uint16_t expected_id, uint16_t expected_seq) {
    struct iphdr *ip;
    struct icmp_hdr_s *icmp;
    int ip_header_len;

    ip = (struct iphdr *)buffer;
    ip_header_len = ip->ihl * 4;
    
    // check if it is an ICMP packet
    if (ip->protocol != IPPROTO_ICMP) {
        return -1;
    }
    
    icmp = (struct icmp_hdr_s *)(buffer + ip_header_len);
    
    // check if it is an ICMP reply
    if (icmp->type != ICMP_ECHO_REPLY) {
        return -1;
    }
    
    struct timeval tv;
    unsigned char *data;
    struct timeval *sent_time;
    
    gettimeofday(&tv, NULL);
    data = (unsigned char *)icmp + sizeof(struct icmp_hdr_s);
    sent_time = (struct timeval *)data;
    
    // calculate round-trip time
    double rtt = (tv.tv_sec - sent_time->tv_sec) * 1000.0;
    rtt += (tv.tv_usec - sent_time->tv_usec) / 1000.0;
    
    printf("Reply from %s: icmp_seq=%d ttl=%d time=%.3f ms\n",
            inet_ntoa(*(struct in_addr *)&ip->saddr),
            ntohs(icmp->seqno),
            ip->ttl, rtt);
    return 0;
}

// signal handler
void signal_handler(int sig) {
    running = 0;
}

int main(int argc, char *argv[]) {
    int sockfd;
    struct sockaddr_ll sockaddr;
    struct in_addr src_ip, dst_ip;
    unsigned char packet[MAX_PACKET_SIZE];
    uint16_t pid = getpid() & 0xFFFF;
    uint16_t seq = 0;
    struct pollfd fds[1];
    int timeout = DEFAULT_TIMEOUT;
    int count = DEFAULT_COUNT;
    int sent_count = 0;
    int received_count = 0;
    char *ifname = "eth0";  // default network interface
    char *dst_ip_str = "10.0.1.1";  // default destination
    
    // parse command line arguments
    if (argc >= 2) {
        ifname = argv[1];
    }
    if (argc >= 3) {
        dst_ip_str = argv[2];
    }
    
    // setup signal handling
    signal(SIGINT, signal_handler);
    signal(SIGTERM, signal_handler);
    
    // get source IP address
    if (get_ip_address(ifname, &src_ip) < 0) {
        fprintf(stderr, "Failed to get IP address for interface %s\n", ifname);
        return 1;
    }
    
    // convert destination IP address
    if (inet_pton(AF_INET, dst_ip_str, &dst_ip) <= 0) {
        fprintf(stderr, "Invalid destination IP address: %s\n", dst_ip_str);
        return 1;
    }
    
    // create raw socket
    sockfd = socket(PF_PACKET, SOCK_DGRAM, htons(ETH_P_IP));
    if (sockfd < 0) {
        perror("socket");
        return 1;
    }
    
    // bind to the specified network interface
    memset(&sockaddr, 0, sizeof(sockaddr));
    sockaddr.sll_family = AF_PACKET;
    sockaddr.sll_protocol = htons(ETH_P_IP);
    sockaddr.sll_ifindex = if_nametoindex(ifname);
    memset(sockaddr.sll_addr, 0xff, sizeof(sockaddr.sll_addr));
    sockaddr.sll_halen = ETH_ALEN;
    if (sockaddr.sll_ifindex == 0) {
        perror("if_nametoindex");
        close(sockfd);
        return 1;
    }
    
    if (bind(sockfd, (struct sockaddr *)&sockaddr, sizeof(sockaddr)) < 0) {
        perror("bind");
        close(sockfd);
        return 1;
    }
    
    printf("PING %s from %s %s\n", dst_ip_str, inet_ntoa(src_ip), ifname);
    
    // setup poll
    fds[0].fd = sockfd;
    fds[0].events = POLLIN;
    
    while (running && sent_count < count) {
        struct timeval tv;
        int packet_len;
        
        // prepare timestamp for sending
        gettimeofday(&tv, NULL);
        
        // construct ICMP packet (after Ethernet header)
        packet_len = build_icmp_packet(packet, MAX_PACKET_SIZE,
                                      pid, seq, src_ip, dst_ip,
                                      (unsigned char *)&tv, sizeof(tv));
        
        if (packet_len < 0) {
            fprintf(stderr, "Failed to build ICMP packet\n");
            break;
        }
        
        // sendto icmp request packet
        if (sendto(sockfd, packet, packet_len, 0,
                   (struct sockaddr *)&sockaddr, sizeof(sockaddr)) < 0) {
            perror("sendto");
            break;
        }
        
        printf("Sent ICMP request with seq=%d\n", seq);
        sent_count++;
        
        // Wait for reply
        while (running) {
            int ret = poll(fds, 1, timeout);

            if (ret < 0) {
                if (errno == EINTR) continue;
                perror("poll");
                break;
            } else if (ret == 0) {
                printf("Request timeout for icmp_seq=%d\n", seq);
                break;
            } else {
                if (fds[0].revents & POLLIN) {
                    unsigned char buffer[MAX_PACKET_SIZE];
                    int n;
                    
                    n = recv(sockfd, buffer, sizeof(buffer), 0);
                    if (n < 0) {
                        if (errno == EINTR) continue;
                        perror("recv");
                        break;
                    }
                    
                    // Parse the received packet
                    if (parse_packet(buffer, n, pid, seq) == 0) {
                        received_count++;

                        // Successfully received a reply, break the inner loop
                        break;
                    }
                }
            }
        }
        
        seq++;
        sleep(1);  // wait 1 second before sending the next packet
    }
    
    printf("\n--- %s ping statistics ---\n", dst_ip_str);
    printf("%d packets transmitted, %d received, %.1f%% packet loss\n",
           sent_count, received_count, 
           (sent_count > 0) ? (100.0 * (sent_count - received_count) / sent_count) : 0.0);
    
    close(sockfd);
    return 0;
}
nsh> hello
PING 10.0.1.1 from 10.0.1.2 eth0
Sent ICMP request with seq=0
Reply from 10.0.1.1: icmp_seq=0 ttl=64 time=0.324 ms
Sent ICMP request with seq=1
Reply from 10.0.1.1: icmp_seq=1 ttl=64 time=0.273 ms
Sent ICMP request with seq=2
Reply from 10.0.1.1: icmp_seq=2 ttl=64 time=0.277 ms
Sent ICMP request with seq=3
Reply from 10.0.1.1: icmp_seq=3 ttl=64 time=0.169 ms
Sent ICMP request with seq=4
Reply from 10.0.1.1: icmp_seq=4 ttl=64 time=0.329 ms

--- 10.0.1.1 ping statistics ---
5 packets transmitted, 5 received, 0.0% packet loss
nsh> 

zhhyu7 avatar Dec 12 '25 04:12 zhhyu7

sockaddr.sll_ifindex = if_nametoindex(ifname); memset(sockaddr.sll_addr, 0xff, sizeof(sockaddr.sll_addr)); sockaddr.sll_halen = ETH_ALEN;

Done

zhhyu7 avatar Dec 13 '25 01:12 zhhyu7