mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-28 09:43:19 +08:00
103 lines
4.0 KiB
Python
103 lines
4.0 KiB
Python
from threading import Lock
|
|
from time import time
|
|
|
|
from requests.adapters import HTTPAdapter
|
|
from requests.exceptions import ConnectionError, MissingSchema, Timeout
|
|
from requests.sessions import Session
|
|
from yarl import URL
|
|
|
|
|
|
class XinferenceModelExtraParameter:
|
|
model_format: str
|
|
model_handle_type: str
|
|
model_ability: list[str]
|
|
max_tokens: int = 512
|
|
context_length: int = 2048
|
|
support_function_call: bool = False
|
|
|
|
def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
|
|
support_function_call: bool, max_tokens: int, context_length: int) -> None:
|
|
self.model_format = model_format
|
|
self.model_handle_type = model_handle_type
|
|
self.model_ability = model_ability
|
|
self.support_function_call = support_function_call
|
|
self.max_tokens = max_tokens
|
|
self.context_length = context_length
|
|
|
|
cache = {}
|
|
cache_lock = Lock()
|
|
|
|
class XinferenceHelper:
|
|
@staticmethod
|
|
def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
|
|
XinferenceHelper._clean_cache()
|
|
with cache_lock:
|
|
if model_uid not in cache:
|
|
cache[model_uid] = {
|
|
'expires': time() + 300,
|
|
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid)
|
|
}
|
|
return cache[model_uid]['value']
|
|
|
|
@staticmethod
|
|
def _clean_cache() -> None:
|
|
try:
|
|
with cache_lock:
|
|
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
|
|
for model_uid in expired_keys:
|
|
del cache[model_uid]
|
|
except RuntimeError as e:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
|
|
"""
|
|
get xinference model extra parameter like model_format and model_handle_type
|
|
"""
|
|
|
|
if not model_uid or not model_uid.strip() or not server_url or not server_url.strip():
|
|
raise RuntimeError('model_uid is empty')
|
|
|
|
url = str(URL(server_url) / 'v1' / 'models' / model_uid)
|
|
|
|
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
|
session = Session()
|
|
session.mount('http://', HTTPAdapter(max_retries=3))
|
|
session.mount('https://', HTTPAdapter(max_retries=3))
|
|
|
|
try:
|
|
response = session.get(url, timeout=10)
|
|
except (MissingSchema, ConnectionError, Timeout) as e:
|
|
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
|
|
if response.status_code != 200:
|
|
raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
|
|
|
|
response_json = response.json()
|
|
|
|
model_format = response_json.get('model_format', 'ggmlv3')
|
|
model_ability = response_json.get('model_ability', [])
|
|
|
|
if response_json.get('model_type') == 'embedding':
|
|
model_handle_type = 'embedding'
|
|
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
|
|
model_handle_type = 'chatglm'
|
|
elif 'generate' in model_ability:
|
|
model_handle_type = 'generate'
|
|
elif 'chat' in model_ability:
|
|
model_handle_type = 'chat'
|
|
else:
|
|
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
|
|
|
|
support_function_call = 'tools' in model_ability
|
|
max_tokens = response_json.get('max_tokens', 512)
|
|
|
|
context_length = response_json.get('context_length', 2048)
|
|
|
|
return XinferenceModelExtraParameter(
|
|
model_format=model_format,
|
|
model_handle_type=model_handle_type,
|
|
model_ability=model_ability,
|
|
support_function_call=support_function_call,
|
|
max_tokens=max_tokens,
|
|
context_length=context_length
|
|
) |