"""Module implementing a SQL based repository for cfvers"""

# Copyright 2003 Iustin Pop
#
# This file is part of cfvers.
#
# cfvers is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# cfvers is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with cfvers; if not, write to the Free Software Foundation,
# Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

import base64
import quopri
import os
import re
import bz2
import zlib
import sys

__all__ = ["RSql",]

from cfvers.main import *

class RSql(object):
    def __init__(self, create=False, cnxargs=None, createopts=None):
        if type(self) is RSql:
            raise TypeError, "You can't instantiate a RSql!"
        return

    def _create(self, createopts=None):
        cursor = self.conn.cursor()
        if createopts is not None and createopts.force:
            print >>sys.stderr, "Please ignore errors while cleaning old items..."
            cursor2 = self.conn.cursor()
            self._delete_schema(cursor2)
            print >>sys.stderr, "Old items have been deleted"
            self.commit()
        self._init_schema(cursor)
        if createopts is None or createopts.doarea:
            self._init_data(cursor)
        self.conn.commit()
        return
    
    def _init_schema(self, cursor):
        cursor.execute("CREATE TABLE areas ( " \
                       "name TEXT PRIMARY KEY, " \
                       "root TEXT, " \
                       "ctime TIMESTAMP WITH TIME ZONE, " \
                       "description TEXT)")
        cursor.execute("CREATE TABLE arearevs (" \
                       "area TEXT, " \
                       "revno INTEGER CHECK(revno > 0), " \
                       "server TEXT, " \
                       "logmsg TEXT, " \
                       "ctime TIMESTAMP WITH TIME ZONE, " \
                       "uid INTEGER, " \
                       "gid INTEGER, " \
                       "commiter TEXT, " \
                       "PRIMARY KEY(area, revno))")
        cursor.execute("CREATE TABLE items (id INTEGER PRIMARY KEY, " \
                       "area TEXT NOT NULL, " \
                       "name TEXT NOT NULL, " \
                       "ctime TIMESTAMP WITH TIME ZONE, " \
                       "dirname TEXT)")
        cursor.execute("CREATE UNIQUE INDEX items_an_idx ON " \
                       "items (area, name)")
        cursor.execute("CREATE INDEX items_dir_idx ON " \
                       "items (dirname)")
        cursor.execute("CREATE TABLE revisions (item INTEGER, " \
                       "revno INTEGER, filename TEXT, filetype INTEGER, " \
                       "filecontents TEXT, sha1sum TEXT, " \
                       "size INTEGER, " \
                       "mode INTEGER, " \
                       "mtime INTEGER, " \
                       "atime INTEGER, " \
                       "ctime INTEGER, " \
                       "inode INTEGER, " \
                       "device INTEGER, " \
                       "nlink INTEGER, " \
                       "uid INTEGER, gid INTEGER, " \
                       "rdev INTEGER, " \
                       "blocks INTEGER, " \
                       "blksize INTEGER, " \
                       "encoding TEXT, "\
                       "PRIMARY KEY(item, revno))")
        return

    def _delete_schema(self, cursor):
        self._try_stm(cursor, "DROP TABLE revisions")
        self._try_stm(cursor, "DROP TABLE arearevs")
        self._try_stm(cursor, "DROP TABLE items")
        self._try_stm(cursor, "DROP TABLE areas")
        return

    def _try_stm(self, cursor, stm, vals={}):
        try:
            cursor.execute(stm, vals)
            self.commit()
        except Exception, e:
            print >>sys.stderr, "Info: An error has occured: '%s'. Continuing" % e
            self.rollback()
        return
    
    def _init_data(self, cursor):
        a = Area(name="default", description="Default area", root="/")
        self.addArea(a)
        return

    def close(self):
        self.conn.close()
        return

    def commit(self):
        self.conn.commit()
        return

    def rollback(self):
        self.conn.rollback()
        return
        
    def getItem(self, id):
        cursor = self.conn.cursor()
        cursor.execute("select id, area, name, ctime, dirname from items where id = %s", (id,))
        row = cursor.fetchone()
        if row is None:
            return None
        i = Item(id=row[0], area=self.getArea(row[1]), name=row[2], ctime=row[3], dirname=row[4])
        return i

    def getItemByName(self, area, name):
        cursor = self.conn.cursor()
        cursor.execute("select id, area, name, ctime, dirname from items where area = %s and name = %s", (area.name, name))
        row = cursor.fetchone()
        if row is None:
            return None
        i = Item(row[0], area, row[2], row[3], row[4])
        return i

    def getItemsByDirname(self, area, name):
        cursor = self.conn.cursor()
        cursor.execute("select id, name, ctime from items where area = %s and dirname = %s", (area.name, name))
        return [Item(row[0], area, row[1], row[2], name) for row in cursor.fetchall()]

    def addItem(self, item):
        self.conn.cursor().execute("insert into items (area, name, ctime, dirname) values (%s, %s, %s, %s)",
                                   (item.area.name, item.name, item.ctime, item.dirname))
        return

    def updItem(self, item):
        pass

    def addEntry(self, entry):
        cursor = self.conn.cursor()
        payload, encoding = self._encode_payload(entry.filecontents)
        cursor.execute("""insert into revisions
        (item, revno, filename, filetype, filecontents,
        mode, mtime, atime, uid, gid, rdev, encoding,
        sha1sum, size, ctime, inode, device, nlink,
        blocks, blksize
        ) values (%s, %s, %s, %s, %s, %s, %s, %s,
        %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
                       (entry.item, entry.revno, entry.filename,
                        entry.filetype, payload, entry.mode,
                        entry.mtime, entry.atime, entry.uid, entry.gid,
                        entry.rdev, encoding, entry.sha1sum, entry.size,
                        entry.ctime, entry.inode, entry.device,
                        entry.nlink, entry.blocks, entry.blksize)
                       )

    def items(self, area=None):
        cursor = self.conn.cursor()
        if area is None:
            alla = self.areas()
            mareas = {}
            for i in alla:
                mareas[i.name] = i
            cursor.execute("select id, area, name, ctime from items")
            items = [Item(row[0], mareas[row[1]], row[2], row[3]) for row in cursor.fetchall()]
        else:
            cursor.execute("select id, area, name, ctime from items where area = %s", (area.name,))
            items = [Item(row[0], area, row[2], row[3]) for row in cursor.fetchall()]
        return items
        
    def areas(self):
        cursor = self.conn.cursor()
        cursor.execute("select name, root, ctime, (select count(*) from items where items.area = areas.name) as nitems, (select max(revno) from arearevs where arearevs.area = areas.name) as revno, description from areas")
        areas = [Area(name=row[0], root=row[1], ctime=row[2], numitems=row[3],
                      revno=row[4], description=row[5])
                 for row in cursor.fetchall()]
        return areas
    
    def getEntry(self, item, revno):
        cursor = self.conn.cursor()
        if revno is None:
            cursor.execute("select item, revno, filename, filecontents, filetype, mode, mtime, atime, uid, gid, rdev, encoding, sha1sum, size, ctime, inode, device, nlink, blocks, blksize from revisions where item = %s order by revno desc limit 1", (item.id,))
        else:
            cursor.execute("select item, revno, filename, filecontents, filetype, mode, mtime, atime, uid, gid, rdev, encoding, sha1sum, size, ctime, inode, device, nlink, blocks, blksize from revisions where item = %s and revno <= %s order by revno desc limit 1", (item.id, revno))
        row = cursor.fetchone()
        if row is None:
            return None
        rev = RevEntry()
        (rev.item, rev.revno, rev.filename, payload, rev.filetype,
         rev.mode, rev.mtime, rev.atime, rev.uid, rev.gid, rev.rdev,
         encoding, rev.sha1sum, rev.size, rev.ctime, rev.inode,
         rev.device, rev.nlink, rev.blocks, rev.blksize,
         ) = row
        rev.filecontents = self._decode_payload(payload, encoding)
        return rev
    
    def _encode_payload(self, payload, dobzip2=False,
                        doquote=True):
        encoding = ""
        if dobzip2:
            ndata = bz2.compress(payload)
            if len(ndata) < len(payload):
                payload = ndata
                encoding = "bzip2:%s" % encoding
        if doquote:
            if payload.find("\0") != -1:
                # payload contains embedded nulls
                b64 = base64.encodestring(payload)
                qp = quopri.encodestring(payload, quotetabs=1)
                if len(b64) < len(qp): # the file is mostly binary
                    encoding = "base64:%s" % encoding
                    payload = b64
                else:
                    encoding = "quoted-printable:%s" % encoding
                    payload = qp
        return payload, encoding
        
    def _decode_payload(self, payload, encoding):
        for enc in encoding.split(":"):
            if enc is None or enc == "":
                break
            elif enc == "base64":
                payload = base64.decodestring(payload)
            elif enc == "quoted-printable":
                payload = quopri.decodestring(payload)
            elif enc == "bzip2":
                payload = bz2.decompress(payload)
            elif enc == "gzip":
                payload = zlib.decompress(payload)
            else:
                raise ValueError, "Unknown encoding '%s'!" % enc
        return payload
        
    def getRevList(self, item):
        revs = []
        cursor = self.conn.cursor()
        cursor.execute("select item, revno, filename, filetype, mode, mtime, atime, uid, gid, rdev, size from revisions where item = %s order by revno desc", (item.id,))
        for row in cursor.fetchall():
            rev = RevEntry()
            (rev.item, rev.revno, rev.filename, rev.filetype, rev.mode, rev.mtime, rev.atime, rev.uid, rev.gid, rev.rdev, rev.size) = row
            revs.append(rev)
        return revs

    def getRevNumbers(self, item):
        cursor = self.conn.cursor()
        cursor.execute("select revno from revisions where item = %s order by revno", (item.id,))
        return [x[0] for x in cursor.fetchall()]

    def addArea(self, a):
        self.conn.cursor().execute("insert into areas (name, root, ctime, description) values (%s, %s, %s, %s)",
                                   (a.name, a.root, a.ctime, a.description))

    def updArea(self, a):
        self.conn.cursor().execute("update areas set name = %s, description = %s, root = %s where id = %s",
                                   (a.name, a.description, a.root, a.name))

    def getArea(self, name):
        cursor = self.conn.cursor()
        cursor.execute("select name, root, ctime, (select count(*) from items where items.area = areas.name) as nitems, (select max(revno) from arearevs where arearevs.area = areas.name) as revno, description from areas where name = %s", (name,))
        row = cursor.fetchone()
        if row is None:
            return None
        a = Area(name=row[0], root=row[1], ctime=row[2],
                 numitems=row[3], revno=row[4], description=row[5])
        return a

    def getAreaRevs(self, area):
        cursor = self.conn.cursor()
        cursor.execute("select area, revno, logmsg, ctime, uid, gid, commiter, server from arearevs where area = %s order by revno desc", (area.name,))
        ars = []
        for row in cursor.fetchall():
            r = AreaRev()
            (r.area, r.revno, r.logmsg, r.ctime, r.uid, r.gid, r.commiter, r.server) = row
            ars.append(r)
        return ars

    def getAreaRevItems(self, ar):
        c2 = self.conn.cursor()
        c2.execute("select items.id from items, revisions where items.area = %s and revisions.revno = %s and revisions.item = items.id",
                   (ar.area, ar.revno))
        return [row[0] for row in c2.fetchall()]

    def putAreaRev(self, ar):
        cursor = self.conn.cursor()
        cursor.execute("insert into arearevs (area, revno, logmsg, ctime, uid, gid, commiter, server) values (%s, %s, %s, %s, %s, %s, %s, %s)",
                       (ar.area, ar.revno, ar.logmsg, ar.ctime,
                        ar.uid, ar.gid, ar.commiter, ar.server))
        return
    
    def numAreas(self):
        cursor = self.conn.cursor()
        cursor.execute("select count(1) from areas")
        row = cursor.fetchone()
        return int(row[0])
