diff --git a/pathtraits/db.py b/pathtraits/db.py index d111a9c..95f3b98 100644 --- a/pathtraits/db.py +++ b/pathtraits/db.py @@ -19,9 +19,42 @@ class TraitsDB: cursor = None traits = [] + @staticmethod + def row_factory(cursor, row): + """ + Turns sqlite3 row into a dict. Only works on a single row at once. + + :param cursor: Description + :param row: Description + """ + fields = [column[0] for column in cursor.description] + res = dict(zip(fields, row)) + return res + + @staticmethod + def merge_rows(rows: list): + """ + Merges a list of row dicts of a path into a sinle dict by pooling trait keys + + :param res: Description + """ + res = {} + for row in rows: + for k, v in row.items(): + # pylint: disable=C0201 + if k in res.keys() and v not in res[k]: + res[k].append(v) + else: + res[k] = [v] + # simplify lists with just one element + # ensure fixed order of list entries + res = {k: sorted(v) if len(v) > 1 else v[0] for k, v in res.items()} + return res + def __init__(self, db_path): db_path = os.path.join(db_path) self.cursor = sqlite3.connect(db_path, autocommit=True).cursor() + self.cursor.row_factory = TraitsDB.row_factory init_path_table_query = """ CREATE TABLE IF NOT EXISTS path ( @@ -67,15 +100,19 @@ def get(self, table, cols="*", condition=None, **kwargs): for (k, v) in kwargs.items() } condition = " AND ".join([f"{k}={v}" for (k, v) in escaped_kwargs.items()]) - get_row_query = f"SELECT {cols} FROM {table} WHERE {condition} LIMIT 1;" + get_row_query = f"SELECT {cols} FROM {table} WHERE {condition};" response = self.execute(get_row_query) - values = response.fetchone() - if values is None: + if response is None: return None - keys = map(lambda x: x[0], response.description) - res = dict(zip(keys, values)) + res = response.fetchall() + if len(res) == 1: + return res[0] + + if isinstance(res, list) and len(res) > 1: + res = TraitsDB.merge_rows(res) + return res def put_path_id(self, path): @@ -130,13 +167,14 @@ def sql_type(value_type): sql_type = sqlite_types.get(value_type, "TEXT") return sql_type - def put(self, table, condition=None, **kwargs): + def put(self, table, condition=None, update=True, **kwargs): """ Puts a row into a table. Creates a row if not present, updates otherwise. + :param update; overwrite existing data """ escaped_kwargs = {k: TraitsDB.escape(v) for (k, v) in kwargs.items()} - if self.get(table, condition=condition, **kwargs): + if update and self.get(table, condition=condition, **kwargs): # update values = " , ".join([f"{k}={v}" for (k, v) in escaped_kwargs.items()]) if condition: @@ -193,7 +231,7 @@ def update_traits(self): ORDER BY name; """ traits = self.execute(get_traits_query).fetchall() - self.traits = [x[0] for x in traits] + self.traits = [list(x.values())[0] for x in traits] self.put_data_view() def create_trait_table(self, trait_name, value_type): @@ -223,7 +261,7 @@ def create_trait_table(self, trait_name, value_type): self.execute(add_table_query) self.update_traits() - def put_trait(self, path_id, trait_name, value): + def put_trait(self, path_id, trait_name, value, update=True): """ Put a trait to the database @@ -233,7 +271,7 @@ def put_trait(self, path_id, trait_name, value): :param value: trait value """ kwargs = {"path": path_id, trait_name: value} - self.put(trait_name, condition=f"path = {path_id}", **kwargs) + self.put(trait_name, condition=f"path = {path_id}", update=update, **kwargs) def add_pathpair(self, pair: PathPair): """ @@ -259,8 +297,17 @@ def add_pathpair(self, pair: PathPair): for k, v in traits.items(): # same YAML key might have different value types # Therefore, add type to key - k = f"{k}_{TraitsDB.sql_type(type(v))}" + + # get element type for list + # add: handle lists with mixed element type + t = type(v[0]) if isinstance(v, list) else type(v) + k = f"{k}_{TraitsDB.sql_type(t)}" if k not in self.traits: - self.create_trait_table(k, type(v)) + t = type(v[0]) if isinstance(v, list) else type(v) + self.create_trait_table(k, t) if k in self.traits: - self.put_trait(path_id, k, v) + if isinstance(v, list): + for vv in v: + self.put_trait(path_id, k, vv, update=False) + else: + self.put_trait(path_id, k, v) diff --git a/test/test.py b/test/test.py index f505d7e..d772e5c 100644 --- a/test/test.py +++ b/test/test.py @@ -31,8 +31,18 @@ def test_example(self): for k, v in target.items(): self.assertEqual(source[k], v) + source = pathtraits.access.get_dict(db, "test/example/EU") + target = { + "description_TEXT": "EU data", + "is_example_BOOL": True, + "score_REAL": 3.5, + "users_TEXT": ["dloos", "fgans"], + } + for k, v in target.items(): + self.assertEqual(source[k], v) + source = len(db.execute("SELECT * FROM data;").fetchall()) - target = 4 + target = 6 self.assertEqual(source, target) os.remove(db_path)