From e2f14482cb42f90df6afb42edfc08e413335e582 Mon Sep 17 00:00:00 2001 From: VLOU <919070296@qq.com> Date: Fri, 10 May 2024 00:46:24 +0800 Subject: [PATCH] =?UTF-8?q?[add]=E6=B7=BB=E5=8A=A0=E8=8E=B7=E5=8F=96?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/src/app/api/models/chatchat/route.ts | 40 +++++++ .../src/app/settings/llm/ChatChat/index.tsx | 6 + .../settings/llm/components/ModelSeletor.tsx | 110 ++++++++++++++++++ .../src/features/ModelSwitchPanel/index.tsx | 2 +- frontend/src/locales/default/setting.ts | 7 ++ frontend/src/services/_url.ts | 3 + frontend/src/services/models.ts | 28 +++++ .../settings/selectors/modelProvider.ts | 9 +- frontend/src/types/message/index.ts | 5 + frontend/src/types/models.ts | 15 +++ frontend/src/types/settings/modelProvider.ts | 12 ++ 11 files changed, 235 insertions(+), 2 deletions(-) create mode 100644 frontend/src/app/api/models/chatchat/route.ts create mode 100644 frontend/src/app/settings/llm/components/ModelSeletor.tsx create mode 100644 frontend/src/services/models.ts create mode 100644 frontend/src/types/models.ts diff --git a/frontend/src/app/api/models/chatchat/route.ts b/frontend/src/app/api/models/chatchat/route.ts new file mode 100644 index 00000000..93e41d76 --- /dev/null +++ b/frontend/src/app/api/models/chatchat/route.ts @@ -0,0 +1,40 @@ +import { getServerConfig } from '@/config/server'; +import { createErrorResponse } from '@/app/api/errorResponse'; +import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth'; +import { getJWTPayload } from '../../chat/auth'; + +export const GET = async (req: Request) => { + + // get Authorization from header + const authorization = req.headers.get(LOBE_CHAT_AUTH_HEADER); + + const { CHATCHAT_PROXY_URL } = getServerConfig(); + + let baseURL = CHATCHAT_PROXY_URL; + + // 为了方便拿到 endpoint,这里直接解析 JWT + if (authorization) { + const jwtPayload = await getJWTPayload(authorization); + if (jwtPayload.endpoint) { + baseURL = jwtPayload.endpoint; + } + } + + let res: Response; + + try { + console.log('get models from:', baseURL) + + res = await fetch(`${baseURL}/models`); + + if (!res.ok) { + // throw new Error(`Failed to fetch models: ${res.status}`); + return createErrorResponse(500, { error: `Failed to fetch models: ${res.status}` }); + } + + return res; + + } catch (e) { + return createErrorResponse(500, { error: e }); + } +} \ No newline at end of file diff --git a/frontend/src/app/settings/llm/ChatChat/index.tsx b/frontend/src/app/settings/llm/ChatChat/index.tsx index 5d7c82e8..b3f90eea 100644 --- a/frontend/src/app/settings/llm/ChatChat/index.tsx +++ b/frontend/src/app/settings/llm/ChatChat/index.tsx @@ -11,6 +11,7 @@ import { ModelProvider } from '@/libs/agent-runtime'; import Checker from '../components/Checker'; import ProviderConfig from '../components/ProviderConfig'; import { LLMProviderBaseUrlKey, LLMProviderConfigKey } from '../const'; +import ModelSelector from '../components/ModelSeletor'; const providerKey = 'chatchat'; @@ -39,6 +40,11 @@ const ChatChatProvider = memo(() => { label: t('llm.ChatChat.customModelName.title'), name: [LLMProviderConfigKey, providerKey, 'customModelName'], }, + { + children: , + desc: t('llm.selectorModel.desc'), + label: t('llm.selectorModel.title'), + }, { children: , desc: t('llm.ChatChat.checker.desc'), diff --git a/frontend/src/app/settings/llm/components/ModelSeletor.tsx b/frontend/src/app/settings/llm/components/ModelSeletor.tsx new file mode 100644 index 00000000..5966fa58 --- /dev/null +++ b/frontend/src/app/settings/llm/components/ModelSeletor.tsx @@ -0,0 +1,110 @@ +import { CheckCircleFilled } from '@ant-design/icons'; +import { Alert, Highlighter } from '@lobehub/ui'; +import { Button } from 'antd'; +import { useTheme } from 'antd-style'; +import { memo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { Flexbox } from 'react-layout-kit'; + +import { useIsMobile } from '@/hooks/useIsMobile'; +import { ModelSelectorError } from '@/types/message'; +import { modelsServer } from '@/services/models'; +import { useGlobalStore } from '@/store/global'; +import { GlobalLLMProviderKey } from '@/types/settings/modelProvider'; +import { currentSettings } from '@/store/global/slices/settings/selectors/settings'; + +interface FetchModelParams { + provider: GlobalLLMProviderKey; +} + +const ModelSelector = memo(({ provider }) => { + const { t } = useTranslation('setting'); + + const [loading, setLoading] = useState(false); + const [pass, setPass] = useState(false); + + const theme = useTheme(); + const [error, setError] = useState(); + + const [setConfig, languageModel ] = useGlobalStore((s) => [ + s.setModelProviderConfig, + currentSettings(s).languageModel, + ]); + + const enable = languageModel[provider]?.enabled || false; + + // 过滤格式 + const filterModel = (data: any[] = []) => { + return data.map((item) => { + + return { + tokens: item?.tokens || 8000, + displayName: item.displayName || item.id, + functionCall: false, // false 默认都不能用使用插件,chatchat 的插件还没弄 + ...item + } + }) + } + + const processProviderModels = () => { + if(!enable) return + + setLoading(true); + + modelsServer.getModels(provider).then((data) => { + if (data.error) { + setError({ message: data.error, type: 500}); + } else { + // 更新模型 + setConfig(provider, { models: filterModel(data.data) }); + + setError(undefined); + setPass(true); + } + + }).finally(() => { + setLoading(false); + }) + } + + const isMobile = useIsMobile(); + + return ( + + + {pass && ( + + + {t('llm.selectorModel.pass')} + + )} + + + {error && ( + + + + {JSON.stringify(error, null, 2)} + + + } + message={t(`response.${error.type}` as any, { ns: 'error' })} + showIcon + type={'error'} + /> + + )} + + ); +}); + +export default ModelSelector; diff --git a/frontend/src/features/ModelSwitchPanel/index.tsx b/frontend/src/features/ModelSwitchPanel/index.tsx index da10ec41..af00cf22 100644 --- a/frontend/src/features/ModelSwitchPanel/index.tsx +++ b/frontend/src/features/ModelSwitchPanel/index.tsx @@ -44,7 +44,7 @@ const ModelSwitchPanel = memo(({ children }) => { provider.chatModels .filter((c) => !c.hidden) .map((model) => ({ - key: model.id, + key: `${provider.id}-${model.id}`, label: , onClick: () => { updateAgentConfig({ model: model.id, provider: provider.id }); diff --git a/frontend/src/locales/default/setting.ts b/frontend/src/locales/default/setting.ts index 3fd542d4..66e2a827 100644 --- a/frontend/src/locales/default/setting.ts +++ b/frontend/src/locales/default/setting.ts @@ -198,6 +198,13 @@ export default { }, }, + selectorModel: { + button: '更新', + desc: '选择代理地址所有模型,默认/v1/models获取', + pass: '更新成功', + title: '更新模型到本地', + }, + checker: { button: '检查', desc: '测试 Api Key 与代理地址是否正确填写', diff --git a/frontend/src/services/_url.ts b/frontend/src/services/_url.ts index 29ff60c7..9e863329 100644 --- a/frontend/src/services/_url.ts +++ b/frontend/src/services/_url.ts @@ -36,6 +36,9 @@ export const API_ENDPOINTS = mapWithBasePath({ // image images: '/api/openai/images', + // models + models: (provider: string) => withBasePath(`/api/models/${provider}`), + // TTS & STT stt: '/api/openai/stt', tts: '/api/openai/tts', diff --git a/frontend/src/services/models.ts b/frontend/src/services/models.ts new file mode 100644 index 00000000..2a0fce3e --- /dev/null +++ b/frontend/src/services/models.ts @@ -0,0 +1,28 @@ +import { getMessageError } from "@/utils/fetch"; +import { API_ENDPOINTS } from "./_url"; +import { createHeaderWithAuth } from "./_auth"; +import { ModelsResponse } from "@/types/models"; +import { GlobalLLMProviderKey } from "@/types/settings/modelProvider"; + + +class ModelsServer{ + getModels = async (provider: GlobalLLMProviderKey): Promise => { + const headers = await createHeaderWithAuth({ provider, headers: { 'Content-Type': 'application/json' } }); + + try { + const res = await fetch(API_ENDPOINTS.models(provider), { + headers, + }); + + if (!res.ok) { + throw await getMessageError(res); + } + + return res.json(); + } catch (error) { + return { error: JSON.stringify(error) }; + } + } +} + +export const modelsServer = new ModelsServer(); \ No newline at end of file diff --git a/frontend/src/store/global/slices/settings/selectors/modelProvider.ts b/frontend/src/store/global/slices/settings/selectors/modelProvider.ts index c0738ce5..67411c8a 100644 --- a/frontend/src/store/global/slices/settings/selectors/modelProvider.ts +++ b/frontend/src/store/global/slices/settings/selectors/modelProvider.ts @@ -63,6 +63,7 @@ const anthropicAPIKey = (s: GlobalStore) => modelProvider(s).anthropic.apiKey; const enableChatChat = (s: GlobalStore) => modelProvider(s).chatchat.enabled; const chatChatProxyUrl = (s: GlobalStore) => modelProvider(s).chatchat.endpoint; +const chatChatModels = (s: GlobalStore) => modelProvider(s).chatchat.models || []; // const azureModelList = (s: GlobalStore): ModelProviderCard => { // const azure = azureConfig(s); @@ -138,6 +139,12 @@ const modelSelectList = (s: GlobalStore): ModelProviderCard[] => { const ollamaChatModels = processChatModels(ollamaModelConfig, OllamaProvider.chatModels); + + const chatChatModelConfig = parseModelString( + currentSettings(s).languageModel.chatchat.customModelName + ) + const chatChatChatModels = processChatModels(chatChatModelConfig, chatChatModels(s)) + return [ { ...OpenAIProvider, @@ -152,7 +159,7 @@ const modelSelectList = (s: GlobalStore): ModelProviderCard[] => { { ...PerplexityProvider, enabled: enablePerplexity(s) }, { ...AnthropicProvider, enabled: enableAnthropic(s) }, { ...MistralProvider, enabled: enableMistral(s) }, - { ...ChatChatProvider, enabled: enableChatChat(s) }, + { ...ChatChatProvider, chatModels: chatChatChatModels, enabled: enableChatChat(s) }, ]; }; diff --git a/frontend/src/types/message/index.ts b/frontend/src/types/message/index.ts index 0b8b5a1e..0f28496d 100644 --- a/frontend/src/types/message/index.ts +++ b/frontend/src/types/message/index.ts @@ -17,6 +17,11 @@ export interface ChatMessageError { type: ErrorType | IPluginErrorType | ILobeAgentRuntimeErrorType; } +export interface ModelSelectorError { + message: string; + type: ErrorType; +} + export interface ChatTranslate extends Translate { content?: string; } diff --git a/frontend/src/types/models.ts b/frontend/src/types/models.ts new file mode 100644 index 00000000..49102f5f --- /dev/null +++ b/frontend/src/types/models.ts @@ -0,0 +1,15 @@ +interface Model { + id: string; + created: number; // 时间戳 + platform_name: string; + owned_by: string; + object: string; + tokens?: number; + displayName?: string; +} + +export interface ModelsResponse { + object?: 'list'; + data?: Model[]; + error?: string; +} \ No newline at end of file diff --git a/frontend/src/types/settings/modelProvider.ts b/frontend/src/types/settings/modelProvider.ts index 78866506..0ca05fdb 100644 --- a/frontend/src/types/settings/modelProvider.ts +++ b/frontend/src/types/settings/modelProvider.ts @@ -1,3 +1,5 @@ +import { ChatModelCard } from "../llm"; + export type CustomModels = { displayName: string; id: string }[]; export interface OpenAIConfig { @@ -22,23 +24,27 @@ export interface AzureOpenAIConfig { deployments: string; enabled: boolean; endpoint?: string; + models?: ChatModelCard[] } export interface ZhiPuConfig { apiKey?: string; enabled: boolean; endpoint?: string; + models?: ChatModelCard[] } export interface MoonshotConfig { apiKey?: string; enabled: boolean; + models?: ChatModelCard[] } export interface GoogleConfig { apiKey?: string; enabled: boolean; endpoint?: string; + models?: ChatModelCard[] } export interface AWSBedrockConfig { @@ -46,34 +52,40 @@ export interface AWSBedrockConfig { enabled: boolean; region?: string; secretAccessKey?: string; + models?: ChatModelCard[] } export interface OllamaConfig { customModelName?: string; enabled?: boolean; endpoint?: string; + models?: ChatModelCard[] } export interface PerplexityConfig { apiKey?: string; enabled: boolean; endpoint?: string; + models?: ChatModelCard[] } export interface AnthropicConfig { apiKey?: string; enabled: boolean; + models?: ChatModelCard[] } export interface MistralConfig { apiKey?: string; enabled: boolean; + models?: ChatModelCard[] } export interface ChatChatConfig { customModelName?: string; enabled?: boolean; endpoint?: string; + models?: ChatModelCard[] } export interface GlobalLLMConfig {