#!/usr/bin/python

#   Copyright (C) 2002-2003 Yannick Gingras <ygingras@ygingras.net>
#   Copyright (C) 2002-2003 Vincent Barbin <vbarbin@openbeatbox.org>

#   This file is part of Open Beat Box.

#   Open Beat Box is free software; you can redistribute it and/or modify
#   it under the terms of the GNU General Public License as published by
#   the Free Software Foundation; either version 2 of the License, or
#   (at your option) any later version.

#   Open Beat Box is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU General Public License for more details.

#   You should have received a copy of the GNU General Public License
#   along with Open Beat Box; if not, write to the Free Software
#   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


from qt          import *
from os          import path
from pprint      import pprint
from PixmapSet   import *
from OBBFuncts   import *
from ImageLoader import ImageLoader
from Stretcher   import Stretcher
from xml.dom.minidom import parse

class OBBSkin:
    def __init__(self):
        self.pixmapSets      = {}
        self.widgets         = {}
        self.stretchedFrames = {}
        self.stretcher       = Stretcher()
        

    def compile(self, filename):
        print "Compiling", filename
        self.baseDir = os.path.splitext(filename)[0]
        try:
            os.mkdir(self.baseDir)
        except OSError:
            # already created ?
            pass
        self.loader = ImageLoader()
        self.doc = parse(filename)
        self.compilePixmapSets()
        

    def load(self, filename):
        self.baseDir = os.path.splitext(filename)[0]
        self.doc = parse(filename)
        self.loadPixmapSets()
        self.loadWidgets()
        

    def save(self, filename):
        raise NotImplementedError()


    def loadWidgets(self):
        self.loadHitButtons()
        self.loadPushButtons()
        self.loadLights()
        self.loadLabels()
        self.loadSpinBoxes()
        self.loadSliders()
        self.loadStretchFrames()


    def getStretchedPix( self, setName, orientation, params):
        stretchPoint = params["stretchPoint"]
        stretchSize  = params["stretchSize"]
        wantedSize   = params["wantedSize"]
        key = (setName, orientation, stretchSize, stretchSize, wantedSize)
        if self.stretchedFrames.has_key(key):
            return self.stretchedFrames[key]
        pixSet = self.pixmapSets[setName]
        stretchedSet = self.stretcher.stretchPixSet( pixSet,
                                                     orientation,
                                                     stretchPoint,
                                                     stretchSize,
                                                     wantedSize )
        self.stretchedFrames[key] = stretchedSet
        return pixSet
        

    def loadHitButtons(self):
        nodes = self.doc.getElementsByTagName("HitButton")
        self.widgets["HitButton"] = {}
        for node in nodes:
            name = node.getAttribute("name")
            params = {}
            for param in ["buttonPixSet", "lightPixSet"]:
                params[param] = self.pixmapSets[node.getAttribute(param)]
            for param in ["lightX", "lightY"]:
                params[param] = int(node.getAttribute(param))
            self.widgets["HitButton"][name] = params


    def loadSpinBoxes(self):
        nodes = self.doc.getElementsByTagName("OBBSpinBox")
        self.widgets["OBBSpinBox"] = {}
        for node in nodes:
            name = node.getAttribute("name")
            params = {}
            for param in ["x",
                          "y",
                          "stretchPoint",
                          "stretchSize",
                          "wantedSize",
                          "borderX",
                          "borderY",
                          "fontSize"]:
                params[param] = int(node.getAttribute(param))
            for param in ["labelPixSet", "upPixSet", "downPixSet"]:
                params[param] = self.getStretchedPix(node.getAttribute(param),
                                                     "x",
                                                     params)

            self.widgets["OBBSpinBox"][name] = params


    def loadSliders(self):
        nodes = self.doc.getElementsByTagName("OBBSlider")
        self.widgets["OBBSlider"] = {}
        for node in nodes:
            name = node.getAttribute("name")
            params = {}
            for param in ["x",
                          "y",
                          "stretchPoint",
                          "stretchSize",
                          "wantedSize",
                          "handleXOffset",
                          "handleYOffset",
                          "minRange",
                          "maxRange",
                          "defaultValue"]:
                params[param] = int(node.getAttribute(param))
            for param in ["framePixSet", "handlePixSet"]:
                params[param] = self.getStretchedPix(node.getAttribute(param),
                                                     "x",
                                                     params)
            params["orientation"] = node.getAttribute("orientation")
            self.widgets["OBBSlider"][name] = params


    def loadStretchFrames(self):
        nodes = self.doc.getElementsByTagName("StretchFrame")
        self.widgets["StretchFrame"] = {}
        for node in nodes:
            name = node.getAttribute("name")
            params = {}
            params["orientation"] = node.getAttribute("orientation")
            for param in ["stretchPoint", "stretchSize", "wantedSize"]:
                params[param] = int(node.getAttribute(param))
            pixSet = self.getStretchedPix(node.getAttribute("pixmapSet"),
                                          params["orientation"],
                                          params)
            params["pixmapSet"] = pixSet
            self.widgets["StretchFrame"][name] = params


    def loadLabels(self):
        nodes = self.doc.getElementsByTagName("OBBLabel")
        self.widgets["OBBLabel"] = {}
        for node in nodes:
            name = node.getAttribute("name")
            params = {}
            for param in ["x",
                          "y",
                          "stretchPoint",
                          "stretchSize",
                          "wantedSize",
                          "borderX",
                          "borderY",
                          "fontSize"]:
                params[param] = int(node.getAttribute(param))
            params["align"] = node.getAttribute("align")
            pixSet = self.getStretchedPix(node.getAttribute("pixmapSet"),
                                          "x",
                                          params)
            params["pixmapSet"] = pixSet
            self.widgets["OBBLabel"][name] = params


    def loadPushButtons(self):
        nodes = self.doc.getElementsByTagName("PushButton")
        self.widgets["PushButton"] = {}
        for node in nodes:
            name = node.getAttribute("name")
            pixSet = self.pixmapSets[node.getAttribute("pixmapSet")]
            params = {"pixmapSet":pixSet}
            self.widgets["PushButton"][name] = params


    def loadLights(self):
        nodes = self.doc.getElementsByTagName("OBBLight")
        self.widgets["OBBLight"] = {}
        for node in nodes:
            name = node.getAttribute("name")
            pixSet = self.pixmapSets[node.getAttribute("pixmapSet")]
            params = {"pixmapSet":pixSet}
            self.widgets["OBBLight"][name] = params


    def createWidget(self, widgetClass, name, params):
        className =  path.splitext(str(widgetClass))[1].replace(".", "")
        skinParams = self.widgets[className][name]
        skinParams.update(params)
        return widgetClass(**skinParams)
    

    def loadPixmapSets(self):
        nodes = self.doc.getElementsByTagName("PixmapSet")
        for node in nodes:
            name = node.getAttribute("name")
            fileHint = node.getAttribute("fileHint")
            type = node.getAttribute("type")
            if type == "triState":
                self.pixmapSets[name] = self.loadTriStatePixmaps(fileHint)
            elif type == "monoState":
                self.pixmapSets[name] = self.loadSimplePixmaps(fileHint)
            else:
                raise Exception("What the hell is type '%s' ?" % type)


    def compilePixmapSets(self):
        nodes = self.doc.getElementsByTagName("PixmapSet")
        for node in nodes:
            name = node.getAttribute("name")
            fileHint = node.getAttribute("fileHint")
            type = node.getAttribute("type")
            if type == "triState":
                self.compileTriStatePixmaps(fileHint)
            elif type == "monoState":
                self.compileSimplePixmaps(fileHint)
            else:
                raise Exception("What the hell is type '%s' ?" % type)

        
    def loadSimplePixmaps(self, fileHint):
        pixmaps = PixmapSet()
        filename = os.path.join( self.baseDir, ("%s.png" % fileHint))
        maskname = os.path.join( self.baseDir, ("%s_mask.png" % fileHint))
        pixmap = QPixmap(filename)
        mask   = QBitmap(maskname)
        pixmap.setMask(mask)
        pixmaps.addState( pixmap, DISABLED )
        return pixmaps


    def loadTriStatePixmaps(self, fileHint):
        pixmaps = PixmapSet()
        for state in (DISABLED, ACTIVATED, DESACTIVATED):
            filename = os.path.join(self.baseDir, "%s_%s.png" % (fileHint,
                                                                  state) )
            maskname = os.path.join(self.baseDir, "%s_%s_mask.png" % (fileHint,
                                                                       state) )
            pixmap = QPixmap(filename)
            mask   = QBitmap(maskname)
            pixmap.setMask(mask)
            pixmaps.addState( pixmap, state )


        return pixmaps


    def compileSimplePixmaps(self, fileHint):
        print "Cropping", fileHint
        filename = "%s.png"      % fileHint
        maskname = "%s_mask.png" % fileHint
        pixmap = QPixmap(os.path.join( getImgDir(), filename))
        croppedPix = self.loader.cropPixmap(pixmap)
        croppedPix.save(os.path.join(self.baseDir, filename), "PNG")
        mask = croppedPix.createHeuristicMask()
        mask.save(os.path.join(self.baseDir, maskname), "PNG")



    def compileTriStatePixmaps(self, fileHint):
        print "Cropping", fileHint
        for state in (DISABLED, ACTIVATED, DESACTIVATED):
            filename = "%s_%s.png"      % (fileHint, state)
            maskname = "%s_%s_mask.png" % (fileHint, state)
            pixmap = QPixmap(os.path.join( getImgDir(), filename))
            croppedPix = self.loader.cropPixmap(pixmap)
            croppedPix.save(os.path.join(self.baseDir, filename), "PNG")
            mask = croppedPix.createHeuristicMask()
            mask.save(os.path.join(self.baseDir, maskname), "PNG")

