#
# This file is part of GNU Enterprise.
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# Copyright 2001-2004 Free Software Foundation
#
# $Id: SQL.py 5614 2004-04-02 10:53:54Z johannes $
#

from gnue.common.schema.scripter.processors.Base import BaseProcessor
from string import join
from types import UnicodeType

# =============================================================================
# Baseclass for SQL processors
# =============================================================================
class SQLProcessor (BaseProcessor):
  """
  This class implements a GNUe Schema Definition processor for SQL-like
  backends. 
  """

  END_COMMAND = ";"             # Symbol used to terminate a command
  END_BATCH   = ""              # Symbol used to terminate a command-sequence

  QUOTECHAR   = "'"             # Character used to quote a string
  ESCAPECHAR  = "'"             # Character used to escape a quote character


  # ---------------------------------------------------------------------------
  # Process the fields collection of a table definition
  # ---------------------------------------------------------------------------

  def _processFields (self, tableDef):
    """
    Populate a table definitions sequences.
    """
    phase = tableDef.getPhase (0)

    if tableDef.action == 'create':
      phase.prologue.append (u"")
      phase.prologue.extend (self.comment ("Create table '%s'" % \
                                            tableDef.name))
      phase.body.append (u"CREATE TABLE %s (" % tableDef.name)
      phase.footer.append (u")%s" % self.END_COMMAND)

      BaseProcessor._processFields (self, tableDef)
    else:
      phase.prologue.append (u"")
      phase.prologue.extend (self.comment ("Alter table '%s'" % \
                                            tableDef.name))
      
      if self.ALTER_MULTIPLE:
        self._alterMultiple (tableDef)
      else:
        self._alterSingle (tableDef)


  # ---------------------------------------------------------------------------
  #
  # ---------------------------------------------------------------------------

  def _alterMultiple (self, tableDef):
    phase = tableDef.getPhase (0)
    phase.body.append (u"ALTER TABLE %s ADD (" % tableDef.name)
    phase.footer.append (u")%s" % self.END_COMMAND)

    BaseProcessor._processFields (self, tableDef)


  def _alterSingle (self, tableDef):
    phase = tableDef.getPhase (0)

    for field in tableDef.fields:
      phase.body.append (u"ALTER TABLE %s ADD " % tableDef.name)
      self._processField (tableDef, field, True)
      phase.body.append ("%s" % self.END_COMMAND)
    


  # ---------------------------------------------------------------------------
  # A single field is usually added to the definitions body
  # ---------------------------------------------------------------------------

  def _processField (self, tableDef, gsField, isLast):
    """
    Default implementation: Add the qualified field to the table definitions
    body.
    """
    field = "  %s" % self._qualify (gsField)

    if not isLast:
      field += ", "

    tableDef.getPhase (0).body.append (field)


  # ---------------------------------------------------------------------------
  # Primary key definition comes after the last field of a table
  # ---------------------------------------------------------------------------

  def _processPrimaryKey (self, tableDef):
    """
    The primary key extends the table definitions body by a constraint
    definition.
    """
    pkDef = tableDef.primaryKey
    flist = join ([pkf.name for pkf in pkDef.fields], ", ")

    phase = tableDef.getPhase (0)
    if len (phase.body):
      phase.body [-1] += u","

    phase.body.append (u"  CONSTRAINT %s PRIMARY KEY (%s)" % \
                        (pkDef.name, flist))


  # ---------------------------------------------------------------------------
  # Integrate index definitions into tableDef
  # ---------------------------------------------------------------------------

  def _processIndices (self, tableDef):
    """
    After processing all indices this function integrates these index
    definitions into the table definitions epilogue.
    """
    BaseProcessor._processIndices (self, tableDef)

    for index in tableDef.indices.values ():
      tableDef.getPhase (0).epilogue.extend (index.merge ())


  # ---------------------------------------------------------------------------
  # Process a single index definition
  # ---------------------------------------------------------------------------

  def _processIndex (self, tableDef, indexDef):
    """
    This function translates an index definition object into SQL code.
    """
    indexDef.prologue.append (u"")
    indexDef.prologue.extend (self.comment ("Create index '%s'" % \
                              indexDef.name))

    if indexDef.unique:
      uniq = u"UNIQUE "
    else:
      uniq = ""

    indexDef.header.append (u"CREATE %sINDEX %s ON %s" % \
                            (uniq, indexDef.name, tableDef.name))

    indexDef.body.append ("  (%s)%s" % \
        (join ([fld.name for fld in indexDef.fields], ", "), self.END_COMMAND))

    indexDef.epilogue.append (u"")



  # ---------------------------------------------------------------------------
  # Integrate constraints into table definition
  # ---------------------------------------------------------------------------

  def _processConstraint (self, tableDef, constraint):
    """
    This function processes a foreign key constraint; these type of constraints
    are put into the second phase.
    """
    if constraint.kind == "foreignkey":
      phase = tableDef.getPhase (1)

      phase.body.append (u"")
      phase.body.extend (self.comment ("CONSTRAINT '%s'" % constraint.name))
      phase.body.append ("ALTER TABLE %s ADD" % tableDef.name)

      phase.body.append ("  CONSTRAINT %s FOREIGN KEY (%s)" % \
        (constraint.name, join ([cf.name for cf in constraint.fields], ", ")))
      phase.body.append ("    REFERENCES %s (%s)%s" % \
        (constraint.reftable, 
         join ([rf.name for rf in constraint.reffields], ", "), 
         self.END_COMMAND))



  # ---------------------------------------------------------------------------
  # Translate a data definition
  # ---------------------------------------------------------------------------

  def _processDataRows (self, dataDef, tableDef):
    """
    This function iterates over all rows of the data definition and calls
    _processDataRow () on them.
    """
    dataDef.prologue.append (u"")
    dataDef.prologue.extend (self.comment ("Data for '%s'" % dataDef.name))

    for row in dataDef.rows:
      self._processDataRow (row, dataDef, tableDef)



  # ---------------------------------------------------------------------------
  # Process a single data row
  # ---------------------------------------------------------------------------

  def _processDataRow (self, row, dataDef, tableDef):
    """
    This function creates an INSERT statement for the given row definition. If
    a table definition is available and the row definition has a column list,
    all apropriate dts_* () functions are called.
    """

    values = []

    for item in row.values:
      if item.value is None:
        values.append (u"NULL")
      else:
        res = (self._dts_type (item))

        if isinstance (res, UnicodeType):
          values.append (res)
        else:
          values.append (str (res))

    # if a column list is available we might use some data transformation
    # services (if a table definition is available too)
    if len (row.columns):
      cols = " (%s)" % join (row.columns, ", ")
    else:
      cols = ""

    # and create an insert statement
    dataDef.body.append (u"INSERT INTO %s%s VALUES (%s)%s" % \
       (dataDef.name, cols, join (values, ", "), self.END_COMMAND))



  # ===========================================================================
  # Datatype translation 
  # ===========================================================================

  # ---------------------------------------------------------------------------
  # String usually becomes a 'varchar'
  # ---------------------------------------------------------------------------

  def string (self, gsField):
    """
    Returns a 'varchar' or 'varchar (length)' if gsField has a length property.
    """
    if hasattr (gsField, "length"):
      res = "varchar (%s)" % gsField.length
    else:
      res = "text"

    return res


  # ---------------------------------------------------------------------------
  # Keep date as 'date'
  # ---------------------------------------------------------------------------

  def date (self, gsField):
    """
    Keep date as 'date'
    """
    return "date"


  # ---------------------------------------------------------------------------
  # Keep time as 'time'
  # ---------------------------------------------------------------------------

  def time (self, gsField):
    """
    Keep time as 'time'
    """
    return "time"


  # ---------------------------------------------------------------------------
  # Keep datetime as 'datetime'
  # ---------------------------------------------------------------------------

  def datetime (self, gsField):
    """
    Keep datetime as 'datetime'
    """
    return "datetime"


  # ===========================================================================
  # Data transformation services
  # ===========================================================================

  # ---------------------------------------------------------------------------
  # escape the quote character
  # ---------------------------------------------------------------------------

  def escapeString (self, aString):
    """
    Usually we escape the quote-character by a preceeding backslash.
    """
    res = aString
    if self.QUOTECHAR != self.ESCAPECHAR:
      res = res.replace (self.ESCAPECHAR, self.ESCAPECHAR * 2)
    res = res.replace (self.QUOTECHAR, self.ESCAPECHAR + self.QUOTECHAR)
    return res


  # ---------------------------------------------------------------------------
  # quote a string
  # ---------------------------------------------------------------------------

  def dts_string (self, gsValue):
    """
    Return a quoted string
    """
    return self.quoteString (gsValue.value)


  # ---------------------------------------------------------------------------
  # take only the date-part of an mx.DateTime.Date instance
  # ---------------------------------------------------------------------------

  def dts_date (self, gsValue):
    """
    Return the quoted date-part of the date object
    """
    return self.quoteString (str (gsValue.value).split (" ") [0])


  # ---------------------------------------------------------------------------
  # get a quoted time
  # ---------------------------------------------------------------------------

  def dts_time (self, gsValue):
    """
    Quote the time object
    """
    return self.quoteString (str (gsValue.value))


  # ---------------------------------------------------------------------------
  # datetime object only needs quotes
  # ---------------------------------------------------------------------------

  def dts_datetime (self, gsValue):
    """
    Quote the datetime object
    """
    return self.quoteString (str (gsValue.value))


  # ---------------------------------------------------------------------------
  # booleans are represented by 0 (FALSE) and 1 (TRUE)
  # ---------------------------------------------------------------------------

  def dts_boolean (self, gsValue):
    """
    Use 1 for TRUE and 0 for FALSE
    """
    if gsValue.value:
      return 1
    else:
      return 0
