| 1 | """Database utilities""" |
|---|
| 2 | |
|---|
| 3 | import os |
|---|
| 4 | import md5 |
|---|
| 5 | import time |
|---|
| 6 | import random |
|---|
| 7 | import urllib |
|---|
| 8 | |
|---|
| 9 | from twisted.enterprise import adbapi |
|---|
| 10 | from twisted.python import reflect |
|---|
| 11 | |
|---|
| 12 | db_modules = {"mysql": "MySQLdb", |
|---|
| 13 | "sqlite": "sqlite3"} |
|---|
| 14 | |
|---|
| 15 | def make_random_etag(uri): |
|---|
| 16 | return md5.new("%s%s%s" % (uri, time.time(), random.random())).hexdigest() |
|---|
| 17 | |
|---|
| 18 | def make_etag(uri, document): |
|---|
| 19 | return md5.new("%s%s" % (uri, document)).hexdigest() |
|---|
| 20 | |
|---|
| 21 | def parseURI(uri): |
|---|
| 22 | """ |
|---|
| 23 | >>> parseURI('mysql://username:123@localhost/openser') |
|---|
| 24 | ('mysql', 'username', '123', 'localhost', None, '/openser', {}) |
|---|
| 25 | >>> parseURI('sqlite:/:memory:') |
|---|
| 26 | ('sqlite', None, None, None, None, ':memory:', {}) |
|---|
| 27 | >>> parseURI('sqlite:///full/path/to/database') |
|---|
| 28 | ('sqlite', None, None, None, None, '/full/path/to/database', {}) |
|---|
| 29 | >>> parseURI('sqlite:/C|/full/path/to/database') |
|---|
| 30 | ('sqlite', None, None, None, None, '/C|/full/path/to/database', {}) |
|---|
| 31 | """ |
|---|
| 32 | schema, rest = uri.split(':', 1) |
|---|
| 33 | assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest |
|---|
| 34 | if rest.startswith('/') and not rest.startswith('//'): |
|---|
| 35 | host = None |
|---|
| 36 | rest = rest[1:] |
|---|
| 37 | elif rest.startswith('///'): |
|---|
| 38 | host = None |
|---|
| 39 | rest = rest[3:] |
|---|
| 40 | else: |
|---|
| 41 | rest = rest[2:] |
|---|
| 42 | if rest.find('/') == -1: |
|---|
| 43 | host = rest |
|---|
| 44 | rest = '' |
|---|
| 45 | else: |
|---|
| 46 | host, rest = rest.split('/', 1) |
|---|
| 47 | if host and host.find('@') != -1: |
|---|
| 48 | user = host[:host.rfind('@')] # Python 2.3 doesn't have .rsplit() |
|---|
| 49 | host = host[host.rfind('@')+1:] # !!! |
|---|
| 50 | if user.find(':') != -1: |
|---|
| 51 | user, password = user.split(':', 1) |
|---|
| 52 | else: |
|---|
| 53 | password = None |
|---|
| 54 | else: |
|---|
| 55 | user = password = None |
|---|
| 56 | if host and host.find(':') != -1: |
|---|
| 57 | _host, port = host.split(':') |
|---|
| 58 | try: |
|---|
| 59 | port = int(port) |
|---|
| 60 | except ValueError: |
|---|
| 61 | raise ValueError, "port must be integer, got '%s' instead" % port |
|---|
| 62 | if not (1 <= port <= 65535): |
|---|
| 63 | raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port |
|---|
| 64 | host = _host |
|---|
| 65 | else: |
|---|
| 66 | port = None |
|---|
| 67 | path = '/' + rest |
|---|
| 68 | if os.name == 'nt': |
|---|
| 69 | if (len(rest) > 1) and (rest[1] == '|'): |
|---|
| 70 | path = "%s:%s" % (rest[0], rest[2:]) |
|---|
| 71 | args = {} |
|---|
| 72 | if path.find('?') != -1: |
|---|
| 73 | path, arglist = path.split('?', 1) |
|---|
| 74 | arglist = arglist.split('&') |
|---|
| 75 | for single in arglist: |
|---|
| 76 | argname, argvalue = single.split('=', 1) |
|---|
| 77 | argvalue = urllib.unquote(argvalue) |
|---|
| 78 | args[argname] = argvalue |
|---|
| 79 | if path == '/:memory:': |
|---|
| 80 | path = path[1:] |
|---|
| 81 | return schema, user, password, host, port, path, args |
|---|
| 82 | |
|---|
| 83 | def connectionForURI(uri): |
|---|
| 84 | """Return a Twisted adbapi connection pool for a given database URI.""" |
|---|
| 85 | schema, user, password, host, port, path, args = parseURI(uri) |
|---|
| 86 | try: |
|---|
| 87 | module = db_modules[schema] |
|---|
| 88 | except Exception: |
|---|
| 89 | raise ValueError("Database scheme '%s' is not supported." % schema) |
|---|
| 90 | |
|---|
| 91 | # reconnecting is safe since we don't use transactions. |
|---|
| 92 | # the following code prefers MySQLdb native reconnect if it's available, |
|---|
| 93 | # falling back to twisted's cp_reconnect. |
|---|
| 94 | # mysql's reconnect is preferred because it's better tested than twisted's |
|---|
| 95 | kwargs = {} |
|---|
| 96 | if module == 'MySQLdb': |
|---|
| 97 | MySQLdb = reflect.namedModule(module) |
|---|
| 98 | if MySQLdb.version_info[:3] >= (1, 2, 2): |
|---|
| 99 | kwargs.setdefault('reconnect', 1) |
|---|
| 100 | kwargs.setdefault('host', host or 'localhost') |
|---|
| 101 | kwargs.setdefault('user', user or '') |
|---|
| 102 | kwargs.setdefault('passwd', password or '') |
|---|
| 103 | path = path.lstrip('/') |
|---|
| 104 | kwargs.setdefault('db', path) |
|---|
| 105 | args = () |
|---|
| 106 | elif module == 'sqlite3': |
|---|
| 107 | if path == ':memory:': |
|---|
| 108 | # otherwise a database per connection is created |
|---|
| 109 | kwargs['cp_min'] = kwargs['cp_max'] = 1 |
|---|
| 110 | args = (path, ) |
|---|
| 111 | |
|---|
| 112 | if 'reconnect' not in kwargs: |
|---|
| 113 | # note that some versions of MySQLdb don't provide reconnect parameter, |
|---|
| 114 | # but set it to 1. |
|---|
| 115 | # hopefully, if underlying reconnect was enabled, twisted will never see |
|---|
| 116 | # a disconnect and its reconnection code won't interfere. |
|---|
| 117 | kwargs.setdefault('cp_reconnect', 1) |
|---|
| 118 | |
|---|
| 119 | kwargs.setdefault('cp_noisy', False) |
|---|
| 120 | |
|---|
| 121 | pool = adbapi.ConnectionPool(module, *args, **kwargs) |
|---|
| 122 | pool.schema = schema |
|---|
| 123 | return pool |
|---|
| 124 | |
|---|
| 125 | def repeat_on_error(N, errorinfo, func, *args, **kwargs): |
|---|
| 126 | #print 'repeat_on_error', N, func.__name__ |
|---|
| 127 | d = func(*args, **kwargs) |
|---|
| 128 | counter = [N] |
|---|
| 129 | def try_again(error): |
|---|
| 130 | #print 'try_again!', func.__name__, counter[0], `error` |
|---|
| 131 | if isinstance(error.value, errorinfo) and counter[0]>0: |
|---|
| 132 | counter[0] -= 1 |
|---|
| 133 | d = func(*args, **kwargs) |
|---|
| 134 | d.addErrback(try_again) |
|---|
| 135 | return d |
|---|
| 136 | return error |
|---|
| 137 | d.addErrback(try_again) |
|---|
| 138 | return d |
|---|
| 139 | |
|---|
| 140 | if __name__=='__main__': |
|---|
| 141 | from twisted.internet import defer |
|---|
| 142 | |
|---|
| 143 | def s(): |
|---|
| 144 | print 's()' |
|---|
| 145 | return defer.succeed(True) |
|---|
| 146 | def f(): |
|---|
| 147 | print 'f()' |
|---|
| 148 | return defer.fail(ZeroDivisionError()) |
|---|
| 149 | |
|---|
| 150 | def getcb(msg): |
|---|
| 151 | def callback(x): |
|---|
| 152 | print '%s callback: %r' % (msg, x) |
|---|
| 153 | def errback(x): |
|---|
| 154 | print '%s errback: %r' % (msg, x) |
|---|
| 155 | return callback, errback |
|---|
| 156 | |
|---|
| 157 | # calls s()'s callback |
|---|
| 158 | d = repeat_on_error(1, Exception, s) |
|---|
| 159 | d.addCallbacks(*getcb('s')) |
|---|
| 160 | |
|---|
| 161 | # calls f() for 4 times (1+3), then gives up and calls last f()'s errback |
|---|
| 162 | d = repeat_on_error(3, Exception, f) |
|---|
| 163 | d.addCallbacks(*getcb('f')) |
|---|
| 164 | |
|---|
| 165 | x = Exception() |
|---|
| 166 | x.lst = [f, f, s] |
|---|
| 167 | |
|---|
| 168 | def bad_func(): |
|---|
| 169 | f, x.lst = x.lst[0], x.lst[1:] |
|---|
| 170 | return f() |
|---|
| 171 | |
|---|
| 172 | d = repeat_on_error(1, Exception, bad_func) |
|---|
| 173 | d.addCallbacks(*getcb('bad_func')) |
|---|
| 174 | |
|---|