#include #include "mxcrypt.h" struct msg_sess { struct mxc_sess sess[1]; int sess_fd; }; union msg_buf { struct cmsghdr align; char buf[CMSG_SPACE(sizeof(int))]; }; #define INIT_MSG(ms, msg, cmsg, iov, buf) \ do { \ iov->iov_base = ms->sess; \ iov->iov_len = sizeof ms->sess; \ \ (void)memset(msg, 0, sizeof msg); \ msg->msg_iov = iov; \ msg->msg_iovlen = 1; \ msg->msg_control = buf; \ msg->msg_controllen = sizeof(buf); \ } while (0) void send_msg_sess(struct msg_sess *ms, int fd) { struct msghdr msg[1]; struct cmsghdr *cmsg; struct iovec iov[1]; union msg_buf buf[1]; INIT_MSG(ms, msg, cmsg, iov, buf); cmsg = CMSG_FIRSTHDR(msg); cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_type = SCM_RIGHTS; cmsg->cmsg_len = CMSG_LEN(sizeof(int)); (void)memcpy(CMSG_DATA(cmsg), &ms->sess_fd, sizeof(int)); assert(sendmsg(fd, msg, 0) > 0); LOG("sent session for file descriptor %d", ms->sess_fd); } void recv_msg_sess(struct msg_sess *ms, int fd) { struct msghdr msg[1]; struct cmsghdr *cmsg; struct iovec iov[1]; union msg_buf buf[1]; INIT_MSG(ms, msg, cmsg, iov, buf); assert(recvmsg(fd, msg, 0) > 0); cmsg = CMSG_FIRSTHDR(msg); assert(cmsg->cmsg_level == SOL_SOCKET); assert(cmsg->cmsg_type == SCM_RIGHTS); assert(cmsg->cmsg_len == CMSG_LEN(sizeof(int))); (void)memcpy(&ms->sess_fd, CMSG_DATA(cmsg), sizeof(int)); LOG("received session for file descriptor %d", ms->sess_fd); } void main_accept(int argc, char * const *argv, int server_fd) { struct msg_sess ms[1]; uint16_t port; int sock_fd; sock_fd = mxc_socket(argc, argv); port = mxc_port(sock_fd); LOG("listening to port %hu", port); ms->sess_fd = accept(sock_fd, NULL, NULL); assert(ms->sess_fd != -1); mxc_sess_init(ms->sess); mxc_hello_recv(ms->sess, ms->sess_fd); mxc_hello_server(ms->sess); mxc_hello_send(ms->sess, ms->sess_fd); LOG("computed session key 0x%08x", ms->sess->key); send_msg_sess(ms, server_fd); (void)close(ms->sess_fd); (void)close(sock_fd); } void main_serve(int acceptor_fd) { struct msg_sess ms[1]; char buf[256]; ssize_t len; recv_msg_sess(ms, acceptor_fd); do { len = mxc_sess_recv(ms->sess, ms->sess_fd, buf, sizeof buf); assert(len >= 0); (void)write(STDOUT_FILENO, buf, len); } while (len > 0); (void)close(ms->sess_fd); } void main(int argc, char * const *argv) { pid_t acceptor_pid, server_pid; int pair_fd[2]; assert(socketpair(PF_UNIX, SOCK_STREAM, 0, pair_fd) == 0); acceptor_pid = fork(); if (acceptor_pid == 0) { (void)close(pair_fd[0]); main_accept(argc, argv, pair_fd[1]); return; } assert(acceptor_pid != -1); server_pid = fork(); if (server_pid == 0) { (void)close(pair_fd[1]); main_serve(pair_fd[0]); return; } assert(server_pid != -1); (void)close(pair_fd[0]); (void)close(pair_fd[1]); assert(waitpid(acceptor_pid, NULL, 0) == acceptor_pid); assert(waitpid(server_pid, NULL, 0) == server_pid); }