openai api需要的依赖,requirement推理依赖

This commit is contained in:
zR 2024-04-13 23:16:59 +08:00
parent 6981085d87
commit e58d99f8ca
2 changed files with 18 additions and 5 deletions

View File

@ -1,7 +1,4 @@
from typing import Dict
from typing import List
from typing import Tuple
import argparse
import gradio as gr
import torch
@ -16,7 +13,7 @@ import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-2B-dpo-fp16")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860)
@ -55,7 +52,7 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
str: real-time generation results of hf model
"""
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
enc = tokenizer(inputs, return_tensors="pt").to("cuda")
enc = tokenizer(inputs, return_tensors="pt").to(next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(
enc,

16
requirements.txt Normal file
View File

@ -0,0 +1,16 @@
# for MiniCPM-2B hf inference
torch>=2.0.0
transformers>=4.36.2
gradio>=4.26.0
# for openai api inference
openai>=1.17.1
tiktoken>=0.6.0
loguru>=0.7.2
sentence_transformers>=2.6.1
sse_starlette>=2.1.0
# for MiniCPM-V hf inference
Pillow>=10.3.0
timm>=0.9.16
sentencepiece>=0.2.0