Module postgres
[hide private]
[frames] | no frames]

Source Code for Module postgres

  1   
  2  from c3errors import * 
  3  from configParser import C3Object 
  4  from baseStore import SimpleStore 
  5  from recordStore import SimpleRecordStore 
  6  from resultSet import SimpleResultSetItem 
  7  from documentFactory import BaseDocumentStream, BaseDocumentFactory 
  8  from resultSetStore import SimpleResultSetStore 
  9  from resultSet import SimpleResultSet 
 10  from baseObjects import IndexStore 
 11  import dynamic 
 12  from utils import elementType, getFirstData 
 13  from utils import nonTextToken 
 14  from index import SimpleIndex 
 15  from PyZ3950 import SRWDiagnostics 
 16   
 17   
 18  # Consider psycopg 
 19  # ...psycopg2 has segfaults as of 2006/07/10 
 20  import pg 
 21  import time 
 22  from PyZ3950 import CQLParser as cql 
 23   
 24  # Idea is to take the results of an SQL search and XMLify them into documents. 
25 -class PostgresDocumentStream(BaseDocumentStream):
26 - def __init__(self, stream=None, format='', tag='', codec=''):
27 raise NotImplementedError
28
29 - def find_documents(self, cache=0):
30 pass
31
32 -class PostgresDocumentFactory(BaseDocumentFactory):
33 database = '' 34 host = '' 35 port = 0 36
37 - def __init__(self, session, config, parent):
38 BaseDocumentFactory.__init__(self, session, config, parent) 39 self.database = self.get_setting(session, 'database', '') 40 self.host = self.get_setting(session, 'host', 'localhost') 41 self.port = int(self.get_setting(session, 'port', '5432'))
42 # XXX add PostgresDocumentStream to stream types 43 # query info to come in .load() 44 45
46 -class PostgresStore(SimpleStore):
47 cxn = None 48 relations = {} 49
50 - def __init__(self, session, config, parent):
51 C3Object.__init__(self, session, config, parent) 52 self.database = self.get_path(session, 'databaseName', 'cheshire3') 53 self.table = self.get_path(session, 'tableName', parent.id + '_' + self.id) 54 self.idNormalizer = self.get_path(session, 'idNormalizer', None) 55 self._verifyDatabases(session)
56
57 - def _handleConfigNode(self, session, node):
58 if (node.nodeType == elementType and node.localName == 'relations'): 59 self.relations = {} 60 for rel in node.childNodes: 61 if (rel.nodeType == elementType and rel.localName == 'relation'): 62 relName = rel.getAttributeNS(None, 'name') 63 fields = [] 64 for fld in rel.childNodes: 65 if fld.nodeType == elementType: 66 if fld.localName == 'object': 67 oid = getFirstData(fld) 68 fields.append([oid, 'VARCHAR', oid]) 69 elif fld.localName == 'field': 70 fname = fld.getAttributeNS(None, 'name') 71 ftype = getFirstData(fld) 72 fields.append([fname, ftype, '']) 73 self.relations[relName] = fields
74 75
76 - def _verifyDatabases(self, session):
77 try: 78 self.cxn = pg.connect(self.database) 79 except pg.InternalError, e: 80 raise ConfigFileException(e.args) 81 82 try: 83 query = "SELECT identifier FROM %s LIMIT 1" % self.table 84 res = self.query(query) 85 except pg.ProgrammingError, e: 86 # no table for self, initialise 87 query = """ 88 CREATE TABLE %s (identifier VARCHAR PRIMARY KEY, 89 data BYTEA, 90 digest VARCHAR(41), 91 size INT, 92 schema VARCHAR, 93 parentStore VARCHAR, 94 parentIdentifier VARCHAR, 95 timeCreated TIMESTAMP, 96 timeModified TIMESTAMP); 97 """ % self.table 98 self.query(query) 99 100 101 # And check additional relations 102 for (name, fields) in self.relations.items(): 103 try: 104 query = "SELECT identifier FROM %s_%s LIMIT 1" % (self.id,name) 105 res = self.query(query) 106 except pg.ProgrammingError, e: 107 # No table for relation, initialise 108 query = "CREATE TABLE %s_%s (identifier SERIAL PRIMARY KEY, " % (self.id, name) 109 for f in fields: 110 query += ("%s %s" % (f[0], f[1])) 111 if f[2]: 112 # Foreign Key 113 query += (" REFERENCES %s (identifier)" % f[2]) 114 query += ", " 115 query = query[:-2] + ");" 116 res = self.query(query)
117
118 - def _openContainer(self, session):
119 if self.cxn == None: 120 self.cxn = pg.connect(self.database)
121
122 - def _closeContainer(self, session):
123 self.cxn.close() 124 self.cxn = None
125
126 - def query(self, query):
127 query = query.encode('utf-8') 128 res = self.cxn.query(query) 129 return res
130
131 - def begin_storing(self, session):
132 self._openContainer(session) 133 return None
134
135 - def commit_storing(self, session):
136 self._closeContainer(session) 137 return None
138
139 - def generate_id(self, session):
140 self._openContainer(session) 141 # Find greatest current id 142 if (self.currentId == -1 or session.environment == 'apache'): 143 query = "SELECT identifier FROM %s ORDER BY identifier DESC LIMIT 1;" % self.table 144 res = self.query(query) 145 try: 146 id = int(res.dictresult()[0]['identifier']) +1 147 except: 148 id = 0 149 self.currentId = id 150 return id 151 else: 152 self.currentId += 1 153 return self.currentId
154
155 - def store_data(self, session, id, data, size=0):
156 self._openContainer(session) 157 id = str(id) 158 now = time.strftime("%Y-%m-%d %H:%M:%S") 159 if (self.idNormalizer <> None): 160 id = self.idNormalizer.process_string(session, id) 161 data = data.replace(nonTextToken, '\\\\000\\\\001') 162 163 query = "INSERT INTO %s (identifier, timeCreated) VALUES ('%s', '%s');" % (self.table, id, now) 164 try: 165 self.query(query) 166 except: 167 # already exists 168 pass 169 170 ndata = data.replace("'", "\\'") 171 if (size): 172 query = "UPDATE %s SET data = '%s', size = %d, timeModified = '%s' WHERE identifier = '%s';" % (self.table, ndata, size, now, id) 173 else: 174 query = "UPDATE %s SET data = '%s', timeModified = '%s' WHERE identifier = '%s';" % (self.table, ndata, now, id) 175 176 try: 177 self.query(query) 178 except: 179 # Uhhh... 180 raise 181 return None 182
183 - def verify_checkSum(self, session, id, data, store=1):
184 # Check record doesn't already exist 185 digest = self.get_setting(session, "digest") 186 if (digest): 187 if (digest == 'md5'): 188 dmod = md5 189 elif (digest == 'sha'): 190 dmod = sha 191 else: 192 raise ConfigFileException("Unknown digest type: %s" % digest) 193 m = dmod.new() 194 195 data = data.encode('utf-8') 196 m.update(data) 197 digest = m.hexdigest() 198 199 query = "SELECT identifier FROM %s WHERE digest = '%s'" % (self.table, digest) 200 res = self.query(query) 201 exist = res.dictresult() 202 if exist: 203 raise ObjectAlreadyExistsException(exist[0]['identifier']) 204 elif store: 205 self.store_checkSum(session, id, digest) 206 return digest
207
208 - def store_checkSum(self, session, id, digest):
209 id = str(id) 210 if self.idNormalizer != None: 211 id = self.idNormalizer.process_string(session, id) 212 query = "UPDATE %s SET digest = '%s' WHERE identifier = '%s';" % (self.table, digest, id) 213 self.cxn.query(query)
214
215 - def fetch_data(self, session, id):
216 self._openContainer(session) 217 sid = str(id) 218 if (self.idNormalizer != None): 219 sid = self.idNormalizer.process_string(session, sid) 220 query = "SELECT data FROM %s WHERE identifier = '%s';" % (self.table, sid) 221 res = self.query(query) 222 data = res.dictresult()[0]['data'] 223 data = data.replace('\\000\\001', nonTextToken) 224 data = data.replace('\\012', '\n') 225 return data
226
227 - def delete_item(self, session, id):
228 self._openContainer(session) 229 sid = str(id) 230 if (self.idNormalizer <> None): 231 sid = self.idNormalizer.process_string(session, str(id)) 232 query = "DELETE FROM %s WHERE identifier = '%s';" % (self.table, sid) 233 self.query(query) 234 return None
235
236 - def fetch_size(self, session, id):
237 self._openContainer(session) 238 sid = str(id) 239 if (self.idNormalizer <> None): 240 sid = self.idNormalizer.process_string(session, sid) 241 query = "SELECT size FROM %s WHERE identifier = '%s';" % (self.table, sid) 242 res = self.query(query) 243 rsz = res.dictresult()[0]['size'] 244 if (rsz): 245 return long(rsz) 246 else: 247 return -1
248
249 - def fetch_checksum(self, session, id):
250 self._openContainer(session) 251 sid = str(id) 252 if self.idNormalizer != None: 253 sid = self.idNormalizer.process_string(session, sid) 254 query = "SELECT digest FROM %s WHERE identifier = '%s';" % (self.table, sid) 255 res = self.query(query) 256 return res.dictresult()[0]['digest']
257
258 - def fetch_idList(self, session, numReq=-1, start=""):
259 # return numReq ids from start 260 ids = [] 261 self._openContainer(session) 262 sid = str(start) 263 if self.idNormalizer != None: 264 sid = self.idNormalizer.process_string(session, sid) 265 query = "SELECT identifier FROM %s ORDER BY identifier DESC" % (self.table) 266 if numReq != -1: 267 query += (" LIMIT %d" % numReq) 268 res = self.query(query) 269 all = res.dictresult() 270 for item in all: 271 ids.append(item['identifier']) 272 return ids
273 274 297 298 310 337
338 -class PostgresRecordStore(PostgresStore, SimpleRecordStore):
339 - def __init__(self, session, node, parent):
340 SimpleRecordStore.__init__(self, session, node, parent) 341 PostgresStore.__init__(self, session, node, parent)
342
343 -class PostgresResultSetStore(PostgresStore, SimpleResultSetStore):
344 - def __init__(self, session, node, parent):
345 SimpleResultSetStore.__init__(self, session, node, parent) 346 PostgresStore.__init__(self, session, node, parent)
347
348 - def _verifyDatabases(self, session):
349 # Custom resultSetStore table 350 try: 351 self.cxn = pg.connect(self.database) 352 except pg.InternalError, e: 353 raise ConfigFileException(e.args) 354 355 try: 356 query = "SELECT identifier FROM %s LIMIT 1" % self.table 357 res = self.query(query) 358 except pg.ProgrammingError, e: 359 # no table for self, initialise 360 query = """ 361 CREATE TABLE %s (identifier VARCHAR PRIMARY KEY, 362 data BYTEA, 363 size INT, 364 class VARCHAR, 365 timeCreated TIMESTAMP, 366 timeAccessed TIMESTAMP, 367 timeExpires TIMESTAMP); 368 """ % self.table 369 self.query(query) 370 # rs.id, rs.serialise(), digest, len(rs), rs.__class__, now, expireTime 371 # NB: rs can't be modified 372 373 # And check additional relations 374 for (name, fields) in self.relations.items(): 375 try: 376 query = "SELECT identifier FROM %s_%s LIMIT 1" % (self.table,name) 377 res = self.query(query) 378 except pg.ProgrammingError, e: 379 # No table for relation, initialise 380 query = "CREATE TABLE %s_%s (identifier SERIAL PRIMARY KEY, " % (self.table, name) 381 for f in fields: 382 query += ("%s %s" % (f[0], f[1])) 383 if f[2]: 384 # Foreign Key 385 query += (" REFERENCES %s (identifier)" % f[2]) 386 query += ", " 387 query = query[:-2] + ");" 388 res = self.query(query)
389 390
391 - def store_data(self, session, id, data, size=0):
392 # should call store_resultSet 393 raise NotImplementedError
394
395 - def create_resultSet(self, session, rset):
396 id = self.generate_id(session) 397 rset.id = id 398 rset.retryOnFail = 1 399 self.store_resultSet(session, rset) 400 return id
401
402 - def store_resultSet(self, session, rset):
403 self._openContainer(session) 404 now = time.time() 405 nowStr = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(now)) 406 if (rset.expires): 407 expires = now + rset.expires 408 else: 409 expires = now + self.get_default(session, 'resultSetTimeout', 600) 410 rset.timeExpires = expires 411 expiresStr = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(expires)) 412 id = rset.id 413 if (self.idNormalizer != None): 414 id = self.idNormalizer.process_string(session, id) 415 416 # Serialise and store 417 srlz = rset.serialise(session) 418 cl = str(rset.__class__) 419 data = srlz.replace('\x00', '\\\\000') 420 ndata = data.replace("'", "\\'") 421 422 query = "INSERT INTO %s (identifier, data, size, class, timeCreated, timeAccessed, timeExpires) VALUES ('%s', '%s', %s, '%s', '%s', '%s', '%s')" % (self.table, id, ndata, len(rset), cl, nowStr, nowStr, expiresStr) 423 424 try: 425 self.query(query) 426 except pg.ProgrammingError, e: 427 # already exists, retry for create 428 if hasattr(rset, 'retryOnFail'): 429 # generate new id, re-store 430 id = self.generate_id(session) 431 if (self.idNormalizer != None): 432 id = self.idNormalizer.process_string(session, id) 433 query = "INSERT INTO %s (identifier, data, size, class, timeCreated, timeAccessed, timeExpires) VALUES ('%s', '%s', %s, '%s', '%s', '%s', '%s')" % (self.table, id, ndata, len(rset), cl, nowStr, nowStr, expiresStr) 434 self.query(query) 435 else: 436 raise ObjectAlreadyExistsException(self.id + '/' + id) 437 return rset
438
439 - def fetch_resultSet(self, session, rsid):
440 self._openContainer(session) 441 442 sid = str(rsid) 443 if (self.idNormalizer != None): 444 sid = self.idNormalizer.process_string(session, sid) 445 query = "SELECT class, data FROM %s WHERE identifier = '%s';" % (self.table, sid) 446 res = self.query(query) 447 try: 448 rdict = res.dictresult()[0] 449 except IndexError: 450 raise ObjectDoesNotExistException(self.id + '/' + sid) 451 452 data = rdict['data'] 453 data = data.replace('\\000', '\x00') 454 data = data.replace('\\012', '\n') 455 # data is res.dictresult() 456 cl = rdict['class'] 457 rset = dynamic.buildObject(session, cl, []) 458 rset.deserialise(session,data) 459 460 # Update expires 461 now = time.time() 462 nowStr = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(now)) 463 expires = now + self.get_default(session, 'resultSetTimeout', 600) 464 rset.timeExpires = expires 465 expiresStr = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(expires)) 466 467 query = "UPDATE %s SET timeAccessed = '%s', timeExpires = '%s' WHERE identifier = '%s';" % (self.table, nowStr, expiresStr, sid) 468 self.query(query) 469 return rset
470 471
472 - def delete_resultSet(self, session, rsid):
473 self._openContainer(session) 474 sid = str(rsid) 475 if (self.idNormalizer != None): 476 sid = self.idNormalizer.process_string(session, sid) 477 query = "DELETE FROM %s WHERE identifier = '%s';" % (self.table, sid) 478 self.query(query)
479
480 - def fetch_resultSetList(self, session):
481 pass
482
483 - def clean(self, session):
484 # here is where sql is nice... 485 self._openContainer(session) 486 nowStr = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(time.time())) 487 query = "DELETE FROM %s WHERE timeExpires < '%s';" % (self.table, nowStr) 488 self.query(query)
489 490 491 # -- non proximity, just store occurences of type per record 492 # CREATE TABLE parent.id + self.id + index.id (identifier SERIAL PRIMARY KEY, term VARCHAR, occurences INT, recordId VARCHAR, stem VARCHAR, pos VARCHAR); 493 494 # -- proximity. Store each token, not each type, per record 495 # CREATE TABLE parent.id + self.id + index.id (identifier SERIAL PRIMARY KEY, term VARCHAR, field VARCHAR, recordId VARCHAR, stem VARCHAR, pos VARCHAR); 496 497 # -- then check adjacency via identifier comparison (plus field/recordId) 498 499 # -- recordId = recordStore.id / record.id 500 # -- so then can do easy intersection/union on them 501 502 # CREATE INDEX parent.id+self.id+index.id+termIndex on aboveTable (term); 503 # BEGIN 504 # INSERT INTO aboveTable (...) VALUES (...); 505 # COMMIT 506 # CLUSTER aboveIndex on aboveTable; 507
508 -class PostgresIndexStore(IndexStore, PostgresStore):
509 database = "" 510 transaction = 0 511
512 - def __init__(self, session, config, parent):
513 IndexStore.__init__(self, session, config, parent) 514 # Open connection 515 self.database = self.get_path(session, 'databaseName', 'cheshire3') 516 # multiple tables, one per index 517 self.transaction = 0
518
519 - def generate_tableName(self, session, index):
520 base = self.parent.id + "__" + self.id + "__" + index.id 521 return base.replace('-', '_').lower()
522
523 - def contains_index(self, session, index):
524 self._openContainer(session) 525 table = self.generate_tableName(session, index) 526 query = "SELECT relname FROM pg_stat_user_tables WHERE relname = '%s'" % table; 527 res = self.query(query) 528 return len(res.dictresult()) == 1
529
530 - def create_index(self, session, index):
531 self._openContainer(session) 532 table = self.generate_tableName(session, index) 533 query = "CREATE TABLE %s (identifier SERIAL PRIMARY KEY, term VARCHAR, occurences INT, recordId VARCHAR, stem VARCHAR, pos VARCHAR)" % table 534 query2 = "CREATE INDEX %s ON %s (term)" % (table + "_INDEX", table) 535 self._openContainer(session) 536 self.query(query) 537 self.query(query2)
538
539 - def begin_indexing(self, session, index):
540 self._openContainer(session) 541 if not self.transaction: 542 self.query('BEGIN') 543 self.transaction = 1
544
545 - def commit_indexing(self, session, index):
546 if self.transaction: 547 self.query('COMMIT') 548 self.transaction = 0 549 table = self.generate_tableName(session, index) 550 termIdx = table + "_INDEX" 551 self.query('CLUSTER %s ON %s' % (termIdx, table))
552
553 - def store_terms(self, session, index, termhash, record):
554 # write directly to db, as sort comes as CLUSTER later 555 table = self.generate_tableName(session, index) 556 queryTmpl = "INSERT INTO %s (term, occurences, recordId) VALUES ('%%s', %%s, '%r')" % (table, record) 557 558 for t in termhash.values(): 559 term = t['text'].replace("'", "\\'") 560 query = queryTmpl % (term, t['occurences']) 561 self.query(query)
562
563 - def fetch_term(self, session, index, term, prox=True):
564 # should return info to create result set 565 # --> [(rec, occs), ...] 566 table = self.generate_tableName(session, index) 567 query = "SELECT recordId, occurences FROM %s WHERE term='%s'" % (table, term) 568 res = self.query(query) 569 dr = res.dictresult() 570 totalRecs = len(dr) 571 occq = "SELECT SUM(occurences) as sum FROM %s WHERE term='%s'" % (table, term) 572 res = self.query(occq) 573 totalOccs = res.dictresult()[0]['sum'] 574 return {'totalRecs' : totalRecs, 'records' : dr, 'totalOccs' : totalOccs}
575
576 - def fetch_termList(self, session, index, term):
577 pass
578
579 - def _cql_to_sql(self, session, query, pm):
580 if (isinstance(query, cql.SearchClause)): 581 idx = pm.resolveIndex(session, query) 582 583 if (idx != None): 584 # check if 'stem' relmod 585 586 # get the index to chunk the term 587 pn = idx.get_setting(session, 'termProcess') 588 if (pn == None): 589 pn = 0 590 else: 591 pn = int(pn) 592 process = idx.sources[pn][1] 593 res = idx._processChain(session, [query.term.value], process) 594 if len(res) == 1: 595 nterm = res.keys()[0] 596 597 # check stem 598 if query.relation['stem']: 599 termCol = 'stem' 600 else: 601 termCol = 'term' 602 table = self.generate_tableName(session, idx) 603 qrval = query.relation.value 604 605 if qrval == "any": 606 terms = [] 607 for t in res: 608 terms.append("'" + t + "'") 609 inStr = ', '.join(terms) 610 q = "SELECT DISTINCT recordid FROM %s WHERE %s in (%s)" % (table, termCol, inStr) 611 elif qrval == "all": 612 qs = [] 613 for t in res: 614 qs.append("SELECT recordid FROM %s WHERE %s = '%s'" % (table, termCol, t)) 615 q = " INTERSECT ".join(qs) 616 elif qrval == "exact": 617 q = "SELECT recordid FROM %s WHERE %s = '%s'" % (table, termCol, nterm) 618 elif qrval == "within": 619 q = "SELECT recordid FROM %s WHERE %s BETWEEN '%s' AND '%s'" % (table, termCol, res[0], nterm) 620 elif qrval in ['>', '<', '>=', '<=', '<>']: 621 q = "SELECT recordid FROM %s WHERE %s %s '%s'" % (table, termCol, qrval, nterm) 622 elif qrval == '=': 623 # no prox 624 raise NotImplementedError() 625 else: 626 raise NotImplementedError(qrval) 627 628 return q 629 else: 630 d = SRWDiagnostics.Diagnostic16() 631 d.details = query.index.toCQL() 632 raise d 633 else: 634 left = self._cql_to_sql(session, query.leftOperand, pm) 635 right = self._cql_to_sql(session, query.rightOperand, pm) 636 bl = query.boolean 637 if bl.value == "and": 638 b = 'INTERSECT' 639 elif bl.value == "or": 640 b = 'UNION' 641 elif bl.value == 'not': 642 b = 'EXCEPT' 643 else: 644 raise NotImplementedError() 645 q = "(%s %s %s)" % (left, b, right) 646 return q
647
648 - def search(self, session, query, db):
649 pm = db.get_path(session, 'protocolMap') 650 if not pm: 651 db._cacheProtocolMaps(session) 652 pm = db.protocolMaps.get('http://www.loc.gov/zing/srw/') 653 db.paths['protocolMap'] = pm 654 query = self._cql_to_sql(session, query, pm) 655 res = self.query(query) 656 dr = res.dictresult() 657 rset = SimpleResultSet([]) 658 rsilist = [] 659 for t in dr: 660 (store, id) = t['recordid'].split('/',1) 661 item = SimpleResultSetItem(session, id, store, 1, resultSet=rset) 662 rsilist.append(item) 663 rset.fromList(rsilist) 664 return rset
665 666 667
668 -class PostgresIndex(SimpleIndex):
669
670 - def construct_resultSet(self, session, terms, queryHash={}):
671 # in: res.dictresult() 672 673 s = SimpleResultSet(session, []) 674 rsilist = [] 675 for t in terms['records']: 676 (store, id) = t['recordid'].split('/',1) 677 occs = t['occurences'] 678 item = SimpleResultSetItem(session, id, store, occs, resultSet=s) 679 rsilist.append(item) 680 s.fromList(rsilist) 681 s.index = self 682 if queryHash: 683 s.queryTerm = queryHash['text'] 684 s.queryFreq = queryHash['occurences'] 685 s.totalRecs = terms['totalRecs'] 686 s.totalOccs = terms['totalOccs'] 687 return s
688