mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
embedding兼容
This commit is contained in:
parent
f30bdf01fc
commit
55b32b29fe
@ -18,6 +18,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@ -208,16 +209,39 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
logger.info(
|
||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||
)
|
||||
model_instance = self._provider_manager.get_model_instance(
|
||||
provider=provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=embeddings_request.model,
|
||||
)
|
||||
texts = embeddings_request.input
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
response = model_instance.invoke_text_embedding(texts=texts, user="abc-123")
|
||||
return await openai_embedding_text(response)
|
||||
try:
|
||||
model_instance = self._provider_manager.get_model_instance(
|
||||
provider=provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=embeddings_request.model,
|
||||
)
|
||||
|
||||
# 判断embeddings_request.input是否为list
|
||||
input = ''
|
||||
if isinstance(embeddings_request.input, list):
|
||||
tokens = embeddings_request.input
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(embeddings_request.model)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, token in enumerate(tokens):
|
||||
text = encoding.decode(token)
|
||||
input += text
|
||||
|
||||
else:
|
||||
input = embeddings_request.input
|
||||
|
||||
response = model_instance.invoke_text_embedding(texts=[input], user="abc-123")
|
||||
return await openai_embedding_text(response)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except InvokeError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
async def create_chat_completion(
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
|
||||
@ -25,7 +25,7 @@ class _CommonZhipuaiAI:
|
||||
if "zhipuai_api_key" in credentials
|
||||
else None,
|
||||
"api_base": credentials["api_base"]
|
||||
if "api_key" in credentials
|
||||
if "api_base" in credentials
|
||||
else credentials["zhipuai_api_base"]
|
||||
if "zhipuai_api_base" in credentials
|
||||
else None,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
@ -20,3 +20,17 @@ def test_llm(init_server: str):
|
||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_embedding(init_server: str):
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/zhipuai/v1")
|
||||
|
||||
text = "你好"
|
||||
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||
@ -1,4 +1,4 @@
|
||||
zhipuai:
|
||||
provider_credential:
|
||||
api_key: 'e6a98ef1c54484c2afeac1ae8cef93ef.1'
|
||||
api_base: 'https://test.bigmodel.cn/stage-api/paas/v4'
|
||||
api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'
|
||||
# api_base: 'https://test.bigmodel.cn/stage-api/paas/v4'
|
||||
@ -1,6 +1,6 @@
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
@ -25,18 +25,15 @@ def test_llm(init_server: str):
|
||||
|
||||
@pytest.mark.requires("zhipuai")
|
||||
def test_embedding(init_server: str):
|
||||
llm = ChatOpenAI(
|
||||
|
||||
model_name="glm-4",
|
||||
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1")
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
embeddings = OpenAIEmbeddings(model="text_embedding",
|
||||
openai_api_key="YOUR_API_KEY",
|
||||
openai_api_base=f"{init_server}/zhipuai/v1")
|
||||
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
text = "你好"
|
||||
|
||||
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||
responses = llm_chain.run("你好")
|
||||
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user