|
|
|
@ -1,26 +1,35 @@ |
|
|
|
|
import re |
|
|
|
|
|
|
|
|
|
from flask import current_app |
|
|
|
|
from mongoengine.base import BaseField |
|
|
|
|
from mongoengine.queryset import STRING_OPERATORS |
|
|
|
|
from Crypto.PublicKey import RSA |
|
|
|
|
from Crypto.Cipher import PKCS1_OAEP |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EncryptedStringField(BaseField): |
|
|
|
|
"""A unicode string field.""" |
|
|
|
|
"""A unicode encrypted string field.""" |
|
|
|
|
keyPair = None |
|
|
|
|
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
|
|
|
|
|
|
|
self.keyPair = RSA.importKey(open("privkey.pem").read()) |
|
|
|
|
|
|
|
|
|
def __init__(self, regex=None, max_length=None, min_length=None, **kwargs): |
|
|
|
|
""" |
|
|
|
|
:param regex: (optional) A string pattern that will be applied during validation |
|
|
|
|
:param max_length: (optional) A max length that will be applied during validation |
|
|
|
|
:param min_length: (optional) A min length that will be applied during validation |
|
|
|
|
:param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.BaseField` |
|
|
|
|
""" |
|
|
|
|
self.regex = re.compile(regex) if regex else None |
|
|
|
|
self.max_length = max_length |
|
|
|
|
self.min_length = min_length |
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
def to_python(self, value): |
|
|
|
|
def __get__(self, instance, owner): |
|
|
|
|
value = instance._data.get(self.name) |
|
|
|
|
encryptor = PKCS1_OAEP.new(self.keyPair.publickey()) |
|
|
|
|
return encryptor.decrypt(value) |
|
|
|
|
|
|
|
|
|
def __set__(self, instance, value): |
|
|
|
|
|
|
|
|
|
key = self.name |
|
|
|
|
encryptor = PKCS1_OAEP.new(self.keyPair.publickey()) |
|
|
|
|
|
|
|
|
|
instance._data[key] = encryptor.encrypt(value) |
|
|
|
|
instance._mark_as_changed(key) |
|
|
|
|
|
|
|
|
|
def to_python(self, value): |
|
|
|
|
if isinstance(value, str): |
|
|
|
|
return value |
|
|
|
|
try: |
|
|
|
@ -29,41 +38,8 @@ class EncryptedStringField(BaseField): |
|
|
|
|
pass |
|
|
|
|
return value |
|
|
|
|
|
|
|
|
|
def validate(self, value): |
|
|
|
|
if not isinstance(value, str): |
|
|
|
|
self.error("StringField only accepts string values") |
|
|
|
|
|
|
|
|
|
if self.max_length is not None and len(value) > self.max_length: |
|
|
|
|
self.error("String value is too long") |
|
|
|
|
|
|
|
|
|
if self.min_length is not None and len(value) < self.min_length: |
|
|
|
|
self.error("String value is too short") |
|
|
|
|
|
|
|
|
|
if self.regex is not None and self.regex.match(value) is None: |
|
|
|
|
self.error("String value did not match validation regex") |
|
|
|
|
|
|
|
|
|
def lookup_member(self, member_name): |
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
def prepare_query_value(self, op, value): |
|
|
|
|
if not isinstance(op, str): |
|
|
|
|
return value |
|
|
|
|
|
|
|
|
|
if op in STRING_OPERATORS: |
|
|
|
|
case_insensitive = op.startswith("i") |
|
|
|
|
op = op.lstrip("i") |
|
|
|
|
|
|
|
|
|
flags = re.IGNORECASE if case_insensitive else 0 |
|
|
|
|
|
|
|
|
|
regex = r"%s" |
|
|
|
|
if op == "startswith": |
|
|
|
|
regex = r"^%s" |
|
|
|
|
elif op == "endswith": |
|
|
|
|
regex = r"%s$" |
|
|
|
|
elif op == "exact": |
|
|
|
|
regex = r"^%s$" |
|
|
|
|
|
|
|
|
|
# escape unsafe characters which could lead to a re.error |
|
|
|
|
value = re.escape(value) |
|
|
|
|
value = re.compile(regex % value, flags) |
|
|
|
|
return super().prepare_query_value(op, value) |