diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 514d8467..fe11a19c 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -18,6 +18,7 @@ from typing import ( cast, ) +import tiktoken import uvicorn from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status from fastapi.middleware.cors import CORSMiddleware @@ -208,16 +209,39 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): logger.info( f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" ) - model_instance = self._provider_manager.get_model_instance( - provider=provider, - model_type=ModelType.TEXT_EMBEDDING, - model=embeddings_request.model, - ) - texts = embeddings_request.input - if isinstance(texts, str): - texts = [texts] - response = model_instance.invoke_text_embedding(texts=texts, user="abc-123") - return await openai_embedding_text(response) + try: + model_instance = self._provider_manager.get_model_instance( + provider=provider, + model_type=ModelType.TEXT_EMBEDDING, + model=embeddings_request.model, + ) + + # 判断embeddings_request.input是否为list + input = '' + if isinstance(embeddings_request.input, list): + tokens = embeddings_request.input + try: + encoding = tiktoken.encoding_for_model(embeddings_request.model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken.get_encoding(model) + for i, token in enumerate(tokens): + text = encoding.decode(token) + input += text + + else: + input = embeddings_request.input + + response = model_instance.invoke_text_embedding(texts=[input], user="abc-123") + return await openai_embedding_text(response) + + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + except InvokeError as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) async def create_chat_completion( self, provider: str, request: Request, chat_request: ChatCompletionRequest diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py index 4b27cb9b..f19135de 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py @@ -25,7 +25,7 @@ class _CommonZhipuaiAI: if "zhipuai_api_key" in credentials else None, "api_base": credentials["api_base"] - if "api_key" in credentials + if "api_base" in credentials else credentials["zhipuai_api_base"] if "zhipuai_api_base" in credentials else None, diff --git a/model-providers/tests/openai_providers_test/test_openai_service.py b/model-providers/tests/openai_providers_test/test_openai_service.py index 6cdb2731..958fa108 100644 --- a/model-providers/tests/openai_providers_test/test_openai_service.py +++ b/model-providers/tests/openai_providers_test/test_openai_service.py @@ -1,6 +1,6 @@ from langchain.chains import LLMChain from langchain_core.prompts import PromptTemplate -from langchain_openai import ChatOpenAI +from langchain_openai import ChatOpenAI, OpenAIEmbeddings import pytest import logging @@ -20,3 +20,17 @@ def test_llm(init_server: str): logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m") + + +@pytest.mark.requires("openai") +def test_embedding(init_server: str): + + embeddings = OpenAIEmbeddings(model="text-embedding-3-large", + openai_api_key="YOUR_API_KEY", + openai_api_base=f"{init_server}/zhipuai/v1") + + text = "你好" + + query_result = embeddings.embed_query(text) + + logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m") \ No newline at end of file diff --git a/model-providers/tests/zhipuai_providers_test/model_providers.yaml b/model-providers/tests/zhipuai_providers_test/model_providers.yaml index dd9a6fa7..ec13ffc4 100644 --- a/model-providers/tests/zhipuai_providers_test/model_providers.yaml +++ b/model-providers/tests/zhipuai_providers_test/model_providers.yaml @@ -1,4 +1,4 @@ zhipuai: provider_credential: - api_key: 'e6a98ef1c54484c2afeac1ae8cef93ef.1' - api_base: 'https://test.bigmodel.cn/stage-api/paas/v4' \ No newline at end of file + api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1' +# api_base: 'https://test.bigmodel.cn/stage-api/paas/v4' \ No newline at end of file diff --git a/model-providers/tests/zhipuai_providers_test/test_zhipuai_service.py b/model-providers/tests/zhipuai_providers_test/test_zhipuai_service.py index 94992e3b..c110b71c 100644 --- a/model-providers/tests/zhipuai_providers_test/test_zhipuai_service.py +++ b/model-providers/tests/zhipuai_providers_test/test_zhipuai_service.py @@ -1,6 +1,6 @@ from langchain.chains import LLMChain from langchain_core.prompts import PromptTemplate -from langchain_openai import ChatOpenAI +from langchain_openai import ChatOpenAI, OpenAIEmbeddings import pytest import logging @@ -25,18 +25,15 @@ def test_llm(init_server: str): @pytest.mark.requires("zhipuai") def test_embedding(init_server: str): - llm = ChatOpenAI( - model_name="glm-4", - openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1") - template = """Question: {question} - - Answer: Let's think step by step.""" + embeddings = OpenAIEmbeddings(model="text_embedding", + openai_api_key="YOUR_API_KEY", + openai_api_base=f"{init_server}/zhipuai/v1") - prompt = PromptTemplate.from_template(template) + text = "你好" - llm_chain = LLMChain(prompt=prompt, llm=llm) - responses = llm_chain.run("你好") - logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m") + query_result = embeddings.embed_query(text) + + logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")