标签:
使用sqllite3和metadata简单的封装了个简单的orm
#!/usr/bim/python #-*-coding: utf-8 -*- import threading import sqlite3 import sys __module__ = sys.modules[__name__] def setup_database(database): Mapper.initialize(database) def bind_mapper(mapper_name, model_cls, mapper_cls = None): if mapper_cls is None: mapper_cls = Mapper mapper = mapper_cls(model_cls) setattr(__module__, mapper_name, mapper) class FieldError(Exception): def __init__(self, message, *args, **kwargs): Exception.__init__(self, message, *args, **kwargs) class Field: def __init__(self, name, affinity=None, validator=None, required=False, default=None, **kwargs): self.affinity = affinity self.validator = validator self.default = default self.required = required self.name = name self.kwargs = kwargs def get_value(self): if hasattr(self, "data"): return self.data else: raise FieldError("Field is not initilize") def set_value(self, value): self.value = self.process_formdata(value) value = property(get_value, set_value) def validate(self): if self.required and self.data is None: raise FieldError("Field is required!") if self.value == self.default: return if self.validator: self.validator(self) def process_formdata(self, value): if value or self.required == False: try: self.data = value except ValueError: self.data = None raise ValueError(‘Not a valid integer value‘) def _pre_validate(self): pass def __call__(self): if hasattr(self, "value"): return self.value else: raise FieldError("Filed is not initilize") class IntegerField(Field): """ A text field, except all input is coerced to an integer. Erroneous input is ignored and will not be accepted as a value. """ def __init__(self, name, affinity=None, validator=None, required=False, defalut=None, **kwargs): Field.__init__(self, name, validator, required, defalut, **kwargs) def process_formdata(self, value): if value: try: self.value = int(value) except ValueError: self.value = None raise ValueError(‘Not a valid integer value‘) def with_metaclass(meta, bases=(object,)): return meta("NewBase", bases, {}) class ModelMeta(type): def __new__(metacls, cls_name, bases, attrs): fields = {} new_attrs = {} for k, v in attrs.iteritems(): if isinstance(v, Field): fields[k] = v else: new_attrs[k] = v cls = type.__new__(metacls, cls_name, bases, new_attrs) cls.fields = cls.fields.copy() cls.fields.update(fields) return cls class ModelMinix(object): fields = {} def __str__(self): return ‘<‘ + self.__class__.__name__ + ‘ : {‘ + ", ".join(["%s=%s" % (field.name, getattr(self, column)) for column, field in self.fields.items()]) + ‘}>‘ def save(self): return self.__mapper__.save(self) class Model(with_metaclass(ModelMeta, (ModelMinix,))): def __init__(self, **kwargs): for k in kwargs: try: if k in self.fields: setattr(self, k, kwargs[k]) except: raise ValueError("not found filed %s" % (k)) def __json__(self): raise NotImplemented("subclass of Model must implement __json__") @classmethod def create_table(cls): raise NotImplemented("subclass of Model must implement create_table") class Mapper: _local = threading.local() """Database Mapper""" def __init__(self, model_cls): self.model = model_cls self.model.__mapper__ = self @staticmethod def initialize(database): if not hasattr(Model, "__database__ "): Mapper.__database__ = database def execute(self, sql): with self.connect() as conn: try: cursor = conn.cursor() cursor.execute(sql) conn.commit() except sqlite3.Error, e: print "SQLite Error: %s" % e.args[0] conn.rollback() @staticmethod def initialized(): """Returns true if the class variable __database__ has been setted.""" print Mapper.__database__ return hasattr(Mapper, "__database__") @classmethod def connect(cls): """create a thread local connection if there isn‘t one yet""" # print(‘connect‘,cls) if not hasattr(cls._local, ‘conn‘): try: cls._local.conn = sqlite3.connect(cls.__database__) #cls._local.conn.execute(‘pragma integrity_check‘) cls._local.conn.row_factory = sqlite3.Row except sqlite3.Error, e: print "Error %s:" % e.args[0] return cls._local.conn def create_table(self): sql = ‘CREATE TABLE IF NOT EXISTS ‘ + self.model.__tablename__ + ‘ (‘ + ‘name varchar(50),‘ + ‘email varchar(20)‘ + ‘)‘ with self.connect() as conn: try: cursor = conn.cursor() cursor.execute(sql) conn.commit() print ‘Create table %s‘ % (self.model.__tablename__) except sqlite3.Error, e: print "SQLite Error: %s" % e.args[0] conn.rollback() def drop_table(self): sql = ‘DROP TABLE IF EXISTS ‘ + self.model.__tablename__ print ‘Drop table ‘ + self.model.__tablename__ self.execute(sql) def deleteby(self, paterns=None, **kwargs): dels = [] vals = [] for k, v in kwargs.iteritems(): if k in self.model.fields: dels.append(k + ‘=?‘) vals.append(v) sql = ‘DELETE FROM %s WHERE %s‘ % (self.model.__tablename__, ‘ AND ‘.join(dels)) with self.connect() as conn: try: cursor = conn.cursor() cursor.execute(sql, vals) conn.commit() return True except sqlite3.Error, e: print "SQLite Error: %s" % e.args[0] conn.rollback() return False def save(self, model): cols = model.fields.keys() vals = [getattr(model, c) for c in self.model.fields] sql = ‘INSERT INTO ‘ + self.model.__tablename__ + ‘ (‘ + ‘, ‘.join(cols) + ‘)‘ + ‘ VALUES (‘ + ‘, ‘ .join([‘?‘] * len(cols)) + ‘)‘ with self.connect() as conn: try: cursor = conn.cursor() cursor.execute(sql, vals) conn.commit() print ‘save %s‘ % model return True except sqlite3.Error, e: print "SQLite Error: %s" % e.args[0] conn.rollback() return False import db import unittest import threading class ModelTest(unittest.TestCase): """docstring for ModelTest""" def setUp(self): class User(db.Model): __tablename__ = "users" name = db.Field("name", "varchar(50)") email = db.Field("email", "varchar(20)", validator=lambda x: 7 < len(x) < 21) db.setup_database(‘conf/test.sqlite‘) db.bind_mapper(‘UserMapper‘, User) db.UserMapper.create_table() self.user = User(name=‘test‘, email=‘test@example.com‘) def tearDown(self): res = db.UserMapper.deleteby(name=self.user.name, email=self.user.email) #db.UserMapper.drop_table() self.assertEqual(True, res) def test_model_save(self): self.assertEqual(True, self.user.save()) def test_mapper_save(self): self.assertEqual(True, db.UserMapper.save(self.user)) def test_in_threads(self): n = 10 ts = [] for i in range(n): t = threading.Thread(target=self.user.save) ts.append(t) for t in ts: t.start() for t in ts: t.join() if __name__ == "__main__": unittest.main()
标签:
原文地址:http://www.cnblogs.com/nagi/p/4191463.html