Skip to content

Instantly share code, notes, and snippets.

@ktsaou
Created October 17, 2024 17:46
Show Gist options
  • Select an option

  • Save ktsaou/0a5146b62f312946a6139ddc928b0e72 to your computer and use it in GitHub Desktop.

Select an option

Save ktsaou/0a5146b62f312946a6139ddc928b0e72 to your computer and use it in GitHub Desktop.

Revisions

  1. ktsaou created this gist Oct 17, 2024.
    229 changes: 229 additions & 0 deletions proc-net-tcp-test.c
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,229 @@
    /*
    Save as: proc-net-tcp-test.c
    Compile with:
    gcc -O2 -o proc-net-tcp-test proc-net-tcp-test.c -lmnl
    Run with:
    ./proc-net-tcp-test 1 >/dev/null
    */

    #define _GNU_SOURCE
    #include <stdio.h>
    #include <stdlib.h>
    #include <unistd.h>
    #include <fcntl.h>
    #include <sched.h>
    #include <errno.h>
    #include <string.h>
    #include <time.h>
    #include <sys/resource.h>

    #include <linux/rtnetlink.h>
    #include <linux/inet_diag.h>
    #include <linux/sock_diag.h>
    #include <linux/unix_diag.h>
    #include <linux/netlink.h>
    #include <libmnl/libmnl.h>

    #include <linux/tcp.h>
    #include <linux/genetlink.h>
    #include <arpa/inet.h>
    #include <sys/stat.h>

    // Function to switch network namespace
    int switch_namespace_if_needed(int target_pid) {
    char current_ns_path[64], target_ns_path[64];
    struct stat current_ns_stat, target_ns_stat;

    snprintf(current_ns_path, sizeof(current_ns_path), "/proc/self/ns/net");
    snprintf(target_ns_path, sizeof(target_ns_path), "/proc/%d/ns/net", target_pid);

    if (stat(current_ns_path, &current_ns_stat) == -1) {
    perror("Failed to stat current network namespace");
    return -1;
    }

    if (stat(target_ns_path, &target_ns_stat) == -1) {
    perror("Failed to stat target network namespace");
    return -1;
    }

    if (current_ns_stat.st_ino == target_ns_stat.st_ino) {
    fprintf(stderr, "Already in the same network namespace as PID %d.\n", target_pid);
    return 0;
    }

    int fd = open(target_ns_path, O_RDONLY);
    if (fd == -1) {
    perror("Failed to open target network namespace");
    return -1;
    }

    if (setns(fd, CLONE_NEWNET) == -1) {
    perror("Failed to switch network namespace");
    close(fd);
    return -1;
    }

    fprintf(stderr, "Successfully switched to network namespace of PID %d.\n", target_pid);
    close(fd);
    return 0;
    }

    // Function to measure time and CPU usage
    void print_time_and_cpu_usage(struct timespec start, struct timespec end, struct rusage usage_start, struct rusage usage_end) {
    double time_elapsed = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
    long user_cpu = (usage_end.ru_utime.tv_sec - usage_start.ru_utime.tv_sec) * 1000000 +
    (usage_end.ru_utime.tv_usec - usage_start.ru_utime.tv_usec);
    long sys_cpu = (usage_end.ru_stime.tv_sec - usage_start.ru_stime.tv_sec) * 1000000 +
    (usage_end.ru_stime.tv_usec - usage_start.ru_stime.tv_usec);

    fprintf(stderr, "Time elapsed: %.6f seconds\n", time_elapsed);
    fprintf(stderr, "User CPU time: %.6f seconds\n", user_cpu / 1e6);
    fprintf(stderr, "System CPU time: %.6f seconds\n", sys_cpu / 1e6);
    }

    // Function to dump /proc/net/tcp
    void dump_proc_net_tcp() {
    fprintf(stderr, "\nDumping /proc/net/tcp\n");
    FILE *f = fopen("/proc/net/tcp", "r");
    if (!f) {
    perror("fopen /proc/net/tcp");
    return;
    }
    size_t size = 0;
    char *line = NULL;
    int socket_count = 0;

    while (getline(&line, &size, f) != -1) {
    fprintf(stdout, "%s", line);
    socket_count++;
    }
    if (line)
    free(line);
    fclose(f);

    fprintf(stderr, "Total TCP sockets (proc): %d\n", socket_count - 1); // Exclude header line
    }

    // Callback for libmnl to parse netlink messages
    static int data_cb(const struct nlmsghdr *nlh, void *data) {
    struct inet_diag_msg *diag_msg = mnl_nlmsg_get_payload(nlh);
    int *socket_count = (int *)data;

    if (diag_msg->idiag_family == AF_INET) {
    char src_addr[INET_ADDRSTRLEN];
    char dst_addr[INET_ADDRSTRLEN];

    inet_ntop(AF_INET, &diag_msg->id.idiag_src, src_addr, sizeof(src_addr));
    inet_ntop(AF_INET, &diag_msg->id.idiag_dst, dst_addr, sizeof(dst_addr));

    fprintf(stdout, "TCP Socket: %s:%d -> %s:%d (state: %d)\n",
    src_addr, ntohs(diag_msg->id.idiag_sport),
    dst_addr, ntohs(diag_msg->id.idiag_dport),
    diag_msg->idiag_state);

    (*socket_count)++;
    }

    return MNL_CB_OK;
    }

    // Function to dump sockets using libmnl
    void dump_libmnl_sockets() {
    fprintf(stderr, "\nDumping TCP sockets using libmnl\n");
    struct mnl_socket *nl = mnl_socket_open(NETLINK_INET_DIAG);
    if (nl == NULL) {
    perror("mnl_socket_open");
    return;
    }

    if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
    perror("mnl_socket_bind");
    mnl_socket_close(nl);
    return;
    }

    char buf[MNL_SOCKET_BUFFER_SIZE];
    struct nlmsghdr *nlh = mnl_nlmsg_put_header(buf);
    nlh->nlmsg_type = SOCK_DIAG_BY_FAMILY;
    nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
    nlh->nlmsg_seq = time(NULL);

    struct inet_diag_req_v2 req = {
    .sdiag_family = AF_INET,
    .sdiag_protocol = IPPROTO_TCP,
    .idiag_states = ~0,
    };
    mnl_nlmsg_put_extra_header(nlh, sizeof(req));
    memcpy(mnl_nlmsg_get_payload(nlh), &req, sizeof(req));

    if (mnl_socket_sendto(nl, nlh, nlh->nlmsg_len) < 0) {
    perror("mnl_socket_sendto");
    mnl_socket_close(nl);
    return;
    }

    int socket_count = 0;
    int ret;
    while ((ret = mnl_socket_recvfrom(nl, buf, sizeof(buf))) > 0) {
    ret = mnl_cb_run(buf, ret, 0, 0, data_cb, &socket_count);
    if (ret <= MNL_CB_STOP) {
    break;
    }
    }
    if (ret == -1) {
    perror("mnl_socket_recvfrom");
    }

    fprintf(stderr, "Total TCP sockets (libmnl): %d\n", socket_count);
    mnl_socket_close(nl);
    }

    int main(int argc, char *argv[]) {
    if (argc != 2) {
    fprintf(stderr, "Usage: %s <pid>\n", argv[0]);
    return EXIT_FAILURE;
    }

    int target_pid = atoi(argv[1]);
    if (target_pid <= 0) {
    fprintf(stderr, "Invalid PID: %s\n", argv[1]);
    return EXIT_FAILURE;
    }

    if (switch_namespace_if_needed(target_pid) == -1) {
    fprintf(stderr, "Error switching to the network namespace of PID %d\n", target_pid);
    return EXIT_FAILURE;
    }

    struct timespec start, end;
    struct rusage usage_start, usage_end;

    // Measure time and CPU for dump_proc_net_tcp
    clock_gettime(CLOCK_MONOTONIC, &start);
    getrusage(RUSAGE_SELF, &usage_start);

    dump_proc_net_tcp();

    clock_gettime(CLOCK_MONOTONIC, &end);
    getrusage(RUSAGE_SELF, &usage_end);
    print_time_and_cpu_usage(start, end, usage_start, usage_end);

    // Measure time and CPU for dump_libmnl_sockets
    clock_gettime(CLOCK_MONOTONIC, &start);
    getrusage(RUSAGE_SELF, &usage_start);

    dump_libmnl_sockets();

    clock_gettime(CLOCK_MONOTONIC, &end);
    getrusage(RUSAGE_SELF, &usage_end);
    print_time_and_cpu_usage(start, end, usage_start, usage_end);

    return EXIT_SUCCESS;
    }