From e2f14482cb42f90df6afb42edfc08e413335e582 Mon Sep 17 00:00:00 2001
From: VLOU <919070296@qq.com>
Date: Fri, 10 May 2024 00:46:24 +0800
Subject: [PATCH] =?UTF-8?q?[add]=E6=B7=BB=E5=8A=A0=E8=8E=B7=E5=8F=96?=
=?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=8E=A5=E5=8F=A3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
frontend/src/app/api/models/chatchat/route.ts | 40 +++++++
.../src/app/settings/llm/ChatChat/index.tsx | 6 +
.../settings/llm/components/ModelSeletor.tsx | 110 ++++++++++++++++++
.../src/features/ModelSwitchPanel/index.tsx | 2 +-
frontend/src/locales/default/setting.ts | 7 ++
frontend/src/services/_url.ts | 3 +
frontend/src/services/models.ts | 28 +++++
.../settings/selectors/modelProvider.ts | 9 +-
frontend/src/types/message/index.ts | 5 +
frontend/src/types/models.ts | 15 +++
frontend/src/types/settings/modelProvider.ts | 12 ++
11 files changed, 235 insertions(+), 2 deletions(-)
create mode 100644 frontend/src/app/api/models/chatchat/route.ts
create mode 100644 frontend/src/app/settings/llm/components/ModelSeletor.tsx
create mode 100644 frontend/src/services/models.ts
create mode 100644 frontend/src/types/models.ts
diff --git a/frontend/src/app/api/models/chatchat/route.ts b/frontend/src/app/api/models/chatchat/route.ts
new file mode 100644
index 00000000..93e41d76
--- /dev/null
+++ b/frontend/src/app/api/models/chatchat/route.ts
@@ -0,0 +1,40 @@
+import { getServerConfig } from '@/config/server';
+import { createErrorResponse } from '@/app/api/errorResponse';
+import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth';
+import { getJWTPayload } from '../../chat/auth';
+
+export const GET = async (req: Request) => {
+
+ // get Authorization from header
+ const authorization = req.headers.get(LOBE_CHAT_AUTH_HEADER);
+
+ const { CHATCHAT_PROXY_URL } = getServerConfig();
+
+ let baseURL = CHATCHAT_PROXY_URL;
+
+ // 为了方便拿到 endpoint,这里直接解析 JWT
+ if (authorization) {
+ const jwtPayload = await getJWTPayload(authorization);
+ if (jwtPayload.endpoint) {
+ baseURL = jwtPayload.endpoint;
+ }
+ }
+
+ let res: Response;
+
+ try {
+ console.log('get models from:', baseURL)
+
+ res = await fetch(`${baseURL}/models`);
+
+ if (!res.ok) {
+ // throw new Error(`Failed to fetch models: ${res.status}`);
+ return createErrorResponse(500, { error: `Failed to fetch models: ${res.status}` });
+ }
+
+ return res;
+
+ } catch (e) {
+ return createErrorResponse(500, { error: e });
+ }
+}
\ No newline at end of file
diff --git a/frontend/src/app/settings/llm/ChatChat/index.tsx b/frontend/src/app/settings/llm/ChatChat/index.tsx
index 5d7c82e8..b3f90eea 100644
--- a/frontend/src/app/settings/llm/ChatChat/index.tsx
+++ b/frontend/src/app/settings/llm/ChatChat/index.tsx
@@ -11,6 +11,7 @@ import { ModelProvider } from '@/libs/agent-runtime';
import Checker from '../components/Checker';
import ProviderConfig from '../components/ProviderConfig';
import { LLMProviderBaseUrlKey, LLMProviderConfigKey } from '../const';
+import ModelSelector from '../components/ModelSeletor';
const providerKey = 'chatchat';
@@ -39,6 +40,11 @@ const ChatChatProvider = memo(() => {
label: t('llm.ChatChat.customModelName.title'),
name: [LLMProviderConfigKey, providerKey, 'customModelName'],
},
+ {
+ children: ,
+ desc: t('llm.selectorModel.desc'),
+ label: t('llm.selectorModel.title'),
+ },
{
children: ,
desc: t('llm.ChatChat.checker.desc'),
diff --git a/frontend/src/app/settings/llm/components/ModelSeletor.tsx b/frontend/src/app/settings/llm/components/ModelSeletor.tsx
new file mode 100644
index 00000000..5966fa58
--- /dev/null
+++ b/frontend/src/app/settings/llm/components/ModelSeletor.tsx
@@ -0,0 +1,110 @@
+import { CheckCircleFilled } from '@ant-design/icons';
+import { Alert, Highlighter } from '@lobehub/ui';
+import { Button } from 'antd';
+import { useTheme } from 'antd-style';
+import { memo, useState } from 'react';
+import { useTranslation } from 'react-i18next';
+import { Flexbox } from 'react-layout-kit';
+
+import { useIsMobile } from '@/hooks/useIsMobile';
+import { ModelSelectorError } from '@/types/message';
+import { modelsServer } from '@/services/models';
+import { useGlobalStore } from '@/store/global';
+import { GlobalLLMProviderKey } from '@/types/settings/modelProvider';
+import { currentSettings } from '@/store/global/slices/settings/selectors/settings';
+
+interface FetchModelParams {
+ provider: GlobalLLMProviderKey;
+}
+
+const ModelSelector = memo(({ provider }) => {
+ const { t } = useTranslation('setting');
+
+ const [loading, setLoading] = useState(false);
+ const [pass, setPass] = useState(false);
+
+ const theme = useTheme();
+ const [error, setError] = useState();
+
+ const [setConfig, languageModel ] = useGlobalStore((s) => [
+ s.setModelProviderConfig,
+ currentSettings(s).languageModel,
+ ]);
+
+ const enable = languageModel[provider]?.enabled || false;
+
+ // 过滤格式
+ const filterModel = (data: any[] = []) => {
+ return data.map((item) => {
+
+ return {
+ tokens: item?.tokens || 8000,
+ displayName: item.displayName || item.id,
+ functionCall: false, // false 默认都不能用使用插件,chatchat 的插件还没弄
+ ...item
+ }
+ })
+ }
+
+ const processProviderModels = () => {
+ if(!enable) return
+
+ setLoading(true);
+
+ modelsServer.getModels(provider).then((data) => {
+ if (data.error) {
+ setError({ message: data.error, type: 500});
+ } else {
+ // 更新模型
+ setConfig(provider, { models: filterModel(data.data) });
+
+ setError(undefined);
+ setPass(true);
+ }
+
+ }).finally(() => {
+ setLoading(false);
+ })
+ }
+
+ const isMobile = useIsMobile();
+
+ return (
+
+
+ {pass && (
+
+
+ {t('llm.selectorModel.pass')}
+
+ )}
+
+
+ {error && (
+
+
+
+ {JSON.stringify(error, null, 2)}
+
+
+ }
+ message={t(`response.${error.type}` as any, { ns: 'error' })}
+ showIcon
+ type={'error'}
+ />
+
+ )}
+
+ );
+});
+
+export default ModelSelector;
diff --git a/frontend/src/features/ModelSwitchPanel/index.tsx b/frontend/src/features/ModelSwitchPanel/index.tsx
index da10ec41..af00cf22 100644
--- a/frontend/src/features/ModelSwitchPanel/index.tsx
+++ b/frontend/src/features/ModelSwitchPanel/index.tsx
@@ -44,7 +44,7 @@ const ModelSwitchPanel = memo(({ children }) => {
provider.chatModels
.filter((c) => !c.hidden)
.map((model) => ({
- key: model.id,
+ key: `${provider.id}-${model.id}`,
label: ,
onClick: () => {
updateAgentConfig({ model: model.id, provider: provider.id });
diff --git a/frontend/src/locales/default/setting.ts b/frontend/src/locales/default/setting.ts
index 3fd542d4..66e2a827 100644
--- a/frontend/src/locales/default/setting.ts
+++ b/frontend/src/locales/default/setting.ts
@@ -198,6 +198,13 @@ export default {
},
},
+ selectorModel: {
+ button: '更新',
+ desc: '选择代理地址所有模型,默认/v1/models获取',
+ pass: '更新成功',
+ title: '更新模型到本地',
+ },
+
checker: {
button: '检查',
desc: '测试 Api Key 与代理地址是否正确填写',
diff --git a/frontend/src/services/_url.ts b/frontend/src/services/_url.ts
index 29ff60c7..9e863329 100644
--- a/frontend/src/services/_url.ts
+++ b/frontend/src/services/_url.ts
@@ -36,6 +36,9 @@ export const API_ENDPOINTS = mapWithBasePath({
// image
images: '/api/openai/images',
+ // models
+ models: (provider: string) => withBasePath(`/api/models/${provider}`),
+
// TTS & STT
stt: '/api/openai/stt',
tts: '/api/openai/tts',
diff --git a/frontend/src/services/models.ts b/frontend/src/services/models.ts
new file mode 100644
index 00000000..2a0fce3e
--- /dev/null
+++ b/frontend/src/services/models.ts
@@ -0,0 +1,28 @@
+import { getMessageError } from "@/utils/fetch";
+import { API_ENDPOINTS } from "./_url";
+import { createHeaderWithAuth } from "./_auth";
+import { ModelsResponse } from "@/types/models";
+import { GlobalLLMProviderKey } from "@/types/settings/modelProvider";
+
+
+class ModelsServer{
+ getModels = async (provider: GlobalLLMProviderKey): Promise => {
+ const headers = await createHeaderWithAuth({ provider, headers: { 'Content-Type': 'application/json' } });
+
+ try {
+ const res = await fetch(API_ENDPOINTS.models(provider), {
+ headers,
+ });
+
+ if (!res.ok) {
+ throw await getMessageError(res);
+ }
+
+ return res.json();
+ } catch (error) {
+ return { error: JSON.stringify(error) };
+ }
+ }
+}
+
+export const modelsServer = new ModelsServer();
\ No newline at end of file
diff --git a/frontend/src/store/global/slices/settings/selectors/modelProvider.ts b/frontend/src/store/global/slices/settings/selectors/modelProvider.ts
index c0738ce5..67411c8a 100644
--- a/frontend/src/store/global/slices/settings/selectors/modelProvider.ts
+++ b/frontend/src/store/global/slices/settings/selectors/modelProvider.ts
@@ -63,6 +63,7 @@ const anthropicAPIKey = (s: GlobalStore) => modelProvider(s).anthropic.apiKey;
const enableChatChat = (s: GlobalStore) => modelProvider(s).chatchat.enabled;
const chatChatProxyUrl = (s: GlobalStore) => modelProvider(s).chatchat.endpoint;
+const chatChatModels = (s: GlobalStore) => modelProvider(s).chatchat.models || [];
// const azureModelList = (s: GlobalStore): ModelProviderCard => {
// const azure = azureConfig(s);
@@ -138,6 +139,12 @@ const modelSelectList = (s: GlobalStore): ModelProviderCard[] => {
const ollamaChatModels = processChatModels(ollamaModelConfig, OllamaProvider.chatModels);
+
+ const chatChatModelConfig = parseModelString(
+ currentSettings(s).languageModel.chatchat.customModelName
+ )
+ const chatChatChatModels = processChatModels(chatChatModelConfig, chatChatModels(s))
+
return [
{
...OpenAIProvider,
@@ -152,7 +159,7 @@ const modelSelectList = (s: GlobalStore): ModelProviderCard[] => {
{ ...PerplexityProvider, enabled: enablePerplexity(s) },
{ ...AnthropicProvider, enabled: enableAnthropic(s) },
{ ...MistralProvider, enabled: enableMistral(s) },
- { ...ChatChatProvider, enabled: enableChatChat(s) },
+ { ...ChatChatProvider, chatModels: chatChatChatModels, enabled: enableChatChat(s) },
];
};
diff --git a/frontend/src/types/message/index.ts b/frontend/src/types/message/index.ts
index 0b8b5a1e..0f28496d 100644
--- a/frontend/src/types/message/index.ts
+++ b/frontend/src/types/message/index.ts
@@ -17,6 +17,11 @@ export interface ChatMessageError {
type: ErrorType | IPluginErrorType | ILobeAgentRuntimeErrorType;
}
+export interface ModelSelectorError {
+ message: string;
+ type: ErrorType;
+}
+
export interface ChatTranslate extends Translate {
content?: string;
}
diff --git a/frontend/src/types/models.ts b/frontend/src/types/models.ts
new file mode 100644
index 00000000..49102f5f
--- /dev/null
+++ b/frontend/src/types/models.ts
@@ -0,0 +1,15 @@
+interface Model {
+ id: string;
+ created: number; // 时间戳
+ platform_name: string;
+ owned_by: string;
+ object: string;
+ tokens?: number;
+ displayName?: string;
+}
+
+export interface ModelsResponse {
+ object?: 'list';
+ data?: Model[];
+ error?: string;
+}
\ No newline at end of file
diff --git a/frontend/src/types/settings/modelProvider.ts b/frontend/src/types/settings/modelProvider.ts
index 78866506..0ca05fdb 100644
--- a/frontend/src/types/settings/modelProvider.ts
+++ b/frontend/src/types/settings/modelProvider.ts
@@ -1,3 +1,5 @@
+import { ChatModelCard } from "../llm";
+
export type CustomModels = { displayName: string; id: string }[];
export interface OpenAIConfig {
@@ -22,23 +24,27 @@ export interface AzureOpenAIConfig {
deployments: string;
enabled: boolean;
endpoint?: string;
+ models?: ChatModelCard[]
}
export interface ZhiPuConfig {
apiKey?: string;
enabled: boolean;
endpoint?: string;
+ models?: ChatModelCard[]
}
export interface MoonshotConfig {
apiKey?: string;
enabled: boolean;
+ models?: ChatModelCard[]
}
export interface GoogleConfig {
apiKey?: string;
enabled: boolean;
endpoint?: string;
+ models?: ChatModelCard[]
}
export interface AWSBedrockConfig {
@@ -46,34 +52,40 @@ export interface AWSBedrockConfig {
enabled: boolean;
region?: string;
secretAccessKey?: string;
+ models?: ChatModelCard[]
}
export interface OllamaConfig {
customModelName?: string;
enabled?: boolean;
endpoint?: string;
+ models?: ChatModelCard[]
}
export interface PerplexityConfig {
apiKey?: string;
enabled: boolean;
endpoint?: string;
+ models?: ChatModelCard[]
}
export interface AnthropicConfig {
apiKey?: string;
enabled: boolean;
+ models?: ChatModelCard[]
}
export interface MistralConfig {
apiKey?: string;
enabled: boolean;
+ models?: ChatModelCard[]
}
export interface ChatChatConfig {
customModelName?: string;
enabled?: boolean;
endpoint?: string;
+ models?: ChatModelCard[]
}
export interface GlobalLLMConfig {