#!/usr/bin/env python
'''
minirepro utility

USAGE

minirepro <db-name> [-h <master-host>] [-U <username>] [-p <port>]
  -q <SQL-file> -f <repo-data-file>

minirepro -?


DESCRIPTION

For any SQL commands, the minirepro utility generates Greenplum Database
information for the commands. The information can be analyzed by Pivotal
support to perform root cause analysis.

The minirepro utility reads the input SQL file, passes the input SQL
command to the server's gp_dump_query_oids() to get the dependent
object ids. Then the utility uses pg_dump to dump the object DDLs, and
queries the system catalog to collect statistics of these relations.
The information is written to output file. The output includes a minimal
sets of DDLs and statistics of relations and functions that are related
to the input SQL commands.


PARAMETERS

<db-name>
  Name of the Greenplum Database.

-h <master-host>
  Greenplum Database master host. Default is localhost.

-U <username>
  Greenplum Database user name to log into the database and run the
  SQL command. Default is the PGUSER environment variable. If PGUSER
  is not defined, OS user name running the utility is used.

-p <port>
  Port that is used to connect to Greenplum Database.
  Default is the PGPORT environment variable. If PGPORT is not defined,
  the default value is 5432.

-q <SQL-file>
  A text file that contains SQL commands. The commands can be on
  multiple lines.

-f <repo-data-file>
  The output file that contains DDLs and statistics of relations
  and functions that are related to the SQL commands.

-? Show this help text and exit.


EXAMPLE

minirepro gptest -h locahost -U gpadmin -p 4444 -q ~/in.sql -f ~/out.sql
'''

import os, sys, re, json, platform, subprocess
from optparse import OptionParser
from pygresql import pgdb
from datetime import datetime

version = '1.13'
PATH_PREFIX = '/tmp/'
PGDUMP_FILE = 'pg_dump_out.sql'
sysnslist = "('pg_toast', 'pg_bitmapindex', 'pg_catalog', 'information_schema', 'gp_toolkit')"
# make search path safe
pgoptions = '-c gp_session_role=utility -c search_path='

class MRQuery(object):
    def __init__(self):
        self.schemas = []
        self.funcids = []
        self.relids = []

def E(query_str):
    return pgdb.escape_string(query_str)

def generate_timestamp():
    timestamp = datetime.now()
    return timestamp.strftime("%Y%m%d%H%M%S")

def result_iter(cursor, arraysize=1000):
    'An iterator that uses fetchmany to keep memory usage down'
    while True:
        results = cursor.fetchmany(arraysize)
        if not results:
            break
        for result in results:
            yield result

def get_server_version(cursor):
    query = "select version()"
    try:
        cursor.execute(query)
        vals = cursor.fetchone()
        return vals[0]
    except pgdb.DatabaseError as e:
        sys.stderr.write('\nError while trying to find HAWQ/GPDB version.\n\n' + str(e) + '\n\n')
        sys.exit(1)

def parse_cmd_line():
    p = OptionParser(usage='Usage: %prog <database> [options]', version='%prog '+version, conflict_handler="resolve", epilog="WARNING: This tool collects statistics about your data, including most common values, which requires some data elements to be included in the output file. Please review output file to ensure it is within corporate policy to transport the output file.")
    p.add_option('-?', '--help', action='help', help='Show this help message and exit')
    p.add_option('-h', '--host', action='store',
                 dest='host', help='Specify a remote host')
    p.add_option('-p', '--port', action='store',
                 dest='port', help='Specify a port other than 5432')
    p.add_option('-U', '--user', action='store', dest='user',
                 help='Connect as someone other than current user')
    p.add_option('-q', action='store', dest='query_file',
                 help='file name that contains the query')
    p.add_option('-f', action='store', dest='output_file',
                 help='minirepro output file name')
    p.add_option('-l', '--hll', action='store_true', dest='dumpHLL',
                 default=False, help='Include HLL stats')
    return p

def dump_query(connectionInfo, query_file):
    (host, port, user, db) = connectionInfo
    print "Extracting metadata from query file %s ..." % query_file

    with open(query_file, 'r') as query_f:
        sql_text = query_f.read()
    query = "select pg_catalog.gp_dump_query_oids('%s')" % E(sql_text)

    toolkit_sql = PATH_PREFIX + 'toolkit.sql'
    with open(toolkit_sql, 'w') as toolkit_f:
        toolkit_f.write(query)

    # disable .psqlrc to prevent unexpected timing and format output
    query_cmd = "psql %s --pset footer --no-psqlrc -Atq -h %s -p %s -U %s -f %s" % (db, host, port, user, toolkit_sql)
    print query_cmd

    p = subprocess.Popen(query_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=os.environ)

    if p.wait() is not 0:
        errormsg = p.communicate()[1]
        sys.stderr.writelines('\nError when executing function gp_dump_query_oids.\n\n' + errormsg + '\n\n')
        sys.exit(1)

    outmsg, errormsg = p.communicate()
    if errormsg:
        sys.stderr.writelines('\nError when executing function gp_dump_query_oids.\n\n' + errormsg + '\n\n')
        sys.exit(1)
    return outmsg

# relation and function oids will be extracted from the dump string
def parse_oids(cursor, json_oids):
    result = MRQuery()
    result.relids = json.loads(json_oids)['relids']
    result.funcids = json.loads(json_oids)['funcids']

    if len(result.relids) == 0:
        result.relids = '0'
    if len(result.funcids) == 0:
        result.funcids = '0'

    cat_query = "SELECT distinct(nspname) FROM pg_class c, pg_namespace n WHERE " \
                "c.relnamespace = n.oid AND c.oid IN (%s) " \
                "AND n.nspname NOT IN %s" % (result.relids, sysnslist)

    cursor.execute(cat_query)
    for vals in result_iter(cursor):
        result.schemas.append(vals[0])

    cat_query = "SELECT distinct(nspname) FROM pg_proc p, pg_namespace n WHERE " \
                "p.pronamespace = n.oid AND p.oid IN (%s) " \
                "AND n.nspname NOT IN %s" % (result.funcids, sysnslist)

    cursor.execute(cat_query)
    for vals in result_iter(cursor):
        result.schemas.append(vals[0])

    return result

def pg_dump_object(mr_query, connectionInfo, envOpts):
    out_file = PATH_PREFIX + PGDUMP_FILE
    dmp_cmd = 'pg_dump -h %s -p %s -U %s -sxO %s' % connectionInfo
    dmp_cmd = "%s --relation-oids %s --function-oids %s -f %s" % \
        (dmp_cmd, mr_query.relids, mr_query.funcids, E(out_file))
    print dmp_cmd
    p = subprocess.Popen(dmp_cmd, shell=True, stderr=subprocess.PIPE, env=envOpts)
    if p.wait() is not 0:
        sys.stderr.write('\nError while dumping schema.\n\n' + p.communicate()[1] + '\n\n')
        sys.exit(1)

def dump_tuple_count(cur, oid_str, f_out):
    stmt = "SELECT pgc.relname, pgn.nspname, pgc.relpages, pgc.reltuples, pgc.relallvisible FROM pg_class pgc, pg_namespace pgn " \
            "WHERE pgc.relnamespace = pgn.oid and pgc.oid in (%s) and pgn.nspname NOT LIKE 'pg_temp_%%'" % (oid_str)

    templateStmt = '-- Table: {1}\n' \
        'UPDATE pg_class\nSET\n' \
        '{0}\n' \
        'WHERE relname = \'{1}\' AND relnamespace = ' \
        '(SELECT oid FROM pg_namespace WHERE nspname = \'{2}\');\n\n'

    cur.execute(stmt)
    columns = [x[0] for x in cur.description]
    types = ['int', 'real', 'int']
    for vals in result_iter(cur):
        lines = []
        for col, val, typ in zip(columns[2:], vals[2:], types):
            # i.e. relpages = 1::int, reltuples = 1.0::real
            lines.append('\t%s = %s::%s' % (col, val, typ))
        updateStmt = templateStmt.format(E(',\n'.join(lines)), E(vals[0]), E(vals[1]))
        f_out.writelines(updateStmt)

def dump_stats(cur, oid_str, f_out, inclHLL):
    query = 'SELECT pgc.relname, pgn.nspname, pga.attname, pgtn.nspname, pgt.typname, pgs.* ' \
        'FROM pg_class pgc, pg_statistic pgs, pg_namespace pgn, pg_attribute pga, pg_type pgt, pg_namespace pgtn ' \
        'WHERE pgc.relnamespace = pgn.oid and pgc.oid in (%s) ' \
        'and pgn.nspname NOT LIKE \'pg_temp_%%\' ' \
        'and pgc.oid = pgs.starelid ' \
        'and pga.attrelid = pgc.oid ' \
        'and pga.attnum = pgs.staattnum ' \
        'and pga.atttypid = pgt.oid ' \
        'and pgt.typnamespace = pgtn.oid ' \
        'ORDER BY pgc.relname, pgs.staattnum' % (oid_str)

    # the outline of the DML used to create stats for one column in pg_statistic
    pstring = '--\n' \
        '-- Table: {0}, Attribute: {1}\n' \
        '--\n' \
        '{2}DELETE FROM pg_statistic WHERE starelid={3} AND staattnum={4};\n' \
        'INSERT INTO pg_statistic VALUES (\n' \
        '{5});\n\n'

    # the types of the columns in the pg_statistic table, except for starelid and stavalues[1-5]
    types = ['smallint',  # staattnum
             'boolean',
             'real',
             'integer',
             'real',
             'smallint',
             'smallint',
             'smallint',
             'smallint',
             'smallint',
             'oid',
             'oid',
             'oid',
             'oid',
             'oid',
             'real[]',
             'real[]',
             'real[]',
             'real[]',
             'real[]'
             ]

    cur.execute(query)

    for vals in result_iter(cur):
        starelid = "'%s.%s'::regclass" % (E(vals[1]), E(vals[0]))
        rowVals = ["\t%s" % (starelid)]
        schemaname = vals[1]
        typeschema = vals[3]
        typename = vals[4]
        if typeschema != "pg_catalog":
            # a type that is not built-in, qualify it with its schema name
            # and play it safe by double-quoting the identifiers
            typename = '"%s"."%s"' % (typeschema, typename)
        i = 0
        hll = False

        # determine the actual types of stavalues[1-5], the columns are defined as "anyarray"
        if vals[4][0] == '_':
            # column is an array type, use as is
            rowTypes = types + [typename] * 5
        else:
            # non-array type, make an array type out of it
            rowTypes = types + [typename + '[]'] * 5

        # create the content of the VALUES clause
        for val, typ in zip(vals[6:], rowTypes):
            i = i + 1
            str_val = "'%s'" % val
            if i == 10 and (val == 98 or val == 99):
                # column stakind5 has value 98 or 99 (HLL stats)
                if inclHLL == False:
                    val = 0
                hll = True

            if val is None:
                val = 'NULL'
            elif isinstance(val, (str, unicode)) and val[0] == '{':
                # use an Escape string for array constants and also add one
                # more layer of backslash escape symbols, since the string
                # will go through two layers of escape processing
                val = "E'%s'" % E(val.replace('\\', '\\\\'))
            if i == 25 and hll == True:
                # the actual HLL value in stavalues5
                if inclHLL == True:
                    rowVals.append('\t{0}::{1}'.format(str_val, 'bytea[]'))
                else:
                    rowVals.append('\t{0}'.format('NULL::int4[]'))
            else:
                rowVals.append('\t{0}::{1}'.format(val, typ))

        # For non-catalog tables we don't need to delete stats first
        # stats need to be deleted only for catalog tables
        linecomment = ''
        if schemaname != 'pg_catalog':
            linecomment = '-- ' # This will comment out the DELETE query

        f_out.writelines(pstring.format(E(vals[0]), E(vals[2]), linecomment, starelid, vals[6], ',\n'.join(rowVals)))

def main():
    parser = parse_cmd_line()
    options, args = parser.parse_args()
    if len(args) != 1:
        parser.error("No database specified")
        exit(1)

    # setup all the arguments & options
    envOpts = os.environ
    db = args[0]
    host = options.host or platform.node()
    user = options.user or ('PGUSER' in envOpts and envOpts['PGUSER']) or os.getlogin()
    port = options.port or ('PGPORT' in envOpts and envOpts['PGPORT']) or '5432'
    query_file = options.query_file
    output_file = options.output_file
    inclHLL = options.dumpHLL

    if query_file is None:
        parser.error("No query file specified.")
        exit(1)
    if output_file is None:
        parser.error("No output file specified.")
        exit(1)
    if not os.path.isfile(query_file):
        parser.error('Query file %s does not exist.' % query_file)
        exit(1)
    output_file = os.path.abspath(output_file)

    timestamp = generate_timestamp()
    global PATH_PREFIX
    PATH_PREFIX = PATH_PREFIX + timestamp + '/'

    # create tmp dir if not already there
    try:
        os.stat(PATH_PREFIX)
    except:
        os.mkdir(PATH_PREFIX)

    # setup the connection info tuple with options
    connectionInfo = (host, port, user, db)
    connectionString = ':'.join([host, port, db, user, '', pgoptions, ''])
    print "Connecting to database: host=%s, port=%s, user=%s, db=%s ..." % connectionInfo
    conn = pgdb.connect(connectionString)
    cursor = conn.cursor()

    # get server version, which is dumped to minirepro output file
    server_ver = get_server_version(cursor)

    """
    invoke gp_toolkit UDF, dump object oids as json text
    input: query file name
    output: json oids string
    """
    json_str = dump_query(connectionInfo, query_file)

    """
    parse json oids string, collect all things that need to be dumped
    input: json oids string
    output: MRQuery class (self.schemas, self.funcids, self.relids)
    """
    mr_query = parse_oids(cursor, json_str)

    # dump relations and functions
    print "Invoking pg_dump to dump DDL ..."
    pg_dump_object(mr_query, connectionInfo, envOpts)

    ### start writing out to stdout ###
    output_dir = os.path.dirname(output_file)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    f_out = open(output_file, 'w')
    ts = datetime.today()
    f_out.writelines(['-- MiniRepro ' + version,
                           '\n-- Copyright (C) 2007 - 2016 Pivotal'
                           '\n-- Database: ' + db,
                           '\n-- Date:     ' + ts.date().isoformat(),
                           '\n-- Time:     ' + ts.time().isoformat(),
                           '\n-- CmdLine:  ' + ' '.join(sys.argv),
                           '\n-- Version:  ' + server_ver + '\n\n'])

    # make sure we connect with the right database
    f_out.writelines('\\connect ' + db + '\n\n')

    # first create schema DDLs
    print "Writing schema DDLs ..."
    table_schemas = ["CREATE SCHEMA %s;\n" % E(schema) for schema in mr_query.schemas if schema != 'public']
    f_out.writelines(table_schemas)

    # write relation and function DDLs
    print "Writing relation and function DDLs ..."
    with open(PATH_PREFIX + PGDUMP_FILE, 'r') as f_pgdump:
	f_out.writelines(f_pgdump)

    # explicitly allow editing of these pg_class & pg_statistic tables
    f_out.writelines(['\n-- ',
                           '\n-- Allow system table modifications',
                           '\n-- ',
                           '\nset allow_system_table_mods=true;\n\n'])

    # dump table stats
    print "Writing table statistics ..."
    dump_tuple_count(cursor, mr_query.relids, f_out)

    # dump column stats
    print "Writing column statistics ..."
    dump_stats(cursor, mr_query.relids, f_out, inclHLL)

    cursor.close()
    conn.close()

    # attach query text
    print "Attaching raw query text ..."
    f_out.writelines(['\n-- ',
                       '\n-- Query text',
                       '\n-- \n\n'])

    with open(query_file, 'r') as query_f:
        for line in query_f:
            f_out.writelines('-- ' + line)

    f_out.writelines('\n-- MiniRepro completed.\n')
    f_out.close()

    print "--- MiniRepro completed! ---"

    # upon success, leave a warning message about data collected
    print 'WARNING: This tool collects statistics about your data, including most common values, which requires some data elements to be included in the output file.'
    print 'Please review output file to ensure it is within corporate policy to transport the output file.'


if __name__ == "__main__":
        main()
