Source code for roster_core.db_access

#!/usr/bin/python

# Copyright (c) 2009, Purdue University
# All rights reserved.
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 
# Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# Redistributions in binary form must reproduce the above copyright notice, this
# list of conditions and the following disclaimer in the documentation and/or
# other materials provided with the distribution.
# 
# Neither the name of the Purdue University nor the names of its contributors
# may be used to endorse or promote products derived from this software without
# specific prior written permission.
# 
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""This module is an API to access the dnsManagement database.

This module should only be run by servers with authentication layers
that are active. This module does not include authentication, but does
include authorization.

The api that will be exposed by this module is meant for use in a web
application or rpc server. This module is not for use in command line tools.

The two primary uses of this class are: 
1. to use convience functions to get large amounts of data out of the db 
  without large amounts of db queries. For usage on this consult the pydoc
  on the individual functions.

2. to Make/Remove/List rows in the database. The method that is used in this
  class is based on generic Make/Remove/Lock functions that take specifc
  dictionaries that correspond to the table that is being referenced. 

  Here is an example of how to remove rows from the acls table:

  acls_dict = db_instance.GetEmptyRowDict('acls')
  acls_dict['acl_name'] = 'test_acl'
  db_instance.StartTransaction()
  try:
    matching_rows = db_instance.ListRow('acls', acls_dict)
    for row in matching_rows:
      db_instance.RemoveRow('acls', row)
  except Exception:
    db_instance.EndTransaction(rollback=True)
  else:
    db_instance.EndTransaction()

Note: MySQLdb.Error can be raised in almost any function in this module. Please
      keep that in mind when using this module.
"""

__copyright__ = 'Copyright (C) 2009, Purdue University'
__license__ = 'BSD'
__version__ = '#TRUNK#'


import Queue
import threading
import time
import uuid
import warnings

import MySQLdb

import constants
import data_validation
import embedded_files
import errors
import helpers_lib
import codecs

[docs]class dbAccess(object): """This class provides the primary interface for connecting and interacting with the roster database. """ def __init__(self, db_host, db_user, db_passwd, db_name, big_lock_timeout, big_lock_wait, thread_safe=True, ssl=False, ssl_ca=None, ssl_cert=None, ssl_key=None, ssl_capath=None, ssl_cipher=None, db_debug=False, db_debug_log=None): """Instantiates the db_access class. Inputs: db_host: string of the database host name db_user: string of the user name used to connect to mysql db_passwd: string of password used to connect to mysql db_name: string of name of database in mysql server to use big_lock_timeout: integer of how long the big lock should be valid for big_lock_wait: integer of how long to wait for proccesses to finish before locking the database thread_safe: boolean of if db_acceess should be thread safe """ # Do some better checking of these args self.db_host = db_host self.db_user = db_user self.db_passwd = db_passwd self.db_name = db_name self.big_lock_timeout = big_lock_timeout self.big_lock_wait = big_lock_wait self.ssl = ssl self.ssl_ca = ssl_ca self.ssl_settings = {} self.db_debug = db_debug self.db_debug_log = db_debug_log if( self.ssl ): if( self.ssl_ca ): self.ssl_settings['ca'] = ssl_ca else: raise errors.ConfigError('ssl_ca not specified in config file.') self.transaction_init = False self.connection = None self.cursor = None # This is generated only when ListRow is called and is then cached for # the life of the object. self.foreign_keys = [] self.data_validation_instance = None self.locked_db = False self.thread_safe = thread_safe self.queue = Queue.Queue() self.now_serving = None self.queue_update_lock = threading.Lock()
[docs] def close(self): """Closes a connection that has been opened. A new connection will be created on StartTransaction. """ if( self.connection is not None ): self.connection.close() self.connection = None
[docs] def cursor_execute(self, execution_string, values={}): """This function allows for the capture of every mysql command that is run in this class. Inputs: execution_string: mysql command string values: dictionary of values for mysql command """ if( self.db_debug ): if( self.db_debug_log ): #If the execution_string contains a unicode character we must account #for it. So we need to use the codecs package to write to a utf-8 log #file, instead of ASCII like the 'normal' open() results in. debug_log_handle = codecs.open(self.db_debug_log, encoding='utf-8', mode='a') debug_log_handle.write(execution_string % values) debug_log_handle.write('\n') debug_log_handle.close() else: print execution_string % values try: self.cursor.execute(execution_string, values) except MySQLdb.ProgrammingError: raise except MySQLdb.Error, e: if( e[0] in errors.PARSABLE_MYSQL_ERRORS ): raise errors.DatabaseError(e) else: raise
[docs] def StartTransaction(self): """Starts a transaction. Also it starts a db connection if none exists or it times out. Always creates a new cursor. This function also serializes all requests on this object and if the big lock has been activated will wait for it to be released. Raises: TransactionError: Cannot start new transaction last transaction not committed or rolled-back. """ if( self.thread_safe ): unique_id = uuid.uuid4() self.queue.put(unique_id) while_sleep = 0 while( unique_id != self.now_serving ): time.sleep(while_sleep) self.queue_update_lock.acquire() if( self.now_serving is None ): self.now_serving = self.queue.get() self.queue_update_lock.release() while_sleep = 0.005 else: if( self.transaction_init ): raise errors.TransactionError('Cannot start new transaction last ' 'transaction not committed or ' 'rolled-back.') if( self.connection is not None ): try: self.cursor = self.connection.cursor(MySQLdb.cursors.DictCursor) self.cursor_execute('DO 0') # NOOP to test connection except MySQLdb.OperationalError: self.connection = None if( self.connection is None ): if( self.ssl ): self.connection = MySQLdb.connect( host=self.db_host, user=self.db_user, passwd=self.db_passwd, db=self.db_name, use_unicode=True, charset='utf8', ssl=self.ssl_settings) else: self.connection = MySQLdb.connect( host=self.db_host, user=self.db_user, passwd=self.db_passwd, db=self.db_name, use_unicode=True, charset='utf8') self.cursor = self.connection.cursor(MySQLdb.cursors.DictCursor) while_sleep = 0 db_lock_locked = 1 while( db_lock_locked ): time.sleep(while_sleep) try: self.cursor_execute('SELECT `locked`, `lock_last_updated`, ' 'NOW() as `now` from `locks` WHERE ' '`lock_name`="db_lock_lock"') rows = self.cursor.fetchall() except MySQLdb.ProgrammingError: break if( not rows ): break lock_last_updated = rows[0]['lock_last_updated'] db_lock_locked = rows[0]['locked'] now = rows[0]['now'] if( (now - lock_last_updated).seconds > self.big_lock_timeout ): break while_sleep = 1 self.transaction_init = True
[docs] def EndTransaction(self, rollback=False): """Ends a transaction. Also does some simple checking to make sure a connection was open first and releases itself from the current queue. Inputs: rollback: boolean of if the transaction should be rolled back Raises: TransactionError: Must run StartTansaction before EndTransaction. """ if( not self.thread_safe ): if( not self.transaction_init ): raise errors.TransactionError('Must run StartTansaction before ' 'EndTransaction.') try: self.cursor.close() if( rollback ): self.connection.rollback() else: self.connection.commit() finally: self.transaction_init = False if( self.thread_safe ): if( not self.queue.empty() ): self.now_serving = self.queue.get() else: self.now_serving = None
[docs] def CheckMaintenanceFlag(self): """Checks the maintenance flag in the database. Outputs: bool: boolean of maintenance mode """ row = self.ListRow('locks', {'lock_name': u'maintenance', 'locked': None}) return bool(row[0]['locked'])
[docs] def LockDb(self): """This function is to lock the whole database for consistent data retrevial. This function expects for self.db_instance.cursor to be instantiated and valid. Raises: TransactionError: Must unlock tables before re-locking them. """ if( self.locked_db is True ): raise errors.TransactionError('Must unlock tables before re-locking them') self.cursor_execute('UPDATE `locks` SET `locked`=1 WHERE ' '`lock_name`="db_lock_lock"') time.sleep(self.big_lock_wait) self.cursor_execute( 'LOCK TABLES %s READ' % ' READ, '.join(self.ListTableNames())) self.locked_db = True
[docs] def UnlockDb(self): """This function is to unlock the whole database. This function expects for self.db_instance.cursor to be instantiated and valid. It also expects all tables to be locked. Raises: TransactionError: Must lock tables before unlocking them. """ if( self.locked_db is False ): raise errors.TransactionError('Must lock tables before unlocking them') self.cursor_execute('UNLOCK TABLES') self.cursor_execute('UPDATE `locks` SET `locked`=0 WHERE ' '`lock_name`="db_lock_lock"') self.locked_db = False
[docs] def InitDataValidation(self): """Get all reserved words and group permissions and init the data_validation_instance """ cursor = self.connection.cursor() try: if( self.db_debug ): if( self.db_debug_log ): #If the execution_string contains a unicode character we must account #for it. So we need to use the codecs package to write to a utf-8 log #file, instead of ASCII like the 'normal' open() results in. debug_log_handle = codecs.open(self.db_debug_log, encoding='utf-8', mode='a') debug_log_handle.write('SELECT reserved_word FROM reserved_words') debug_log_handle.write('\n') debug_log_handle.close() else: print 'SELECT reserved_word FROM reserved_words' cursor.execute('SELECT reserved_word FROM reserved_words') reserved_words_rows = cursor.fetchall() if( self.db_debug ): if( self.db_debug_log ): debug_log_handle = codecs.open(self.db_debug_log, encoding='utf-8', mode='a') debug_log_handle.write('SELECT record_type FROM record_types') debug_log_handle.write('\n') debug_log_handle.close() else: print 'SELECT record_type FROM record_types' cursor.execute('SELECT record_type FROM record_types') record_types_rows = cursor.fetchall() finally: cursor.close() words = [row[0] for row in reserved_words_rows] record_types = [row[0] for row in record_types_rows] self.data_validation_instance = data_validation.DataValidation( words, record_types)
[docs] def MakeRow(self, table_name, row_dict): """Creates a row in the database using the table name and row dict Inputs: table_name: string of valid table name from constants row_dict: dictionary that coresponds to table_name Raises: InvalidInputError: Table name not valid TransactionError: Must run StartTansaction before inserting Outputs: int: last insert id """ if( not table_name in helpers_lib.GetValidTables() ): raise errors.InvalidInputError('Table name not valid: %s' % table_name) if( not self.transaction_init ): raise errors.TransactionError('Must run StartTansaction before ' 'inserting.') if( self.data_validation_instance is None ): self.InitDataValidation() self.data_validation_instance.ValidateRowDict(table_name, row_dict) column_names = [] column_assignments = [] for k in row_dict.iterkeys(): column_names.append(k) column_assignments.append('%s%s%s' % ('%(', k, ')s')) query = 'INSERT INTO %s (%s) VALUES (%s)' % (table_name, ','.join(column_names), ','.join(column_assignments)) self.cursor_execute(query, row_dict) return self.cursor.lastrowid
[docs] def TableRowCount(self, table_name): """Counts the amount of records in a table and returns it. Inputs: table_name: string of valid table name from constants Raises: InvalidInputError: Table name not valid TransactionError: Must run StartTansaction before getting row count. Outputs: int: number of rows found """ if( not table_name in helpers_lib.GetValidTables() ): raise errors.InvalidInputError('Table name not valid: %s' % table_name) if( not self.transaction_init ): raise errors.TransactionError('Must run StartTansaction before getting ' 'row count.') self.cursor_execute('SELECT COUNT(*) FROM %s' % table_name) row_count = self.cursor.fetchone() return row_count['COUNT(*)']
[docs] def RemoveRow(self, table_name, row_dict): """Removes a row in the database using the table name and row dict Inputs: table_name: string of valid table name from constants row_dict: dictionary that coresponds to table_name Raises: InvalidInputError: Table name not valid TransactionError: Must run StartTansaction before deleting Outputs: int: number of rows affected """ if( not table_name in helpers_lib.GetValidTables() ): raise errors.InvalidInputError('Table name not valid: %s' % table_name) if( not self.transaction_init ): raise errors.TransactionError('Must run StartTansaction before deleting.') if( self.data_validation_instance is None ): self.InitDataValidation() self.data_validation_instance.ValidateRowDict(table_name, row_dict) where_list = [] for k in row_dict.iterkeys(): where_list.append('%s=%s%s%s' % (k, '%(', k, ')s')) query = 'DELETE FROM %s WHERE %s' % (table_name, ' AND '.join(where_list)) self.cursor_execute(query, row_dict) return self.cursor.rowcount
[docs] def UpdateRow(self, table_name, search_row_dict, update_row_dict): """Updates a row in the database using search and update dictionaries. Inputs: table_name: string of valid table name from constants search_row_dict: dictionary that coresponds to table_name containing search args update_row_dict: dictionary that coresponds to table_name containing update args Raises: InvalidInputError: Table name not valid TransactionError: Must run StartTansaction before inserting Outputs: int: number of rows affected """ if( not table_name in helpers_lib.GetValidTables() ): raise errors.InvalidInputError('Table name not valid: %s' % table_name) if( not self.transaction_init ): raise errors.TransactionError('Must run StartTansaction before deleting.') if( self.data_validation_instance is None ): self.InitDataValidation() self.data_validation_instance.ValidateRowDict(table_name, search_row_dict, none_ok=True) self.data_validation_instance.ValidateRowDict(table_name, update_row_dict, none_ok=True) query_updates = [] query_searches = [] combined_dict = {} for k, v in update_row_dict.iteritems(): if( v is not None ): query_updates.append('%s%s%s%s' % (k, '=%(update_', k, ')s')) combined_dict['update_%s' % k] = v for k, v in search_row_dict.iteritems(): if( v is not None ): query_searches.append('%s=%s%s%s' % (k, '%(search_', k, ')s')) combined_dict['search_%s' % k] = v query = 'UPDATE %s SET %s WHERE %s' % (table_name, ','.join(query_updates), ' AND '.join(query_searches)) self.cursor_execute(query, combined_dict) return self.cursor.rowcount
[docs] def ListRow(self, *args, **kwargs): """Lists rows in the database using a dictionary of tables. Then returns the rows found. Joins are auto generated on the fly based on foreign keys in the database. Inputs: args: pairs of string of table name and dict of rows kwargs: lock_rows: default False column: column to search range on, if using multiple tables, the column must be in the first table in args. range_values: range tuple of values to search within for on column is_date: boolean of if range is of dates example usage: ListRow('users', user_row_dict, 'user_group_assignments', user_assign_row_dict, lock_rows=True) Raises: TransactionError: Must run StartTansaction before inserting UnexpectedDataError: If is_date is specified you must specify column and range UnexpectedDataError: If column or range is specified both are needed InvalidInputError: Found unknown option(s) UnexpectedDataError: No args given, must at least have a pair of table name and row dict UnexpectedDataError: Number of unnamed args is not even. Args should be entered in pairs of table name and row dict. InvalidInputError: Table name not valid InvalidInputError: Column not found in row UnexpectedDataError: Column in table is not a DateTime type UnexpectedDataError: Date from range is not a valid datetime object InvalidInputError: Range must be int if is_date is not set InvalidInputError: Multiple tables were passed in but no joins were found Outputs: tuple of row dicts consisting of all the tables that were in the input. all column names in the db are unique so no colisions occour example: ({'user_name': 'sharrell', 'access_level': 10, 'user_group_assignments_group_name: 'cs', 'user_group_assignments_user_name: 'sharrell'}, {'user_name': 'sharrell', 'access_level': 10, 'user_group_assignments_group_name: 'eas', 'user_group_assignments_user_name: 'sharrell'}) """ if( not self.transaction_init ): raise errors.TransactionError('Must run StartTansaction before getting ' 'data.') if( self.data_validation_instance is None ): self.InitDataValidation() valid_tables = helpers_lib.GetValidTables() tables = {} table_names = [] lock_rows = False column = None range_values = () is_date = None if( kwargs ): if( 'lock_rows' in kwargs ): lock_rows = kwargs['lock_rows'] del kwargs['lock_rows'] if( 'column' in kwargs ): column = kwargs['column'] del kwargs['column'] if( 'range_values' in kwargs ): range_values = kwargs['range_values'] del kwargs['range_values'] if( 'is_date' in kwargs ): is_date = kwargs['is_date'] del kwargs['is_date'] if( column is None and is_date is not None ): raise errors.UnexpectedDataError('If is_date is specified you must ' 'specify column and range') if( bool(column) ^ bool(range_values) ): raise errors.UnexpectedDataError('If column or range is specified ' 'both are needed') if( kwargs ): raise errors.InvalidInputError('Found unknown option(s): ' '%s' % kwargs.keys()) if( not args ): raise errors.UnexpectedDataError('No args given, must at least have a ' 'pair of table name and row dict') if( len(args) % 2 ): raise errors.UnexpectedDataError( 'Number of unnamed args is not even. Args ' 'should be entered in pairs of table name ' 'and row dict.') count = 0 for arg in args: count += 1 if( count % 2 ): if( not arg in valid_tables ): raise errors.InvalidInputError('Table name not valid: %s' % arg) current_table_name = arg else: # do checking in validate row dict to check if it is a dict self.data_validation_instance.ValidateRowDict(current_table_name, arg, none_ok=True, all_none_ok=True) tables[current_table_name] = arg table_names.append(current_table_name) if( range_values ): if( column not in args[1] ): raise errors.InvalidInputError('Column %s not found in row' 'dictionary: %s' % (column, args[1])) if( is_date ): if( constants.TABLES[args[0]][column] != 'DateTime' ): raise errors.UnexpectedDataError('column: %s in table %s is not a' 'DateTime type' % (column, args[0])) for date in range_values: if( not self.data_validation_instance.isDateTime(date) ): raise errors.UnexpectedDataError( 'Date: %s from range is not a valid ' 'datetime object' % date) else: for value in range_values: if( not self.data_validation_instance.isUnsignedInt(value) ): raise errors.InvalidInputError('Range must be int if is_date ' 'is not set') query_where = [] if( len(tables) > 1 ): if( not self.foreign_keys ): self.cursor_execute('SELECT table_name, column_name, ' 'referenced_table_name, referenced_column_name ' 'FROM information_schema.key_column_usage WHERE ' 'referenced_table_name IS NOT NULL AND ' 'referenced_table_schema="%s"' % self.db_name) self.foreign_keys = self.cursor.fetchall() for key in self.foreign_keys: if( key['table_name'] in table_names and key['referenced_table_name'] in table_names ): query_where.append('(%(table_name)s.%(column_name)s=' '%(referenced_table_name)s.' '%(referenced_column_name)s)' % key) if( not query_where ): raise errors.InvalidInputError('Multiple tables were passed in but no ' 'joins were found') column_names = [] search_dict = {} for table_name, row_dict in tables.iteritems(): for key, value in row_dict.iteritems(): column_names.append('%s.%s' % (table_name, key)) if( value is not None ): search_dict[key] = value query_where.append('%s%s%s%s' % (key, '=%(', key, ')s')) if( range_values ): search_dict['start'] = range_values[0] search_dict['end'] = range_values[1] query_where.append('%s%s%s%s' % (column, '>=%(start)s AND ', column, '<=%(end)s')) query_end = '' if( query_where ): query_end = 'WHERE %s' % ' AND '.join(query_where) if( lock_rows ): query_end = '%s FOR UPDATE' % query_end query = 'SELECT %s FROM %s %s' % (','.join(column_names), ','.join(table_names), query_end) self.cursor_execute(query, search_dict) return self.cursor.fetchall()
[docs] def GetEmptyRowDict(self, table_name): """Gives a dict that has all the members needed to interact with the the given table using the Make/Remove/ListRow functions. Inputs: table_name: string of valid table name from constants Raises: InvalidInputError: Table name not valid Outputs: dictionary: of empty row for specificed table. example acls dict: {'acl_name': None 'acl_range_allowed: None, 'acl_cidr_block': None } """ row_dict = helpers_lib.GetRowDict(table_name) if( not row_dict ): raise errors.InvalidInputError('Table name not valid: %s' % table_name) for key in row_dict.iterkeys(): row_dict[key] = None return row_dict # Not sure this is needed, buuuuut.
[docs] def GetValidTables(self): """Export this function to the top level of the db_access stuff so it can be used without importing un-needed classes. Outputs: list: valid table names """ helpers_lib.GetValidTables()
[docs] def GetRecordArgsDict(self, record_type): """Get args for a specific record type from the db and shove them into a dictionary. Inputs: record_type: string of record type Raises: InvalidInputError: Unknown record type Outputs: dictionary: keyed by argument name with values of data type of that arg example: {'mail_host': 'Hostname' 'priority': 'UnsignedInt'} """ search_record_arguments_dict = self.GetEmptyRowDict('record_arguments') search_record_arguments_dict['record_arguments_type'] = record_type self.StartTransaction() try: record_arguments = self.ListRow('record_arguments', search_record_arguments_dict) finally: self.EndTransaction() record_arguments_dict = {} if( not record_arguments ): raise errors.InvalidInputError('Unknown record type: %s' % record_type) for record_argument in record_arguments: record_arguments_dict[record_argument['argument_name']] = ( record_argument['argument_data_type']) return record_arguments_dict
[docs] def GetEmptyRecordArgsDict(self, record_type): """Gets empty args dict for a specific record type Inputs: record_type: string of record type Outputs: dictionary: keyed by argument name with values of None example: {'mail_host': None 'priority': None} """ args_dict = self.GetRecordArgsDict(record_type) for k in args_dict.iterkeys(): args_dict[k] = None return args_dict
[docs] def ValidateRecordArgsDict(self, record_type, record_args_dict, none_ok=False): """Type checks record args dynamically. Inputs: record_type: string of record_type record_args_dict: dictionary for args keyed by arg name. a filled out dict from GetEmptyRecordArgsDict() none_ok: boolean of if None types should be acepted. Raises: InvalidInputError: dict for record type should have these keys FucntionError: No function to check data type UnexpectedDataError: Invalid data type """ record_type_dict = self.GetRecordArgsDict(record_type) if( not set(record_type_dict.keys()) == set(record_args_dict.keys()) ): raise errors.InvalidInputError('dict for record type %s should have ' 'these keys: %s' % (record_type, record_type_dict)) if( self.data_validation_instance is None ): self.InitDataValidation() data_validation_methods = dir(data_validation.DataValidation([], [])) for record_arg_name in record_args_dict.keys(): if( not 'is%s' % record_type_dict[record_arg_name] in data_validation_methods ): raise errors.FucntionError('No function to check data type %s' % record_type_dict[record_arg_name]) if( none_ok and record_args_dict[record_arg_name] is None ): continue if( not getattr(self.data_validation_instance, 'is%s' % record_type_dict[record_arg_name])( record_args_dict[record_arg_name]) ): raise errors.UnexpectedDataError('Invalid data type %s: %s' % ( record_type_dict[record_arg_name], record_args_dict[record_arg_name]))
[docs] def ListTableNames(self): """Lists all tables in the database. Outputs: List: List of tables """ query = 'SHOW TABLES' self.cursor_execute(query) tables = self.cursor.fetchall() table_list = [] for table_dict in tables: for table in table_dict: table_list.append(table_dict[table]) return table_list
[docs] def GetCurrentTime(self): """Returns datetime object of current time in database. Outputs: datetime: current time in the database """ self.cursor_execute('SELECT NOW()') return self.cursor.fetchone()['NOW()']
[docs] def CreateRosterDatabase(self, schema=None): """Destroys existing table structure in database and replaces it with schema that is passed in(or default schema). DO NOT RUN THIS AGAINST A DATABASE THAT IS NOT READY TO BE CLEARED This function is used because of a poorly understood bug in MySQLdb that does not allow our schema to be executed as one big query. The work around is splitting the whole thing up and commiting each piece separately. Inputs: schema: string of sql schema """ if( schema is None ): schema = embedded_files.SCHEMA_FILE schema_lines = schema.split('\n') execute_lines = [] continued_line = [] for line in schema_lines: if( line.lstrip().startswith('#') ): continue if( line.endswith(';') ): continued_line.append(line) execute_lines.append('\n'.join(continued_line)) continued_line = [] else: continued_line.append(line) warnings.filterwarnings('ignore', 'Unknown table.*') for line in execute_lines: self.StartTransaction() try: self.cursor_execute(line) finally: self.EndTransaction()
[docs] def DumpDatabase(self): """This will dump the entire database to memory. This would be done by mysqldump but it needs to be done in the same lock as other processes. So this is a simple mysqldump function. Outputs: Dictionary: Dictionary with keys of table name and schema/data for each table as values. """ table_data = {} self.cursor_execute('SHOW TABLES') table_names = self.cursor.fetchall() self.cursor_execute('SET OPTION SQL_QUOTE_SHOW_CREATE=1') for table_name in table_names: table_name = table_name.values()[0] table_data[table_name] = {} self.cursor_execute('SHOW CREATE TABLE %s' % table_name) table_data[table_name]['schema'] = self.cursor.fetchone()['Create Table'] self.cursor_execute('DESCRIBE %s' % table_name) table_data[table_name]['columns'] = [] table_descriptions = self.cursor.fetchall() for table_description in table_descriptions: table_data[table_name]['columns'].append(table_description['Field']) self.cursor_execute('SELECT %s FROM %s' % (','.join(table_data[table_name]['columns']), table_name)) table_rows = self.cursor.fetchall() table_data[table_name]['rows'] = [] for row in table_rows: row_dict = {} for key, value in row.iteritems(): row_dict[key] = self.connection.literal(value) if( isinstance(row_dict[key], str) ): row_dict[key] = unicode(row_dict[key], 'utf-8') table_data[table_name]['rows'].append(row_dict) return table_data ### These functions are for the user class
[docs] def GetUserAuthorizationInfo(self, user): """Grabs authorization data from the db and returns a dict. This function does two selects on the db, one for forward and one for reverse zones. It also parses the data into a dict for ease of use. Inputs: user: string of username Raises: UnexpectedDataError: Row did not contain reverse_range_permissions or forward_zone_permissions Outputs: dict: dict with all the relevant information example: {'user_access_level': '2', 'user_name': 'shuey', 'forward_zones': [ {'zone_name': 'cs.university.edu', 'group_permission': ['a', 'aaaa']}, {'zone_name': 'eas.university.edu', 'group_permission': ['a', 'aaaa', 'cname']}, {'zone_name': 'bio.university.edu', 'group_permission': ''a', 'ns'}], 'groups': ['cs', 'bio'], 'reverse_ranges': [ {'cidr_block': '192.168.0.0/24', 'group_permission': ['ptr', 'cname']}, {'cidr_block': '192.168.0.0/24', 'group_permission': ['ptr']}, {'cidr_block': '192.168.1.0/24', 'group_permission': ['ptr', 'cname']}]} """ auth_info_dict = {} db_data = [] users_dict = self.GetEmptyRowDict('users') users_dict['user_name'] = user groups_dict = self.GetEmptyRowDict('groups') user_group_assignments_dict = self.GetEmptyRowDict('user_group_assignments') forward_zone_permissions_dict = self.GetEmptyRowDict( 'forward_zone_permissions') reverse_range_permissions_dict = self.GetEmptyRowDict( 'reverse_range_permissions') group_forward_permissions_dict = self.GetEmptyRowDict( 'group_forward_permissions') group_reverse_permissions_dict = self.GetEmptyRowDict( 'group_reverse_permissions') auth_info_dict['user_name'] = user auth_info_dict['groups'] = [] auth_info_dict['forward_zones'] = [] auth_info_dict['reverse_ranges'] = [] self.StartTransaction() try: db_data.extend(self.ListRow('users', users_dict, 'groups', groups_dict, 'user_group_assignments', user_group_assignments_dict, 'forward_zone_permissions', forward_zone_permissions_dict, 'group_forward_permissions', group_forward_permissions_dict)) db_data.extend(self.ListRow('users', users_dict, 'groups', groups_dict, 'user_group_assignments', user_group_assignments_dict, 'reverse_range_permissions', reverse_range_permissions_dict, 'group_reverse_permissions', group_reverse_permissions_dict)) if( not db_data ): self.cursor_execute('SELECT access_level FROM users ' 'WHERE user_name="%s"' % user) db_data.extend(self.cursor.fetchall()) if( db_data ): auth_info_dict['user_access_level'] = db_data[0]['access_level'] return auth_info_dict else: return {} finally: self.EndTransaction() auth_info_dict['user_access_level'] = db_data[0]['access_level'] for row in db_data: if( row.has_key('group_forward_permissions_group_permission') ): if( not row['user_group_assignments_group_name'] in auth_info_dict['groups'] ): auth_info_dict['groups'].append( row['user_group_assignments_group_name']) if( not {'zone_name': row['forward_zone_permissions_zone_name'], 'group_permission': row[ 'group_forward_permissions_group_permission']} in (auth_info_dict['forward_zones']) ): auth_info_dict['forward_zones'].append( {'zone_name': row['forward_zone_permissions_zone_name'], 'group_permission': row[ 'group_forward_permissions_group_permission']}) elif( row.has_key('group_reverse_permissions_group_permission') ): if( not row['user_group_assignments_group_name'] in auth_info_dict['groups'] ): auth_info_dict['groups'].append( row['user_group_assignments_group_name']) if( not {'cidr_block': row['reverse_range_permissions_cidr_block'], 'group_permission': row[ 'group_reverse_permissions_group_permission']} in auth_info_dict['reverse_ranges'] ): auth_info_dict['reverse_ranges'].append( {'cidr_block': row['reverse_range_permissions_cidr_block'], 'group_permission': row[ 'group_reverse_permissions_group_permission']}) elif( auth_info_dict.has_key('forward_zones') and auth_info_dict[ 'user_access_level'] >= 64 ): if( {'zone_name': row['forward_zone_permissions_zone_name'], 'group_permission': []} not in auth_info_dict['forward_zones'] ): auth_info_dict['forward_zones'].append( {'zone_name': row['forward_zone_permissions_zone_name'], 'group_permission': []}) elif( auth_info_dict.has_key('reverse_ranges') and auth_info_dict[ 'user_access_level'] >= 64 ): if( {'cidr_block': row['reverse_permissions_cidr_block'], 'group_permission': []} not in auth_info_dict['reverse_ranges'] ): auth_info_dict['reverse_ranges'].append( {'cidr_block': row['reverse_range_permissions_cidr_block'], 'group_permission': []}) else: raise errors.RecordError('Returned row is corrupt.') return auth_info_dict
[docs] def GetZoneOrigin(self, zone_name, view_name): """Returns zone origin of zone_name that is passed in. If no zone origin found, return None Inputs: zone_name: string of zone_name view_name: string of view_name Outputs: string of zone origin or None """ zone_view_assignments_dict = self.GetEmptyRowDict( 'zone_view_assignments') zone_view_assignments_dict['zone_view_assignments_zone_name'] = zone_name zone_view_assignments_dict[ 'zone_view_assignments_view_dependency'] = view_name zone_view_assignment_rows = self.ListRow( 'zone_view_assignments', zone_view_assignments_dict) if( zone_view_assignment_rows ): return zone_view_assignment_rows[0]['zone_origin'] else: return None # vi: set ai aw sw=2: