大模型学习笔记:Memory 临时、长期记忆

Memory临时记忆

RunnableWithMessageHistory 是LangChain内Runnable接口的实现,主要用于创建一个带有历史记忆功能的Runnable实例(链),在创建时需要提供一个BaseChatMessageHistory的具体实现(用来存储历史消息),InMemoryChatMessageHistory可以实现在内存中存储历史,额外,如果需要在invoke或stream执行链的同时,将提示词print出来,可以在链中加入自定义函数实现。注意函数的输入应原封不动返回出去,避免破坏原有业务。

from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import InMemoryChatMessageHistory

model = ChatTongyi(model="qwen3-max")
prompt = PromptTemplate.from_template(
    "你需要根据会话历史回应用户问题。对话历史:{chat_history},用户提问:{input},请回答"
)

"""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "你需要根据会话历史回应用户问题。对话历史:"),
        MessagesPlaceholder("chat_history"),
        ("human", "请回答如下问题:{input}")
    ]
)

"""

str_parser = StrOutputParser()

def print_prompt(full_prompt):
    print("="*20, full_prompt.to_string(), "="*20)
    return full_prompt

base_chain = prompt | print_prompt | model | str_parser

store = {}      # key就是session,value就是InMemoryChatMessageHistory类对象

# 实现通过会话id获取InMemoryChatMessageHistory类对象
def get_history(session_id):
    if session_id not in store:
        store[session_id] = InMemoryChatMessageHistory()
    return store[session_id]

# 创建一个新的链,对原有链增强功能:自动附加历史消息
conversation_chain = RunnableWithMessageHistory(
    base_chain,     # 被增强的原有chain
    get_history,    # 通过会话id获取InMemoryChatMessageHistory类对象
    input_messages_key="input",             # 表示用户输入在模板中的占位符
    history_messages_key="chat_history"     # 表示用户输入在模板中的占位符
)

if __name__ == '__main__':
    # 固定格式,添加LangChain的配置,为当前程序配置所属的session_id
    session_config = {
        "configurable": {
            "session_id": "user_001"
        }
    }

    res = conversation_chain.invoke({"input": "小明有2个猫"}, session_config)
    print("第1次执行:", res)
    res = conversation_chain.invoke({"input": "小刚有1只狗"}, session_config)
    print("第2次执行:", res)
    res = conversation_chain.invoke({"input": "总共有几个宠物"}, session_config)
    print("第3次执行:", res)

Memory长期记忆

import os, json
from typing import Sequence
from langchain_core.messages import messages_from_dict, message_to_dict, BaseMessage
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory


#message_to_dict: 单个消息对象 (BaseMessage类实例)--> dict
#messages_to_dict: 消息对象列表(List[BaseMessage])--> List[dict]
# AIMessage, HumanMessage, SystemMessage等消息对象都继承自BaseMessage类

class FileChatMessageHistory(BaseChatMessageHistory):
    def __init__(self, session_id, storage_path):
        self.session_id = session_id        #会话ID
        self.storage_path = storage_path    #不同会话ID的存储文件,所在的文件夹路径
        self.file_path = os.path.join(self.storage_path, f"{self.session_id}.json")
        # 如果文件不存在,则创建一个空文件
        os.makedirs(os.path.dirname(self.file_path), exist_ok=True)

    def add_messages(self, message: Sequence[BaseMessage]):
        #Sequence 序列,类和list,tuple等都可以被看作是Sequence的子类
        all_messages = list(self.messages)  # 获取当前会话的所有消息,messages是BaseChatMessageHistory类的属性,表示当前会话的消息列表
        all_messages.extend(message)    # 将新消息添加到当前会话的消息列表中    
        # 将消息列表转换为字典列表,并写入文件
        #类对象写入文件 --> 一堆二进制
        #为了方便,可以将BaseMessage类对象转换为dict字典对象,再借用json模块将json写入文件
        #官方message_to_dict
        new_messages = []
        for msg in all_messages:
            d = message_to_dict(msg)   # 将BaseMessage类对象转换为dict字典对象
            new_messages.append(d)
        #将数据写入文件
        with open(self.file_path, "w", encoding="utf-8") as f:
            json.dump(new_messages, f)

    @property #表示将一个方法变成属性调用,调用时不需要加括号
    def messages(self) -> list[BaseMessage]:
        #当前文件内:list[dict] --> list[BaseMessage]  
        try:
            with open(self.file_path, "r", encoding="utf-8") as f:
                message_data = json.load(f)    # 从文件中读取数据,得到一个list[dict]对象
                return messages_from_dict(message_data)   # 将list[dict]对象转换为list[BaseMessage]对象
        except FileNotFoundError:
            return []  
    def clear(self):
        #清空文件内容
        with open(self.file_path, "w", encoding="utf-8") as f:
            json.dump([], f)


model = ChatTongyi(model="qwen3-max")
prompt = PromptTemplate.from_template(
    "你需要根据会话历史回应用户问题。对话历史:{chat_history},用户提问:{input},请回答"
)
"""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "你需要根据会话历史回应用户问题。对话历史:"),
        MessagesPlaceholder("chat_history"),
        ("human", "请回答如下问题:{input}")
    ]
)

"""
str_parser = StrOutputParser()
def print_prompt(full_prompt):
    print("="*20, full_prompt.to_string(), "="*20)
    return full_prompt
base_chain = prompt | print_prompt | model | str_parser
store = {}      # key就是session,value就是InMemoryChatMessageHistory类对象

# 实现通过会话id获取InMemoryChatMessageHistory类对象
def get_history(session_id):
    return FileChatMessageHistory(session_id, storage_path="./chat_history")   # 每个会话ID对应一个文件,存储在./chat_history文件夹下

# 创建一个新的链,对原有链增强功能:自动附加历史消息
conversation_chain = RunnableWithMessageHistory(
    base_chain,     # 被增强的原有chain
    get_history,    # 通过会话id获取InMemoryChatMessageHistory类对象
    input_messages_key="input",             # 表示用户输入在模板中的占位符
    history_messages_key="chat_history"     # 表示用户输入在模板中的占位符
)

if __name__ == '__main__':
    # 固定格式,添加LangChain的配置,为当前程序配置所属的session_id
    session_config = {
        "configurable": {
            "session_id": "user_001"
        }
    }
    #res = conversation_chain.invoke({"input": "小明有2个猫"}, session_config)
    #print("第1次执行:", res)
    #res = conversation_chain.invoke({"input": "小刚有1只狗"}, session_config)
    #print("第2次执行:", res)
    res = conversation_chain.invoke({"input": "总共有几个宠物"}, session_config)
    print("第3次执行:", res)

 

 

Leave a Comment

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.