/* SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only */
/* Copyright (c) 2022 Brett Sheffield <bacs@librecast.net> */

#include "test_libmld.h"
#include <librecast.h>
#include <mld.h>
#include <unistd.h>
#include <pthread.h>
#include <semaphore.h>

static mld_t *mld;
static struct in6_addr *addr;
static unsigned int ifx;
static sem_t sem;

void *thread_mld_wait(void *arg)
{
	int rc;
	int *waitcount = (int *)arg;
	rc = mld_wait(mld, ifx, addr, 0);
	test_assert(rc == 0, "mld_wait() returns 0");
	sem_post(&sem);
	(*waitcount)++;
	return arg;
}

int main(void)
{
	lc_ctx_t *lctx;
	lc_channel_t *chan;
	int rc;

	test_name("mld_watch()");

	ifx = get_multicast_if();
	test_assert(ifx, "get_multicast_if() - find multicast capable interface");
	mld = mld_init(0);
	test_assert(mld != NULL, "mld_t allocated");

	mld_start(mld);
	/* ensure all threads created */
	for (int i = 0; i < MLD_THREADS; i++) assert(mld->q[i]);

	/* generate a random multicast address */
	lctx = lc_ctx_new();
	test_assert(lctx != NULL, "lc_ctx_new()");
	chan = lc_channel_random(lctx);
	test_assert(chan != NULL, "lc_channel_random()");
	addr = lc_channel_in6addr(chan);
	test_assert(addr != NULL, "lc_channel_in6addr()");

	pthread_t tid;
	struct timespec ts = {0};
	int waitcount = 0;
	sem_init(&sem, 0, 0);

	/* mld_wait() will block unless MLD_DONTWAIT */
	pthread_create(&tid, NULL, thread_mld_wait, &waitcount);
	clock_gettime(CLOCK_REALTIME, &ts);
	ts.tv_sec++;
	errno = 0;
	rc = sem_timedwait(&sem, &ts);
	test_assert(errno == ETIMEDOUT, "sem_wait() blocks without MLD_DONTWAIT");
	pthread_cancel(tid);
	pthread_join(tid, NULL);

	/* MLD_DONTWAIT causes mld_wait() to be non-blocking */
	rc = mld_wait(mld, ifx, addr, MLD_DONTWAIT);
	test_assert(errno == EWOULDBLOCK, "sem_wait() - EWOULDBLOCK");
	test_assert(rc == -1, "mld_wait with MLD_DONTWAIT returns -1 - group not added");

	/* now add a group to the filter */
	rc = mld_filter_grp_add(mld, ifx, addr);
	test_assert(rc == 0, "add group to filter");
	sleep(1);
	rc = mld_wait(mld, ifx, addr, 0);
	test_assert(rc == 0, "mld_wait() returns 0 when grp in filter");

	waitcount = 0;
	pthread_create(&tid, NULL, thread_mld_wait, &waitcount);
	sem_wait(&sem);
	test_assert(rc == 0, "sem_wait() blocks, then returns when grp added");
	pthread_cancel(tid);
	pthread_join(tid, NULL);
	test_assert(waitcount == 1, "waitcount = %i", waitcount);

	lc_ctx_free(lctx);
	mld_stop(mld);
	mld_free(mld);
	return fails;
}
