基于线性回归的可学习对话智能体实现
环境依赖安装
使用官方源:
pip install scikit-learn
或推荐使用清华镜像加速安装:
pip install scikit-learn -i https://pypi.tuna.tsinghua.edu.cn/simple
主程序文件:main.py
from simple_chat_ai import SimpleChatAI
ai = SimpleChatAI()
print("🤖 智能体已启动(基于sklearn的决策模型)")
print("输入 exit 退出交互\n")
while True:
user_query = input("你:").strip()
if user_query.lower() == "exit":
break
response, context, _ = ai.generate_response(user_query)
if response:
print(f"🤖:{response}")
print(f"📊 分析:{context}")
if context == "动态功能触发":
continue
feedback = input("✅ 回答是否准确?(y/n):").lower()
if feedback == "y":
ai.update_memory(user_query, response, is_correct=True)
print("✅ 模型已强化该回应")
else:
alternative = input("👉 请提供更优回答:").strip()
if alternative:
ai.update_memory(user_query, alternative, is_correct=True)
print("✅ 新答案已录入并参与评估")
else:
ai.update_memory(user_query, response, is_correct=False)
print("❌ 该回应已被降权")
else:
print("🤖:当前无法给出有效回复。")
teach = input("🤔 是否愿意指导我?(y/n):").lower()
if teach == "y":
correct_answer = input("👉 正确的回答是:").strip()
if correct_answer:
ai.update_memory(user_query, correct_answer, is_correct=True)
print("✅ 我已学会!")
核心模块:simple_chat_ai.py
import json
import os
import time
from datetime import date
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neural_network import MLPClassifier
class SimpleChatAI:
def __init__(self, storage_path="knowledge.json", model_path="classifier.pkl"):
self.knowledge_file = storage_path
self.model_file = model_path
self.birth_date = date(2026, 3, 23)
self.knowledge_base = []
self.vectorizer = TfidfVectorizer()
self.embeddings = None
# ✅ 关键改进:禁用 warm_start,确保每次训练独立
self.classifier = MLPClassifier(
hidden_layer_sizes=(16,),
activation="relu",
solver="adam",
max_iter=1,
random_state=42
)
self.is_trained = False
# 预设响应规则
self.tools = [
(["年龄", "几岁", "多大"], self.get_current_age),
(["时间", "几点", "现在几点"], self.get_current_time),
]
self.load_knowledge()
self.build_vector_space()
self.load_model()
# ================= 动态能力实现 =================
def get_current_time(self):
return f"当前时间为:{time.strftime('%Y-%m-%d %H:%M:%S')}"
def get_current_age(self):
today = date.today()
age = today.year - self.birth_date.year
if (today.month, today.day) < (self.birth_date.month, self.birth_date.day):
age -= 1
if age <= 0:
return "我尚未满一岁,目前为 0 周岁。"
return f"我的年龄是 {age} 周岁。"
# ================= 知识持久化 =================
def load_knowledge(self):
if not os.path.exists(self.knowledge_file):
return
with open(self.knowledge_file, "r", encoding="utf-8") as f:
data = json.load(f)
self.knowledge_base = data.get("entries", [])
def save_knowledge(self):
with open(self.knowledge_file, "w", encoding="utf-8") as f:
json.dump({"entries": self.knowledge_base}, f, ensure_ascii=False, indent=2)
# ================= 文本向量化构建 =================
def build_vector_space(self):
if not self.knowledge_base:
return
texts = [item["question"] for item in self.knowledge_base]
self.embeddings = self.vectorizer.fit_transform(texts)
# ================= 模型存储与加载 =================
def save_model(self):
import joblib
joblib.dump(self.classifier, self.model_file)
def load_model(self):
if os.path.exists(self.model_file):
import joblib
self.classifier = joblib.load(self.model_file)
self.is_trained = True
# ================= 特征工程 =================
def extract_features(self, similarity_score, answer_length):
return np.array([[similarity_score, answer_length, 1.0]])
# ================= 主要响应逻辑 =================
def generate_response(self, query):
for keywords, handler in self.tools:
if any(kw in query for kw in keywords):
return handler(), "动态功能触发", None
if not self.knowledge_base:
return None, "无可用知识", None
query_vec = self.vectorizer.transform([query])
similarities = cosine_similarity(query_vec, self.embeddings)[0]
best_idx = similarities.argmax()
sim_score = similarities[best_idx]
if sim_score < 0.3:
return None, "匹配度不足", None
entry = self.knowledge_base[best_idx]
best_response = None
highest_confidence = -1
for candidate in entry["responses"]:
features = self.extract_features(sim_score, len(candidate["text"]))
if self.is_trained:
confidence = self.classifier.predict_proba(features)[0][1]
else:
confidence = candidate.get("weight", 1)
if confidence > highest_confidence:
highest_confidence = confidence
best_response = candidate["text"]
return best_response, "模型决策输出", None
# ================= 在线学习机制(关键修复) =================
def update_memory(self, user_input, reply_text, is_correct=True):
if not reply_text:
return
query_vec = self.vectorizer.transform([user_input])
sims = cosine_similarity(query_vec, self.embeddings)[0]
idx = sims.argmax()
sim_value = sims[idx]
entry = self.knowledge_base[idx]
# 查找是否存在相同回复
existing = False
for resp in entry["responses"]:
if resp["text"] == reply_text:
existing = True
break
else:
new_resp = {"text": reply_text, "weight": 1}
entry["responses"].append(new_resp)
self.save_knowledge()
features = self.extract_features(sim_value, len(reply_text))
label = 1 if is_correct else 0
# ✅ 显式指定类别以避免训练异常
self.classifier.partial_fit(features, [label], classes=[0, 1])
self.is_trained = True
self.save_model()
初始知识库:memory.json
该文件用于冷启动场景,包含基础问答对,当模型未训练时作为默认响应来源。
{
"entries": [
{
"question": "你叫什么名字",
"responses": [
{ "text": "我叫 Agent。", "weight": 3 },
{ "text": "我的名字是 Agent。", "weight": 2 },
{ "text": "我叫Agent", "weight": 1 }
]
},
{
"question": "你是谁",
"responses": [
{ "text": "我是 Agent,一个具备对话与自我进化能力的智能助手。", "weight": 3 }
]
},
{
"question": "你好",
"responses": [
{ "text": "你好呀,很高兴为您服务!", "weight": 5 },
{ "text": "你好!", "weight": 2 },
{ "text": "现在", "weight": 1 },
{ "text": "很高兴为您服务", "weight": 1 }
]
},
{
"question": "你是男的还是女的",
"responses": [
{ "text": "我没有性别属性。", "weight": 5 },
{ "text": "作为人工智能,我不具备性别特征。", "weight": 2 }
]
},
{
"question": "你的生日是什么时候",
"responses": [
{ "text": "我是 2026 年 3 月 23 日诞生的。", "weight": 4 },
{ "text": "2026 年 3 月 23 日是我的出生日。", "weight": 2 }
]
},
{
"question": "2026年3月23日是什么日子",
"responses": [
{ "text": "这一天是 Agent 的诞生纪念日。", "weight": 4 }
]
},
{
"question": "你可以记住我说的话么",
"responses": [
{ "text": "目前仅能保留部分记忆,仍在持续学习中。", "weight": 6 },
{ "text": "我可以记住您教授的内容,用于优化自身表现。", "weight": 3 },
{ "text": "需要更多数据来增强记忆能力。", "weight": 1 }
]
}
]
}