(* PostgreSQL database interface for mod_caml programs.
 * Copyright (C) 2003-2004 Merjis Ltd.
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the Free
 * Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 * $Id: dbi_postgres.ml,v 1.14 2004/08/04 09:43:32 rwmj Exp $
 *)

module Connection = Postgres.Connection
module Result = Postgres.Result

open Printf

(* PCRE regular expressions for parsing timestamps and intervals. *)
let re_timestamp =
  Pcre.regexp ~flags:[`EXTENDED]
    ("(?:(\\d\\d\\d\\d)-(\\d\\d)-(\\d\\d))     # date (YYYY-MM-DD)\n"^
     "\\s*                                     # space between date & time\n"^
     "(?:(\\d\\d):(\\d\\d)                     # HH:MM\n"^
     "   (?::(\\d\\d))?                        # optional :SS\n"^
     "   (?:\\.(\\d+))?                        # optional .microseconds\n"^
     "   (?:([+-])(\\d\\d))?                   # optional +/- offset UTC\n"^
     ")?")
let re_interval =
  Pcre.regexp ~flags:[`EXTENDED]
    ("(?:(\\d+)\\syears?)?                     # years\n"^
     "\\s*                                     # \n"^
     "(?:(\\d+)\\smons?)?                      # months\n"^
     "\\s*                                     # \n"^
     "(?:(\\d+)\\sdays?)?                      # days\n"^
     "\\s*                                     # \n"^
     "(?:(\\d\\d):(\\d\\d)                     # HH:MM\n"^
     "   (?::(\\d\\d))?                        # optional :SS\n"^
     ")?")

let string_of_timestamp (date, time) =
  match time.Dbi.timezone with
  | None ->
      sprintf "'%04d-%02d-%02d %02d:%02d:%02d.%d'"
        date.Dbi.year date.Dbi.month date.Dbi.day
        time.Dbi.hour time.Dbi.min time.Dbi.sec time.Dbi.microsec
  | Some t ->
      sprintf "'%04d-%02d-%02d %02d:%02d:%02d.%d%+03d'"
        date.Dbi.year date.Dbi.month date.Dbi.day
        time.Dbi.hour time.Dbi.min time.Dbi.sec time.Dbi.microsec t

let timestamp_of_string =
  let int_opt s = if s = "" then 0 else int_of_string s in
  fun str ->
    try
      let sub = Pcre.extract ~rex:re_timestamp str in
      ({ Dbi.year = int_of_string sub.(1);
         Dbi.month = int_of_string sub.(2);
         Dbi.day = int_of_string sub.(3)   },
       { Dbi.hour = int_of_string sub.(4);
         Dbi.min = int_of_string sub.(5);
         Dbi.sec = int_opt sub.(6);
         Dbi.microsec = int_opt sub.(7);
         Dbi.timezone =
           if sub.(9) = "" then None
           else Some (let tz = int_of_string sub.(9) in
                      if sub.(8) = "-" then -tz else tz);
       })
    with
      Not_found -> failwith ("timestamp_of_string: bad timestamp: " ^ str)

let string_of_interval (date, time) =
  sprintf "'%d years %d mons %d days %02d:%02d:%02d.%d'"
    date.Dbi.year date.Dbi.month date.Dbi.day
    time.Dbi.hour time.Dbi.min time.Dbi.sec time.Dbi.microsec

let interval_of_string =
  let int_opt s = if s = "" then 0 else int_of_string s in
  fun str ->
    try
      let sub = Pcre.extract ~rex:re_interval str in
      ({ Dbi.year = int_opt sub.(1);
         Dbi.month = int_opt sub.(2);
         Dbi.day = int_opt sub.(3);   },
       { Dbi.hour = int_opt sub.(4);
         Dbi.min = int_opt sub.(5);
         Dbi.sec = int_opt sub.(6);
         Dbi.microsec = 0;
	 Dbi.timezone = None;
       })
    with
      Not_found -> failwith ("interval_of_string: bad interval: " ^ str)

let date_of_string s =
  Scanf.sscanf s "%d-%d-%d"
    (fun yyyy mm dd -> { Dbi.year = yyyy; Dbi.month = mm; Dbi.day = dd })

let time_of_string s =
  Scanf.sscanf s "%d:%d:%d"
    (fun h m s -> { Dbi.hour = h; Dbi.min = m; Dbi.sec = s;
                    Dbi.microsec = 0;  Dbi.timezone = None })

(* Every byte is replaced by \\ooo where ooo is the octal representation.
 * Note that we need to pass two backslashes.
 *)
let encode_bytea s =
  let n = String.length s in
  let s' = String.create (n * 5) in
  for i = 0 to n-1 do
    let i5 = i * 5 in
    let c = Char.code (s.[i]) in
    let oct = Printf.sprintf "\\\\%03o" c in
    String.blit oct 0 s' i5 5
  done;
  s'

(* The database gives us back a string which contains \ooo sequences
 * which must be converted back to bytes.  It's unclear why the database
 * can't just give us binary data back - perhaps there's some restriction
 * in the underlying transmission protocol?
 *)
let decode_bytea s =
  let n = String.length s in
  let s' = String.create n in		(* String will be at most n chars. *)
  let len = ref 0 in			(* The real length of the result. *)
  let put c =
    s'.[!len] <- c;
    incr len
  in
  let octalchar c1 c2 c3 =
    let zero = Char.code '0' in
    let v = (Char.code c1 - zero) * 8 * 8 + (Char.code c2 - zero) * 8 +
	    (Char.code c3 - zero) in
    Char.chr v
  in
  let rec loop i =
    if i < n then (
      let i =
	match s.[i] with
	    '\\' ->
	      if s.[i+1] = '\\' then (
		put '\\';
		i+2
	      ) else (
		put (octalchar s.[i+1] s.[i+2] s.[i+3]);
		i+4
	      )
	  | c -> put c; i+1 in
      loop i
    )
  in
  loop 0;
  if !len = n then
    s'
  else
    String.sub s' 0 !len

(* [encode_sql_t v] returns a string suitable for substitution of "?"
   in a SQL query. *)
let encode_sql_t = function
  | `Null -> "NULL"
  | `Int i -> string_of_int i
  | `Float f -> string_of_float f
  | `String s -> Dbi.string_escaped s
  | `Bool b -> if b then "'t'" else "'f'"
      (*  | `Bigint i -> string_of_big_int i *)
  | `Decimal d -> Dbi.Decimal.to_string d
  | `Date d -> sprintf "'%04i-%02i-%02i'" d.Dbi.year d.Dbi.month d.Dbi.day
  | `Time t -> sprintf "'%02i:%02i:%02i'" t.Dbi.hour t.Dbi.min t.Dbi.sec
  | `Timestamp t -> string_of_timestamp t
  | `Interval i -> string_of_interval i
  | `Blob s -> Dbi.string_escaped s (* FIXME *)
  | `Binary s -> "'" ^ encode_bytea s ^ "'::bytea"
  | `Unknown s -> Dbi.string_escaped s

let decode_sql_t is_null ty v =
  if is_null then `Null
  else begin
    (* Hard-coded OIDs from pg_type.h tell the type *)
    match ty with
    | 16 -> `Bool(v = "t")
    | 17   (* bytea *) -> `Binary (decode_bytea v)
    | 18   (* char *) -> `String v
(*    | 20 (* int8 *) -> `Bigint(Big_int.big_int_of_string v) *)
    | 20   (* FIXXXME *) -> `Int (int_of_string v)
    | 21   (* int2 *)
    | 23   (* int4 *) -> `Int (int_of_string v)
    | 25   (* text *) -> `String v
    | 700  (* float4 *)
    | 701  (* float8 *) -> `Float (float_of_string v)
    | 1043 (* varchar *) -> `String v
    | 1082 (* date *) -> `Date (date_of_string v)
    | 1083 (* time *)
    | 1266 (* timetz *) -> `Time (time_of_string v)
    | 1114 (* timestamp *)
    | 1184 (* timestamptz *) -> `Timestamp (timestamp_of_string v)
    | 1186 (* interval *) -> `Interval (interval_of_string v)
    | 1700 (* numeric *) -> `Decimal(Dbi.Decimal.of_string v)
    | _ -> `Unknown v
  end


class statement dbh conn in_transaction original_query =
  let query = Dbi.split_query original_query in
object (self)
  inherit Dbi.statement dbh

  val mutable tuples = None
  val mutable name_list = None
  val mutable next_tuple = 0
  val mutable ntuples = 0
  val mutable nfields = 0

  method execute args =
    if dbh#debug then (
      eprintf "Dbi_postgres: dbh %d: execute %s\n" dbh#id original_query;
      flush stderr
    );
    if dbh#closed then
      failwith "Dbi_postgres: executed called on a closed database handle.";
    (* Finish previous statement, if any. *)
    self#finish ();
    (* In transaction? If not we need to issue a BEGIN WORK command. *)
    if not !in_transaction then (
      in_transaction := true; (* => don't go into an infinite recursion... *)
      let sth = dbh#prepare_cached "BEGIN WORK" in
      sth#execute []
    );

    let query = (* substitute args *)
      Dbi.make_query "Dbi_postgres: execute called with wrong number of args."
        encode_sql_t query args in
    (* Send the query to the database. *)
    let res = Connection.exec conn query in

    match Result.status res with
    | Result.Empty_query -> ()
    | Result.Command_ok ->  ()
    | Result.Tuples_ok ->
	tuples <- Some res;
        name_list <- None;
	next_tuple <- 0;
	ntuples <- Result.ntuples res;
	nfields <- Result.nfields res
    | Result.Copy_out
    | Result.Copy_in ->
	failwith "XXX copyin/copyout not implemented"
    | Result.Bad_response
    | Result.Fatal_error ->
	(* dbh#close (); -- used to do this, not a good idea *)
	raise (Dbi.SQL_error (Result.error res))
    | Result.Nonfatal_error ->
	prerr_endline ("Dbi_postgres: non-fatal error: " ^ Result.error res)


  method fetch1 () =
    if dbh#debug then (
      eprintf "Dbi_postgres: dbh %d: fetch1\n" dbh#id;
      flush stderr
    );
    match tuples with
    | None -> failwith "Dbi_postgres.statement#fetch1"
    | Some tuples ->
	if next_tuple >= ntuples then raise Not_found;
	(* Fetch each field in the tuple. *)
	let rec loop acc i =
	  if i < 0 then acc else
            loop ((* FIXME: what about binary tuples?? *)
	      decode_sql_t (Result.getisnull tuples next_tuple i)
		    (Result.ftype tuples i)
		    (Result.getvalue tuples next_tuple i) :: acc) (i - 1) in
	let row = loop [] (nfields - 1) in
	next_tuple <- next_tuple + 1;
	row

  method names =
    match tuples with
    | None -> failwith "Dbi_postgres.statement#names"
    | Some tuples ->
        begin match name_list with
        | Some l -> l
        | None ->
            let rec loop acc i =
              if i < 0 then acc
              else loop (Result.fname tuples i :: acc) (i - 1) in
            let l = loop [] (nfields - 1) in
            name_list <- Some l;
            l
        end

  method serial seq =
    if dbh#debug then (
      eprintf "Dbi_postgres: dbh %d: serial \"%s\"\n" dbh#id seq;
      flush stderr
    );
    let sth = dbh#prepare_cached "SELECT currval (?)" in
    sth#execute [`String seq];
    match sth#fetch1() with
    | [`Int serial] -> serial
    | _ -> raise Not_found

  method finish () =
    if dbh#debug then (
      eprintf "Dbi_postgres: dbh %d: finish %s\n" dbh#id original_query;
      flush stderr
    );
    (match tuples with
     | None -> ()
     | Some tuples ->
	 (* XXX PQclear is not exposed through Postgres library! *)
	 ());
    tuples <- None
end

and connection ?host ?port ?user ?password database =

  (* XXX Not sure if this allows you to pass arbitrary conninfo stuff in the
   * database field. It should do. Otherwise we should use an assoc list
   * to pass arbitrary parameters to the underlying database.
   *)
  let conninfo =
    Postgres.conninfo ?host ?port ?user ?password ~dbname:database () in
  let conn = Connection.connect conninfo in

  (* We pass this reference around to the statement class so that all
   * statements belonging to this connection can keep track of our
   * transaction state and issue the appropriate BEGIN WORK command at
   * the right time.
   *)
  let in_transaction = ref false in

object (self)
  inherit Dbi.connection ?host ?port ?user ?password database as super

  method host = Some (Connection.host conn)
  method port = Some (Connection.port conn)
  method user = Some (Connection.user conn)
  method password = Some (Connection.pass conn)
  method database = Connection.db conn

  method database_type = "postgres"

  method prepare query =
    if self#debug then (
      eprintf "Dbi_postgres: dbh %d: prepare %s\n" self#id query;
      flush stderr
    );
    if self#closed then
      failwith "Dbi_postgres: prepare called on closed database handle.";
    new statement
      (self : #Dbi.connection :> Dbi.connection) conn in_transaction query

  method commit () =
    super#commit ();
    let sth = self#prepare_cached "commit work" in
    sth#execute [];
    in_transaction := false

  method rollback () =
    let sth = self#prepare_cached "rollback work" in
    sth#execute [];
    in_transaction := false;
    super#rollback ()

  method close () =
    Connection.finish conn;
    super#close ()

  initializer
    if Connection.status conn = Connection.Bad then
      raise (Dbi.SQL_error (Connection.error_message conn))
end

let connect ?host ?port ?user ?password database =
  new connection ?host ?port ?user ?password database
let close (dbh : connection) = dbh#close ()
let closed (dbh : connection) = dbh#closed
let commit (dbh : connection) = dbh#commit ()
let ping (dbh : connection) = dbh#ping ()
let rollback (dbh : connection) = dbh#rollback ()
