当前位置:首页 > 技术 > 正文内容

基于线性回归的可学习对话智能体实现

访客 技术 2026年6月6日 1

环境依赖安装

使用官方源:

pip install scikit-learn

或推荐使用清华镜像加速安装:

pip install scikit-learn -i https://pypi.tuna.tsinghua.edu.cn/simple

主程序文件:main.py

from simple_chat_ai import SimpleChatAI

ai = SimpleChatAI()

print("🤖 智能体已启动(基于sklearn的决策模型)")
print("输入 exit 退出交互\n")

while True:
    user_query = input("你:").strip()
    if user_query.lower() == "exit":
        break

    response, context, _ = ai.generate_response(user_query)

    if response:
        print(f"🤖:{response}")
        print(f"📊 分析:{context}")

        if context == "动态功能触发":
            continue

        feedback = input("✅ 回答是否准确?(y/n):").lower()
        if feedback == "y":
            ai.update_memory(user_query, response, is_correct=True)
            print("✅ 模型已强化该回应")
        else:
            alternative = input("👉 请提供更优回答:").strip()
            if alternative:
                ai.update_memory(user_query, alternative, is_correct=True)
                print("✅ 新答案已录入并参与评估")
            else:
                ai.update_memory(user_query, response, is_correct=False)
                print("❌ 该回应已被降权")
    else:
        print("🤖:当前无法给出有效回复。")
        teach = input("🤔 是否愿意指导我?(y/n):").lower()
        if teach == "y":
            correct_answer = input("👉 正确的回答是:").strip()
            if correct_answer:
                ai.update_memory(user_query, correct_answer, is_correct=True)
                print("✅ 我已学会!")

核心模块:simple_chat_ai.py

import json
import os
import time
from datetime import date

import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neural_network import MLPClassifier


class SimpleChatAI:
    def __init__(self, storage_path="knowledge.json", model_path="classifier.pkl"):
        self.knowledge_file = storage_path
        self.model_file = model_path

        self.birth_date = date(2026, 3, 23)

        self.knowledge_base = []

        self.vectorizer = TfidfVectorizer()
        self.embeddings = None

        # ✅ 关键改进:禁用 warm_start,确保每次训练独立
        self.classifier = MLPClassifier(
            hidden_layer_sizes=(16,),
            activation="relu",
            solver="adam",
            max_iter=1,
            random_state=42
        )
        self.is_trained = False

        # 预设响应规则
        self.tools = [
            (["年龄", "几岁", "多大"], self.get_current_age),
            (["时间", "几点", "现在几点"], self.get_current_time),
        ]

        self.load_knowledge()
        self.build_vector_space()
        self.load_model()

    # ================= 动态能力实现 =================
    def get_current_time(self):
        return f"当前时间为:{time.strftime('%Y-%m-%d %H:%M:%S')}"

    def get_current_age(self):
        today = date.today()
        age = today.year - self.birth_date.year
        if (today.month, today.day) < (self.birth_date.month, self.birth_date.day):
            age -= 1
        if age <= 0:
            return "我尚未满一岁,目前为 0 周岁。"
        return f"我的年龄是 {age} 周岁。"

    # ================= 知识持久化 =================
    def load_knowledge(self):
        if not os.path.exists(self.knowledge_file):
            return
        with open(self.knowledge_file, "r", encoding="utf-8") as f:
            data = json.load(f)
            self.knowledge_base = data.get("entries", [])

    def save_knowledge(self):
        with open(self.knowledge_file, "w", encoding="utf-8") as f:
            json.dump({"entries": self.knowledge_base}, f, ensure_ascii=False, indent=2)

    # ================= 文本向量化构建 =================
    def build_vector_space(self):
        if not self.knowledge_base:
            return
        texts = [item["question"] for item in self.knowledge_base]
        self.embeddings = self.vectorizer.fit_transform(texts)

    # ================= 模型存储与加载 =================
    def save_model(self):
        import joblib
        joblib.dump(self.classifier, self.model_file)

    def load_model(self):
        if os.path.exists(self.model_file):
            import joblib
            self.classifier = joblib.load(self.model_file)
            self.is_trained = True

    # ================= 特征工程 =================
    def extract_features(self, similarity_score, answer_length):
        return np.array([[similarity_score, answer_length, 1.0]])

    # ================= 主要响应逻辑 =================
    def generate_response(self, query):
        for keywords, handler in self.tools:
            if any(kw in query for kw in keywords):
                return handler(), "动态功能触发", None

        if not self.knowledge_base:
            return None, "无可用知识", None

        query_vec = self.vectorizer.transform([query])
        similarities = cosine_similarity(query_vec, self.embeddings)[0]
        best_idx = similarities.argmax()
        sim_score = similarities[best_idx]

        if sim_score < 0.3:
            return None, "匹配度不足", None

        entry = self.knowledge_base[best_idx]
        best_response = None
        highest_confidence = -1

        for candidate in entry["responses"]:
            features = self.extract_features(sim_score, len(candidate["text"]))
            if self.is_trained:
                confidence = self.classifier.predict_proba(features)[0][1]
            else:
                confidence = candidate.get("weight", 1)

            if confidence > highest_confidence:
                highest_confidence = confidence
                best_response = candidate["text"]

        return best_response, "模型决策输出", None

    # ================= 在线学习机制(关键修复) =================
    def update_memory(self, user_input, reply_text, is_correct=True):
        if not reply_text:
            return

        query_vec = self.vectorizer.transform([user_input])
        sims = cosine_similarity(query_vec, self.embeddings)[0]
        idx = sims.argmax()
        sim_value = sims[idx]

        entry = self.knowledge_base[idx]

        # 查找是否存在相同回复
        existing = False
        for resp in entry["responses"]:
            if resp["text"] == reply_text:
                existing = True
                break
        else:
            new_resp = {"text": reply_text, "weight": 1}
            entry["responses"].append(new_resp)
            self.save_knowledge()

        features = self.extract_features(sim_value, len(reply_text))
        label = 1 if is_correct else 0

        # ✅ 显式指定类别以避免训练异常
        self.classifier.partial_fit(features, [label], classes=[0, 1])
        self.is_trained = True
        self.save_model()

初始知识库:memory.json

该文件用于冷启动场景,包含基础问答对,当模型未训练时作为默认响应来源。

{
  "entries": [
    {
      "question": "你叫什么名字",
      "responses": [
        { "text": "我叫 Agent。", "weight": 3 },
        { "text": "我的名字是 Agent。", "weight": 2 },
        { "text": "我叫Agent", "weight": 1 }
      ]
    },
    {
      "question": "你是谁",
      "responses": [
        { "text": "我是 Agent,一个具备对话与自我进化能力的智能助手。", "weight": 3 }
      ]
    },
    {
      "question": "你好",
      "responses": [
        { "text": "你好呀,很高兴为您服务!", "weight": 5 },
        { "text": "你好!", "weight": 2 },
        { "text": "现在", "weight": 1 },
        { "text": "很高兴为您服务", "weight": 1 }
      ]
    },
    {
      "question": "你是男的还是女的",
      "responses": [
        { "text": "我没有性别属性。", "weight": 5 },
        { "text": "作为人工智能,我不具备性别特征。", "weight": 2 }
      ]
    },
    {
      "question": "你的生日是什么时候",
      "responses": [
        { "text": "我是 2026 年 3 月 23 日诞生的。", "weight": 4 },
        { "text": "2026 年 3 月 23 日是我的出生日。", "weight": 2 }
      ]
    },
    {
      "question": "2026年3月23日是什么日子",
      "responses": [
        { "text": "这一天是 Agent 的诞生纪念日。", "weight": 4 }
      ]
    },
    {
      "question": "你可以记住我说的话么",
      "responses": [
        { "text": "目前仅能保留部分记忆,仍在持续学习中。", "weight": 6 },
        { "text": "我可以记住您教授的内容,用于优化自身表现。", "weight": 3 },
        { "text": "需要更多数据来增强记忆能力。", "weight": 1 }
      ]
    }
  ]
}

相关文章

Linux crontab 详解

1) crontab 是什么cron 是 Linux 的定时任务守护进程;crontab 是用来编辑/查看“按时间周期执行命令”的表(cron table)。常见两类:用户 crontab:每个用户一份(crontab -e 编辑)系统级 crontab / cron.d:可指定执行用户(/etc/crontab、/etc/cron.d/*)2) crontab 时间...

富文本里可以允许的 HTML 属性

一、所有标签默认允许的安全属性(极少)class        (可选)id           (通常建议禁用)title️ 注意:id 容易被滥用做锚点注入,很多系统直接禁用class 允许的话最好只允许固定前缀(如 editor-*)二、a 标签允许属性<a href="" t...

Mac 安装 Node.js 指南

方法一:通过官网安装包(最简单,适合初学者)如果你只是想快速安装并开始使用,这是最直接的方法。访问 Node.js 官网。页面会显示两个版本:LTS (Recommended For Most Users):长期支持版,最稳定。建议选这个。Current:最新特性版,包含最新功能但可能不够稳定。下载 .pkg 安装包并运行。按照安装向导点击“下一步”即可完成。方法二:使用 Homebrew 安装(...

Dom\HTML_NO_DEFAULT_NS 的副作用:自动加闭合标签

在使用Dom\HTMLDocument时,Dom\HTML_NO_DEFAULT_NS 将禁止在解析过程中设置元素的命名空间, 此设置是为了与DOMDocument向后兼容而存在的。当使用它时,已知的一个副作用就是:自动加闭合标签例如 </img> 为什么会这样?当你使用:Dom\HTML_NO_DEFAULT_NS文档会变成 无命名空间模式,此时内部更接近 XML...

Laravel 事件和监听器创建

在 Laravel 中,使用 Artisan 命令创建 Events(事件) 和 Listeners(监听器) 是非常高效的。你可以通过以下几种方式来实现:1. 手动创建单个 Event如果你只想创建一个事件类,可以使用 make:event 命令:Bashphp artisan make:event UserRegistered执行后,文件将生成在 app/Even...

自定义域名解析神器 dnsmasq

什么是 dnsmasq?dnsmasq 是一个轻量级、功能强大的网络服务工具,专为小型和中等规模网络设计。它是一个综合的网络基础设施解决方案[1]。dnsmasq 能做什么?功能说明应用场景DNS 转发与缓存将 DNS 查询转发到上游服务器(ISP、Google DNS 等),并在本地缓存结果加快 DNS 查询速度,减少外部 DNS 流量本地 DNS解析本地网络设备的主机名,无需编辑&n...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。