Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions pathtraits/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,42 @@ class TraitsDB:
cursor = None
traits = []

@staticmethod
def row_factory(cursor, row):
"""
Turns sqlite3 row into a dict. Only works on a single row at once.

:param cursor: Description
:param row: Description
"""
fields = [column[0] for column in cursor.description]
res = dict(zip(fields, row))
return res

@staticmethod
def merge_rows(rows: list):
"""
Merges a list of row dicts of a path into a sinle dict by pooling trait keys

:param res: Description
"""
res = {}
for row in rows:
for k, v in row.items():
# pylint: disable=C0201
if k in res.keys() and v not in res[k]:
res[k].append(v)
else:
res[k] = [v]
# simplify lists with just one element
# ensure fixed order of list entries
res = {k: sorted(v) if len(v) > 1 else v[0] for k, v in res.items()}
return res

def __init__(self, db_path):
db_path = os.path.join(db_path)
self.cursor = sqlite3.connect(db_path, autocommit=True).cursor()
self.cursor.row_factory = TraitsDB.row_factory

init_path_table_query = """
CREATE TABLE IF NOT EXISTS path (
Expand Down Expand Up @@ -67,15 +100,19 @@ def get(self, table, cols="*", condition=None, **kwargs):
for (k, v) in kwargs.items()
}
condition = " AND ".join([f"{k}={v}" for (k, v) in escaped_kwargs.items()])
get_row_query = f"SELECT {cols} FROM {table} WHERE {condition} LIMIT 1;"
get_row_query = f"SELECT {cols} FROM {table} WHERE {condition};"
response = self.execute(get_row_query)
values = response.fetchone()

if values is None:
if response is None:
return None

keys = map(lambda x: x[0], response.description)
res = dict(zip(keys, values))
res = response.fetchall()
if len(res) == 1:
return res[0]

if isinstance(res, list) and len(res) > 1:
res = TraitsDB.merge_rows(res)

return res

def put_path_id(self, path):
Expand Down Expand Up @@ -130,13 +167,14 @@ def sql_type(value_type):
sql_type = sqlite_types.get(value_type, "TEXT")
return sql_type

def put(self, table, condition=None, **kwargs):
def put(self, table, condition=None, update=True, **kwargs):
"""
Puts a row into a table. Creates a row if not present, updates otherwise.
:param update; overwrite existing data
"""
escaped_kwargs = {k: TraitsDB.escape(v) for (k, v) in kwargs.items()}

if self.get(table, condition=condition, **kwargs):
if update and self.get(table, condition=condition, **kwargs):
# update
values = " , ".join([f"{k}={v}" for (k, v) in escaped_kwargs.items()])
if condition:
Expand Down Expand Up @@ -193,7 +231,7 @@ def update_traits(self):
ORDER BY name;
"""
traits = self.execute(get_traits_query).fetchall()
self.traits = [x[0] for x in traits]
self.traits = [list(x.values())[0] for x in traits]
self.put_data_view()

def create_trait_table(self, trait_name, value_type):
Expand Down Expand Up @@ -223,7 +261,7 @@ def create_trait_table(self, trait_name, value_type):
self.execute(add_table_query)
self.update_traits()

def put_trait(self, path_id, trait_name, value):
def put_trait(self, path_id, trait_name, value, update=True):
"""
Put a trait to the database

Expand All @@ -233,7 +271,7 @@ def put_trait(self, path_id, trait_name, value):
:param value: trait value
"""
kwargs = {"path": path_id, trait_name: value}
self.put(trait_name, condition=f"path = {path_id}", **kwargs)
self.put(trait_name, condition=f"path = {path_id}", update=update, **kwargs)

def add_pathpair(self, pair: PathPair):
"""
Expand All @@ -259,8 +297,17 @@ def add_pathpair(self, pair: PathPair):
for k, v in traits.items():
# same YAML key might have different value types
# Therefore, add type to key
k = f"{k}_{TraitsDB.sql_type(type(v))}"

# get element type for list
# add: handle lists with mixed element type
t = type(v[0]) if isinstance(v, list) else type(v)
k = f"{k}_{TraitsDB.sql_type(t)}"
if k not in self.traits:
self.create_trait_table(k, type(v))
t = type(v[0]) if isinstance(v, list) else type(v)
self.create_trait_table(k, t)
if k in self.traits:
self.put_trait(path_id, k, v)
if isinstance(v, list):
for vv in v:
self.put_trait(path_id, k, vv, update=False)
else:
self.put_trait(path_id, k, v)
12 changes: 11 additions & 1 deletion test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,18 @@ def test_example(self):
for k, v in target.items():
self.assertEqual(source[k], v)

source = pathtraits.access.get_dict(db, "test/example/EU")
target = {
"description_TEXT": "EU data",
"is_example_BOOL": True,
"score_REAL": 3.5,
"users_TEXT": ["dloos", "fgans"],
}
for k, v in target.items():
self.assertEqual(source[k], v)

source = len(db.execute("SELECT * FROM data;").fetchall())
target = 4
target = 6
self.assertEqual(source, target)
os.remove(db_path)

Expand Down