/* Copyright (C) 2004 MySQL AB

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA */


#if defined(__WIN__) || defined(_WIN32) || defined(_WIN64)
# include <winsock2.h>
# define MSG_DONTWAIT 0
#else
# include <sys/types.h>
# include <sys/time.h>
# include <sys/socket.h>
# include <netinet/in.h>
# include <arpa/inet.h>
# include <sys/fcntl.h>
# include <netdb.h>
# include <unistd.h>
# define closesocket close
#endif
#include <glib.h>

#include <fcntl.h>
#include <errno.h>
#include <string.h>

#include <openssl/rand.h>
#include <openssl/err.h>

#include "net_line_client.h"

#define INITIAL_BUFFER_SIZE 4096
#define BUFFER_INCR_SIZE 1024
#define INITIAL_LINES_COUNT 64


static void init_ssl()
{
  static char buf[8];
  int i;

  SSLeay_add_ssl_algorithms();
  SSL_load_error_strings();
  SSL_library_init();
  
  srand(time(NULL));
  for (i= 0; i < sizeof(buf); i++)
    buf[i]=rand();
  RAND_seed(buf, sizeof(buf));
}


static void show_ssl_error(const char *func)
{
  char buf[256];
  int err;

  err= ERR_get_error();
  ERR_error_string(err, buf);
  g_error("SSL error in %s: %s (%d)\n", func, buf, err);
}



MNetLineClient *mnet_init_client(int line_buffer_size)
{
  static int ssl_init_needed= 1;
  MNetLineClient *client= g_malloc0(sizeof(MNetLineClient));

  if (ssl_init_needed)
  {
    init_ssl();
    ssl_init_needed= 0;
  }

  client->state= MSDisconnected;

  client->socket= -1;

  client->out_buffer= g_malloc(sizeof(char)*INITIAL_BUFFER_SIZE);
  client->out_buffer_len= 0;
  client->out_buffer_alloced= INITIAL_BUFFER_SIZE;

  client->in_buffer= g_malloc(sizeof(char)*INITIAL_BUFFER_SIZE);
  client->in_buffer_len= 0;
  client->in_buffer_alloced= INITIAL_BUFFER_SIZE;

  client->in_lines_used= 0;
  client->in_lines_alloced= line_buffer_size;
  client->in_lines= g_malloc0(sizeof(char*)*line_buffer_size);

  return client;
}


MNetLineClient *mnet_init_ssl_client(int line_buffer_size)
{
  MNetLineClient *client= mnet_init_client(line_buffer_size);
  
  client->ssl_ctx= SSL_CTX_new(SSLv3_client_method());
  if (!client->ssl_ctx)
  {
    show_ssl_error("SSL_CTX_new");
    //mnet_free_client(client)
    return NULL;
  }

  return client;
}


int mnet_client_connect(MNetLineClient *client, const char *host, int port)
{
  struct sockaddr_in addr;
  struct hostent *hptr;
  struct in_addr **aptr;

  /* get a socket */
  client->socket= socket(PF_INET, SOCK_STREAM, 0);
  if (client->socket < 0)
  {
    g_error("could not create socket: %s", g_strerror(errno));
    client->state= MSError;
    return -2;
  }

  /* resolve name */
  memset(&addr, sizeof(struct sockaddr_in), 0);
  addr.sin_family = AF_INET;
  addr.sin_port = htons(port);

  hptr = gethostbyname(host);
  if (!hptr)
  {
    closesocket(client->socket);
    client->socket= -1;
    client->state= MSResolveError;
    return -1;
  }

  /* try all addresses */
  for (aptr = (struct in_addr**)hptr->h_addr_list; *aptr != NULL; aptr++) {
    memcpy(&addr.sin_addr, *aptr, sizeof(**aptr));
    if (connect(client->socket, (struct sockaddr*)&addr, sizeof(addr)) == 0) {
      break;
    }
  }

  if (!*aptr)
  {
    g_error("could not connect socket to %s:%i: %s", 
            host, port, strerror(errno));
    closesocket(client->socket);
    client->socket= -1;
    client->state= MSConnectError;
    return -1;
  }

#if 0
  {
    /* set non-blocking */
    int status= fcntl(client->socket, F_GETFL, 0);
    if (status != -1)
    {
      status |= O_NONBLOCK;
      fcntl(client->socket, F_SETFL, status);
    }
  }
#endif

  /* setup TLS layer */
  if (client->ssl_ctx)
  {
    client->ssl= SSL_new(client->ssl_ctx);
    if (!client->ssl)
    {
      show_ssl_error("SSL_new");
      closesocket(client->socket);
      client->socket= -1;
      client->state= MSError;
      return -1;
    }
    if (!SSL_set_fd(client->ssl, client->socket))
    {
      show_ssl_error("SSL_set_fd");
      closesocket(client->socket);
      client->socket= -1;
      SSL_free(client->ssl);
      client->ssl= NULL;
      client->state= MSError;
      return -1;
    }
    SSL_set_connect_state(client->ssl);

    /* handshake */
    if (!SSL_connect(client->ssl))
    {
      show_ssl_error("SSL_connect");
      closesocket(client->socket);
      client->socket= -1;
      SSL_free(client->ssl);
      client->ssl= NULL;
      client->state= MSSSLError;
      return -1;
    }
    else
    {
      g_message("ssl handshake ok");
    }
  }
  
  client->state= MSConnected;
  
  return 0;
}


int mnet_client_get_certificate_info(MNetLineClient *client,
                                     MNetCertificateInfo *info)
{
  X509 *cert;
  
  if (!(cert= SSL_get_peer_certificate(client->ssl)))
    return -1;

  X509_NAME_oneline(X509_get_subject_name(cert), info->subject,
                    sizeof(info->subject));
  X509_NAME_oneline(X509_get_issuer_name(cert), info->issuer,
                     sizeof(info->issuer));
/*  
  ASN1_TIME_snprintf(info->valid_not_before, sizeof(info->valid_not_before),
                      X509_get_notBefore(cert));
  ASN1_TIME_snprintf(info->valid_not_after, sizeof(info->valid_not_after),
                      X509_get_notAfter(cert));
  */
  return 0;
}


int mnet_client_disconnect(MNetLineClient *client)
{
  if (client->ssl)
  {
    SSL_shutdown(client->ssl);
    client->ssl= NULL;
  }

  closesocket(client->socket);
  client->socket= 0;

  client->state= MSDisconnected;

  return -1;
}


int mnet_free_client(MNetLineClient *client)
{
  return -1;
}


int mnet_send(MNetLineClient *client, const char *data, unsigned int len)
{
  /* check overflow */
  if (client->out_buffer_len > INT_MAX - len)
  {
    client->state= MSError;
    return -2;
  }

  if (client->out_buffer_len + len > client->out_buffer_alloced)
  {
    client->out_buffer_alloced= client->out_buffer_len+MAX(len, BUFFER_INCR_SIZE);
    client->out_buffer= g_realloc(client->out_buffer, 
                                  client->out_buffer_alloced);
    if (!client->out_buffer)
    {
      client->state= MSError;
      return -2;
    }
  }

  strncpy(client->out_buffer + client->out_buffer_len, data, len);
  client->out_buffer_len+= len;
  client->out_buffer[client->out_buffer_len]= 0;

  return 0;
}


int mnet_send_line(MNetLineClient *client, const char *line)
{
  int len= strlen(line)+2;

  if (client->out_buffer_len + len > client->out_buffer_alloced)
  {
    client->out_buffer_alloced+= MAX(len, BUFFER_INCR_SIZE);
    client->out_buffer= g_try_realloc(client->out_buffer, 
                                      client->out_buffer_alloced);
    if (!client->out_buffer)
    {
      client->state= MSError;
      return -2;
    }
  }

  strcpy(client->out_buffer + client->out_buffer_len, line);
  client->out_buffer_len+= len;
  strcpy(client->out_buffer + client->out_buffer_len - 2, "\r\n");

  return 0;
}


char *mnet_get_line(MNetLineClient *client)
{
  if (client->in_lines_used > 0)
  {
    char *line= client->in_lines[0];

    client->in_lines_used--;
    memmove(client->in_lines, client->in_lines+1, 
            sizeof(char*)*client->in_lines_used);

    return line;
  }
  return NULL;
}


int mnet_flush_data(MNetLineClient *client)
{
  int ret;
  
  if (client->out_buffer_len > 0)
  {
    for (;;)
    {
      if (client->ssl)
      {
        ret= SSL_write(client->ssl, client->out_buffer, client->out_buffer_len);
        switch (SSL_get_error(client->ssl, ret))
        {
        case SSL_ERROR_SSL:
          show_ssl_error("SSL_write");
          client->state= MSSSLError;
          return -1;
        case SSL_ERROR_SYSCALL:
          ret= -1;
          break;
        case SSL_ERROR_ZERO_RETURN:
          ret= 0;
          break;
        }
      } 
      else
      {
        ret= send(client->socket, client->out_buffer, client->out_buffer_len,
                  MSG_DONTWAIT);
      }
    
      if (ret < 0)
      {
        if (errno == EINTR || errno == EAGAIN)
          continue;

        g_error("error sending data: %s", g_strerror(errno));
        client->state= MSError;
        return -1;
      }
      client->out_buffer_len-= ret;
      break;
    }
  }
  
  return 0;
}


int mnet_wait_ready(MNetLineClient *client, int max_wait_msec)
{
  fd_set rfd, wfd;
  struct timeval timeout;
  int ret;

  for (;;)
  {
    FD_ZERO(&rfd);
    FD_SET(client->socket, &rfd);

    FD_ZERO(&wfd);
    if (client->out_buffer_len > 0)
      FD_SET(client->socket, &wfd);

    if (max_wait_msec < 0)
    {
      ret = select(client->socket+1, &rfd, &wfd, NULL, NULL);
    }
    else
    {
      timeout.tv_sec = max_wait_msec/1000;
      timeout.tv_usec = (max_wait_msec%1000)*1000;
      
      ret = select(client->socket+1, &rfd, &wfd, NULL, &timeout);
    }
    if (ret < 0)
    {
      if (errno == EINTR)
        continue;
      client->state= MSError;
      return -1;
    }
    else if (ret == 0)
    {
      return 0;
    }
    
    ret= 0;
    if (FD_ISSET(client->socket, &wfd))
      ret |= MWWriteOK;
    if (FD_ISSET(client->socket, &rfd))
      ret |= MWReadOK;
    break;
  }

  return ret;
}


static void break_lines(MNetLineClient *client)
{
  char *ptr= client->in_buffer;
  char *ptre= memchr(ptr, '\n', client->in_buffer_len);
  int count= 0;

  while (ptre && client->in_lines_used < client->in_lines_alloced)
  {
    *ptre= 0;
    if (ptre > ptr && *(ptre-1)=='\r')
      *(ptre-1)= 0;

    client->in_lines[client->in_lines_used++]= g_strdup(ptr);

    count+= ptre-ptr+1;

    ptr= ptre+1;
    ptre= memchr(ptr, '\n', client->in_buffer_len - count);
  }
  memmove(client->in_buffer, client->in_buffer+count,
          client->in_buffer_len - count);
  client->in_buffer_len-= count;
}


int mnet_read_data(MNetLineClient *client)
{
  int count;
  
  if (client->in_lines_used == client->in_lines_alloced)
  {
    /* line buffer is full! */
    return 0;
  }

  if (client->in_buffer_len == client->in_buffer_alloced)
  {
    client->in_buffer_alloced+= BUFFER_INCR_SIZE;

    /* increase buffer size */
    client->in_buffer= g_realloc(client->in_buffer, client->in_buffer_alloced);
    if (!client->in_buffer)
    {
      g_error("could not allocate memory for buffer");
      client->state= MSError;
      return -2;
    }
  }

  if (client->ssl)
  {
    count= SSL_read(client->ssl,
                    client->in_buffer+client->in_buffer_len,
                    client->in_buffer_alloced-client->in_buffer_len);

    switch (SSL_get_error(client->ssl, count))
    {
    case SSL_ERROR_SSL:
      show_ssl_error("SSL_read");
      client->state= MSSSLError;
      return -1;
    case SSL_ERROR_SYSCALL:
      count= -1;
      break;
    case SSL_ERROR_ZERO_RETURN:
      count= 0;
      break;
    }
  }
  else
  {
    count= recv(client->socket,
                client->in_buffer+client->in_buffer_len,
                client->in_buffer_alloced-client->in_buffer_len,
                MSG_DONTWAIT);
  }

  if (count < 0) 
  {
    if (errno != EAGAIN && errno != EINTR) 
    {
      g_error("error reading data: %s", g_strerror(errno));
      client->state= MSError;
      return -1;
    }
    else 
    {
      return 0;
    }
  } 
  else if (count == 0)
  {
    client->state= MSDisconnected;
    return -1;
  }

  client->in_buffer_len+= count;

  /* break input in lines */
  break_lines(client);
  
  return 0;
}


