/*-
 * Copyright (c) 2007 Robert N. M. Watson
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 * $FreeBSD$
 */

/*
 * A few regression tests for UNIX domain sockets.  Run from single-user mode
 * as it checks the openfiles sysctl to look for leaks, and we don't want that
 * changing due to other processes doing stuff.
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/sysctl.h>

#include <err.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

int forcegc = 1;

static int
getopenfiles(void)
{
	size_t len;
	int i;

	len = sizeof(i);
	if (sysctlbyname("kern.openfiles", &i, &len, NULL, 0) < 0)
		err(-1, "kern.openfiles");
	return (i);
}

static int
getinflight(void)
{
	size_t len;
	int i;

	len = sizeof(i);
	if (sysctlbyname("net.local.inflight", &i, &len, NULL, 0) < 0)
		err(-1, "net.local.inflight");
	return (i);
}

static void
sendfd(const char *errstr, int fd, int fdtosend)
{
	struct msghdr mh;
	struct message { struct cmsghdr msg_hdr; int fd; } m;
	ssize_t len;
	int after_inflight, before_inflight;

	before_inflight = getinflight();

	bzero(&mh, sizeof(mh));
	bzero(&m, sizeof(m));
	mh.msg_control = &m;
	mh.msg_controllen = sizeof(m);
	m.msg_hdr.cmsg_len = sizeof(m);
	m.msg_hdr.cmsg_level = SOL_SOCKET;
	m.msg_hdr.cmsg_type = SCM_RIGHTS;
	m.fd = fdtosend;
	len = sendmsg(fd, &mh, 0);
	if (len < 0)
		err(-1, "%s: sendmsg", errstr);
	after_inflight = getinflight();
	if (after_inflight != before_inflight + 1)
		errx(-1, "%s: sendfd: before %d after %d\n", errstr,
		    before_inflight, after_inflight);
}

static void
close2(int fd1, int fd2)
{

	close(fd1);
	close(fd2);
}

static void
close3(int fd1, int fd2, int fd3)
{

	close(fd1);
	close(fd2);
	close(fd3);
}

static void
close5(int fd1, int fd2, int fd3, int fd4, int fd5)
{

	close3(fd1, fd2, fd3);
	close2(fd4, fd5);
}

static void
alloc2fds(const char *errstr, int *sv)
{

	if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) < 0)
		err(-1, "%s: socketpair", errstr);
}

static void
alloc3fds(const char *errstr, int *s, int *sv)
{

	if ((*s = socket(PF_UNIX, SOCK_STREAM, 0)) < 0)
		err(-1, "%s: socket", errstr);
	if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) < 0)
		err(-1, "%s: socketpair", errstr);
}

static void
alloc5fds(const char *errstr, int *s, int *sva, int *svb)
{

	if ((*s = socket(PF_UNIX, SOCK_STREAM, 0)) < 0)
		err(-1, "%s: socket", errstr);
	if (socketpair(PF_UNIX, SOCK_STREAM, 0, sva) < 0)
		err(-1, "%s: socketpair", errstr);
	if (socketpair(PF_UNIX, SOCK_STREAM, 0, svb) < 0)
		err(-1, "%s: socketpair", errstr);
}

static void
save_sysctls(int *before_inflight, int *before_openfiles)
{

	*before_inflight = getinflight();
	*before_openfiles = getopenfiles();
}

/*
 * Try hard to make sure that the GC does in fact run before we test the
 * condition of things.
 */
static void
trigger_gc(void)
{
	int s;

	if (forcegc) {
		if ((s = socket(PF_UNIX, SOCK_STREAM, 0)) < 0)
			err(-1, "trigger_gc: socket");
		close(s);
	}
	sleep(1);
}

static void
test_sysctls(const char *errstr, int before_inflight, int before_openfiles)
{
	int after_inflight, after_openfiles;

	trigger_gc();
	after_inflight = getinflight();
	if (after_inflight != before_inflight)
		warnx("%s: before inflight: %d, after inflight: %d",
		    errstr, before_inflight, after_inflight);

	after_openfiles = getopenfiles();
	if (after_openfiles != before_openfiles)
		warnx("%s: before: %d, after: %d", errstr, before_openfiles,
		    after_openfiles);
}

static void
twosome_nothing(void)
{
	int inflight, openfiles;
	int sv[2];
	const char *test;

	/*
	 * Create a pair, close in one order.
	 */
	test = "twosome_nothing1";
	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc2fds(test, sv);
	close2(sv[0], sv[1]);
	test_sysctls(test, inflight, openfiles);

	/*
	 * Create a pair, close in the other order.
	 */
	test = "twosome_nothing2";
	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc2fds(test, sv);
	close2(sv[0], sv[1]);
	test_sysctls(test, inflight, openfiles);
}

/*
 * Using a socket pair, send various endpoints over the pair and close in
 * various orders.
 */
static void
twosome_drop_work(const char *test, int sendvia, int tosend, int closefirst)
{
	int inflight, openfiles;
	int sv[2];

	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc2fds(test, sv);
	sendfd(test, sv[sendvia], sv[tosend]);
	if (closefirst == 0)
		close2(sv[0], sv[1]);
	else
		close2(sv[1], sv[0]);
	test_sysctls(test, inflight, openfiles);
}

static void
twosome_drop(void)
{

	/*
	 * In various combations, some wastefully symmetric, create socket
	 * pairs and send one or another endpoint over one or another
	 * endpoint, closing the endpoints in various orders.
	 */
	twosome_drop_work("twosome_drop1", 0, 0, 0);
	twosome_drop_work("twosome_drop2", 0, 0, 1);
	twosome_drop_work("twosome_drop3", 0, 1, 0);
	twosome_drop_work("twosome_drop4", 0, 1, 1);
	twosome_drop_work("twosome_drop5", 1, 0, 0);
	twosome_drop_work("twosome_drop6", 1, 0, 1);
	twosome_drop_work("twosome_drop7", 1, 1, 0);
	twosome_drop_work("twosome_drop8", 1, 1, 1);
}

static void
threesome_nothing(void)
{
	int inflight, openfiles;
	int s, sv[2];
	const char *test;

	test = "threesome_nothing";
	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc3fds(test, &s, sv);
	close3(s, sv[0], sv[1]);
	test_sysctls(test, inflight, openfiles);
}

/*
 * threesome_drop: create a pair and a spare, send the spare over the pair, and
 * close in various orders and make sure all the fds went away.
 */
static void
threesome_drop(void)
{
	int inflight, openfiles;
	int s, sv[2];
	const char *test;

	/*
	 * threesome_drop1: close sent send receive
	 */
	test = "threesome_drop1";
	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc3fds(test, &s, sv);
	sendfd(test, sv[0], s);
	close3(s, sv[0], sv[1]);
	test_sysctls(test, inflight, openfiles);

	/*
	 * threesome_drop2: close sent receive send
	 */
	test = "threesome_drop2";
	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc3fds(test, &s, sv);
	sendfd(test, sv[0], s);
	close3(s, sv[1], sv[0]);
	test_sysctls(test, inflight, openfiles);

	/*
	 * threesome_drop3: close receive sent send
	 */
	test = "threesome_drop3";
	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc3fds(test, &s, sv);
	sendfd(test, sv[0], s);
	close3(sv[1], s, sv[0]);
	test_sysctls(test, inflight, openfiles);

	/*
	 * threesome_drop4: close receive send sent
	 */
	test = "threesome_drop4";
	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc3fds(test, &s, sv);
	sendfd(test, sv[0], s);
	close3(sv[1], sv[0], s);
	test_sysctls(test, inflight, openfiles);

	/*
	 * threesome_drop5: close send receive sent
	 */
	test = "threesome_drop5";
	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc3fds(test, &s, sv);
	sendfd(test, sv[0], s);
	close3(sv[0], sv[1], s);
	test_sysctls(test, inflight, openfiles);

	/*
	 * threesome_drop6: close send sent receive
	 */
	test = "threesome_drop6";
	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc3fds(test, &s, sv);
	close3(sv[0], s, sv[1]);
	test_sysctls(test, inflight, openfiles);
}

/*
 * Fivesome tests: create two socket pairs and a spare, send the spare over
 * the first socket pair, then send the first socket pair over the second
 * socket pair, and GC.  Do various closes at various points to exercise
 * various cases.
 */
static void
fivesome_nothing(void)
{
	int inflight, openfiles;
	int spare, sva[2], svb[2];
	const char *test;

	test = "fivesome_nothing";
	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc5fds(test, &spare, sva, svb);
	close5(spare, sva[0], sva[1], svb[0], svb[1]);
	test_sysctls(test, inflight, openfiles);
}

static void
fivesome_drop_work(const char *test, int close_spare_after_send,
    int close_sva_after_send)
{
	int inflight, openfiles;
	int spare, sva[2], svb[2];

	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc5fds(test, &spare, sva, svb);

	/*
	 * Send spare over sva.
	 */
	sendfd(test, sva[0], spare);
	if (close_spare_after_send)
		close(spare);

	/*
	 * Send sva over svb.
	 */
	sendfd(test, svb[0], sva[0]);
	sendfd(test, svb[0], sva[1]);
	if (close_sva_after_send)
		close2(sva[0], sva[1]);

	close2(svb[0], svb[1]);

	if (!close_sva_after_send)
		close2(sva[0], sva[1]);
	if (!close_spare_after_send)
		close(spare);

	test_sysctls(test, inflight, openfiles);
}

static void
fivesome_drop(void)
{

	fivesome_drop_work("fivesome_drop1", 0, 0);
	fivesome_drop_work("fivesome_drop2", 0, 1);
	fivesome_drop_work("fivesome_drop3", 1, 0);
	fivesome_drop_work("fivesome_drop4", 1, 1);
}

/*
 * Create a somewhat nasty dual-socket socket intended to upset the garbage
 * collector if mark-and-sweep is wrong.
 */
static void
complex_cycles(void)
{
	int inflight, openfiles;
	int spare, sva[2], svb[2];
	const char *test;

	test = "complex_cycles";
	printf("%s\n", test);
	save_sysctls(&inflight, &openfiles);
	alloc5fds(test, &spare, sva, svb);
	sendfd(test, sva[0], svb[0]);
	sendfd(test, sva[0], svb[1]);
	sendfd(test, svb[0], sva[0]);
	sendfd(test, svb[0], sva[1]);
	sendfd(test, svb[0], spare);
	sendfd(test, sva[0], spare);
	close5(spare, sva[0], sva[1], svb[0], svb[1]);
	test_sysctls(test, inflight, openfiles);
}

int
main(int argc, char *argv[])
{

	printf("Open files at start: %d\n", getopenfiles());

	twosome_nothing();
	twosome_drop();

	threesome_nothing();
	threesome_drop();

	fivesome_nothing();
	fivesome_drop();

	complex_cycles();

	printf("Open files at finish: %d\n", getopenfiles());
	return (0);
}
