mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 07:23:29 +08:00
embedding兼容
This commit is contained in:
parent
f30bdf01fc
commit
55b32b29fe
@ -18,6 +18,7 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
|
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@ -208,16 +209,39 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||||
)
|
)
|
||||||
model_instance = self._provider_manager.get_model_instance(
|
try:
|
||||||
provider=provider,
|
model_instance = self._provider_manager.get_model_instance(
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
provider=provider,
|
||||||
model=embeddings_request.model,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
)
|
model=embeddings_request.model,
|
||||||
texts = embeddings_request.input
|
)
|
||||||
if isinstance(texts, str):
|
|
||||||
texts = [texts]
|
# 判断embeddings_request.input是否为list
|
||||||
response = model_instance.invoke_text_embedding(texts=texts, user="abc-123")
|
input = ''
|
||||||
return await openai_embedding_text(response)
|
if isinstance(embeddings_request.input, list):
|
||||||
|
tokens = embeddings_request.input
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(embeddings_request.model)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||||
|
model = "cl100k_base"
|
||||||
|
encoding = tiktoken.get_encoding(model)
|
||||||
|
for i, token in enumerate(tokens):
|
||||||
|
text = encoding.decode(token)
|
||||||
|
input += text
|
||||||
|
|
||||||
|
else:
|
||||||
|
input = embeddings_request.input
|
||||||
|
|
||||||
|
response = model_instance.invoke_text_embedding(texts=[input], user="abc-123")
|
||||||
|
return await openai_embedding_text(response)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
except InvokeError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||||
|
|||||||
@ -25,7 +25,7 @@ class _CommonZhipuaiAI:
|
|||||||
if "zhipuai_api_key" in credentials
|
if "zhipuai_api_key" in credentials
|
||||||
else None,
|
else None,
|
||||||
"api_base": credentials["api_base"]
|
"api_base": credentials["api_base"]
|
||||||
if "api_key" in credentials
|
if "api_base" in credentials
|
||||||
else credentials["zhipuai_api_base"]
|
else credentials["zhipuai_api_base"]
|
||||||
if "zhipuai_api_base" in credentials
|
if "zhipuai_api_base" in credentials
|
||||||
else None,
|
else None,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
import pytest
|
import pytest
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -20,3 +20,17 @@ def test_llm(init_server: str):
|
|||||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_embedding(init_server: str):
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
|
||||||
|
openai_api_key="YOUR_API_KEY",
|
||||||
|
openai_api_base=f"{init_server}/zhipuai/v1")
|
||||||
|
|
||||||
|
text = "你好"
|
||||||
|
|
||||||
|
query_result = embeddings.embed_query(text)
|
||||||
|
|
||||||
|
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||||
@ -1,4 +1,4 @@
|
|||||||
zhipuai:
|
zhipuai:
|
||||||
provider_credential:
|
provider_credential:
|
||||||
api_key: 'e6a98ef1c54484c2afeac1ae8cef93ef.1'
|
api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'
|
||||||
api_base: 'https://test.bigmodel.cn/stage-api/paas/v4'
|
# api_base: 'https://test.bigmodel.cn/stage-api/paas/v4'
|
||||||
@ -1,6 +1,6 @@
|
|||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
import pytest
|
import pytest
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -25,18 +25,15 @@ def test_llm(init_server: str):
|
|||||||
|
|
||||||
@pytest.mark.requires("zhipuai")
|
@pytest.mark.requires("zhipuai")
|
||||||
def test_embedding(init_server: str):
|
def test_embedding(init_server: str):
|
||||||
llm = ChatOpenAI(
|
|
||||||
|
|
||||||
model_name="glm-4",
|
embeddings = OpenAIEmbeddings(model="text_embedding",
|
||||||
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1")
|
openai_api_key="YOUR_API_KEY",
|
||||||
template = """Question: {question}
|
openai_api_base=f"{init_server}/zhipuai/v1")
|
||||||
|
|
||||||
Answer: Let's think step by step."""
|
text = "你好"
|
||||||
|
|
||||||
prompt = PromptTemplate.from_template(template)
|
query_result = embeddings.embed_query(text)
|
||||||
|
|
||||||
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||||
responses = llm_chain.run("你好")
|
|
||||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user