embedding兼容

This commit is contained in:
glide-the 2024-04-16 14:56:17 +08:00
parent f30bdf01fc
commit 55b32b29fe
5 changed files with 60 additions and 25 deletions

View File

@ -18,6 +18,7 @@ from typing import (
cast,
)
import tiktoken
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
@ -208,16 +209,39 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
)
model_instance = self._provider_manager.get_model_instance(
provider=provider,
model_type=ModelType.TEXT_EMBEDDING,
model=embeddings_request.model,
)
texts = embeddings_request.input
if isinstance(texts, str):
texts = [texts]
response = model_instance.invoke_text_embedding(texts=texts, user="abc-123")
return await openai_embedding_text(response)
try:
model_instance = self._provider_manager.get_model_instance(
provider=provider,
model_type=ModelType.TEXT_EMBEDDING,
model=embeddings_request.model,
)
# 判断embeddings_request.input是否为list
input = ''
if isinstance(embeddings_request.input, list):
tokens = embeddings_request.input
try:
encoding = tiktoken.encoding_for_model(embeddings_request.model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, token in enumerate(tokens):
text = encoding.decode(token)
input += text
else:
input = embeddings_request.input
response = model_instance.invoke_text_embedding(texts=[input], user="abc-123")
return await openai_embedding_text(response)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except InvokeError as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest

View File

@ -25,7 +25,7 @@ class _CommonZhipuaiAI:
if "zhipuai_api_key" in credentials
else None,
"api_base": credentials["api_base"]
if "api_key" in credentials
if "api_base" in credentials
else credentials["zhipuai_api_base"]
if "zhipuai_api_base" in credentials
else None,

View File

@ -1,6 +1,6 @@
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
@ -20,3 +20,17 @@ def test_llm(init_server: str):
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
@pytest.mark.requires("openai")
def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1")
text = "你好"
query_result = embeddings.embed_query(text)
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")

View File

@ -1,4 +1,4 @@
zhipuai:
provider_credential:
api_key: 'e6a98ef1c54484c2afeac1ae8cef93ef.1'
api_base: 'https://test.bigmodel.cn/stage-api/paas/v4'
api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'
# api_base: 'https://test.bigmodel.cn/stage-api/paas/v4'

View File

@ -1,6 +1,6 @@
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest
import logging
@ -25,18 +25,15 @@ def test_llm(init_server: str):
@pytest.mark.requires("zhipuai")
def test_embedding(init_server: str):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1")
template = """Question: {question}
Answer: Let's think step by step."""
embeddings = OpenAIEmbeddings(model="text_embedding",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1")
prompt = PromptTemplate.from_template(template)
text = "你好"
llm_chain = LLMChain(prompt=prompt, llm=llm)
responses = llm_chain.run("你好")
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
query_result = embeddings.embed_query(text)
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")