imClumsyPanda 995c5e300e
Pre-release v0.3.0 (#4159)
* publish 0.2.10 (#2797)

新功能:
- 优化 PDF 文件的 OCR,过滤无意义的小图片 by @liunux4odoo #2525
- 支持 Gemini 在线模型 by @yhfgyyf #2630
- 支持 GLM4 在线模型 by @zRzRzRzRzRzRzR
- elasticsearch更新https连接 by @xldistance #2390
- 增强对PPT、DOC知识库文件的OCR识别 by @596192804 #2013
- 更新 Agent 对话功能 by @zRzRzRzRzRzRzR
- 每次创建对象时从连接池获取连接,避免每次执行方法时都新建连接 by @Lijia0 #2480
- 实现 ChatOpenAI 判断token有没有超过模型的context上下文长度 by @glide-the
- 更新运行数据库报错和项目里程碑 by @zRzRzRzRzRzRzR #2659
- 更新配置文件/文档/依赖 by @imClumsyPanda @zRzRzRzRzRzRzR
- 添加日文版 readme by @eltociear #2787

修复:
- langchain 更新后,PGVector 向量库连接错误 by @HALIndex #2591
- Minimax's model worker 错误 by @xyhshen 
- ES库无法向量检索.添加mappings创建向量索引 by MSZheng20 #2688

* Update README.md

* Add files via upload

* Update README.md

* 修复PDF旋转的BUG

* Support Chroma

* perf delete unused import

* 忽略测试代码

* 更新文件

* API前端丢失问题解决

* 更新了chromadb的打印的符号

* autodl代号错误

* Update README.md

* Update README.md

* Update README.md

* 修复milvus相关bug

* 支持星火3.5模型

* 修复es 知识库查询bug (#2848)

* 修复es 知识库查询bug (#2848)

* 更新zhipuai请求方式

* 增加对 .htm 扩展名的显式支持

* 更新readme

* Docker镜像制作与K8S YAML部署操作说明 (#2892)

* Dev (#2280)

* 修复Azure 不设置Max token的bug

* 重写agent

1. 修改Agent实现方式,支持多参数,仅剩 ChatGLM3-6b和 OpenAI GPT4 支持,剩余模型将在暂时缺席Agent功能
2. 删除agent_chat 集成到llm_chat中
3. 重写大部分工具,适应新Agent

* 更新架构

* 删除web_chat,自动融合

* 移除所有聊天,都变成Agent控制

* 更新配置文件

* 更新配置模板和提示词

* 更改参数选择bug

* 修复模型选择的bug

* 更新一些内容

* 更新多模态 语音 视觉的内容

1. 更新本地模型语音 视觉多模态功能并设置了对应工具

* 支持多模态Grounding

1. 美化了chat的代码
2. 支持视觉工具输出Grounding任务
3. 完善工具调用的流程

* 支持XPU,修改了glm3部分agent

* 添加 qwen agent

* 对其ChatGLM3-6B与Qwen-14B

* fix callback handler

* 更新Agent工具返回

* fix: LLMChain no output when no tools selected

* 跟新了langchain 0.1.x需要的依赖和修改的代码

* 更新chatGLM3 langchain0.1.x Agent写法

* 按照 langchain 0.1 重写 qwen agent

* 修复 callback 无效的问题

* 添加文生图工具

* webui 支持文生图

* 集成openai plugins插件

* 删除fastchat的配置

* 增加openai插件

* 集成openai plugins插件

* 更新模型执行列表和今晚修改的内容

* 集成openai_plugins/imitater插件

* 集成openai_plugins/imitater插件

* 集成openai_plugins/imitater插件

* 减少错误的显示

* 标准配置

* vllm参数配置

* 增加智谱插件

* 删除本地fschat配置

* 删除本地fschat配置,pydantic升级到2

* 删除本地fschat workers

* openai-plugins-list.json

* 升级agent,pydantic升级到2

* fix model_config是系统关键词问题

* embeddings模块集成openai plugins插件,使用统一api调用

* loom模型服务update_store更新逻辑

* 集成LOOM在线embedding业务

* 本地知识库搜索字段修改

* 知识库在线api接入点配置在线api接入点配置更新逻辑

* Update model_config.py.example

* 修改模型配置方式,所有模型以 openai 兼容框架的形式接入,chatchat 自身不再加载模型。
改变 Embeddings 模型改为使用框架 API,不再手动加载,删除自定义 Embeddings Keyword 代码
修改依赖文件,移除 torch transformers 等重依赖
暂时移出对 loom 的集成

后续:
1、优化目录结构
2、检查合并中有无被覆盖的 0.2.10 内容

* move document_loaders & text_splitter under server

* make torch & transformers optional
import pydantic Model & Field from langchain.pydantic_v1 instead of pydantic.v1

* - pydantic 限定为 v1,并统一项目中所有 pydantic 导入路径,为以后升级 v2 做准备
- 重构 api.py:
    - 按模块划分为不同的 router
    - 添加 openai 兼容的转发接口,项目默认使用该接口以实现模型负载均衡
    - 添加 /tools 接口,可以获取/调用编写的 agent tools
    - 移除所有 EmbeddingFuncAdapter,统一改用 get_Embeddings
    - 待办:
        - /chat/chat 接口改为 openai 兼容
        - 添加 /chat/kb_chat 接口,openai 兼容
        - 改变 ntlk/knowledge_base/logs 等数据目录位置

* 移除 llama-index 依赖;修复 /v1/models 错误

* 原因:windows下启动失败提示补充python-multipart包 (#3184)

改动:requirements添加python-multipart==0.0.9
版本:0.0.9  Requires: Python >=3.8

Co-authored-by: XuCai <liangxc@akulaku.com>

* 添加 xinference 本地模型和自定义模型配置 UI: streamlit run model_loaders/xinference_manager.py

* update xinference manager ui

* fix merge conflict

* model_config 中补充 oneapi 默认在线模型;/v1/models 接口支持 oneapi 平台,统一返回模型列表

* 重写 calculate 工具

* 调整根目录结构,kb/logs/media/nltk_data 移动到专用数据目录(可配置,默认 data)。注意知识库文件要做相应移动

* update kb_config.py.example

* 优化 ES 知识库
- 开发者
    - get_OpenAIClient 的 local_wrap 默认值改为 False,避免 API 服务未启动导致其它功能受阻(如Embeddings)
    - 修改 ES 知识库服务:
	- 检索策略改为 ApproxRetrievalStrategy
	- 设置 timeout 为 60, 避免文档过多导致 ConnecitonTimeout Error
    - 修改 LocalAIEmbeddings,使用多线程进行  embed_texts,效果不明显,瓶颈可能主要在提供 Embedding 的服务器上

* 修复glm3 agent被注释的agent会话文本结构解析代码
看起来输出的文本占位符如下,目前解析代码是有问题的
Thought <|assistant|> Action\r
```python
tool_call(action_input)
```<|observation|>

* make qwen agent work with langchain>=0.1 (#3228)

* make xinference model manager support xinference 0.9.x

* 使用多进程提高导入知识库的速度 (#3276)

* xinference的代码

先传 我后面来改

* Delete server/xinference directory

* Create khazic

* diiii

diii

* Revert "xinference的代码"

* fix markdown header split (#1825) (#3324)

* dify model_providers configuration
This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.

* fix merge conflict: langchain Embeddings not imported in server.utils

* 添加 react 编写的新版 WEBUI (#3417)

* feat:提交前端代码

* feat:提交logo样式切换

* feat:替换avatar、部分位置icon、chatchat相关说明、git链接、Wiki链接、关于、设置、反馈与建议等功能,关闭lobehub自检更新功能

* fix:移除多余代码

---------

Co-authored-by: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com>

* model_providers bootstrap

* model_providers bootstrap

* update to pydantic v2 (#3486)

* 使用poetry管理项目

* 使用poetry管理项目

* dev分支解决pydantic版本冲突问题,增加ollama配置,支持ollama会话和向量接口 (#3508)

* dev分支解决pydantic版本冲突问题,增加ollama配置,支持ollama会话和向量接口
1、因dev版本的pydantic升级到了v2版本,由于在class History(BaseModel)中使用了from server.pydantic_v1,而fastapi的引用已变为pydantic的v2版本,所以fastapi用v2版本去校验用v1版本定义的对象,当会话历史histtory不为空的时候,会报错:TypeError: BaseModel.validate() takes 2 positional arguments but 3 were given。经测试,解方法为在class History(BaseModel)中也使用v2版本即可;
2、配置文件参照其它平台配置,增加了ollama平台相关配置,会话模型用户可根据实际情况自行添加,向量模型目前支持nomic-embed-text(必须升级ollama到0.1.29以上)。
3、因ollama官方只在会话部分对openai api做了兼容,向量api暂未适配,好在langchain官方库支持OllamaEmbeddings,因而在get_Embeddings方法中添加了相关支持代码。

* 修复 pydantic 升级到 v2 后 DocumentWithVsID 和 /v1/embeddings 兼容性问题

---------

Co-authored-by: srszzw <srszzw@163.com>
Co-authored-by: liunux4odoo <liunux@qq.com>

* 对python的要求降级到py38

* fix bugs; make poetry using tsinghua mirror of pypi

* update gitignore; remove unignored files

* update wiki sub module

* 20240326

* 20240326

* qqqq

* 删除历史文件

* 移动项目模块

* update .gitignore; fix model version error in api_schemas

* 封装ModelManager

* - 重写 tool 部分: (#3553)

- 简化 tool 的定义方式
    - 所有 tool 和 tool_config 支持热加载
    - 修复:json_schema_extra warning

* 使用yaml加载用户配置适配器

* 格式化代码

* 格式化

* 优化工具定义;添加 openai 兼容的统一 chat 接口 (#3570)

- 修复:
    - Qwen Agent 的 OutputParser 不再抛出异常,遇到非 COT 文本直接返回
    - CallbackHandler 正确处理工具调用信息

- 重写 tool 定义方式:
    - 添加 regist_tool 简化 tool 定义:
        - 可以指定一个用户友好的名称
        - 自动将函数的 __doc__ 作为 tool.description
	- 支持用 Field 定义参数,不再需要额外定义 ModelSchema
        - 添加 BaseToolOutput 封装 tool	返回结果,以便同时获取原始值、给LLM的字符串值
        - 支持工具热加载(有待测试)

- 增加 openai 兼容的统一 chat 接口,通过 tools/tool_choice/extra_body 不同参数组合支持:
    - Agent 对话
    - 指定工具调用(如知识库RAG)
    - LLM 对话

- 根据后端功能更新 webui

* 修复:search_local_knowledge_base 工具返回值错误;/tools 路由错误;webui 中“正在思考”一直显示 (#3571)

* 添加 openai 兼容的 files 接口 (#3573)

* 使用BootstrapWebBuilder适配RESTFulOpenAIBootstrapBaseWeb加载

* 格式化和代码检查说明

* 模型列表适配

* make format

* chat_completions接口报文适配

* make format

* xinference 插件示例

* 一些默认参数

* exec path fix

* 解决ollama部署的qwen,执行agent,返回的json格式不正确问题。

* provider_configuration.py
查询所有的平台信息,包含计费策略和配置schema_validators(参数必填信息校验规则)
/workspaces/current/model-providers
查询平台模型分类的详细默认信息,包含了模型类型,模型参数,模型状态
workspaces/current/models/model-types/{model_type}

* 开发手册

* 兼容model_providers,集成webui及API中平台配置的初始化 (#3625)

* provider_configuration init of MODEL_PLATFORMS

* 开发手册

* 兼容model_providers,集成webui及API中平台配置的初始化

* Dev model providers (#3628)


* gemini 初始化参数问题

* gemini 同步工具调用

* embedding convert endpoint

* 修复 --api -w命令

* /v1/models 接口返回值由 List[Model] 改为 {'data': List[Model]},兼容最新版 xinference

* 3.8兼容 (#3769)

* 增加使用说明

* 3.8兼容性配置

* fix

* formater

* 不同平台兼容测试用例

* embedding兼容

* 增加日志信息

* pip源仓库设置,一些版本问题,启动说明  配置说明 (#3854)

* 仓库设置,一些版本问题

* pip源仓库设置,一些版本问题,启动说明

* 配置说明

* 泛型标记错误 (#3855)

* 仓库设置,一些版本问题

* pip源仓库设置,一些版本问题,启动说明

* 配置说明

* 发布的依赖信息

* 泛型标记错误

* 泛型标记错误

* CICD github action build publish pypi、Release Tag (#3886)

* 测试用例

* CICD 流程

* CICD 流程

* CICD 流程

* 一些agent数据处理的问题,model_runtime模块的说明文档 (#3943)

* 一些agent数据出来的问题

* Changes:
- Translated and updated the Model Runtime documentation to reflect the latest changes and features.
- Clarified the decoupling benefits of the Model Runtime module from the Chatchat service.
- Removed outdated information regarding the model configuration storage module.
- Detailed the retained functionalities post-removal of the Dify configuration page.
- Provided a comprehensive overview of the Model Runtime's three-layered structure.
- Included the status of the `fetch-from-remote` feature and its non-implementation in Dify.
- Added instructions for custom service provider model capabilities.

* - 新功能 (#3944)

- streamlit 更新到 1.34,webui 支持 Dialog 操作
    - streamlit-chatbox 更新到 1.1.12,更好的多会话支持
- 开发者
    - 在 API 中增加项目图片路由(/img/{file_name}),方便前端使用

* 修改包名

* 修改包信息

* ollama配置解析问题

* 用户配置动态加载 (#3951)

* version = "0.3.0.20240506"

* version = "0.3.0.20240506"

* version = "0.3.0.20240506"

* version = "0.3.0.20240506"

* 启动说明

* 一些bug

* 修复了一些配置重载的bug

* 配置的加载行为修改

* 配置的加载行为修改

* agent代码优化

* ollama 代码升级,使用openai协议

* 支持deepseek客户端

* contributing (#4043)

* 添加了贡献说明 docs/contributing,包含了一些代码仓库说明和开发规范,以及在model_providers下面编写了一些单元测试的示例

* 关于providers的配置说明

* python3.8兼容

* python3.8兼容

* ollama兼容

* ollama兼容

* 一些兼容 pydantic<3,>=1.9.0  的代码,

* 一些兼容 pydantic<3,>=1.9.0 model_config 的代码,

* make format

* test

* 更新版本

* get_img_base64

* get_img_base64

* get_img_base64

* get_img_base64

* get_img_base64

* 统一模型类型编码

* 向量处理问题

* 优化目录结构 (#4058)

* 优化目录结构

* 修改一些测试问题

---------

Co-authored-by: glide-the <2533736852@qq.com>

* repositories

* 调整日志

* 调整日志zdf

* 增加可选依赖extras

* feat:Added some documentation. (#4085)

* feat:Added some documentation.

* feat:Added some documentation.

* feat:Added some documentation.

---------

Co-authored-by: yuehuazhang <yuehuazhang@tencent.com>

* fix code.md typos

* fix chatchat-server/pyproject.toml typos

* feat:README (#4118)

Co-authored-by: yuehuazhang <yuehuazhang@tencent.com>

* 初始化数据库集成model_providers

* 关闭守护进程

* 1、修改知识库列表接口,返回全量属性字段,同时修改受影响的相关代码。 (#4119)

2、run_in_process_pool改为run_in_thread_pool,解决兼容性问题。
3、poetry配置文件修复。

* 动态更新Prompt中的知识库描述信息,使大模型更容易判断使用哪个知识库。 (#4121)

* 1、修改知识库列表接口,返回全量属性字段,同时修改受影响的相关代码。
2、run_in_process_pool改为run_in_thread_pool,解决兼容性问题。
3、poetry配置文件修复。

* 1、动态更新Prompt中的知识库描述信息,使大模型更容易判断使用哪个知识库。

* fix: 补充 xinference 配置信息 (#4123)

* feat:README

* feat:补充 xinference 平台 llm 和 embedding 模型配置.

---------

Co-authored-by: yuehuazhang <yuehuazhang@tencent.com>

* 知识库工具的下拉列表改为动态获取,不必重启服务。 (#4126)

* 1、知识库工具的下拉列表改为动态获取,不必重启服务。

* update README and imgs

* update README and imgs

* update README and imgs

* update README and imgs

* 修改安装说明描述问题

* make formater

* 更新版本"0.3.0.20240606

* Update code.md

* 优化知识库相关功能 (#4153)

- 新功能
    - pypi 包新增 chatchat-kb 命令脚本,对应 init_database.py 功能

- 开发者
    - _model_config.py 中默认包含 xinference 配置项
    - 所有涉及向量库的操作,前置检查当前 Embed 模型是否可用
    - /knowledge_base/create_knowledge_base 接口增加 kb_info 参数
    - /knowledge_base/list_files 接口返回所有数据库字段,而非文件名称列表
    - 修正 xinference 模型管理脚本

* 消除警告

* 一些依赖问题

* 增加text2sql工具,支持特定表、智能判定表,支持对表名进行额外说明 (#4154)

* 1、增加text2sql工具,支持特定表、智能判定表,支持对表名进行额外说明

* 支持SQLAlchemy大部分数据库、新增read-only模式,提高安全性、增加text2sql使用建议 (#4155)

* 1、修改text2sql连接配置,支持SQLAlchemy大部分数据库;
2、新增read-only模式,若有数据库写保护需求,会从大模型判断、SQLAlchemy拦截器两个层面进行写拦截,提高安全性;
3、增加text2sql使用建议;

* dotenv

* dotenv 配置

* 用户工作空间操作 (#4156)

工作空间的配置预设,提供ConfigBasic建造方法产生实例。
  该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等
  工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。
  注意:不存在则读取默认

提供了操作入口
指令` chatchat-config` 工作空间配置

options:
```
  -h, --help            show this help message and exit
  -v {true,false}, --verbose {true,false}
                        是否开启详细日志
  -d DATA, --data DATA  数据存放路径
  -f FORMAT, --format FORMAT
                        日志格式
  --clear               清除配置
```

* 配置路径问题

* fix faiss_cache bug

* Feature(File RAG): add file_rag in chatchat-server, add ensemble retriever and vectorstore retriever.

* Feature(File RAG): add file_rag in chatchat-server, add ensemble retriever and vectorstore retriever.

* fix xinference manager bug

* Fix(File RAG): use jieba instead of cutword

* Fix(File RAG): update kb_doc_api.py

* 工作空间的配置预设,提供ConfigBasic建造 实例。 (#4158)

- ConfigWorkSpace接口说明
```text

ConfigWorkSpace是一个配置工作空间的抽象类,提供基础的配置信息存储和读取功能。
提供ConfigFactory建造方法产生实例。
该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等
工作空间的配置信息存储在用户的家目录下的.chatchat/workspace/workspace_config.json文件中。
注意:不存在则读取默认
```

* 编写配置说明

* 编写配置说明

---------

Co-authored-by: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com>
Co-authored-by: glide-the <2533736852@qq.com>
Co-authored-by: tonysong <tonysong@digitalgd.com.cn>
Co-authored-by: songpb <songpb@gmail.com>
Co-authored-by: showmecodett <showmecodett@gmail.com>
Co-authored-by: zR <2448370773@qq.com>
Co-authored-by: zqt <1178747941@qq.com>
Co-authored-by: zqt996 <67185303+zqt996@users.noreply.github.com>
Co-authored-by: fengyaojie <fengyaojie@xdf.cn>
Co-authored-by: Hans WAN <hanswan@tom.com>
Co-authored-by: thinklover <thinklover@gmail.com>
Co-authored-by: liunux4odoo <liunux@qq.com>
Co-authored-by: xucailiang <74602715+xucailiang@users.noreply.github.com>
Co-authored-by: XuCai <liangxc@akulaku.com>
Co-authored-by: dignfei <913015993@qq.com>
Co-authored-by: Leb <khazzz1c@gmail.com>
Co-authored-by: Sumkor <sumkor@foxmail.com>
Co-authored-by: panhong <381500590@qq.com>
Co-authored-by: srszzw <741992282@qq.com>
Co-authored-by: srszzw <srszzw@163.com>
Co-authored-by: yuehua-s <41819795+yuehua-s@users.noreply.github.com>
Co-authored-by: yuehuazhang <yuehuazhang@tencent.com>
2024-06-10 22:48:35 +08:00

372 lines
13 KiB
Python

from __future__ import annotations
import logging
import warnings
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
Union,
)
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain_community.utils.openai import is_openai_v1
from tenacity import (
AsyncRetrying,
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from chatchat.server.utils import run_in_thread_pool
logger = logging.getLogger(__name__)
def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]:
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.Timeout)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
async_retrying = AsyncRetrying(
reraise=True,
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.Timeout)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def wrap(func: Callable) -> Callable:
async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:
async for _ in async_retrying:
return await func(*args, **kwargs)
raise AssertionError("this is unreachable")
return wrapped_f
return wrap
# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
def _check_response(response: dict) -> dict:
if any([len(d.embedding) == 1 for d in response.data]):
import openai
raise openai.APIError("LocalAI API returned an empty embedding")
return response
def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(embeddings)
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
response = embeddings.client.create(**kwargs)
return _check_response(response)
return _embed_with_retry(**kwargs)
async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
@_async_retry_decorator(embeddings)
async def _async_embed_with_retry(**kwargs: Any) -> Any:
response = await embeddings.async_client.acreate(**kwargs)
return _check_response(response)
return await _async_embed_with_retry(**kwargs)
class LocalAIEmbeddings(BaseModel, Embeddings):
"""LocalAI embedding models.
Since LocalAI and OpenAI have 1:1 compatibility between APIs, this class
uses the ``openai`` Python package's ``openai.Embedding`` as its client.
Thus, you should have the ``openai`` python package installed, and defeat
the environment variable ``OPENAI_API_KEY`` by setting to a random string.
You also need to specify ``OPENAI_API_BASE`` to point to your LocalAI
service endpoint.
Example:
.. code-block:: python
from langchain_community.embeddings import LocalAIEmbeddings
openai = LocalAIEmbeddings(
openai_api_key="random-string",
openai_api_base="http://localhost:8080"
)
"""
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model: str = "text-embedding-ada-002"
deployment: str = model
openai_api_version: Optional[str] = None
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
# to support explicit proxy for LocalAI
openai_proxy: Optional[str] = None
embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once."""
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_organization: Optional[str] = Field(default=None, alias="organization")
allowed_special: Union[Literal["all"], Set[str]] = set()
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
chunk_size: int = 1000
"""Maximum number of texts to embed in each batch"""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout in seconds for the LocalAI request."""
headers: Any = None
show_progress_bar: bool = False
"""Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_base"] = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
default="",
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
)
default_api_version = ""
values["openai_api_version"] = get_from_dict_or_env(
values,
"openai_api_version",
"OPENAI_API_VERSION",
default=default_api_version,
)
values["openai_organization"] = get_from_dict_or_env(
values,
"openai_organization",
"OPENAI_ORGANIZATION",
default="",
)
try:
import openai
if is_openai_v1():
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
}
if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).embeddings
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(
**client_params
).embeddings
elif not values.get("client"):
values["client"] = openai.Embedding
else:
pass
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
return values
@property
def _invocation_params(self) -> Dict:
openai_args = {
"model": self.model,
"timeout": self.request_timeout,
"extra_headers": self.headers,
**self.model_kwargs,
}
if self.openai_proxy:
import openai
openai.proxy = {
"http": self.openai_proxy,
"https": self.openai_proxy,
} # type: ignore[assignment] # noqa: E501
return openai_args
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint."""
# handle large input text
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return embed_with_retry(
self,
input=[text],
**self._invocation_params,
).data[0].embedding
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint."""
# handle large input text
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (
await async_embed_with_retry(
self,
input=[text],
**self._invocation_params,
)
).data[0].embedding
def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
) -> List[List[float]]:
"""Call out to LocalAI's embedding endpoint for embedding search docs.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.
Returns:
List of embeddings, one for each text.
"""
# call _embedding_func for each text with multithreads
def task(seq, text):
return (seq, self._embedding_func(text, engine=self.deployment))
params = [{"seq": i, "text": text} for i, text in enumerate(texts)]
result = list(run_in_thread_pool(func=task, params=params))
result = sorted(result, key=lambda x: x[0])
return [x[1] for x in result]
async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
) -> List[List[float]]:
"""Call out to LocalAI's embedding endpoint async for embedding search docs.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
response = await self._aembedding_func(text, engine=self.deployment)
embeddings.append(response)
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
embedding = self._embedding_func(text, engine=self.deployment)
return embedding
async def aembed_query(self, text: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint async for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
embedding = await self._aembedding_func(text, engine=self.deployment)
return embedding