跳转至

记忆

概述

Memory(智能体记忆) 模块提供智能体的记忆的存储和召回。 智能体的记忆主要由智能体的对话和交互过程产生。 记忆的最小单位是 Message,每个 message 都有个 role(比如:"User","Assistant")和 content,还有额外的 metadata 构成。

Message

描述 message 内部结构的信息

Role

角色用于区分对话中不同类型的消息,帮助模型理解要如何针对特定的消息做出相应的回应。 message 有 5 种类型,分别代表 user/system/assistant/tool/unspecified

角色 描述
system 用于告诉聊天模型如何做并提供额外的上下文。
user 代表用户与模型交互的输入,通常是文本或其他交互式输入。
assistant 代表模型的响应,可以包括文本或调用工具的请求。
tool 用于在调用工具后,将外部数据或处理结果传递回模型的消息。适用于支持工具调用的聊天模型。
unspecified 用户未指定时的默认值。
class MessageRole(Enum):
    """
    Enum class for message.
    """

    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"
    TOOL = "tool"
    UNSPECIFIED = "unspecified"

Chunks

message 中的核心内容是 chunks,chunks 是个 list,每个 chunk 里面存储 message 的具体内容。设计成 list,可以支持 message 包含多模态数据,比如同时有文本和语言数据。

MessageChunk 的 content 可以是 str/bytes 和任意可序列化的对象。 这么设计是为了支持 agent 记忆的内存存储和持久化存储两种模式。

content 是可序列化的 python 对象的好处是:

  • 使用内存存储、读取记忆场景,调用方读到的是 python 对象,不需要序列化和反序列化的过程。

  • 需要持久化的场景,用户调用 add_chunk() 后,存储侧可以自动调用Serializable对象的 dump() 方法把 Python 对象存储起来, 反序列化的时候,可以根据Serializable对象的load() 方法把对象还原。

import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Protocol, TypeVar, runtime_checkable

def get_current_timestamp() -> int:
    return time.time_ns()

@dataclass
class Message:
    """
    Message model
    """

    chunks: list[MessageChunk] = field(default_factory=list)
    id: str = field(default_factory=lambda: str(uuid.uuid4()))
    timestamp: int = field(default_factory=get_current_timestamp)
    meta_data: dict = field(default_factory=dict)
    role: MessageRole = MessageRole.UNSPECIFIED

@dataclass
class MessageChunk:
    """
    Message chunk model
    """

    content: str | bytes | Serializable
    require_upload: bool = False  # 是否需要上传
    media_id: str | None = None


T = TypeVar("T", bound="Serializable")
@runtime_checkable
class Serializable(Protocol):
    def dump(self) -> str:
        raise NotImplementedError

    @classmethod
    def load(cls: type[T], data: str) -> T:
        """从字符串反序列化对象"""
        raise NotImplementedError

额外信息

同时 message 里面还含有如下额外信息,用于对记忆做召回和排序

  • id: message 的唯一标识,默认使用 uuid
  • timestamp: message 生成的时间戳
  • meta_data: 其他 meta_data,dict 形式,由用户指定,比如可以设置 agent_id,session_id 等,用于对不同的会话做持久化

快速开始

Memory 实例化

from tongagents.memory.memory import SimpleMemory
memory = SimpleMemory()

这个 SimpleMemory 用于 agent runtime 框架运行时,存储完全在内存中,不带持久化功能,需要持久化功能的话,可以参考后面的 PersistenceMemory

记忆存储

调用 add_chunk() 存储信息,参数 content 可以是 str | bytes|Serializable 类型。 Serializable 的详细信息见下面的 API 解释

content = "some content"
memory.add_chunk(content=content)

存储多模态数据到 message

message 是记忆存储的最小单位, 一条 message 中可以包含 文本和图片、语音混合的内容

from tongagents.message.message import (
  MessageChunk,
  MessageRole,
)

chunks = []
text_chunk = MessageChunk(content="请描述图片内容")
chunks.append(text_chunk)
image_chunk = MessageChunk(content=b"image content")
chunks.append(image_chunk)
role = MessageRole.USER
memory.add_chunks(chunks=chunks,role=role)

存储多模态数据,用户需要先构建 MessageChunk,每个 chunk 存储一种类型的数据。 使用 memory.add_chunks() 把 chunk 的 list 添加到记忆模块中去。 chunk list 会被隐式的用来构建一条 message。 Message 的具体定义可以参考 Message

Get Messages

获取记忆中的 messages 信息

all = memory.get_messages()

get_messages() 返回的是 list of message

Get Contents

大多数情况下, 用户希望从 message 中拿到 content,而不是 message 对象

contents = memory.get_contents()

contents 是 list[str|bytes|Serializable],对 agent 来说, contents 返回的是用户当时存进去的 python 对象,没有序列化的过程,可以直接使用

带持久化功能的 Memory

用户存到记忆模块的 message, 带持久化功能

  from tongagents.memory.memory import PersistenceMemory
  from tongagents.utils.media import MediaFileManager

  media_manager = MediaFileManager(base_path="/tmp/TongAgents/tests/medias")

  agent_id = "test_agent"
  session_id = "test_session"
  memory = PersistenceMemory(
      agent_id=agent_id, session_id=session_id, media_manager=media_manager
  )

media_manager 是个存储管理器,用于存储消息中的图片、音频等对象,兼容本地存储和 s3。 用户可以指定 base_path base_path 可以是本地文件系统路径, 也可以是 s3 的 URL。

远程存储引擎

PersistenceMemory 默认会使用本地的 sqllite 作为持久化存储引擎。用户也可以使用自己的 mysql 等关系型数据库作为存储引擎。 PersistenceMemory 底层使用 sqlalchemy 来管理存储引擎, 所以用户使用 mysql 作为存储引擎的时候,可以通过指定 URL 来初始化一个存储引擎。 代码如下:

from tongagents.memory.sql_engine import SQLStoreEngine
from tongagents.memory.memory import PersistenceMemory
from tongagents.utils.media import MediaFileManager

DATABASE_URL = "mysql+mysqlconnector://root:root@127.0.0.1:3306/test_db"

agent_id = "test_agent"
session_id = "test_session"
memory = PersistenceMemory(
        agent_id=agent_id,
        session_id=session_id,
        media_manager=media_manager,
        engine=SQLStoreEngine(DATABASE_URL),
)

存储简单 str content

from tongagents.message.message import (
MessageRole,
)

role = MessageRole.USER # 按需替换
memory.add_chunk(content="simple_string" role = role)

memory 为前面实例化的 PersistenceMemory 对象

存储二进制内容

二进制内容如果需要 s3 存储,需要指定 require_upload = True

memory.add_chunk(content=b"binary_data", require_upload=True)

用户也可以自己调 s3 接口(自己的 s3 URL),拿到 s3 返回的 media_id,然后再存到记忆中

memory.add_chunk(media_id=media_id)

存储可序列化内容

有些场景下,用户需要直接把自定义的 python 对象扔到存储里面。 为了方便从持久化存储中 load 出当时存储的 python 对象,需要对象是个Serializable类型,实现 load()dump 方法, 定义如下:

@runtime_checkable
class Serializable(Protocol):
  def dump(self) -> str:
      raise NotImplementedError

  @classmethod
  def load(cls: type[T], data: str) -> T:
      """从字符串反序列化对象"""
      raise NotImplementedError
#type alias, str is the name of the class, callable is the deserializer
ConfigDeserializer = dict[str, Callable[[str], Serializable]]

ConfigDeserializer是个 dict 形式的解析器配置,用于从持久化存储中 load 出当时存的 python 对象。 key 是可序列化对象的class name,value 是这个对象的load() 方法。使用例子如下:

from tongagents.memory.memory import PersistenceMemory
from tongagents.utils.media import MediaFileManager
from tongagents.message.message import ConfigDeserializer, Serializable
import json

agent_id = "test_agent1"
session_id = "test_session1"
media_manager = MediaFileManager(base_path="/tmp/TongAgents/tests/medias")


class ToolCall:
    def __init__(self, data: dict):
        self.data = data

    def dump(self) -> str:
        return json.dumps(self.data)  # JSON 序列化

    @classmethod
    def load(cls, data: str) -> "ToolCall":
        parsed_data = json.loads(data)  # 反序列化
        return cls(parsed_data)


content: Serializable = ToolCall(
    {"tool_call": {"name": "weather", "params": {"city": "Paris"}}}
)

memory = PersistenceMemory(
    agent_id=agent_id,
    session_id=session_id,
    media_manager=media_manager,
)
memory.add_chunk(content=content)
name = type(content).__name__
config = ConfigDeserializer({name: ToolCall.load})
result = memory.get_contents(config=config)[0]
assert isinstance(result, ToolCall)
assert result.data["tool_call"]["name"] == "weather"

TooCall 是我们定义的一个 Serializable 对象。 使用 memory.add_chunk()ToolCall 对象存储到 PersistenceMemory 中。

从持久化 memory 中取回 ToolCall 对象需要配置 ConfigDeserializer dict, 用以从持久化存储中读到的 str 反序列成 ToolCall 对象。

configkeyToolCallclass.__name__valueToolCallload() 方法

API

MemoryBase 定义了记忆的抽象接口,提供存储、查询记忆的功能

add_chunk

调用add_chunkSerializable python 对象添加到记忆中

class MemoryBase(ABC):
  @abstractmethod
  def add_chunk(
      self,
      content: Serializable,
      role: MessageRole = MessageRole.UNSPECIFIED,
      require_upload: bool = False,
  ):
      """
      abstract interface
      Add messages  to memory
      Args:
          content: Serializable
          role: MessageRole
          require_upload: bool
      """

add_chunks

调用 add_chunks 把多模态数据添加到记忆中。 chunks 会隐式的构建一条 message, 适用于一条 message 中包含用户的对话文本、图片、语音等多种内容的场景

@abstractmethod
def add_chunks(self, chunks: list[MessageChunk], role=MessageRole.UNSPECIFIED):
  """
  abstract interface
  support for a message contains multiple chunks, such as one message contains
  text and image chunks
  """

add_messages

存储 list of message 到记忆

@abstractmethod
def add_messages(self, messages: list[Message]):
    """
    abstract interface
    Add messages to memory
    """

get_messages

从记忆中读取 messages

@abstractmethod
def get_messages(
    self,
    agent_id: str | None = None,
    session_id: str | None = None,
    recent_n: int | None = None,
) -> list[Message]:
    """
    abstract interface
    Get messages from memory
    """

get_contents

从记忆中读取 message 结构里的 content 实体,不含元信息。

content 是用户存入记忆的文本信息 or 二进制信息或者原始的 python 对象

@abstractmethod
def get_contents(self) -> list[str | bytes | Serializable]:
    """
    abstract interface
    Get contents from memory
    """

SimpleMemory

最简单的记忆存储,纯内存形式,message 存储到列表中。

SimpleMemory 继承 MemoryBase, 和 MemoryBase 中定义的接口保持一致。

class SimpleMemory(MemoryBase):
    """
    agent running memory
    """

    def __init__(self):
        self.store = []

PersistenceMemory

持久化的记忆, 底层使用 s3、 关系型 DB 存储相关数据

from tongagents.memory.memory import MemoryBase
from tongagents.memory.engine import StoreEngine
from tongagents.memory.sql_engine import SQLStoreEngine
from tongagents.utils.media import MediaFileManager

DEFAULT_MEDIA_MANAGER = MediaFileManager()

class PersistenceMemory(MemoryBase):
    """
    support for persistence memory
    """

    def __init__(
        self,
        agent_id: str,
        session_id: str,
        engine: StoreEngine | None = None,
        media_manager: MediaFileManager = DEFAULT_MEDIA_MANAGER,
    ):
        if engine is None:
            engine = SQLStoreEngine()
        self.engine = engine
        self.meta_data = {}
        self.meta_data["agent_id"] = agent_id
        self.meta_data["session_id"] = session_id
        self.media_manager = media_manager