/*
 * Copyright (C) 2012 - David Goulet <dgoulet@efficios.com>
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License, version 2 only, as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 * more details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 51
 * Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 */

#define _GNU_SOURCE
#include <assert.h>
#include <limits.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <errno.h>

#include <common/defaults.h>
#include <common/error.h>

#include "inet.h"

/*
 * INET protocol operations.
 */
static const struct lttcomm_proto_ops inet_ops = {
	.bind = lttcomm_bind_inet_sock,
	.close = lttcomm_close_inet_sock,
	.connect = lttcomm_connect_inet_sock,
	.accept = lttcomm_accept_inet_sock,
	.listen = lttcomm_listen_inet_sock,
	.recvmsg = lttcomm_recvmsg_inet_sock,
	.sendmsg = lttcomm_sendmsg_inet_sock,
};

/*
 * Creates an PF_INET socket.
 */
__attribute__((visibility("hidden")))
int lttcomm_create_inet_sock(struct lttcomm_sock *sock, int type, int proto)
{
	int val = 1, ret;

	/* Create server socket */
	if ((sock->fd = socket(PF_INET, type, proto)) < 0) {
		PERROR("socket inet");
		goto error;
	}

	sock->ops = &inet_ops;

	/*
	 * Set socket option to reuse the address.
	 */
	ret = setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(int));
	if (ret < 0) {
		PERROR("setsockopt inet");
		goto error;
	}

	return 0;

error:
	return -1;
}

/*
 * Bind socket and return.
 */
__attribute__((visibility("hidden")))
int lttcomm_bind_inet_sock(struct lttcomm_sock *sock)
{
	int ret;

	ret = bind(sock->fd, &sock->sockaddr.addr.sin,
			sizeof(sock->sockaddr.addr.sin));
	if (ret < 0) {
		PERROR("bind inet");
	}

	return ret;
}

/*
 * Connect PF_INET socket.
 */
__attribute__((visibility("hidden")))
int lttcomm_connect_inet_sock(struct lttcomm_sock *sock)
{
	int ret, closeret;

	ret = connect(sock->fd, (struct sockaddr *) &sock->sockaddr.addr.sin,
			sizeof(sock->sockaddr.addr.sin));
	if (ret < 0) {
		/*
		 * Don't print message on connect error, because connect is used in
		 * normal execution to detect if sessiond is alive.
		 */
		goto error_connect;
	}

	return ret;

error_connect:
	closeret = close(sock->fd);
	if (closeret) {
		PERROR("close inet");
	}

	return ret;
}

/*
 * Do an accept(2) on the sock and return the new lttcomm socket. The socket
 * MUST be bind(2) before.
 */
__attribute__((visibility("hidden")))
struct lttcomm_sock *lttcomm_accept_inet_sock(struct lttcomm_sock *sock)
{
	int new_fd;
	socklen_t len;
	struct lttcomm_sock *new_sock;

	if (sock->proto == LTTCOMM_SOCK_UDP) {
		/*
		 * accept(2) does not exist for UDP so simply return the passed socket.
		 */
		new_sock = sock;
		goto end;
	}

	new_sock = lttcomm_alloc_sock(sock->proto);
	if (new_sock == NULL) {
		goto error;
	}

	len = sizeof(new_sock->sockaddr.addr.sin);

	/* Blocking call */
	new_fd = accept(sock->fd, (struct sockaddr *) &new_sock->sockaddr.addr.sin,
			&len);
	if (new_fd < 0) {
		PERROR("accept inet");
		goto error;
	}

	new_sock->fd = new_fd;
	new_sock->ops = &inet_ops;

end:
	return new_sock;

error:
	free(new_sock);
	return NULL;
}

/*
 * Make the socket listen using LTTNG_SESSIOND_COMM_MAX_LISTEN.
 */
__attribute__((visibility("hidden")))
int lttcomm_listen_inet_sock(struct lttcomm_sock *sock, int backlog)
{
	int ret;

	if (sock->proto == LTTCOMM_SOCK_UDP) {
		/* listen(2) does not exist for UDP so simply return success. */
		ret = 0;
		goto end;
	}

	/* Default listen backlog */
	if (backlog <= 0) {
		backlog = LTTNG_SESSIOND_COMM_MAX_LISTEN;
	}

	ret = listen(sock->fd, backlog);
	if (ret < 0) {
		PERROR("listen inet");
	}

end:
	return ret;
}

/*
 * Receive data of size len in put that data into the buf param. Using recvmsg
 * API.
 *
 * Return the size of received data.
 */
__attribute__((visibility("hidden")))
ssize_t lttcomm_recvmsg_inet_sock(struct lttcomm_sock *sock, void *buf,
		size_t len, int flags)
{
	struct msghdr msg;
	struct iovec iov[1];
	ssize_t ret = -1;
	size_t len_last;

	memset(&msg, 0, sizeof(msg));

	iov[0].iov_base = buf;
	iov[0].iov_len = len;
	msg.msg_iov = iov;
	msg.msg_iovlen = 1;

	msg.msg_name = (struct sockaddr *) &sock->sockaddr.addr.sin;
	msg.msg_namelen = sizeof(sock->sockaddr.addr.sin);

	do {
		len_last = iov[0].iov_len;
		ret = recvmsg(sock->fd, &msg, flags);
		if (ret > 0) {
			iov[0].iov_base += ret;
			iov[0].iov_len -= ret;
			assert(ret <= len_last);
		}
	} while ((ret > 0 && ret < len_last) || (ret < 0 && errno == EINTR));
	if (ret < 0) {
		PERROR("recvmsg inet");
	} else if (ret > 0) {
		ret = len;
	}
	/* Else ret = 0 meaning an orderly shutdown. */

	return ret;
}

/*
 * Send buf data of size len. Using sendmsg API.
 *
 * Return the size of sent data.
 */
__attribute__((visibility("hidden")))
ssize_t lttcomm_sendmsg_inet_sock(struct lttcomm_sock *sock, void *buf,
		size_t len, int flags)
{
	struct msghdr msg;
	struct iovec iov[1];
	ssize_t ret = -1;

	memset(&msg, 0, sizeof(msg));

	iov[0].iov_base = buf;
	iov[0].iov_len = len;
	msg.msg_iov = iov;
	msg.msg_iovlen = 1;

	switch (sock->proto) {
	case LTTCOMM_SOCK_UDP:
		msg.msg_name = (struct sockaddr *) &sock->sockaddr.addr.sin;
		msg.msg_namelen = sizeof(sock->sockaddr.addr.sin);
		break;
	default:
		break;
	}

	do {
		ret = sendmsg(sock->fd, &msg, flags);
	} while (ret < 0 && errno == EINTR);
	if (ret < 0) {
		/*
		 * Only warn about EPIPE when quiet mode is deactivated.
		 * We consider EPIPE as expected.
		 */
		if (errno != EPIPE || !lttng_opt_quiet) {
			PERROR("sendmsg inet");
		}
	}

	return ret;
}

/*
 * Shutdown cleanly and close.
 */
__attribute__((visibility("hidden")))
int lttcomm_close_inet_sock(struct lttcomm_sock *sock)
{
	int ret;

	/* Don't try to close an invalid marked socket */
	if (sock->fd == -1) {
		return 0;
	}

	ret = close(sock->fd);
	if (ret) {
		PERROR("close inet");
	}

	/* Mark socket */
	sock->fd = -1;

	return ret;
}
