# Copyright 2005 Lars Wirzenius <liw@iki.fi>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


"""Read, manipulate, and write .newsrc style files.

A .newsrc file is the traditional Unix format for keeping track of
subscribed newsgroups and which articles in the groups have been
read. Each group is represented by one line in the file, of the
following format:

    group.name: 1-2, 5
    
Where "group.name" is the name of the group, ":" indicates the 
group is subscribed ("!" marks an unsubscribed group), and the
rest of the line lists numbers of articles in that group that
have been read.

Lars Wirzenius <liw@iki.fi>
"""


import types
import unittest
import random
import StringIO


class NewsrcSet:

    """A set of article numbers."""

    # self.set is a list of integers and tuples of integers. A single
    # integer means that integer. A tuple (a,b) means all integers from
    # a to b, inclusive. The items in the list are sorted in ascending
    # order. This is a significant space optimization for .newsrc files.

    def __init__(self):
        self.set = []
        
    def _unpack(self, item):
        """Return pair of integers from item in self.set"""
        if type(item) == types.IntType:
            return item, item
        else:
            return item

    def _pack(self, lowest, highest):
        """Create an item for self.set from a pair of integers"""
        if lowest == highest:
            return lowest
        elif lowest < highest:
            return (lowest, highest)
        else:
            return None

    def _join_after_add(self, i):
        """Combine adjacent items in self.set starting with [i] until done"""
        while i is not None:
            i = self._join_adjacent_items(i)

    def _join_adjacent_items(self, i):
        """Combine [i] with [i-1] or [i+1], return index for next iteration"""
        (a1, b1) = self._unpack(self.set[i])
        if i > 0:
            (a0, b0) = self._unpack(self.set[i-1])
            if b0 + 1 == a1:
                del self.set[i]
                self.set[i-1] = self._pack(a0, b1)
                return i-1
        if i+1 < len(self.set):
            (a2, b2) = self._unpack(self.set[i+1])
            if b1 + 1 == a2:
                del self.set[i+1]
                self.set[i] = self._pack(a1, b2)
                return i
        return None

    def add(self, number):
        """Add an integer to the set"""
        for i in range(len(self.set)):
            (a, b) = self._unpack(self.set[i])
            if number >= a and number <= b:
                # number is already included in the set.
                return
            elif number < a:
                # number should be inserted before i
                self.set.insert(i, number)
                self._join_after_add(i)
                return
        
        # Reached the end of the list. Append the number.
        self.set.append(number)
        self._join_after_add(len(self.set) - 1)

    def add_range(self, lowest, highest):
        """Add all integers lowest...highest, inclusive, to the set"""
        for i in range(len(self.set)):
            (a, b) = self._unpack(self.set[i])
            if highest < a:
                self.set.insert(i, self._pack(lowest, highest))
                self._join_after_add(i)
                return
            elif highest <= b or (a <= lowest <= b):
                self.set[i] = self._pack(min(a, lowest), b)
                self._join_after_add(i)
                return

        self.set.append(self._pack(lowest, highest))
        self._join_after_add(len(self.set) - 1)
        
    def _insert_item(self, index, item):
        if item is not None:
            self.set.insert(index, item)

    def remove(self, number):
        """Remove an integer from the set"""
        for i in range(len(self.set)):
            (a, b) = self._unpack(self.set[i])
            if number < a:
                # Not in set.
                return
            elif number == a == b:
                # Exactly this item.
                del self.set[i]
                return
            elif a <= number <= b:
                del self.set[i]
                self._insert_item(i, self._pack(number+1, b))
                self._insert_item(i, self._pack(a, number-1))
                return
        
    def remove_range(self, lowest, highest):
        """Remove all integers in lowest...highest, inclusive, from the set"""
        deleted = []

        for i in range(len(self.set)):
            (a, b) = self._unpack(self.set[i])
            if highest < a:
                break
            elif b < lowest:
                pass
            elif a <= lowest <= highest <= b:
                del self.set[i]
                self._insert_item(i, self._pack(highest + 1, b))
                self._insert_item(i, self._pack(a, lowest - 1))
                break
            elif lowest <= a <= b <= highest:
                deleted.insert(0, i)
            elif lowest <= a <= highest <= b:
                del self.set[i]
                self._insert_item(i, self._pack(highest + 1, b))
                break
            elif a <= lowest <= b <= highest:
                del self.set[i]
                self._insert_item(i, self._pack(a, lowest-1))
                break

        for i in deleted:
            del self.set[i]

    def __contains__(self, number):
        """Does the set contain a particular integer?"""
        for item in self.set:
            (a, b) = self._unpack(item)
            if a <= number <= b:
                return True
        return False

    def _format_item(self, item):
        (a, b) = self._unpack(item)
        if a == b:
            return "%d" % a
        else:
            return "%d-%d" % (a, b)

    def format(self):
        """Format a set to textual output"""
        return ", ".join([self._format_item(item) for item in self.set])

    def parse_and_add(self, line):
        """Parse textual output and add the integers to the set"""
        for item in [item.strip() for item in line.split(",")]:
            if "-" in item:
                (a, b) = item.split("-")
                self.add_range(int(a), int(b))
            else:
                self.add(int(item))


class NewsrcSetTestCases(unittest.TestCase):

    MAX = 100
    
    def shuffled_numbers(self):
        list = range(self.MAX)
        random.shuffle(list)
        return list

    def is_compressed(self, set):
        items = set.set
        if len(items) < 2:
            return True
        (prev_a, prev_b) = set._unpack(items[0])
        for item in items[1:]:
            (a, b) = set._unpack(item)
            if prev_b+1 == a:
                print item
                print items
                return False
            prev_b = b
        return True

    def testCreate(self):
        set = NewsrcSet()
        self.failUnlessEqual(set.set, [])
        for i in range(self.MAX):
            self.failIf(i in set)
        self.failUnless(self.is_compressed(set))

    def testContains(self):
        set = NewsrcSet()
        for i in range(self.MAX):
            self.failIf(i in set)
        self.failUnless(self.is_compressed(set))

    def testAdd(self):
        set = NewsrcSet()
        self.failUnless(self.is_compressed(set))
        for i in self.shuffled_numbers():
            self.failIf(i in set)
            set.add(i)
            self.failUnless(i in set)
            self.failUnless(self.is_compressed(set))

    def testAddRange(self):
        set = NewsrcSet()
        for i in range(self.MAX):
            self.failIf(i in set)
        self.failUnless(self.is_compressed(set))
        set.add_range(0, self.MAX-1)
        for i in range(self.MAX):
            self.failUnless(i in set)
        self.failIf(self.MAX in set)
        self.failUnless(self.is_compressed(set))

    def testRemove(self):        
        set = NewsrcSet()
        set.add_range(0, self.MAX-1)
        numbers = range(self.MAX)
        random.shuffle(numbers)
        for i in numbers:
            self.failUnless(i in set)
            set.remove(i)
            self.failIf(i in set)
            self.failUnless(self.is_compressed(set))

    def testAddRange(self):
        set = NewsrcSet()
        set.add_range(0, self.MAX-1)
        numbers = range(self.MAX)
        random.shuffle(numbers)
        for i in numbers:
            self.failUnless(i in set)
        self.failUnless(self.is_compressed(set))
        set.remove_range(0, self.MAX-1)
        for i in numbers:
            self.failIf(i in set)
        self.failUnless(self.is_compressed(set))

    def testFormat(self):
        set = NewsrcSet()
        set.add_range(1, 20)
        set.add_range(26, 28)
        set.add(22)
        set.add(24)
        set.add(30)
        self.failUnlessEqual(set.format(), "1-20, 22, 24, 26-28, 30")

    def testParseAndAdd(self):
        set = NewsrcSet()
        set.parse_and_add("1-20, 22, 24, 26-28, 30")
        for i in range(1, 21) + [22, 24, 26, 27, 28, 30]:
            self.failUnless(i in set)
            set.remove(i)
        self.failUnlessEqual(set.set, [])

    def testRoundTrip(self):
        set = NewsrcSet()
        orig_string = "1-20, 22, 24, 26-28, 30"
        set.parse_and_add(orig_string)
        self.failUnlessEqual(set.format(), orig_string)


class Newsrc:

    """Class to represent a .newsrc file."""

    def __init__(self):
        self.groups = [] # List of (group, is_subscribed) in the order 
                        # groups should appear in the file
        self.articles = {} # Sets of article numbers indexed by group name
        
    def add_group(self, group, is_subscribed):
        """Add a group to the file."""
        assert group not in self.articles
        self.groups.append((group, is_subscribed))
        self.articles[group] = NewsrcSet()
    
    def remove_group(self, group):
        """Remove a group from the file."""
        self.groups.remove(group)
        del self.articles[group]

    def mark_as_read(self, group, article):
        """Mark an article as read."""
        self.articles[group].add(article)
    
    def mark_as_unread(self, group, article):
        """Mark an article as unread."""
        self.articles[group].remove(article)
    
    def mark_range_as_read(self, group, lowest, highest):
        """Mark a range of articles as read."""
        self.articles[group].add_range(lowest, highest)

    def mark_range_as_unread(self, group, lowest, highest):
        """Mark a range of articles as unread."""
        self.articles[group].remove_range(lowest, highest)
    
    def set_subscription(self, group, is_subscribed):
        """Change the subscription state of a group."""
        for i in range(len(self.groups)):
            if self.groups[i][0] == group:
                self.groups[i] = (group, is_subscribed)
                return
    
    def subscribe(self, group):
        """Mark the group as subscribed."""
        self.set_subscription(group, True)
    
    def unsubscribe(self, group):
        """Mark the group as unsubscribed."""
        self.set_subscription(group, False)

    def get_groups(self):
        """Return list of group names."""
        return self.groups[:]

    def get_articles(self, group):
        """Get set of article numbers for a given group."""
        return self.articles[group]

    def set_group_order(self, groups):
        """Set order of groups in file.
        
        'groups' is a list of group names. If the file contains more
        groups than are given in the file, they are put after the ones
        in 'groups', in the order they were before this method was
        called.
        
        """
        
        dict = {}
        for group, is_subscribed in self.groups:
            dict[group] = is_subscribed

        new_groups = []
        for group in groups:
            new_groups.append((group, dict[group]))
            del dict[group]

        for group, is_subscribed in self.groups:
            if group in dict:
                new_groups.append((group, dict[group]))

        self.groups = new_groups

    def load(self, file):
        """Read an open .newsrc file."""
        for line in file:
            if "!" in line:
                (name, articles) = line.split("!", 1)
                is_subscribed = False
            elif ":" in line:
                (name, articles) = line.split(":", 1)
                is_subscribed = True
            else:
                raise SyntaxError("syntax error in file " + file.name)
            self.add_group(name, is_subscribed)
            self.get_articles(name).parse_and_add(articles)

    def save(self, file):
        """Write an open .newsrc file."""
        dict = {
            True: ": ",
            False: "! ",
        }
        for group, is_subscribed in self.groups:
            file.write("%s%s%s\n" %
                       (group, 
                        dict[is_subscribed], 
                        self.get_articles(group).format()))


class NewsrcTestCases(unittest.TestCase):

    file_data = """\
foo! 1-2
bar: 1-2, 4-6
"""

    def testCreate(self):
        newsrc = Newsrc()
        self.failUnlessEqual(newsrc.groups, [])
        self.failUnlessEqual(newsrc.articles, {})

    def testLoadAndSave(self):
        newsrc = Newsrc()
        newsrc.load(StringIO.StringIO(self.file_data))
        groups = newsrc.get_groups()
        groups.sort()
        self.failUnlessEqual(groups, [("bar", True), ("foo", False)])

        saved = StringIO.StringIO()
        newsrc.save(saved)
        self.failUnlessEqual(saved.getvalue(), self.file_data)

if __name__ == "__main__":
    unittest.main()
