diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index d740ad1d2..a2cab0c98 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1,5 +1,6 @@ import json import random +import re import textwrap import threading import time @@ -20,6 +21,34 @@ logger = get_logger(__name__) +def _sanitize_tsquery_words(query_words: list[str]) -> list[str]: + """Sanitize query words for safe use with PostgreSQL to_tsquery(). + + Strips tsquery operator characters and other special symbols that can + cause parsing errors when mixed content (e.g. message IDs with + underscores, Chinese text) is passed to ``to_tsquery``. Each word is + reduced to its alphanumeric/CJK core so that the jieba text-search + configuration can tokenize it correctly. + + Returns a de-duplicated list of non-empty sanitized words. + """ + # Keep word characters (letters, digits, underscore) and CJK unified ideographs. + valid_chars_re = re.compile( + r"[^\w\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff]", + ) + sanitized: list[str] = [] + seen: set[str] = set() + for w in query_words: + # Strip surrounding single quotes that callers may have added for tsquery + w = w.strip().strip("'") + # Remove characters that are not word-characters or CJK + cleaned = valid_chars_re.sub("", w) + if cleaned and cleaned not in seen: + seen.add(cleaned) + sanitized.append(cleaned) + return sanitized + + def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: node_id = item["id"] memory = item["memory"] @@ -1653,8 +1682,11 @@ def search_by_keywords_tfidf( filter_conditions = self._build_filter_conditions_sql(filter) where_clauses.extend(filter_conditions) # Add fulltext search condition - # Convert query_text to OR query format: "word1 | word2 | word3" - tsquery_string = " | ".join(query_words) + # Sanitize and convert query_text to OR query format: "word1 | word2 | word3" + safe_words = _sanitize_tsquery_words(query_words) + if not safe_words: + return [] + tsquery_string = " | ".join(safe_words) where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") @@ -1768,7 +1800,10 @@ def search_by_fulltext( filter_conditions = self._build_filter_conditions_sql(filter) where_clauses.extend(filter_conditions) - tsquery_string = " | ".join(query_words) + safe_words = _sanitize_tsquery_words(query_words) + if not safe_words: + return [] + tsquery_string = " | ".join(safe_words) where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") diff --git a/tests/graph_dbs/test_sanitize_tsquery.py b/tests/graph_dbs/test_sanitize_tsquery.py new file mode 100644 index 000000000..108f92050 --- /dev/null +++ b/tests/graph_dbs/test_sanitize_tsquery.py @@ -0,0 +1,86 @@ +"""Tests for _sanitize_tsquery_words — standalone, no heavy imports.""" + +import re + + +# --------------------------------------------------------------------------- +# Inline the function under test to avoid pulling in the full memos import +# chain (which requires a running logging backend). The canonical copy lives +# in ``memos.graph_dbs.polardb._sanitize_tsquery_words``. +# --------------------------------------------------------------------------- + + +def _sanitize_tsquery_words(query_words: list[str]) -> list[str]: + valid_chars_re = re.compile( + r"[^\w\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff]", + ) + sanitized: list[str] = [] + seen: set[str] = set() + for w in query_words: + w = w.strip().strip("'") + cleaned = valid_chars_re.sub("", w) + if cleaned and cleaned not in seen: + seen.add(cleaned) + sanitized.append(cleaned) + return sanitized + + +class TestSanitizeTsqueryWords: + """Unit tests for FTS query word sanitization.""" + + def test_plain_english_words(self): + assert _sanitize_tsquery_words(["hello", "world"]) == ["hello", "world"] + + def test_chinese_text(self): + result = _sanitize_tsquery_words(["我要", "测试"]) + assert result == ["我要", "测试"] + + def test_mixed_content_message_id_and_chinese(self): + """Reproduce the original bug: mixed IDs + Chinese text.""" + words = ["message_id", "om_x100b544a390604b8c3e1b7d8641f08e", "我要测试"] + result = _sanitize_tsquery_words(words) + assert len(result) == 3 + assert "message_id" in result + assert "om_x100b544a390604b8c3e1b7d8641f08e" in result + assert "我要测试" in result + + def test_single_quoted_words_are_stripped(self): + words = ["'hello'", "'world'"] + result = _sanitize_tsquery_words(words) + assert result == ["hello", "world"] + + def test_special_characters_removed(self): + words = ["hello!", "world@#$"] + result = _sanitize_tsquery_words(words) + assert result == ["hello", "world"] + + def test_empty_words_filtered(self): + words = ["", " ", "hello", ""] + result = _sanitize_tsquery_words(words) + assert result == ["hello"] + + def test_deduplication(self): + words = ["hello", "hello", "world"] + result = _sanitize_tsquery_words(words) + assert result == ["hello", "world"] + + def test_empty_input(self): + assert _sanitize_tsquery_words([]) == [] + + def test_all_special_chars_returns_empty(self): + words = ["!@#", "$%^"] + result = _sanitize_tsquery_words(words) + assert result == [] + + def test_underscores_preserved(self): + words = ["message_id", "user_name"] + result = _sanitize_tsquery_words(words) + assert result == ["message_id", "user_name"] + + def test_tsquery_operators_stripped(self): + """Tsquery operators like & | ! should be stripped from within words.""" + words = ["hello & world", "foo | bar"] + result = _sanitize_tsquery_words(words) + # Spaces and operators removed; alphanumeric parts merge + assert "helloworld" in result + assert "foobar" in result