记忆¶
概述¶
Memory(智能体记忆) 模块提供智能体的记忆的存储和召回。 智能体的记忆主要由智能体的对话和交互过程产生。 记忆的最小单位是 Message,每个 message 都有个 role(比如:"User","Assistant")和 content,还有额外的 metadata 构成。
Message¶
描述 message 内部结构的信息
Role¶
角色用于区分对话中不同类型的消息,帮助模型理解要如何针对特定的消息做出相应的回应。 message 有 5 种类型,分别代表 user/system/assistant/tool/unspecified
| 角色 | 描述 |
|---|---|
| system | 用于告诉聊天模型如何做并提供额外的上下文。 |
| user | 代表用户与模型交互的输入,通常是文本或其他交互式输入。 |
| assistant | 代表模型的响应,可以包括文本或调用工具的请求。 |
| tool | 用于在调用工具后,将外部数据或处理结果传递回模型的消息。适用于支持工具调用的聊天模型。 |
| unspecified | 用户未指定时的默认值。 |
class MessageRole(Enum):
"""
Enum class for message.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
UNSPECIFIED = "unspecified"
Chunks¶
message 中的核心内容是 chunks,chunks 是个 list,每个 chunk 里面存储 message 的具体内容。设计成 list,可以支持 message 包含多模态数据,比如同时有文本和语言数据。
MessageChunk 的 content 可以是 str/bytes 和任意可序列化的对象。 这么设计是为了支持 agent 记忆的内存存储和持久化存储两种模式。
content 是可序列化的 python 对象的好处是:
-
使用内存存储、读取记忆场景,调用方读到的是 python 对象,不需要序列化和反序列化的过程。
-
需要持久化的场景,用户调用 add_chunk() 后,存储侧可以自动调用
Serializable对象的dump()方法把 Python 对象存储起来, 反序列化的时候,可以根据Serializable对象的load()方法把对象还原。
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Protocol, TypeVar, runtime_checkable
def get_current_timestamp() -> int:
return time.time_ns()
@dataclass
class Message:
"""
Message model
"""
chunks: list[MessageChunk] = field(default_factory=list)
id: str = field(default_factory=lambda: str(uuid.uuid4()))
timestamp: int = field(default_factory=get_current_timestamp)
meta_data: dict = field(default_factory=dict)
role: MessageRole = MessageRole.UNSPECIFIED
@dataclass
class MessageChunk:
"""
Message chunk model
"""
content: str | bytes | Serializable
require_upload: bool = False # 是否需要上传
media_id: str | None = None
T = TypeVar("T", bound="Serializable")
@runtime_checkable
class Serializable(Protocol):
def dump(self) -> str:
raise NotImplementedError
@classmethod
def load(cls: type[T], data: str) -> T:
"""从字符串反序列化对象"""
raise NotImplementedError
额外信息¶
同时 message 里面还含有如下额外信息,用于对记忆做召回和排序
- id: message 的唯一标识,默认使用 uuid
- timestamp: message 生成的时间戳
- meta_data: 其他 meta_data,dict 形式,由用户指定,比如可以设置 agent_id,session_id 等,用于对不同的会话做持久化
快速开始¶
Memory 实例化¶
这个 SimpleMemory 用于 agent runtime 框架运行时,存储完全在内存中,不带持久化功能,需要持久化功能的话,可以参考后面的 PersistenceMemory
记忆存储¶
调用 add_chunk() 存储信息,参数 content 可以是 str | bytes|Serializable 类型。 Serializable 的详细信息见下面的 API 解释
存储多模态数据到 message¶
message 是记忆存储的最小单位, 一条 message 中可以包含 文本和图片、语音混合的内容
from tongagents.message.message import (
MessageChunk,
MessageRole,
)
chunks = []
text_chunk = MessageChunk(content="请描述图片内容")
chunks.append(text_chunk)
image_chunk = MessageChunk(content=b"image content")
chunks.append(image_chunk)
role = MessageRole.USER
memory.add_chunks(chunks=chunks,role=role)
存储多模态数据,用户需要先构建 MessageChunk,每个 chunk 存储一种类型的数据。 使用 memory.add_chunks() 把 chunk 的 list 添加到记忆模块中去。 chunk list 会被隐式的用来构建一条 message。 Message 的具体定义可以参考 Message
Get Messages¶
获取记忆中的 messages 信息
get_messages() 返回的是 list of message
Get Contents¶
大多数情况下, 用户希望从 message 中拿到 content,而不是 message 对象
contents 是 list[str|bytes|Serializable],对 agent 来说, contents 返回的是用户当时存进去的 python 对象,没有序列化的过程,可以直接使用
带持久化功能的 Memory¶
用户存到记忆模块的 message, 带持久化功能
from tongagents.memory.memory import PersistenceMemory
from tongagents.utils.media import MediaFileManager
media_manager = MediaFileManager(base_path="/tmp/TongAgents/tests/medias")
agent_id = "test_agent"
session_id = "test_session"
memory = PersistenceMemory(
agent_id=agent_id, session_id=session_id, media_manager=media_manager
)
media_manager 是个存储管理器,用于存储消息中的图片、音频等对象,兼容本地存储和 s3。 用户可以指定 base_path base_path 可以是本地文件系统路径, 也可以是 s3 的 URL。
远程存储引擎¶
PersistenceMemory 默认会使用本地的 sqllite 作为持久化存储引擎。用户也可以使用自己的 mysql 等关系型数据库作为存储引擎。 PersistenceMemory 底层使用 sqlalchemy 来管理存储引擎, 所以用户使用 mysql 作为存储引擎的时候,可以通过指定 URL 来初始化一个存储引擎。 代码如下:
from tongagents.memory.sql_engine import SQLStoreEngine
from tongagents.memory.memory import PersistenceMemory
from tongagents.utils.media import MediaFileManager
DATABASE_URL = "mysql+mysqlconnector://root:root@127.0.0.1:3306/test_db"
agent_id = "test_agent"
session_id = "test_session"
memory = PersistenceMemory(
agent_id=agent_id,
session_id=session_id,
media_manager=media_manager,
engine=SQLStoreEngine(DATABASE_URL),
)
存储简单 str content¶
from tongagents.message.message import (
MessageRole,
)
role = MessageRole.USER # 按需替换
memory.add_chunk(content="simple_string", role = role)
memory 为前面实例化的 PersistenceMemory 对象
存储二进制内容¶
二进制内容如果需要 s3 存储,需要指定 require_upload = True
用户也可以自己调 s3 接口(自己的 s3 URL),拿到 s3 返回的 media_id,然后再存到记忆中
存储可序列化内容¶
有些场景下,用户需要直接把自定义的 python 对象扔到存储里面。
为了方便从持久化存储中 load 出当时存储的 python 对象,需要对象是个Serializable类型,实现 load(), dump 方法, 定义如下:
@runtime_checkable
class Serializable(Protocol):
def dump(self) -> str:
raise NotImplementedError
@classmethod
def load(cls: type[T], data: str) -> T:
"""从字符串反序列化对象"""
raise NotImplementedError
#type alias, str is the name of the class, callable is the deserializer
ConfigDeserializer = dict[str, Callable[[str], Serializable]]
ConfigDeserializer是个 dict 形式的解析器配置,用于从持久化存储中 load 出当时存的 python 对象。
key 是可序列化对象的class name,value 是这个对象的load() 方法。使用例子如下:
from tongagents.memory.memory import PersistenceMemory
from tongagents.utils.media import MediaFileManager
from tongagents.message.message import ConfigDeserializer, Serializable
import json
agent_id = "test_agent1"
session_id = "test_session1"
media_manager = MediaFileManager(base_path="/tmp/TongAgents/tests/medias")
class ToolCall:
def __init__(self, data: dict):
self.data = data
def dump(self) -> str:
return json.dumps(self.data) # JSON 序列化
@classmethod
def load(cls, data: str) -> "ToolCall":
parsed_data = json.loads(data) # 反序列化
return cls(parsed_data)
content: Serializable = ToolCall(
{"tool_call": {"name": "weather", "params": {"city": "Paris"}}}
)
memory = PersistenceMemory(
agent_id=agent_id,
session_id=session_id,
media_manager=media_manager,
)
memory.add_chunk(content=content)
name = type(content).__name__
config = ConfigDeserializer({name: ToolCall.load})
result = memory.get_contents(config=config)[0]
assert isinstance(result, ToolCall)
assert result.data["tool_call"]["name"] == "weather"
TooCall 是我们定义的一个 Serializable 对象。 使用 memory.add_chunk() 把 ToolCall 对象存储到 PersistenceMemory 中。
从持久化 memory 中取回 ToolCall 对象需要配置 ConfigDeserializer dict, 用以从持久化存储中读到的 str 反序列成 ToolCall 对象。
config 的 key 是 ToolCall 的 class.__name__, value 是ToolCall 的load() 方法
API¶
MemoryBase 定义了记忆的抽象接口,提供存储、查询记忆的功能
add_chunk¶
调用add_chunk 把 Serializable python 对象添加到记忆中
class MemoryBase(ABC):
@abstractmethod
def add_chunk(
self,
content: Serializable,
role: MessageRole = MessageRole.UNSPECIFIED,
require_upload: bool = False,
):
"""
abstract interface
Add messages to memory
Args:
content: Serializable
role: MessageRole
require_upload: bool
"""
add_chunks¶
调用 add_chunks 把多模态数据添加到记忆中。 chunks 会隐式的构建一条 message, 适用于一条 message 中包含用户的对话文本、图片、语音等多种内容的场景
@abstractmethod
def add_chunks(self, chunks: list[MessageChunk], role=MessageRole.UNSPECIFIED):
"""
abstract interface
support for a message contains multiple chunks, such as one message contains
text and image chunks
"""
add_messages¶
存储 list of message 到记忆
@abstractmethod
def add_messages(self, messages: list[Message]):
"""
abstract interface
Add messages to memory
"""
get_messages¶
从记忆中读取 messages
@abstractmethod
def get_messages(
self,
agent_id: str | None = None,
session_id: str | None = None,
recent_n: int | None = None,
) -> list[Message]:
"""
abstract interface
Get messages from memory
"""
get_contents¶
从记忆中读取 message 结构里的 content 实体,不含元信息。
content 是用户存入记忆的文本信息 or 二进制信息或者原始的 python 对象
@abstractmethod
def get_contents(self) -> list[str | bytes | Serializable]:
"""
abstract interface
Get contents from memory
"""
SimpleMemory¶
最简单的记忆存储,纯内存形式,message 存储到列表中。
SimpleMemory 继承 MemoryBase, 和 MemoryBase 中定义的接口保持一致。
PersistenceMemory¶
持久化的记忆, 底层使用 s3、 关系型 DB 存储相关数据
from tongagents.memory.memory import MemoryBase
from tongagents.memory.engine import StoreEngine
from tongagents.memory.sql_engine import SQLStoreEngine
from tongagents.utils.media import MediaFileManager
DEFAULT_MEDIA_MANAGER = MediaFileManager()
class PersistenceMemory(MemoryBase):
"""
support for persistence memory
"""
def __init__(
self,
agent_id: str,
session_id: str,
engine: StoreEngine | None = None,
media_manager: MediaFileManager = DEFAULT_MEDIA_MANAGER,
):
if engine is None:
engine = SQLStoreEngine()
self.engine = engine
self.meta_data = {}
self.meta_data["agent_id"] = agent_id
self.meta_data["session_id"] = session_id
self.media_manager = media_manager