#!/usr/bin/env python
# Copyright (c) 2014 Pivotal inc.


import datetime
import subprocess
import os
import platform
import sys
import re
from contextlib import closing
from distutils.version import LooseVersion
from optparse import OptionParser
from pygresql import pgdb

gpsd_version = '%prog 1.0'

sysnslist = "('pg_toast', 'pg_bitmapindex', 'pg_temp_1', 'pg_catalog', 'information_schema')"
orca = False
# make search path safe
pgoptions = '-c gp_session_role=utility -c search_path='


def ResultIter(cursor, arraysize=1000):
    'An iterator that uses fetchmany to keep memory usage down'
    while True:
        results = cursor.fetchmany(arraysize)
        if not results:
            break
        for result in results:
            yield result


def getVersion(envOpts):
    cmd = subprocess.Popen('psql --pset footer -Atqc "select version()" template1', shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=envOpts)
    if cmd.wait() is not 0:
        sys.stderr.write('\nError while trying to find HAWQ/GPDB version.\n\n' + cmd.communicate()[1] + '\n\n')
        sys.exit(1)
    return cmd.communicate()[0]


def dumpSchema(connectionInfo, envOpts):
    dmp_cmd = 'pg_dump -h %s -p %s -U %s -s -x --gp-syntax -O %s' % connectionInfo
    p = subprocess.Popen(dmp_cmd, shell=True, stderr=subprocess.PIPE, env=envOpts)
    if p.wait() is not 0:
        sys.stderr.write('\nError while dumping schema.\n\n' + p.communicate()[1] + '\n\n')
        sys.exit(1)


def dumpGlobals(connectionInfo, envOpts):
    if orca:
        dmp_cmd = 'pg_dumpall -h %s -p %s -U %s -l %s -g --no-gp-syntax' % connectionInfo
    else:
        dmp_cmd = 'pg_dumpall -h %s -p %s -U %s -l %s -g --gp-syntax' % connectionInfo
    p = subprocess.Popen(dmp_cmd, shell=True, stderr=subprocess.PIPE, env=envOpts)
    if p.wait() is not 0:
        sys.stderr.write('\nError while dumping globals.\n\n' + p.communicate()[1] + '\n\n')
        sys.exit(1)


def dumpTupleCount(cur):
    stmt = "SELECT pgc.relname, pgn.nspname, pgc.relpages, pgc.reltuples, pgc.relallvisible FROM pg_class pgc, pg_namespace pgn WHERE pgc.relnamespace = pgn.oid and pgn.nspname NOT IN " + \
        sysnslist
    setStmt = '-- Table: {1}\n' \
        'UPDATE pg_class\nSET\n' \
        '{0}\n' \
        'WHERE relname = \'{1}\' AND relnamespace = ' \
        '(SELECT oid FROM pg_namespace WHERE nspname = \'{2}\');\n\n'

    cur.execute(stmt)
    columns = map(lambda x: x[0], cur.description)
    types = ['int', 'real', 'int']
    for vals in ResultIter(cur):
        print setStmt.format(',\n'.join(map(lambda t: '\t%s = %s::%s' %
                                            t, zip(columns[2:], vals[2:], types))), vals[0], vals[1])


def dumpStats(cur, inclHLL):
    query = 'SELECT pgc.relname, pgn.nspname, pga.attname, pgtn.nspname, pgt.typname, pgs.* ' \
        'FROM pg_class pgc, pg_statistic pgs, pg_namespace pgn, pg_attribute pga, pg_type pgt, pg_namespace pgtn ' \
        'WHERE pgc.relnamespace = pgn.oid and pgn.nspname NOT IN ' + \
        sysnslist + \
        ' and pgc.oid = pgs.starelid ' \
        'and pga.attrelid = pgc.oid ' \
        'and pga.attnum = pgs.staattnum ' \
        'and pga.atttypid = pgt.oid ' \
        'and pgt.typnamespace = pgtn.oid ' \
        'ORDER BY pgc.relname, pgs.staattnum'
    # the outline of the DML used to create stats for one column in pg_statistic
    pstring = '--\n' \
        '-- Table: {0}, Attribute: {1}\n' \
        '--\n' \
        'INSERT INTO pg_statistic VALUES (\n' \
        '{2});\n\n'
    # the types of the columns in the pg_statistic table, except for starelid and stavalues[1-5]
    types = ['smallint',  # staattnum
             'boolean',
             'real',
             'integer',
             'real',
             'smallint',
             'smallint',
             'smallint',
             'smallint',
             'smallint',
             'oid',
             'oid',
             'oid',
             'oid',
             'oid',
             'real[]',
             'real[]',
             'real[]',
             'real[]',
             'real[]'
             ]

    cur.execute(query)

    for vals in ResultIter(cur):
        rowVals = ["\t'%s.%s'::regclass" % tuple(vals[1::-1])]

        i = 0
        hll = False
        typeschema = vals[3]
        typename = vals[4]
        if typeschema != "pg_catalog":
            # a type that is not built-in, qualify it with its schema name
            # and play it safe by double-quoting the identifiers
            typename = '"%s"."%s"' % (typeschema, typename)

        # determine the actual types of stavalues[1-5], the columns are defined as "anyarray"
        if vals[4][0] == '_':
            # column is an array type, use as is
            rowTypes = types + [typename] * 5
        else:
            # non-array type, make an array type out of it
            rowTypes = types + [typename + '[]'] * 5

        # create the content of the VALUES clause
        for val, typ in zip(vals[6:], rowTypes):
            i = i + 1
            str_val = "'%s'" % val
            if i == 10 and (val == 98 or val == 99):
                # column stakind5 has value 98 or 99 (HLL stats)
                if inclHLL == False:
                    val = 0
                hll = True

            if val is None:
                val = 'NULL'
            elif isinstance(val, (str, unicode)) and val[0] == '{':
                val = val.replace("'", "''").replace('\\', '\\\\')
                val = "E'" + val + "'"
            if i == 25 and hll == True:
                # the actual HLL value in stavalues5
                if inclHLL == True:
                    rowVals.append('\t{0}::{1}'.format(str_val, 'bytea[]'))
                else:
                    rowVals.append('\t{0}'.format('NULL::int4[]'))
            else:
                rowVals.append('\t{0}::{1}'.format(val, typ))
        print pstring.format(vals[0], vals[2], ',\n'.join(rowVals))


def parseCmdLine():
    p = OptionParser(usage='Usage: %prog [options] <database>', version=gpsd_version, conflict_handler="resolve", epilog="WARNING: This tool collects statistics about your data, including most common values, which requires some data elements to be included in the output file. Please review output file to ensure it is within corporate policy to transport the output file.")
    p.add_option('-?', '--help', action='help', help='Show this help message and exit')
    p.add_option('-h', '--host', action='store',
                 dest='host', help='Specify a remote host')
    p.add_option('-p', '--port', action='store',
                 dest='port', help='Specify a port other than 5432')
    p.add_option('-U', '--user', action='store', dest='user',
                 help='Connect as someone other than current user')
    p.add_option('-s', '--stats-only', action='store_false', dest='dumpSchema',
                 default=True, help='Just dump the stats, do not do a schema dump')
    p.add_option('-l', '--hll', action='store_true', dest='dumpHLL',
                 default=False, help='Include HLL stats')
    return p


def isOrcaPresent(versionString):
    if 'HAWQ' in versionString:
        return True
    match = re.search('(?<=Greenplum Database )[0-9.]+', versionString)
    if match and len(match.group()):
        return LooseVersion(match.group()) >= LooseVersion('4.3')
    # No HAWQ and we can't find the GPDB Version
    sys.stderr.write('Cannot determine HAWQ/GPDB version.  Exiting.\n')
    sys.exit(1)


def main():
    parser = parseCmdLine()
    options, args = parser.parse_args()
    if len(args) != 1:
        parser.error("No database specified")
        exit(1)

    # OK - now let's setup all the arguments & options
    db = args[0]
    host = options.host or platform.node()
    user = options.user or os.getlogin()
    port = options.port or '5432'
    inclSchema = options.dumpSchema
    inclHLL = options.dumpHLL

    envOpts = os.environ
    envOpts['PGOPTIONS'] = pgoptions

    version = getVersion(envOpts)

    # Let's update the global variable so we can tell if we're using ORCA aware or not
    global orca
    orca = isOrcaPresent(version)

    if orca:
        envOpts['PGOPTIONS'] = pgoptions + ' -c optimizer=off'

    timestamp = datetime.datetime.today()

    # setup the connection info tuple with options
    connectionInfo = (host, port, user, db)
    connectionString = ':'.join([host, port, db, user, '', pgoptions, ''])

    sys.stdout.writelines(['\n-- Greenplum database Statistics Dump',
                           '\n-- Copyright (C) 2007 - 2014 Pivotal'
                           '\n-- Database: ' + db,
                           '\n-- Date:     ' + timestamp.date().isoformat(),
                           '\n-- Time:     ' + timestamp.time().isoformat(),
                           '\n-- CmdLine:  ' + ' '.join(sys.argv),
                           '\n-- Version:  ' + version + '\n\n'])

    sys.stdout.flush()
    if inclSchema:
        dumpGlobals(connectionInfo, envOpts)
        # Now be sure that when we load the rest we are doing it in the right
        # database
        print '\\connect ' + db + '\n'
        sys.stdout.flush()
        dumpSchema(connectionInfo, envOpts)
    else:
        print '\\connect ' + db + '\n'
        sys.stdout.flush()

    # Now we have to explicitly allow editing of these pg_class &
    # pg_statistic tables
    sys.stdout.writelines(['\n-- ',
                           '\n-- Allow system table modifications',
                           '\n-- ',
                           '\nset allow_system_table_mods=true;\n\n'])
    sys.stdout.flush()

    try:
        with closing(pgdb.connect(connectionString)) as connection:
            with closing(connection.cursor()) as cursor:
                dumpTupleCount(cursor)
                dumpStats(cursor, inclHLL)
        # Upon success, leave a warning message about data collected
        sys.stderr.writelines(['WARNING: This tool collects statistics about your data, including most common values, '
                               'which requires some data elements to be included in the output file.\n',
                               'Please review output file to ensure it is within corporate policy to transport the output file.\n'])
    except pgdb.DatabaseError, err:  # catch *all* exceptions
        sys.stderr.write('Error while dumping statistics:\n')
        sys.stderr.write(err.message)
        sys.exit(1)


if __name__ == "__main__":
        main()
