# hpstats.py -- wrapper for pstats that analyses hotshot data.

# Copyright (c) 2005 Floris Bruynooghe

# All rights reserved.

# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, provided
# that the above copyright notice(s) and this permission notice appear
# in all copies of the Software and that both the above copyright
# notice(s) and this permission notice appear in supporting
# documentation.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT OF THIRD PARTY RIGHTS. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR HOLDERS INCLUDED IN THIS NOTICE BE LIABLE FOR
# ANY CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY
# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
# OF THIS SOFTWARE.

# Except as contained in this notice, the name of a copyright holder
# shall not be used in advertising or otherwise to promote the sale,
# use or other dealings in this Software without prior written
# authorization of the copyright holder.

"""hpstats - examine and print reports on code profiled by hprofile."""


import os.path
import copy
import re
import pickle
import sys
import tempfile

import hstats


__all__ = ["Stats"]


class Stats:
    """Class to analyse profiling data.

    All the methods do return 'self' so that calls like
    Stats('file').strip_dirs().print_stats() are valid.
    """
    # For ease of use, the indexes of the rows in self._data:
    _D_CALLS = 0
    _D_CALLS_P = 1
    _D_TOTTIME = 2
    _D_PERCALL_T = 3
    _D_CUMTIME = 4
    _D_PERCALL_C = 5
    _D_NAME_MODULE = 6
    _D_NAME_LINE = 7
    _D_NAME_FUNC = 8
    _D_NAME_LINE_STR = 9
    _D_PARENTS = 10

    def __init__(self, filename, *args):
        """Instantiate a statistics object of profiling data.

        filename - The file containing the profiling data.
        
        *args - More filenames to load.  Data from all files is
          merged.
        """
        self._data = list()             # The data
        self._data_order = None         # How the data is sorted.
        # Could call self.add() here instead, but it would be slightly
        # less efficient.  Most of the time only a singe filename will
        # be given to __init__ so we don't do a
        # self._merge_duplicate_data() every time.  self.add() however
        # does a self._merge_dduplicate_data() every time.
        files = list(args)
        files.insert(0, filename)
        for filename in files:
            self._load_file(filename)
        if len(files) > 1:
            self._merge_duplicate_data()

    def _load_file(self, filename):
        """Try and load a profiling file.

        Automatically tries both hotshot format and pickle.  Raises a
        RuntimeError in case of failure.

        The data gets inserted right into self._data.  If this is not
        the only data in there it may be necessary to call
        self._merge_duplicate_data().
        """
        if not isinstance(filename, str):
            raise ValueError, "%s is not a filename (string)." % type(filename)

        # Disable stdout for a short while!
        bak = sys.stdout
        sys.stdout = tempfile.TemporaryFile()
        try:
            self._load_hotshot(filename)
        except:
            try:
                self._load_pickled(filename)
            except:
                sys.stdout = bak
                raise RuntimeError, "unable to load file: %s" % filename
            else:
                sys.stdout = bak
        else:
            sys.stdout = bak
        
    def _load_hotshot(self, filename):
        """Load hotshot data from `filename'.

        The data gets inserted right into self._data.  If this is not
        the only data in there it may be necessary to call
        self._merge_duplicate_data().
        """
        stats = _BackEnd(filename)
        data, dd = stats.get_data(extra=['parents'])
        for row in data:
            r = list()
            r.append(row[0][0])         # No calls
            r.append(row[0][0] - row[0][1]) # No primitive calls
            r.append(row[1])            # Total time
            r.append(row[2])            # Time per call of total time
            r.append(row[3])            # Cumulative time
            r.append(row[4])            # Time per call of cumulative time
            r.append(row[5][0])         # File name
            r.append(row[5][1])         # Line no
            r.append(row[5][2])         # Function name
            r.append(str(row[5][1]))    # Line no as string (used for sort)
            r.append(row[6])
            self._data.append(r)

    def _load_pickled(self, filename):
        """Load data dumbed into `filename' by self.dump_stats().

        The data gets inserted right into self._data.  If this is not
        the only data in there it may be necessary to call
        self._merge_duplicate_data().
        """
        fileobject = file(filename, "rb")
        data = pickle.load(fileobject)
        fileobject.close()
        self._data.extend(data)

    def _merge_duplicate_data(self):
        """Merge duplicate entries in self._data.

        An entry is a duplicate when the (module_name, line_number,
        function_name) tuple is identical.

        It is obvious that any sorting of the data is lost.
        """
        data = dict()
        for index in xrange(len(self._data)):
            row = self._data[index]
            key = self._tuple_name_from_row(row)
            if data.has_key(key):       # Merge row
                r = data[key]
                r[self._D_CALLS] += row[self._D_CALLS]
                r[self._D_CALLS_P] += row[self._D_CALLS_P]
                r[self._D_TOTTIME] += row[self._D_TOTTIME]
                r[self._D_PERCALL_T] = r[self._D_TOTTIME] \
                                       / r[self._D_CALLS]
                r[self._D_CUMTIME] += row[self._D_CUMTIME]
                r[self._D_PERCALL_C] = r[self._D_CUMTIME] \
                                       / r[self._D_CALLS]
                old_p = r[self._D_PARENTS]
                new_p = row[self._D_PARENTS]
                for k in new_p.keys():
                    if old_p.has_key(k): # Merge parent
                        old_p[k] += new_p[k]
                    else:               # Add parent
                        old_p[k] = new_p[k]
            else:                       # Add row
                data[key] = row
        self._data = [v for v in data.values()]

    def strip_dirs(self):
        """Remove leading path names from filenames.

        This will modify the data stored, any sorting will not be
        conserved.  Also take care as two functions with the same
        name, same line number and same filename but different path
        will be merged into just one.
        """
        for index in xrange(len(self._data)):
            row = self._data[index]
            module = row[self._D_NAME_MODULE]
            module = os.path.basename(module)
            row[self._D_NAME_MODULE] = module
            new_parents = dict()
            for parent, v in row[self._D_PARENTS].iteritems():
                name = list(parent)
                name[0] = os.path.basename(name[0])
                name = tuple(name)
                if new_parents.has_key(name): # Merge parents
                    new_parents[name] += v
                else:                   # Add parent
                    new_parents[name] = v
            row[self._D_PARENTS] = new_parents
        self._merge_duplicate_data()
        return self

    def add(self, filename, *args):
        """Merge another profiling file into the data being analysed.

        If some data has the same filename, line number and function
        name as data already in the instance the two will be merged.

        filename - The file containing the profiling data.
        
        *args - More filenames to load.  Data from all files is
          merged.
        """
        files = list(args)
        files.insert(0, filename)
        for filename in files:
            self._load_file(filename)
        self._merge_duplicate_data()
        return self

    def dump_stats(self, filename):
        """Save data into a file.

        The file will get overwritten withouth warning if it does
        exists already.

        filename - Name to save the data under.
        """
        # Delete any unofficial data like the dict of callees
        for row_index in range(len(self._data)):
            while len(self._data[row_index]) > 11:
                del self._data[row_index][-1]

        # Open file and save
        file_object = file(filename, "wb")
        pickle.dump(self._data, file_object, pickle.HIGHEST_PROTOCOL)
        file_object.close()

    def sort_stats(self, *keys):
        """Sort the data of the Stats object.

        keys - One or more strings or unambiguous abbreviations of
          'calls', 'cumulative', 'file', 'module', 'pcalls', 'line',
          'name', 'nfl', 'stdname' and 'time'.  For compatibility also
          -1, 0, 1 and 2 are permitted.  They are interpreted as
          'stdname', 'calls', 'time', and 'cumulative' respectively.

          If more then one key is given the others are secondary keys
          to sort on.  They can not be used when `key' is an old style
          numeric argument however.
        """
        keys = list(keys)

        # Check for old style sort args and replace them with new ones.
        if len(keys) == 1:
            key = keys[0]
            if isinstance(key, int):
                if key not in [-1, 0, 1, 2]:
                    raise ValueError, "%d is not in range [-1, 0, 1, 2]."
                if len(keys) > 1:
                    raise RuntimeError, "old style sort only accepts one key."
                sort_map = {-1: "stdname",
                            0: "calls",
                            1: "time",
                            2: "cumulative"}
                key = sort_map[key]
                keys[0] = key

        self._data_order = keys = self._expand_abbrev(keys)
        order = self._make_order(keys)
        #self._data = self._stats.sort_data(self._data, order)
        # Hack to get hstats.Stats.sort_data() to work:
        class DummyStats(hstats.Stats):
            def __init__(self):
                pass
        dummy = DummyStats()
        self._data = dummy.sort_data(self._data, order)
        return self

    def _expand_abbrev(self, abbrev_list):
        """Expand abbreviations of sort keywords.

        Takes a list of all the keywords to be exanded as argument.
        Returns a list with the expanded words in the same order.
        Words fully expanded already are left alone.
        """
        valid_words = ['calls', 'cumulative', 'file', 'module', 'pcalls',
                       'line', 'name', 'nfl', 'stdname', 'time']
        result = list()
        for word in abbrev_list:
            hit = 0
            for valid in valid_words:
                if valid.startswith(word):
                    hit += 1
                    full_word = valid
            if hit == 1:
                result.append(full_word)
            elif hit > 1:
                raise ValueError, "%s is not a unique abbreviation." % word
            else:
                raise ValueError, "%s does not match any keyword." % word
        return result

    def _make_order(self, keywords):
        """From a list of keywords, return the sort list."""
        name_to_column_map = {'calls': self._D_CALLS,
                              'cumulative': self._D_CUMTIME,
                              'file': self._D_NAME_MODULE,
                              'module': self._D_NAME_MODULE,
                              'pcalls': self._D_CALLS_P,
                              'line': self._D_NAME_LINE,
                              'name': self._D_NAME_FUNC,
                              'nfl': None,
                              'stdname': None,
                              'time': self._D_TOTTIME}
        order = list()
        for key in keywords:
            if key == 'nfl':
                pair1 = (self._D_NAME_FUNC, 'a')
                pair2 = (self._D_NAME_MODULE, 'a')
                pair3 = (self._D_NAME_LINE, 'a')
                order.extend([pair1, pair2, pair3])
            elif key == 'stdname':
                pair1 = (self._D_NAME_MODULE, 'a')
                pair2 = (self._D_NAME_LINE_STR, 'a')
                pair3 = (self._D_NAME_FUNC, 'a')
                order.extend([pair1, pair2, pair3])
            else:
                column = name_to_column_map[key]
                if key in ['calls', 'cumulative', 'pcalls', 'line', 'time']:
                    direction = 'd'
                else:
                    direction = 'a'
                order.append((column, direction))
        return order

    def reverse_order(self):
        """Reverse sorting order of the data. """
        self._data.reverse()
        return self

    def print_stats(self, *restrictions):
        """Print a report of the profiling data.

        *restrictions can be one of:
        
           o Integer: only this many number of lines are printed.

           o Float x, where 0.0 <= x <= 1.0: only this percentage of
             the lines will be printed.

           o Regular expression: only lines where the name matches
             this expression will be printed.

        Multiple restrictions are applied sequentially.
        """
        data_to_print = self._restrict_data(*restrictions)
        self._print_preamble()
        self._print_data(data_to_print)
        print
        print
        return self

    def _print_preamble(self):
        """Print out the preamble of the stats data."""
        no_calls = sum([x[self._D_CALLS] for x in self._data])
        no_pcalls = sum([x[self._D_CALLS_P] for x in self._data])
        if no_calls == no_pcalls:
            pcalls_text = ""
        else:
            pcalls_text = " (%d primitive calls)" % no_pcalls
        time = sum([x[self._D_TOTTIME] for x in self._data])
            
        print "%9d function calls%s in %.3f CPU seconds" \
               % (no_calls, pcalls_text, time)
        print
        print "   %s" % self._get_order_text()
        print
        print "   ncalls  tottime  percall  cumtime  percall" \
              " filename:lineno(function)"

    def _get_order_text(self):
        """Return the text that explains the order of the data."""
        description_map = {"calls"     : "call count",
                           "cumulative": "cumulative time",
                           "file"      : "file name",
                           "line"      : "line number",
                           "module"    : "file name",
                           "name"      : "function name",
                           "nfl"       : "name/file/line",
                           "pcalls"    : "call count",
                           "stdname"   : "standard name",
                           "time"      : "internal time"}
        if self._data_order == None:
            text = "Random listing order was used"
        else:
            order = self._data_order[:]
            text = "Ordered by: " + str(description_map[order.pop(0)])
            for key in order:
                text += ', ' + str(description_map[key])
        return text

    def _restrict_data(self, *restrictions):
        """Apply restrictions to the data in this instance and return data.

        This is used for printing stats withouth showing all data.
        For a description bout the arguments see the .print_stats()
        method.
        """
        if len(restrictions) == 0:
            return self._data

        data = copy.copy(self._data)
        for restriction in restrictions:
            if isinstance(restriction, int):
                data = data[:restriction]
                continue
            elif isinstance(restriction, float):
                if not 0.0 <= restriction <= 1.0:
                    raise ValueError, \
                          "%s is not between 0.0 and 1.0." % restriction
                length = len(data)
                lines = int(length*restriction)
                data = data[:lines]
                continue
            elif isinstance(restriction, str):
                regexp = re.compile(restriction)
                new_data = list()
                for row in data:
                    name = self._repr_name_from_row(row)
                    if regexp.search(name):
                        new_data.append(row)
                data = new_data
                continue
            else:
                raise ValueError, "%s is an invalid restriction."
        return data

    def _print_data(self, data_to_print):
        """Print out the data of all the stats."""
        for row in data_to_print:
            values = dict()
            no_calls = row[self._D_CALLS]
            no_pcalls = row[self._D_CALLS_P]
            if no_calls == no_pcalls:
                values['calls'] = str(no_calls)
            else:
                values['calls'] = str(no_calls) + '/' + str(no_pcalls)
            values['time'] = row[self._D_TOTTIME]
            values['percall_t'] = row[self._D_PERCALL_T]
            values['cumtime'] = row[self._D_CUMTIME]
            values['percall_c'] = row[self._D_PERCALL_C]
            values['name'] = self._repr_name_from_row(row)
            print " %(calls)8s %(time)8.3f %(percall_t)8.3f %(cumtime)8.3f" \
                  " %(percall_c)8.3f %(name)s" % values

    def _repr_name_from_row(self, row):
        """Return string representation of the name of a row in the data.

        This means the returned string is what is meant to be print
        off at the right hand side for a report.
        """
        return self._repr_name_from_tuple(self._tuple_name_from_row(row))

    def _tuple_name_from_row(self, row):
        """Return tuple representation of the name of a row in the data."""
        return (row[self._D_NAME_MODULE],
                row[self._D_NAME_LINE],
                row[self._D_NAME_FUNC])

    def _repr_name_from_tuple(self, name):
        """Return string representation of the name in tuple from."""
        string = name[0] + ':' + str(name[1]) + '(' + name[2] + ')'
        return string

    def print_callers(self, *restrictions):
        """Print all callers for each profiled function.

        *restrictions as for the print_stats() method.
        
        A number in parentheses after each caller shows how many times
        this specific call was made. A second number is the cumulative
        time spent in the function at the right.
        """
        data_to_print = self._restrict_data(*restrictions)
        self._print_preamble_callers()
        self._print_data_callers(data_to_print)
        print
        print
        return self

    def _print_preamble_callers(self):
        """Print out the preamble for the callers report."""
        print "   %s" % self._get_order_text()
        print
        print "Function                         was called by..."

    def _print_data_callers(self, data):
        """Print out the data of the callers."""
        for row in data:
            text_name = self._repr_name_from_row(row)
            text_name = text_name.ljust(34)
            text_parents = []
            parents = row[self._D_PARENTS]
            for parent, val in parents.iteritems():
                text_parent = self._repr_name_from_tuple(parent)
                text_parent += '(' + str(val) + ')'
                text_parent += '    '
                cumtime = self._row_from_tuple(parent)[self._D_CUMTIME]
                text_parent += "%.3f" % cumtime
                text_parents.append(text_parent)
            if len(text_parents) == 0:
                text_parents = ["--"]

            # Data (strings) gathered, printing follows
            print text_name + text_parents.pop(0)
            for text in text_parents:
                print ' '*34 + text

    def _row_from_tuple(self, tuple_name):
        """Return the row identified by `tuple_name'"""
        for row in self._data:
            key = (row[self._D_NAME_MODULE],
                   row[self._D_NAME_LINE],
                   row[self._D_NAME_FUNC])
            if key == tuple_name:
                return row
        raise ValueError, "tuple %s does not identify a row." % tuple_name
            
    def print_callees(self, *restrictions):
        """Print all functions called by indicated function.

        *restrictions as for the print_stats() method.
        """
        self._collect_callees()
        data_to_print = self._restrict_data(*restrictions)
        self._print_preamble_callees()
        self._print_data_callees(data_to_print)
        print
        print
        return self

    def _print_preamble_callees(self):
        """Print out the preamble for the callees report."""
        print "   %s" % self._get_order_text()
        print
        print "Function                         called..."

    def _print_data_callees(self, data):
        """Print out the data for the callees report."""
        for row in data:
            text_name = self._repr_name_from_row(row)
            text_name = text_name.ljust(34)
            text_kids = []
            kids = row[-1]
            for kid, val in kids.iteritems():
                text_kid = self._repr_name_from_tuple(kid)
                text_kid += '(' + str(val) + ')'
                text_kid += ' '*4
                cumtime = self._row_from_tuple(kid)[self._D_CUMTIME]
                text_kid += "%.3f" % cumtime
                text_kids.append(text_kid)
            if len(text_kids) == 0:
                text_kids = ["--"]

            # Data (strings) gathered, printing follows
            print text_name + text_kids.pop(0)
            for text in text_kids:
                print ' '*34 + text

    def _collect_callees(self):
        """Calculate all the callees of every function.

        The callees dictionary is appended to every row of self._data.
        """
        for index, row in enumerate(self._data):
            func = self._tuple_name_from_row(row)
            callees = dict()
            for row2 in self._data:
                row2_parents = row2[self._D_PARENTS]
                if func in row2_parents.keys():
                    child = self._tuple_name_from_row(row2)
                    no_calls = row2_parents[func]
                    callees[child] = no_calls
            self._data[index].append(callees)


class _BackEnd(hstats.Stats):
    """Class to read hotshot (and thus hprofile) profiling files.

    Changes from hstats.Stats:
    
    o Time data is stored in Seconds instead of nanoseconds.

    o Bias that gets saved by hprofile gets compensated for.

    o The '<string>' call does not get removed.

    o The data description returned by .get_data() will always be:
      ['call', 'time', 'avgtime', 'cumtime', 'avgcumtime', 'name',
      'parents'] (when using the `extra=['parents']' argument).
    """
    def __init__(self, filename):
        hstats.Stats.__init__(self, filename, keep_exec=True)
        bias = None
        if 'hprofile-bias' in self.get_info().keys():
            # Find bias.
            bias = float(self.get_info()['hprofile-bias'][0])
        for key in self._data.keys():
            # Convert to seconds.
            self._data[key][hstats._SDATA_TIME] *= .000001
            self._data[key][hstats._SDATA_CUMTIME] *= .000001
            # Compensate bias.
            if bias != None:
                self._data[key][hstats._SDATA_TIME] -= bias
                self._data[key][hstats._SDATA_CUMTIME] -= bias
