# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
"""This module extends ast "scoped" node, i.e. which are opening a new
local scope in the language definition : Module, Class, Function (and
Lambda in some extends).

Each new methods and attributes added on each class are documented
below.


:version:   $Revision: 1.6 $  
:author:    Sylvain Thenault
:copyright: 2003-2005 LOGILAB S.A. (Paris, FRANCE)
:contact:   http://www.logilab.fr/ -- mailto:python-projects@logilab.org
:copyright: 2003-2005 Sylvain Thenault
:contact:   mailto:thenault@gmail.com
"""

from __future__ import generators

__revision__ = "$Id: scoped_nodes.py,v 1.6 2005/11/02 11:56:45 syt Exp $"
__doctype__ = "restructuredtext en"

import sys
from compiler.ast import Module, Class, Function, Lambda, Dict, Tuple, Raise, \
     Pass, Raise, Return

from logilab.common.compat import chain        

from logilab.astng._exceptions import ResolveError, NotFoundError, NoDefault
from logilab.astng.utils import extend_class

# module class dict/iterator interface ########################################

class LocalsDictMixIn:
    """ this class provides locals handling common to Module, Function
    and Class nodes, including a dict like interface for direct access
    to locals information
    
    /!\ this class should not be used directly /!\ it's
    only used as a methods and attribute container, and update the
    original class from the compiler.ast module using its dictionnary
    (see below the class definition)
    """
    
    # attributes below are set by the builder module or by raw factories
    
    # dictionary of locals with name as key and node defining the local as
    # value    
    locals = None
    
    def frame(self):
        """return the first node defining a new local scope (i.e. Module,
        Function or Class)
        """
        return self
    
    def set_local(self, name, stmt):
        """define <name> in locals (<stmt> is the node defining the name)
        if the node is a Module node (i.e. has globals), add the name to globals

        if the name is already defined, ignore it
        """
        self.locals.setdefault(name, []).append(stmt)
        
    __setitem__ = set_local
    
    def add_local_node(self, child_node):
        """append a child which should alter locals to the given node"""
        self._append_node(child_node)
        self.set_local(child_node.name, child_node)

    def _append_node(self, child_node):
        """append a child, linking it in the tree"""
        self.code.nodes.append(child_node)
        child_node.parent = self
    
    def __getitem__(self, item):
        """method from the `dict` interface returning the first node
        associated with the given name in the locals dictionnary

        :type item: str
        :param item: the name of the locally defined object
        :raises KeyError: if the name is not defined
        """
        return self.locals[item][0]
    
    def __iter__(self):
        """method from the `dict` interface returning an iterator on `self.keys()`
        """
        return iter(self.keys())
    
    def keys(self):
        """method from the `dict` interface returning a tuple containing locally
        defined names
        """
        return self.locals.keys()
##         associated to nodes which are instance of `Function` or
##         `Class`
##         """
##         # FIXME: sort keys according to line number ?
##         try:
##             return self.__keys
##         except AttributeError:
##             keys = [member.name for member in self.locals.values()
##                     if (isinstance(member, Function) or isinstance(member, Class))
##                         and member.parent.frame() is self]
##             self.__keys = tuple(keys)
##             return keys

    def values(self):
        """method from the `dict` interface returning a tuple containing locally
        defined nodes which are instance of `Function` or `Class`
        """
        return [self[key] for key in self.keys()]
    
    def items(self):
        """method from the `dict` interface returning a list of tuple containing
        each locally defined name with its associated node, which is an instance of
        `Function` or `Class`
        """
        return zip(self.keys(), self.values())

    def has_key(self, name):
        """method from the `dict` interface returning True if the given name
        is defined in the locals dictionary
        """
        return self.locals.has_key(name)
    
    __contains__ = has_key
    
extend_class(Module, LocalsDictMixIn)
extend_class(Class, LocalsDictMixIn)
extend_class(Function, LocalsDictMixIn)
extend_class(Lambda, LocalsDictMixIn)


# Module  #####################################################################

class ModuleNG:
    """/!\ this class should not be used directly /!\ it's
    only used as a methods and attribute container, and update the
    original class from the compiler.ast module using its dictionnary
    (see below the class definition)
    """
        
    # attributes below are set by the builder module or by raw factories

    # the file from which as been extracted the astng representation. It may
    # be None if the representation has been built from a built-in module
    file = None
    # the module name
    name = None
    # boolean for astng built from source (i.e. ast)
    pure_python = None
    # boolean for package module
    package = None
    # dictionary of globals with name as key and node defining the global
    # as value
    globals = None
  
    def _append_node(self, child_node):
        """append a child version specific to Module node"""
        self.node.nodes.append(child_node)
        child_node.parent = self
        
    def source_line(self):
        """return the source line number, 0 on a module"""
        return 0

    def statement(self):
        """return the first parent node marked as statement node
        consider a module as a statement...
        """
        return self

    def wildcard_import_names(self):
        """return the list of imported names when this module is 'wildard
        imported'

        It doesn't include the '__builtins__' name which is added by the
        current CPython implementation of wildcard imports.
        """
        # take advantage of a living module if it exists
        try:
            living = sys.modules[self.name]
        except KeyError:
            pass
        else:
            try:
                return living.__all__
            except AttributeError:
                return [name for name in living.__dict__.keys()
                        if not name.startswith('_')]
        try:
            explicit = self['__all__'].rhs()
            # should be a tuple of constant string
            return [const.value for const in explicit.nodes]
        except KeyError:
            return [name for name in self.keys()
                    if not name.startswith('_')]

extend_class(Module, ModuleNG)

# Function  ###################################################################

class FunctionNG:
    """/!\ this class should not be used directly /!\ it's
    only used as a methods and attribute container, and update the
    original class from the compiler.ast module using its dictionnary
    (see below the class definition)
    """

    # attributes below are set by the builder module or by raw factories

    # function's type, 'function' | 'method' | 'staticmethod' | 'classmethod'
    type = 'function'
    # list of argument names. MAY BE NONE on some builtin functions where
    # arguments are unknown
    argnames = None

    def is_method(self):
        """return true if the function node should be considered as a method"""
        return self.type != 'function'

    def is_abstract(self, pass_is_abstract=True):
        """return true if the method is abstract
        It's considered as abstract if the only statement is a raise of
        NotImplementError, or, if pass_is_abstract, a pass statement
        """
        for child_node in self.code.getChildNodes():
            if isinstance(child_node, Raise) and child_node.expr1:
                names = child_node.expr1.names()
                if names and names[0] == 'NotImplementedError':
                    return True
            if pass_is_abstract and isinstance(child_node, Pass):
                return True
            return False
        # empty function is the same as function with a single "pass" statement
        if pass_is_abstract:
            return True
        
    def self_resolve(self, name, frame):
        """overriden from NodeNG class:
        self resolve on Function check we did not catch a function's argument
        """
        if self is frame:
            # the searched name has been found in the function locals,
            # and resolved to the function itself, so it must be one
            # of the function's arguments
            try:
                value = self.default_value(name)
            except NoDefault:
                raise ResolveError(name)
            return value.self_resolve(name, self.parent)
        return self

    def format_args(self):
        """return arguments formatted as string"""
        if self.argnames is None: # information is missing
            return ''
        result = []
        args, kwargs, last, default_idx = self._pos_information()
        for i in range(len(self.argnames)):
            name = self.argnames[i]
            if type(name) is type(()):
                name = '(%s)' % ','.join(name)
            if i == last and kwargs:
                name = '**%s' % name
            elif args and i == last or (kwargs and i == last - 1):
                name = '*%s' % name
            elif i >= default_idx:
                name = '%s=%s' % (name, self.defaults[i - default_idx].as_string())
            result.append(name)
        return ', '.join(result)

    def default_value(self, argname):
        """return the default value for an argument

        :raise `NoDefault`: if there is no default value defined
        """
        if self.argnames is None: # information is missing
            raise NoDefault()
        args, kwargs, last, default_idx = self._pos_information()
        for i in range(len(self.argnames)):
            name = self.argnames[i]
            if name != argname:
                continue
            if i == last and kwargs:
                val_node = Dict([])
                val_node.parent = self
                return val_node
            if name == argname and ((args and i == last)
                                    or (kwargs and i == last - 1)):
                val_node = Tuple([])
                val_node.parent = self
                return val_node
            if i >= default_idx:
                return self.defaults[i - default_idx]
            break
        raise NoDefault()

    def _pos_information(self):
        """return a 4-uple with positional information about arguments:
        (true if * is used,
         true if ** is used,
         index of the last argument,
         index of the first argument having a default value)
        """
        args = self.flags & 4
        kwargs = self.flags & 8
        last = len(self.argnames) - 1
        default_idx = len(self.argnames) - (len(self.defaults) +
                                            (args and 1 or 0) +
                                            (kwargs and 1 or 0))
        return args, kwargs, last, default_idx

    # FIXME: are methods below really useful ?
    
##     def raises(self):
##         """return an iterator on exceptions raised below the given node
##         (ie expr1 attribute of the Raise node)
##         """
##         for child in self.nodes_of_class(Raise):
##             if child.expr1:
##                 yield child.expr1
##     def returns(self):
##         """return an iterator on nodes used in return statements below the given
##         node"""
##         for child in self.nodes_of_class(Return):
##             yield child.value

extend_class(Function, FunctionNG)

# lambda nodes may also need some of the function members
Lambda._pos_information = FunctionNG._pos_information.im_func
Lambda.format_args = FunctionNG.format_args.im_func
Lambda.default_value = FunctionNG.default_value.im_func
Lambda.type = 'function'

# Class ######################################################################

def _class_type(klass):
    """return a Class node type to differ metaclass, interface and exception
    from 'regular' classes
    """
    if klass._type is not None:
        return klass._type
    if klass.name == 'type':
        klass._type = 'metaclass'
    elif klass.name.endswith('Interface'):
        klass._type = 'interface'
    elif klass.name.endswith('Exception'):
        klass._type = 'exception'
    else:
        for base in klass.ancestors(recurs=False):
            if base.type != 'class':
                klass._type = base.type
                break
    if klass._type is None:
        klass._type = 'class'
    return klass._type

def _iface_hdlr(iface_node):
    """a handler function used by interfaces to handle suspicious
    interface nodes
    """
    try:
        yield iface_node.resolve_dotted(iface_node.as_string())
    except:
        return

class ClassNG:
    """/!\ this class should not be used directly /!\ it's
    only used as a methods and attribute container, and update the
    original class from the compiler.ast module using its dictionnary
    (see below the class definition)
    """
    
    _type = None
    type = property(_class_type,
                    doc="class'type, possible values are 'class' | "
                    "'metaclass' | 'interface' | 'exception'")
    
    # attributes below are set by the builder module or by raw factories
    
    # a dictionary of class instances attributes
    instance_attrs = None
    # list of parent class as a list of string (ie names as they appears
    # in the class definition)
    basenames = None

    def ancestors(self, recurs=True):
        """return an iterator on the node base classes in a prefixed
        depth first order
        
        :param recurs:
          boolean indicating if it should recurse or return direct
          ancestors only
        """
        # FIXME: should be possible to choose the resolution order
        for baseobj in self.parent.frame().resolve_all(self.basenames):
            if not isinstance(baseobj, Class):
                # duh ?
                continue
            yield baseobj
            if recurs:
                for grandpa in baseobj.ancestors(True):
                    yield grandpa

    def local_attr_ancestors(self, name):
        """return an iterator on astng representation of parent classes
        which have <name> defined in their locals
        """
        for astng in self.ancestors():
            if astng.locals.has_key(name):
                yield astng

    def instance_attr_ancestors(self, name):
        """return an iterator on astng representation of parent classes
        which have <name> defined in their instance attribute dictionary
        """
        for astng in self.ancestors():
            if astng.instance_attrs.has_key(name):
                yield astng

    def local_attr(self, name):
        """return the astng associated to name in this class locals or
        in its parents

        :raises `NotFoundError`:
          if no attribute with this name has been find in this class or
          its parent classes
        """
        try:
            return self[name]
        except KeyError:
            # get if from the first parent implementing it if any
            for class_node in self.local_attr_ancestors(name):
                return class_node[name]
        raise NotFoundError(name)
        
    def instance_attr(self, name):
        """return the astng associated to name in this class instance
        attributes dictionary or in its parents

        :raises `NotFoundError`:
          if no attribute with this name has been find in this class or
          its parent classes
        """
        try:
            return self.instance_attrs[name]
        except KeyError:
            # get if from the first parent implementing it if any
            for class_node in self.instance_attr_ancestors(name):
                return class_node.instance_attrs[name]
        raise NotFoundError(name)

    def attr(self, name):
        try:
            return self.instance_attr(name)
        except NotFoundError:
            return self.local_attr(name)
            
    def methods(self):
        """return an iterator on all methods defined in the class and
        its ancestors
        """
        done = {}
        for astng in chain(iter((self,)), self.ancestors()):
            for meth in astng.mymethods():
                if done.has_key(meth.name):
                    continue
                done[meth.name] = None
                yield meth
                
    def mymethods(self):
        """return an iterator on all methods defined in the class"""
        for member in self.values():
            if isinstance(member, Function):
                yield member

    def interfaces(self, herited=True, handler_func=_iface_hdlr):
        """return an iterator on interfaces implemented by the given
        class node
        """
        # FIXME: what if __implements__ = (MyIFace, MyParent.__implements__)...
        try:
            implements = self.local_attr('__implements__')
        except NotFoundError:
            return
        if not herited and not implements.frame() is self:
            return
        implements = implements.rhs()
        if hasattr(implements, 'nodes'):
            implements = implements.nodes
        else:
            implements = (implements,)
    ##     if not (isinstance(implements, tuple) or isinstance(implements, list)):
    ##         implements = (implements,)
        for iface in implements:
    ##         if isinstance(iface, Class):
    ##             yield iface
    ##             continue
            # let the handler function take care of this....
            for iface in handler_func(iface):
                yield iface


extend_class(Class, ClassNG)
