embedding兼容

This commit is contained in:
glide-the 2024-04-16 14:56:17 +08:00
parent f30bdf01fc
commit 55b32b29fe
5 changed files with 60 additions and 25 deletions

View File

@ -18,6 +18,7 @@ from typing import (
cast, cast,
) )
import tiktoken
import uvicorn import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -208,16 +209,39 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
logger.info( logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
) )
model_instance = self._provider_manager.get_model_instance( try:
provider=provider, model_instance = self._provider_manager.get_model_instance(
model_type=ModelType.TEXT_EMBEDDING, provider=provider,
model=embeddings_request.model, model_type=ModelType.TEXT_EMBEDDING,
) model=embeddings_request.model,
texts = embeddings_request.input )
if isinstance(texts, str):
texts = [texts] # 判断embeddings_request.input是否为list
response = model_instance.invoke_text_embedding(texts=texts, user="abc-123") input = ''
return await openai_embedding_text(response) if isinstance(embeddings_request.input, list):
tokens = embeddings_request.input
try:
encoding = tiktoken.encoding_for_model(embeddings_request.model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, token in enumerate(tokens):
text = encoding.decode(token)
input += text
else:
input = embeddings_request.input
response = model_instance.invoke_text_embedding(texts=[input], user="abc-123")
return await openai_embedding_text(response)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except InvokeError as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
async def create_chat_completion( async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest self, provider: str, request: Request, chat_request: ChatCompletionRequest

View File

@ -25,7 +25,7 @@ class _CommonZhipuaiAI:
if "zhipuai_api_key" in credentials if "zhipuai_api_key" in credentials
else None, else None,
"api_base": credentials["api_base"] "api_base": credentials["api_base"]
if "api_key" in credentials if "api_base" in credentials
else credentials["zhipuai_api_base"] else credentials["zhipuai_api_base"]
if "zhipuai_api_base" in credentials if "zhipuai_api_base" in credentials
else None, else None,

View File

@ -1,6 +1,6 @@
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest import pytest
import logging import logging
@ -20,3 +20,17 @@ def test_llm(init_server: str):
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m") logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
@pytest.mark.requires("openai")
def test_embedding(init_server: str):
embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
openai_api_key="YOUR_API_KEY",
openai_api_base=f"{init_server}/zhipuai/v1")
text = "你好"
query_result = embeddings.embed_query(text)
logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")

View File

@ -1,4 +1,4 @@
zhipuai: zhipuai:
provider_credential: provider_credential:
api_key: 'e6a98ef1c54484c2afeac1ae8cef93ef.1' api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'
api_base: 'https://test.bigmodel.cn/stage-api/paas/v4' # api_base: 'https://test.bigmodel.cn/stage-api/paas/v4'

View File

@ -1,6 +1,6 @@
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import pytest import pytest
import logging import logging
@ -25,18 +25,15 @@ def test_llm(init_server: str):
@pytest.mark.requires("zhipuai") @pytest.mark.requires("zhipuai")
def test_embedding(init_server: str): def test_embedding(init_server: str):
llm = ChatOpenAI(
model_name="glm-4", embeddings = OpenAIEmbeddings(model="text_embedding",
openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1") openai_api_key="YOUR_API_KEY",
template = """Question: {question} openai_api_base=f"{init_server}/zhipuai/v1")
Answer: Let's think step by step.""" text = "你好"
prompt = PromptTemplate.from_template(template) query_result = embeddings.embed_query(text)
llm_chain = LLMChain(prompt=prompt, llm=llm) logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
responses = llm_chain.run("你好")
logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")