添加knowledge会话接口

This commit is contained in:
VLOU 2024-03-31 00:29:10 +08:00
parent c6b92bc4d0
commit ed9ecebffc
7 changed files with 232 additions and 0 deletions

View File

@ -21,6 +21,7 @@ import {
LobeRuntimeAI,
LobeZhipuAI,
ModelProvider,
LobeKnowledgeAI,
} from '@/libs/agent-runtime';
import { TraceClient } from '@/libs/traces';
@ -167,6 +168,11 @@ class AgentRuntime {
runtimeModel = this.initMistral(payload);
break;
}
case ModelProvider.Knowledge: {
runtimeModel = this.initKnowledge(payload);
break;
}
}
return new AgentRuntime(runtimeModel);
@ -268,6 +274,13 @@ class AgentRuntime {
return new LobeMistralAI({ apiKey });
}
private static initKnowledge(payload: JWTPayload) {
const { KNOWLEDGE_PROXY_URL } = getServerConfig();
const baseURL = payload?.endpoint || KNOWLEDGE_PROXY_URL;
return new LobeKnowledgeAI({ baseURL });
}
}
export default AgentRuntime;

View File

@ -114,5 +114,7 @@ export const getProviderConfig = () => {
ENABLE_OLLAMA: !!process.env.OLLAMA_PROXY_URL,
OLLAMA_PROXY_URL: process.env.OLLAMA_PROXY_URL || '',
KNOWLEDGE_PROXY_URL: process.env.OLLAMA_PROXY_URL || '',
};
};

View File

@ -34,6 +34,9 @@ export const AgentRuntimeErrorType = {
InvalidAnthropicAPIKey: 'InvalidAnthropicAPIKey',
AnthropicBizError: 'AnthropicBizError',
InvalidKnowledgeArgs: 'InvalidKnowledgeArgs',
KnowledgeBizError: 'KnowledgeBizError',
} as const;
export type ILobeAgentRuntimeErrorType =

View File

@ -12,3 +12,4 @@ export { LobePerplexityAI } from './perplexity';
export * from './types';
export { AgentRuntimeError } from './utils/createError';
export { LobeZhipuAI } from './zhipu';
export { LobeKnowledgeAI } from './knowledge';

View File

@ -0,0 +1,102 @@
// @vitest-environment node
import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { ChatStreamCallbacks } from '@/libs/agent-runtime';
import * as debugStreamModule from '../utils/debugStream';
import { LobeKnowledgeAI } from './index';
const provider = 'knowledge';
const defaultBaseURL = 'http://localhost:7861/v1';
const bizErrorType = 'knowledgeBizError';
const invalidErrorType = 'InvalidKnowledgeArgs';
// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});
let instance: LobeKnowledgeAI;
beforeEach(() => {
instance = new LobeKnowledgeAI({ apiKey: 'knowledge', baseURL: defaultBaseURL });
// 使用 vi.spyOn 来模拟 chat.completions.create 方法
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
});
afterEach(() => {
vi.clearAllMocks();
});
describe('LobeKnowledgeAI', () => {
describe('init', ()=>{
it('should init with default baseURL', () => {
expect(instance.baseURL).toBe(defaultBaseURL);
});
})
describe('chat', () => {
it('should return a StreamingTextResponse on successful API call', async () => {
// Arrange
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);
(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);
// Act
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'gpt-3.5-turbo',
temperature: 0,
});
// Assert
expect(result).toBeInstanceOf(Response);
});
it('should return a StreamingTextResponse on successful API call', async () => {
// Arrange
const mockResponse = Promise.resolve({
"id": "chatcmpl-98QIb3NiYLYlRTB6t0VrJ0wntNW6K",
"object": "chat.completion",
"created": 1711794745,
"model": "gpt-3.5-turbo-0125",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "你好!有什么可以帮助你的吗?"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 17,
"total_tokens": 26
},
"system_fingerprint": "fp_b28b39ffa8"
});
(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);
// Act
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'gpt-3.5-turbo',
stream: false,
temperature: 0,
});
// Assert
expect(result).toBeInstanceOf(Response);
});
})
});

View File

@ -0,0 +1,110 @@
import { OpenAIStream, StreamingTextResponse } from 'ai';
import OpenAI, { ClientOptions } from 'openai';
import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { desensitizeUrl } from '../utils/desensitizeUrl';
import { handleOpenAIError } from '../utils/handleOpenAIError';
import { Stream } from 'openai/streaming';
const DEFAULT_BASE_URL = 'http://localhost:7861/v1';
export class LobeKnowledgeAI implements LobeRuntimeAI {
private client: OpenAI;
baseURL: string;
constructor({ apiKey = 'knowledge', baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions) {
if (!baseURL) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidKnowledgeArgs);
this.client = new OpenAI({ apiKey, baseURL, ...res });
this.baseURL = baseURL;
}
async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
try {
const response = await this.client.chat.completions.create(
payload as unknown as (OpenAI.ChatCompletionCreateParamsStreaming | OpenAI.ChatCompletionCreateParamsNonStreaming),
);
if (LobeKnowledgeAI.isStream(response)) {
const [prod, debug] = response.tee();
if (process.env.DEBUG_OLLAMA_CHAT_COMPLETION === '1') {
debugStream(debug.toReadableStream()).catch(console.error);
}
return new StreamingTextResponse(OpenAIStream(prod, options?.callback), {
headers: options?.headers,
});
} else {
if (process.env.DEBUG_OLLAMA_CHAT_COMPLETION === '1') {
console.debug(JSON.stringify(response));
}
const stream = LobeKnowledgeAI.createChatCompletionStream(response?.choices[0].message.content || '');
return new StreamingTextResponse(stream);
}
} catch (error) {
let desensitizedEndpoint = this.baseURL;
if (this.baseURL !== DEFAULT_BASE_URL) {
desensitizedEndpoint = desensitizeUrl(this.baseURL);
}
if ('status' in (error as any)) {
switch ((error as Response).status) {
case 401: {
throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: error as any,
errorType: AgentRuntimeErrorType.InvalidKnowledgeArgs,
provider: ModelProvider.Knowledge,
});
}
default: {
break;
}
}
}
const { errorResult, RuntimeError } = handleOpenAIError(error);
const errorType = RuntimeError || AgentRuntimeErrorType.OllamaBizError;
throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: errorResult,
errorType,
provider: ModelProvider.Knowledge,
});
}
}
static isStream(obj: unknown): obj is Stream<OpenAI.Chat.Completions.ChatCompletionChunk> {
return typeof Stream !== 'undefined' && (obj instanceof Stream || obj instanceof ReadableStream);
}
// 创建一个类型为 Stream<string> 的流
static createChatCompletionStream(text: string): ReadableStream<string> {
const stream = new ReadableStream({
start(controller) {
controller.enqueue(text);
controller.close();
},
});
return stream;
}
}

View File

@ -34,4 +34,5 @@ export enum ModelProvider {
Perplexity = 'perplexity',
Tongyi = 'tongyi',
ZhiPu = 'zhipu',
Knowledge = 'knowledge',
}