Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions pathtraits/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,32 @@
logger = logging.getLogger(__name__)


def nest_dict(flat_dict, delimiter="/"):
"""
Transforms a flat dictionary with path-like keys into a nested dictionary.

:param flat_dict: The flat dictionary with path-like keys.
:param delimiter: The delimiter used in the keys (default is '/').
:return: A nested dictionary.
"""
nested_dict = {}

for path, value in flat_dict.items():
keys = path.split(delimiter)
current = nested_dict

for key in keys[:-1]:
# If the key doesn't exist or is not a dictionary, create/overwrite it as a dictionary
if key not in current or not isinstance(current[key], dict):
current[key] = {}
current = current[key]

# Set the value at the final key
current[keys[-1]] = value

return nested_dict


def get_dict(self, path):
"""
Get traits for a path as a Python dictionary
Expand Down Expand Up @@ -40,6 +66,7 @@ def get_dict(self, path):
if not (v and k != "path"):
continue
res[k] = v
res = nest_dict(res)
return res


Expand Down
53 changes: 39 additions & 14 deletions pathtraits/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import sqlite3
import os
from collections.abc import MutableMapping
import yaml
from pathtraits.pathpair import PathPair

Expand Down Expand Up @@ -33,12 +34,12 @@ def row_factory(cursor, row):
if v is None:
continue
# sqlite don't know bool
if k.endswith("_BOOL"):
if k.endswith("/BOOL"):
v = v > 0
if isinstance(v, float):
v_int = int(v)
v = v_int if v_int == v else v
k = k.removesuffix("_TEXT").removesuffix("_REAL").removesuffix("_BOOL")
k = k.removesuffix("/TEXT").removesuffix("/REAL").removesuffix("/BOOL")
res[k] = v
return res

Expand All @@ -53,15 +54,35 @@ def merge_rows(rows: list):
for row in rows:
for k, v in row.items():
# pylint: disable=C0201
if k in res.keys() and v not in res[k]:
if not k in res.keys():
res[k] = []
if not v in res[k]:
res[k].append(v)
else:
res[k] = [v]

# simplify lists with just one element
# ensure fixed order of list entries
res = {k: sorted(v, key=str) if len(v) > 1 else v[0] for k, v in res.items()}
return res

@staticmethod
def flatten_dict(dictionary: dict, root_key: str = "", separator: str = "/"):
"""
Docstring for flatten_dict

:param d: Description
:type d: dict
"""
items = []
for key, value in dictionary.items():
new_key = root_key + separator + key if root_key else key
if isinstance(value, MutableMapping):
items.extend(
TraitsDB.flatten_dict(value, new_key, separator=separator).items()
)
else:
items.append((new_key, value))
return dict(items)

def __init__(self, db_path):
db_path = os.path.join(db_path)
self.cursor = sqlite3.connect(db_path, autocommit=True).cursor()
Expand Down Expand Up @@ -189,15 +210,15 @@ def put(self, table, condition=None, update=True, **kwargs):
# update
values = " , ".join([f"{k}={v}" for (k, v) in escaped_kwargs.items()])
if condition:
update_query = f"UPDATE {table} SET {values} WHERE {condition};"
update_query = f"UPDATE [{table}] SET {values} WHERE {condition};"
else:
update_query = f"UPDATE {table} SET {values};"
update_query = f"UPDATE [{table}] SET {values};"
self.execute(update_query)
else:
# insert
keys = " , ".join(escaped_kwargs.keys())
keys = "[" + "], [".join(escaped_kwargs.keys()) + "]"
values = " , ".join([str(x) for x in escaped_kwargs.values()])
insert_query = f"INSERT INTO {table} ({keys}) VALUES ({values});"
insert_query = f"INSERT INTO [{table}] ({keys}) VALUES ({values});"
self.execute(insert_query)

def put_data_view(self):
Expand All @@ -209,15 +230,15 @@ def put_data_view(self):
if self.traits:
join_query = " ".join(
[
f"LEFT JOIN {x} ON {x}.path = path.id"
f"LEFT JOIN [{x}] ON [{x}].path = path.id \n"
for x in self.traits
if x != "path"
]
)

create_view_query = f"""
CREATE VIEW data AS
SELECT path.path, {', '.join(self.traits)}
SELECT path.path, [{'], ['.join(self.traits)}]
FROM path
{join_query};
"""
Expand Down Expand Up @@ -263,9 +284,9 @@ def create_trait_table(self, trait_name, value_type):
return
sql_type = TraitsDB.sql_type(value_type)
add_table_query = f"""
CREATE TABLE {trait_name} (
CREATE TABLE [{trait_name}] (
path INTEGER,
{trait_name} {sql_type},
[{trait_name}] {sql_type},
FOREIGN KEY(path) REFERENCES path(id)
);
"""
Expand Down Expand Up @@ -303,6 +324,8 @@ def add_pathpair(self, pair: PathPair):
if not isinstance(traits, dict):
return

traits = TraitsDB.flatten_dict(traits)

# put path in db only if there are traits
path_id = self.put_path_id(os.path.abspath(pair.object_path))
for k, v in traits.items():
Expand All @@ -312,13 +335,15 @@ def add_pathpair(self, pair: PathPair):
# get element type for list
# add: handle lists with mixed element type
t = type(v[0]) if isinstance(v, list) else type(v)
k = f"{k}_{TraitsDB.sql_type(t)}"
k = f"{k}/{TraitsDB.sql_type(t)}"
if k not in self.traits:
t = type(v[0]) if isinstance(v, list) else type(v)
self.create_trait_table(k, t)
if k in self.traits:
# add to list
if isinstance(v, list):
for vv in v:
self.put_trait(path_id, k, vv, update=False)
# overwrite scalar
else:
self.put_trait(path_id, k, v)
5 changes: 5 additions & 0 deletions test/example/EU/meta.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@ users:
- dloos
- fgans
score: 3.5
foo:
bar:
a: 1
b: 2
c: [1, 2, 3]
3 changes: 2 additions & 1 deletion test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_eu(self):
"is_example": True,
"score": 3.5,
"users": ["dloos", "fgans"],
"foo": {"bar": {"a": 1, "b": 2, "c": [1, 2, 3]}},
}
for k, v in target.items():
self.assertEqual(source[k], v)
Expand All @@ -63,7 +64,7 @@ def test_example(self):

def test_data_view(self):
source = len(self.db.execute("SELECT * FROM data;").fetchall())
target = 6
target = 10
self.assertEqual(source, target)


Expand Down