gte-base-zh LangChain集成教程:将gte-base-zh作为Embeddings类注入RAG链

张开发
2026/4/9 5:04:43 15 分钟阅读

分享文章

gte-base-zh LangChain集成教程:将gte-base-zh作为Embeddings类注入RAG链
gte-base-zh LangChain集成教程将gte-base-zh作为Embeddings类注入RAG链1. 环境准备与模型部署在开始集成之前我们需要先确保gte-base-zh模型已经正确部署并运行。这个模型由阿里巴巴达摩院训练基于BERT框架专门为中文文本嵌入优化在信息检索、语义相似度计算等场景下表现优异。1.1 模型部署步骤首先确认模型文件位置gte-base-zh模型默认安装在/usr/local/bin/AI-ModelScope/gte-base-zh启动xinference服务这是模型服务的基础xinference-local --host 0.0.0.0 --port 9997然后通过专用脚本启动模型服务python /usr/local/bin/launch_model_server.py1.2 验证服务状态部署完成后检查服务是否正常启动cat /root/workspace/model_server.log看到类似下面的输出说明模型服务已经成功启动Model gte-base-zh loaded successfully Service started on port 99972. 理解gte-base-zh嵌入模型gte-base-zh是一个专门针对中文优化的文本嵌入模型它在海量中文文本对上进行训练能够将文本转换为高质量的向量表示。2.1 模型核心能力这个模型的主要优势包括中文优化专门为中文文本设计和训练高质量嵌入生成的向量能够很好地捕捉语义信息多场景适用支持信息检索、语义相似度、文本重排序等任务易于集成提供标准的API接口方便与其他系统集成2.2 测试模型功能在集成到LangChain之前可以先通过Web界面测试模型功能访问xinference的Web UI界面点击示例文本或输入自定义文本点击相似度比对按钮查看模型返回的相似度结果这样可以帮助你理解模型的工作原理和效果。3. LangChain集成实战现在进入核心部分我们将把gte-base-zh嵌入模型集成到LangChain的RAG链中。3.1 安装必要依赖首先确保安装了必要的Python包pip install langchain openai xinference3.2 创建自定义Embeddings类我们需要创建一个继承自LangChain BaseEmbeddings类的自定义类from langchain.embeddings.base import Embeddings from typing import List import requests import json class GTEBaseZHEmbeddings(Embeddings): def __init__(self, base_urlhttp://localhost:9997): self.base_url base_url self.model_name gte-base-zh def embed_documents(self, texts: List[str]) - List[List[float]]: 为文档列表生成嵌入向量 embeddings [] for text in texts: embedding self._get_embedding(text) embeddings.append(embedding) return embeddings def embed_query(self, text: str) - List[float]: 为查询文本生成嵌入向量 return self._get_embedding(text) def _get_embedding(self, text: str) - List[float]: 调用gte-base-zh模型API获取嵌入向量 url f{self.base_url}/v1/embeddings payload { model: self.model_name, input: text } try: response requests.post(url, jsonpayload) response.raise_for_status() result response.json() return result[data][0][embedding] except Exception as e: print(f获取嵌入向量失败: {e}) return [0.0] * 768 # 返回默认向量3.3 集成到RAG链中现在我们可以将自定义的嵌入类集成到完整的RAG链中from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain.chains import RetrievalQA from langchain.llms import OpenAI # 初始化自定义嵌入模型 embeddings GTEBaseZHEmbeddings() # 准备文档并分割 text_splitter RecursiveCharacterTextSplitter( chunk_size1000, chunk_overlap200 ) documents [你的文档内容在这里...] # 替换为实际文档 texts text_splitter.split_documents(documents) # 创建向量存储 vectorstore Chroma.from_documents( documentstexts, embeddingembeddings, persist_directory./chroma_db ) # 创建检索器 retriever vectorstore.as_retriever( search_typesimilarity, search_kwargs{k: 5} ) # 创建RAG链 qa_chain RetrievalQA.from_chain_type( llmOpenAI(), # 可以使用其他LLM chain_typestuff, retrieverretriever, return_source_documentsTrue )4. 实际应用示例让我们看一个完整的应用示例展示如何使用集成后的RAG链。4.1 问答系统实现def setup_rag_system(documents_path): 设置完整的RAG系统 # 读取文档 with open(documents_path, r, encodingutf-8) as f: content f.read() # 初始化嵌入模型 embeddings GTEBaseZHEmbeddings() # 分割文档 text_splitter RecursiveCharacterTextSplitter( chunk_size1000, chunk_overlap200 ) texts text_splitter.split_text(content) # 创建向量存储 vectorstore Chroma.from_texts( textstexts, embeddingembeddings, persist_directory./rag_db ) # 创建检索QA链 qa_chain RetrievalQA.from_chain_type( llmOpenAI(temperature0), chain_typestuff, retrievervectorstore.as_retriever(), return_source_documentsTrue ) return qa_chain # 使用示例 rag_system setup_rag_system(your_documents.txt) result rag_system(你的问题是什么) print(result[result])4.2 批量处理优化对于大量文档的处理我们可以优化性能from concurrent.futures import ThreadPoolExecutor class BatchGTEBaseZHEmbeddings(GTEBaseZHEmbeddings): def embed_documents(self, texts: List[str], batch_size: int 32) - List[List[float]]: 批量处理文档嵌入提高效率 all_embeddings [] with ThreadPoolExecutor(max_workers4) as executor: for i in range(0, len(texts), batch_size): batch texts[i:i batch_size] batch_embeddings list(executor.map(self._get_embedding, batch)) all_embeddings.extend(batch_embeddings) return all_embeddings5. 性能优化与最佳实践为了获得更好的效果这里有一些实用的优化建议。5.1 配置优化调整模型参数以获得更好的性能# 优化后的嵌入类配置 class OptimizedGTEEmbeddings(GTEBaseZHEmbeddings): def __init__(self, base_urlhttp://localhost:9997, timeout30): super().__init__(base_url) self.timeout timeout self.max_retries 3 def _get_embedding(self, text: str) - List[float]: 带重试机制的嵌入获取 for attempt in range(self.max_retries): try: url f{self.base_url}/v1/embeddings payload { model: self.model_name, input: text[:512] # 限制文本长度 } response requests.post( url, jsonpayload, timeoutself.timeout ) response.raise_for_status() result response.json() return result[data][0][embedding] except Exception as e: if attempt self.max_retries - 1: print(f所有重试失败: {e}) return [0.0] * 768 continue5.2 缓存策略实现嵌入向量缓存减少重复计算from functools import lru_cache class CachedGTEEmbeddings(GTEBaseZHEmbeddings): lru_cache(maxsize1000) def _get_embedding_cached(self, text: str) - List[float]: 带缓存的嵌入获取 return self._get_embedding(text) def embed_query(self, text: str) - List[float]: return self._get_embedding_cached(text) def embed_documents(self, texts: List[str]) - List[List[float]]: return [self._get_embedding_cached(text) for text in texts]6. 常见问题解决在实际使用过程中可能会遇到一些问题这里提供解决方案。6.1 连接问题处理如果遇到连接问题可以这样处理def check_service_health(base_url): 检查模型服务健康状态 try: response requests.get(f{base_url}/health, timeout5) return response.status_code 200 except: return False # 使用健康检查 if not check_service_health(http://localhost:9997): print(模型服务未就绪请检查服务状态) # 可以在这里添加自动重启服务的逻辑6.2 性能监控添加性能监控代码import time from datetime import datetime class MonitoredGTEEmbeddings(GTEBaseZHEmbeddings): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.request_times [] def _get_embedding(self, text: str) - List[float]: start_time time.time() try: result super()._get_embedding(text) elapsed time.time() - start_time self.request_times.append(elapsed) return result except Exception as e: elapsed time.time() - start_time print(f请求失败耗时 {elapsed:.2f}s: {e}) raise def get_performance_stats(self): 获取性能统计 if not self.request_times: return 暂无请求数据 avg_time sum(self.request_times) / len(self.request_times) max_time max(self.request_times) min_time min(self.request_times) return f平均: {avg_time:.2f}s, 最大: {max_time:.2f}s, 最小: {min_time:.2f}s7. 总结通过本教程你已经学会了如何将gte-base-zh嵌入模型集成到LangChain的RAG链中。关键要点包括模型部署正确启动xinference和gte-base-zh模型服务自定义集成创建继承自BaseEmbeddings的自定义类实战应用将嵌入模型应用到完整的RAG系统中性能优化通过批量处理、缓存等策略提升性能问题解决处理常见的连接和性能问题这种集成方式不仅适用于gte-base-zh也可以作为其他嵌入模型集成到LangChain的参考模板。在实际项目中你可以根据具体需求调整配置参数优化性能表现。记住定期检查模型服务的健康状态确保RAG系统的稳定运行。如果遇到问题可以参考常见问题解决部分或者查看模型服务的日志文件来排查问题。获取更多AI镜像想探索更多AI镜像和应用场景访问 CSDN星图镜像广场提供丰富的预置镜像覆盖大模型推理、图像生成、视频生成、模型微调等多个领域支持一键部署。

更多文章