Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 19 additions & 61 deletions astrbot/core/db/po.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint


class TimestampMixin(SQLModel):
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": lambda: datetime.now(timezone.utc)},
)


class PlatformStat(SQLModel, table=True):
"""This class represents the statistics of bot usage across different platforms.
Expand All @@ -30,7 +38,7 @@ class PlatformStat(SQLModel, table=True):
)


class ConversationV2(SQLModel, table=True):
class ConversationV2(TimestampMixin, SQLModel, table=True):
__tablename__: str = "conversations"

inner_conversation_id: int | None = Field(
Expand All @@ -47,11 +55,7 @@ class ConversationV2(SQLModel, table=True):
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False)
content: list | None = Field(default=None, sa_type=JSON)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)

title: str | None = Field(default=None, max_length=255)
persona_id: str | None = Field(default=None)
token_usage: int = Field(default=0, nullable=False)
Expand All @@ -68,7 +72,7 @@ class ConversationV2(SQLModel, table=True):
)


class PersonaFolder(SQLModel, table=True):
class PersonaFolder(TimestampMixin, SQLModel, table=True):
"""Persona 文件夹,支持递归层级结构。
用于组织和管理多个 Persona,类似于文件系统的目录结构。
Expand All @@ -92,11 +96,6 @@ class PersonaFolder(SQLModel, table=True):
"""父文件夹ID,NULL表示根目录"""
description: str | None = Field(default=None, sa_type=Text)
sort_order: int = Field(default=0)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)

__table_args__ = (
UniqueConstraint(
Expand All @@ -106,7 +105,7 @@ class PersonaFolder(SQLModel, table=True):
)


class Persona(SQLModel, table=True):
class Persona(TimestampMixin, SQLModel, table=True):
"""Persona is a set of instructions for LLMs to follow.
It can be used to customize the behavior of LLMs.
Expand All @@ -131,11 +130,6 @@ class Persona(SQLModel, table=True):
"""所属文件夹ID,NULL 表示在根目录"""
sort_order: int = Field(default=0)
"""排序顺序"""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)

__table_args__ = (
UniqueConstraint(
Expand All @@ -145,7 +139,7 @@ class Persona(SQLModel, table=True):
)


class Preference(SQLModel, table=True):
class Preference(TimestampMixin, SQLModel, table=True):
"""This class represents preferences for bots."""

__tablename__: str = "preferences"
Expand All @@ -161,11 +155,6 @@ class Preference(SQLModel, table=True):
"""ID of the scope, such as 'global', 'umo', 'plugin_name'."""
key: str = Field(nullable=False)
value: dict = Field(sa_type=JSON, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)

__table_args__ = (
UniqueConstraint(
Expand All @@ -177,7 +166,7 @@ class Preference(SQLModel, table=True):
)


class PlatformMessageHistory(SQLModel, table=True):
class PlatformMessageHistory(TimestampMixin, SQLModel, table=True):
"""This class represents the message history for a specific platform.
It is used to store messages that are not LLM-generated, such as user messages
Expand All @@ -198,14 +187,9 @@ class PlatformMessageHistory(SQLModel, table=True):
default=None,
) # Name of the sender in the platform
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)


class PlatformSession(SQLModel, table=True):
class PlatformSession(TimestampMixin, SQLModel, table=True):
"""Platform session table for managing user sessions across different platforms.
A session represents a chat window for a specific user on a specific platform.
Expand Down Expand Up @@ -233,11 +217,6 @@ class PlatformSession(SQLModel, table=True):
"""Display name for the session"""
is_group: int = Field(default=0, nullable=False)
"""0 for private chat, 1 for group chat (not implemented yet)"""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)

__table_args__ = (
UniqueConstraint(
Expand All @@ -247,7 +226,7 @@ class PlatformSession(SQLModel, table=True):
)


class Attachment(SQLModel, table=True):
class Attachment(TimestampMixin, SQLModel, table=True):
"""This class represents attachments for messages in AstrBot.
Attachments can be images, files, or other media types.
Expand All @@ -269,11 +248,6 @@ class Attachment(SQLModel, table=True):
path: str = Field(nullable=False) # Path to the file on disk
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
mime_type: str = Field(nullable=False) # MIME type of the file
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)

__table_args__ = (
UniqueConstraint(
Expand All @@ -283,7 +257,7 @@ class Attachment(SQLModel, table=True):
)


class ChatUIProject(SQLModel, table=True):
class ChatUIProject(TimestampMixin, SQLModel, table=True):
"""This class represents projects for organizing ChatUI conversations.
Projects allow users to group related conversations together.
Expand All @@ -310,11 +284,6 @@ class ChatUIProject(SQLModel, table=True):
"""Title of the project"""
description: str | None = Field(default=None, max_length=1000)
"""Description of the project"""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)

__table_args__ = (
UniqueConstraint(
Expand All @@ -338,7 +307,6 @@ class SessionProjectRelation(SQLModel, table=True):
"""Session ID from PlatformSession"""
project_id: str = Field(nullable=False, max_length=36)
"""Project ID from ChatUIProject"""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))

__table_args__ = (
UniqueConstraint(
Expand All @@ -348,7 +316,7 @@ class SessionProjectRelation(SQLModel, table=True):
)


class CommandConfig(SQLModel, table=True):
class CommandConfig(TimestampMixin, SQLModel, table=True):
"""Per-command configuration overrides for dashboard management."""

__tablename__ = "command_configs" # type: ignore
Expand All @@ -368,14 +336,9 @@ class CommandConfig(SQLModel, table=True):
note: str | None = Field(default=None, sa_type=Text)
extra_data: dict | None = Field(default=None, sa_type=JSON)
auto_managed: bool = Field(default=False, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)


class CommandConflict(SQLModel, table=True):
class CommandConflict(TimestampMixin, SQLModel, table=True):
"""Conflict tracking for duplicated command names."""

__tablename__ = "command_conflicts" # type: ignore
Expand All @@ -392,11 +355,6 @@ class CommandConflict(SQLModel, table=True):
note: str | None = Field(default=None, sa_type=Text)
extra_data: dict | None = Field(default=None, sa_type=JSON)
auto_generated: bool = Field(default=False, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)

__table_args__ = (
UniqueConstraint(
Expand Down