openai api需要的依赖,requirement推理依赖

This commit is contained in:
zR 2024-04-13 23:16:59 +08:00
parent 6981085d87
commit e58d99f8ca
2 changed files with 18 additions and 5 deletions

View File

@ -1,7 +1,4 @@
from typing import Dict
from typing import List from typing import List
from typing import Tuple
import argparse import argparse
import gradio as gr import gradio as gr
import torch import torch
@ -16,7 +13,7 @@ import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="") parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-2B-dpo-fp16")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"]) parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1") parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860) parser.add_argument("--server_port", type=int, default=7860)
@ -55,7 +52,7 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
str: real-time generation results of hf model str: real-time generation results of hf model
""" """
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False) inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
enc = tokenizer(inputs, return_tensors="pt").to("cuda") enc = tokenizer(inputs, return_tensors="pt").to(next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer) streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict( generation_kwargs = dict(
enc, enc,

16
requirements.txt Normal file
View File

@ -0,0 +1,16 @@
# for MiniCPM-2B hf inference
torch>=2.0.0
transformers>=4.36.2
gradio>=4.26.0
# for openai api inference
openai>=1.17.1
tiktoken>=0.6.0
loguru>=0.7.2
sentence_transformers>=2.6.1
sse_starlette>=2.1.0
# for MiniCPM-V hf inference
Pillow>=10.3.0
timm>=0.9.16
sentencepiece>=0.2.0