xinference的代码

先传 我后面来改
This commit is contained in:
Leb 2024-03-13 16:37:37 +08:00 committed by GitHub
parent 7b2b24c0bc
commit eec8d51a91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1321 additions and 0 deletions

View File

View File

@ -0,0 +1,42 @@
<svg width="152" height="24" viewBox="0 0 152 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="xorbits 1" clip-path="url(#clip0_9866_6170)">
<path id="Vector" d="M8.00391 12.3124C8.69334 13.0754 9.47526 13.7494 10.3316 14.3188C11.0667 14.8105 11.8509 15.2245 12.6716 15.5541C14.1617 14.1465 15.3959 12.4907 16.3192 10.6606L21.7051 0L12.3133 7.38353C10.5832 8.74456 9.12178 10.416 8.00391 12.3124Z" fill="url(#paint0_linear_9866_6170)"/>
<path id="Vector_2" d="M7.23504 18.9512C6.56092 18.5012 5.92386 18.0265 5.3221 17.5394L2.06445 24L7.91975 19.3959C7.69034 19.2494 7.46092 19.103 7.23504 18.9512Z" fill="url(#paint1_linear_9866_6170)"/>
<path id="Vector_3" d="M19.3161 8.57474C21.0808 10.9147 21.5961 13.5159 20.3996 15.3053C18.6526 17.9189 13.9161 17.8183 9.82024 15.0812C5.72435 12.3441 3.82024 8.0065 5.56729 5.39297C6.76377 3.60356 9.36318 3.0865 12.2008 3.81886C7.29318 1.73474 2.62376 1.94121 0.813177 4.64474C-1.45976 8.04709 1.64435 14.1177 7.74494 18.1889C13.8455 22.26 20.6361 22.8124 22.9091 19.4118C24.7179 16.703 23.1173 12.3106 19.3161 8.57474Z" fill="url(#paint2_linear_9866_6170)"/>
<g id="Xorbits Inference">
<path d="M35.5162 12.142L38.5402 17H36.7482L34.5502 13.472L32.4922 17H30.7142L33.7382 12.142L30.7002 7.27002H32.4922L34.7042 10.826L36.7762 7.27002H38.5542L35.5162 12.142Z" fill="#1D2939"/>
<path d="M43.3584 17.126C42.6304 17.126 41.9724 16.9627 41.3844 16.636C40.7964 16.3 40.3344 15.8334 39.9984 15.236C39.6624 14.6294 39.4944 13.9293 39.4944 13.136C39.4944 12.352 39.6671 11.6567 40.0124 11.05C40.3577 10.4434 40.8291 9.97668 41.4264 9.65002C42.0237 9.32335 42.6911 9.16002 43.4284 9.16002C44.1657 9.16002 44.8331 9.32335 45.4304 9.65002C46.0277 9.97668 46.4991 10.4434 46.8444 11.05C47.1897 11.6567 47.3624 12.352 47.3624 13.136C47.3624 13.92 47.185 14.6154 46.8304 15.222C46.4757 15.8287 45.9904 16.3 45.3744 16.636C44.7677 16.9627 44.0957 17.126 43.3584 17.126ZM43.3584 15.74C43.769 15.74 44.1517 15.642 44.5064 15.446C44.8704 15.25 45.1644 14.956 45.3884 14.564C45.6124 14.172 45.7244 13.696 45.7244 13.136C45.7244 12.576 45.6171 12.1047 45.4024 11.722C45.1877 11.33 44.9031 11.036 44.5484 10.84C44.1937 10.644 43.8111 10.546 43.4004 10.546C42.9897 10.546 42.607 10.644 42.2524 10.84C41.9071 11.036 41.6317 11.33 41.4264 11.722C41.221 12.1047 41.1184 12.576 41.1184 13.136C41.1184 13.9667 41.3284 14.6107 41.7484 15.068C42.1777 15.516 42.7144 15.74 43.3584 15.74Z" fill="#1D2939"/>
<path d="M50.2561 10.406C50.4895 10.014 50.7974 9.71068 51.1801 9.49602C51.5721 9.27202 52.0341 9.16002 52.5661 9.16002V10.812H52.1601C51.5348 10.812 51.0588 10.9707 50.7321 11.288C50.4148 11.6054 50.2561 12.156 50.2561 12.94V17H48.6601V9.28602H50.2561V10.406Z" fill="#1D2939"/>
<path d="M55.3492 10.434C55.6198 10.0607 55.9885 9.75735 56.4552 9.52402C56.9312 9.28135 57.4585 9.16002 58.0372 9.16002C58.7185 9.16002 59.3345 9.32335 59.8852 9.65002C60.4358 9.97668 60.8698 10.4434 61.1872 11.05C61.5045 11.6473 61.6632 12.3333 61.6632 13.108C61.6632 13.8827 61.5045 14.578 61.1872 15.194C60.8698 15.8007 60.4312 16.2767 59.8712 16.622C59.3205 16.958 58.7092 17.126 58.0372 17.126C57.4398 17.126 56.9078 17.0093 56.4412 16.776C55.9838 16.5427 55.6198 16.244 55.3492 15.88V17H53.7532V6.64002H55.3492V10.434ZM60.0392 13.108C60.0392 12.576 59.9272 12.1187 59.7032 11.736C59.4885 11.344 59.1992 11.05 58.8352 10.854C58.4805 10.6487 58.0978 10.546 57.6872 10.546C57.2858 10.546 56.9032 10.6487 56.5392 10.854C56.1845 11.0594 55.8952 11.358 55.6712 11.75C55.4565 12.142 55.3492 12.604 55.3492 13.136C55.3492 13.668 55.4565 14.1347 55.6712 14.536C55.8952 14.928 56.1845 15.2267 56.5392 15.432C56.9032 15.6374 57.2858 15.74 57.6872 15.74C58.0978 15.74 58.4805 15.6374 58.8352 15.432C59.1992 15.2174 59.4885 14.9093 59.7032 14.508C59.9272 14.1067 60.0392 13.64 60.0392 13.108Z" fill="#1D2939"/>
<path d="M63.7734 8.26402C63.4841 8.26402 63.2414 8.16602 63.0454 7.97002C62.8494 7.77402 62.7514 7.53135 62.7514 7.24202C62.7514 6.95268 62.8494 6.71002 63.0454 6.51402C63.2414 6.31802 63.4841 6.22002 63.7734 6.22002C64.0534 6.22002 64.2914 6.31802 64.4874 6.51402C64.6834 6.71002 64.7814 6.95268 64.7814 7.24202C64.7814 7.53135 64.6834 7.77402 64.4874 7.97002C64.2914 8.16602 64.0534 8.26402 63.7734 8.26402ZM64.5574 9.28602V17H62.9614V9.28602H64.5574Z" fill="#1D2939"/>
<path d="M68.2348 10.588V14.858C68.2348 15.1474 68.3002 15.3573 68.4309 15.488C68.5709 15.6093 68.8042 15.67 69.1308 15.67H70.1109V17H68.8508C68.1322 17 67.5815 16.832 67.1988 16.496C66.8162 16.16 66.6248 15.614 66.6248 14.858V10.588H65.7148V9.28602H66.6248V7.36802H68.2348V9.28602H70.1109V10.588H68.2348Z" fill="#1D2939"/>
<path d="M74.1018 17.126C73.4952 17.126 72.9492 17.0187 72.4638 16.804C71.9878 16.58 71.6098 16.2813 71.3298 15.908C71.0498 15.5253 70.9005 15.1007 70.8818 14.634H72.5338C72.5618 14.9607 72.7158 15.236 72.9958 15.46C73.2852 15.6747 73.6445 15.782 74.0738 15.782C74.5218 15.782 74.8672 15.698 75.1098 15.53C75.3618 15.3527 75.4878 15.1287 75.4878 14.858C75.4878 14.5687 75.3478 14.354 75.0678 14.214C74.7972 14.074 74.3632 13.92 73.7658 13.752C73.1872 13.5933 72.7158 13.4394 72.3518 13.29C71.9878 13.1407 71.6705 12.912 71.3998 12.604C71.1385 12.296 71.0078 11.89 71.0078 11.386C71.0078 10.9753 71.1292 10.602 71.3718 10.266C71.6145 9.92068 71.9598 9.65002 72.4078 9.45402C72.8652 9.25802 73.3878 9.16002 73.9758 9.16002C74.8532 9.16002 75.5578 9.38402 76.0898 9.83202C76.6312 10.2707 76.9205 10.8727 76.9578 11.638H75.3618C75.3338 11.2927 75.1938 11.0173 74.9418 10.812C74.6898 10.6067 74.3492 10.504 73.9198 10.504C73.4998 10.504 73.1778 10.5833 72.9538 10.742C72.7298 10.9007 72.6178 11.1107 72.6178 11.372C72.6178 11.5773 72.6925 11.75 72.8418 11.89C72.9912 12.03 73.1732 12.142 73.3878 12.226C73.6025 12.3007 73.9198 12.3987 74.3398 12.52C74.8998 12.6693 75.3572 12.8233 75.7118 12.982C76.0758 13.1314 76.3885 13.3554 76.6498 13.654C76.9112 13.9527 77.0465 14.3493 77.0558 14.844C77.0558 15.2827 76.9345 15.6747 76.6918 16.02C76.4492 16.3654 76.1038 16.636 75.6558 16.832C75.2172 17.028 74.6992 17.126 74.1018 17.126Z" fill="#1D2939"/>
<path d="M83.4531 7.27002V17H81.8571V7.27002H83.4531Z" fill="#1D2939"/>
<path d="M89.1605 9.16002C89.7671 9.16002 90.3085 9.28602 90.7845 9.53802C91.2698 9.79002 91.6478 10.1633 91.9185 10.658C92.1891 11.1527 92.3245 11.75 92.3245 12.45V17H90.7425V12.688C90.7425 11.9973 90.5698 11.47 90.2245 11.106C89.8791 10.7327 89.4078 10.546 88.8105 10.546C88.2131 10.546 87.7371 10.7327 87.3825 11.106C87.0371 11.47 86.8645 11.9973 86.8645 12.688V17H85.2685V9.28602H86.8645V10.168C87.1258 9.85068 87.4571 9.60335 87.8585 9.42602C88.2691 9.24868 88.7031 9.16002 89.1605 9.16002Z" fill="#1D2939"/>
<path d="M97.3143 10.588H95.8863V17H94.2763V10.588H93.3663V9.28602H94.2763V8.74002C94.2763 7.85335 94.5096 7.20935 94.9763 6.80802C95.4523 6.39735 96.1943 6.19202 97.2023 6.19202V7.52202C96.7169 7.52202 96.3763 7.61535 96.1803 7.80202C95.9843 7.97935 95.8863 8.29202 95.8863 8.74002V9.28602H97.3143V10.588Z" fill="#1D2939"/>
<path d="M105.519 12.954C105.519 13.2433 105.5 13.5047 105.463 13.738H99.5687C99.6154 14.354 99.844 14.8487 100.255 15.222C100.665 15.5954 101.169 15.782 101.767 15.782C102.625 15.782 103.232 15.4227 103.587 14.704H105.309C105.075 15.4133 104.651 15.9967 104.035 16.454C103.428 16.902 102.672 17.126 101.767 17.126C101.029 17.126 100.367 16.9627 99.7787 16.636C99.2 16.3 98.7427 15.8334 98.4067 15.236C98.08 14.6294 97.9167 13.9293 97.9167 13.136C97.9167 12.3427 98.0754 11.6473 98.3927 11.05C98.7194 10.4434 99.172 9.97668 99.7507 9.65002C100.339 9.32335 101.011 9.16002 101.767 9.16002C102.495 9.16002 103.143 9.31868 103.713 9.63602C104.282 9.95335 104.725 10.4014 105.043 10.98C105.36 11.5493 105.519 12.2073 105.519 12.954ZM103.853 12.45C103.843 11.862 103.633 11.3907 103.223 11.036C102.812 10.6813 102.303 10.504 101.697 10.504C101.146 10.504 100.675 10.6813 100.283 11.036C99.8907 11.3813 99.6574 11.8527 99.5827 12.45H103.853Z" fill="#1D2939"/>
<path d="M108.405 10.406C108.639 10.014 108.947 9.71068 109.329 9.49602C109.721 9.27202 110.183 9.16002 110.715 9.16002V10.812H110.309C109.684 10.812 109.208 10.9707 108.881 11.288C108.564 11.6054 108.405 12.156 108.405 12.94V17H106.809V9.28602H108.405V10.406Z" fill="#1D2939"/>
<path d="M118.972 12.954C118.972 13.2433 118.954 13.5047 118.916 13.738H113.022C113.069 14.354 113.298 14.8487 113.708 15.222C114.119 15.5954 114.623 15.782 115.22 15.782C116.079 15.782 116.686 15.4227 117.04 14.704H118.762C118.529 15.4133 118.104 15.9967 117.488 16.454C116.882 16.902 116.126 17.126 115.22 17.126C114.483 17.126 113.82 16.9627 113.232 16.636C112.654 16.3 112.196 15.8334 111.86 15.236C111.534 14.6294 111.37 13.9293 111.37 13.136C111.37 12.3427 111.529 11.6473 111.846 11.05C112.173 10.4434 112.626 9.97668 113.204 9.65002C113.792 9.32335 114.464 9.16002 115.22 9.16002C115.948 9.16002 116.597 9.31868 117.166 9.63602C117.736 9.95335 118.179 10.4014 118.496 10.98C118.814 11.5493 118.972 12.2073 118.972 12.954ZM117.306 12.45C117.297 11.862 117.087 11.3907 116.676 11.036C116.266 10.6813 115.757 10.504 115.15 10.504C114.6 10.504 114.128 10.6813 113.736 11.036C113.344 11.3813 113.111 11.8527 113.036 12.45H117.306Z" fill="#1D2939"/>
<path d="M124.155 9.16002C124.762 9.16002 125.303 9.28602 125.779 9.53802C126.264 9.79002 126.642 10.1633 126.913 10.658C127.184 11.1527 127.319 11.75 127.319 12.45V17H125.737V12.688C125.737 11.9973 125.564 11.47 125.219 11.106C124.874 10.7327 124.402 10.546 123.805 10.546C123.208 10.546 122.732 10.7327 122.377 11.106C122.032 11.47 121.859 11.9973 121.859 12.688V17H120.263V9.28602H121.859V10.168C122.12 9.85068 122.452 9.60335 122.853 9.42602C123.264 9.24868 123.698 9.16002 124.155 9.16002Z" fill="#1D2939"/>
<path d="M128.543 13.136C128.543 12.3427 128.701 11.6473 129.019 11.05C129.345 10.4434 129.793 9.97668 130.363 9.65002C130.932 9.32335 131.585 9.16002 132.323 9.16002C133.256 9.16002 134.026 9.38402 134.633 9.83202C135.249 10.2707 135.664 10.9007 135.879 11.722H134.157C134.017 11.3394 133.793 11.0407 133.485 10.826C133.177 10.6113 132.789 10.504 132.323 10.504C131.669 10.504 131.147 10.7373 130.755 11.204C130.372 11.6613 130.181 12.3053 130.181 13.136C130.181 13.9667 130.372 14.6153 130.755 15.082C131.147 15.5487 131.669 15.782 132.323 15.782C133.247 15.782 133.858 15.376 134.157 14.564H135.879C135.655 15.348 135.235 15.9733 134.619 16.44C134.003 16.8973 133.237 17.126 132.323 17.126C131.585 17.126 130.932 16.9627 130.363 16.636C129.793 16.3 129.345 15.8334 129.019 15.236C128.701 14.6294 128.543 13.9293 128.543 13.136Z" fill="#1D2939"/>
<path d="M144.259 12.954C144.259 13.2433 144.241 13.5047 144.203 13.738H138.309C138.356 14.354 138.585 14.8487 138.995 15.222C139.406 15.5954 139.91 15.782 140.507 15.782C141.366 15.782 141.973 15.4227 142.327 14.704H144.049C143.816 15.4133 143.391 15.9967 142.775 16.454C142.169 16.902 141.413 17.126 140.507 17.126C139.77 17.126 139.107 16.9627 138.519 16.636C137.941 16.3 137.483 15.8334 137.147 15.236C136.821 14.6294 136.657 13.9293 136.657 13.136C136.657 12.3427 136.816 11.6473 137.133 11.05C137.46 10.4434 137.913 9.97668 138.491 9.65002C139.079 9.32335 139.751 9.16002 140.507 9.16002C141.235 9.16002 141.884 9.31868 142.453 9.63602C143.023 9.95335 143.466 10.4014 143.783 10.98C144.101 11.5493 144.259 12.2073 144.259 12.954ZM142.593 12.45C142.584 11.862 142.374 11.3907 141.963 11.036C141.553 10.6813 141.044 10.504 140.437 10.504C139.887 10.504 139.415 10.6813 139.023 11.036C138.631 11.3813 138.398 11.8527 138.323 12.45H142.593Z" fill="#1D2939"/>
</g>
</g>
<defs>
<linearGradient id="paint0_linear_9866_6170" x1="2.15214" y1="24.3018" x2="21.2921" y2="0.0988218" gradientUnits="userSpaceOnUse">
<stop stop-color="#E9A85E"/>
<stop offset="1" stop-color="#F52B76"/>
</linearGradient>
<linearGradient id="paint1_linear_9866_6170" x1="2.06269" y1="24.2294" x2="21.2027" y2="0.028252" gradientUnits="userSpaceOnUse">
<stop stop-color="#E9A85E"/>
<stop offset="1" stop-color="#F52B76"/>
</linearGradient>
<linearGradient id="paint2_linear_9866_6170" x1="-0.613606" y1="3.843" x2="21.4449" y2="18.7258" gradientUnits="userSpaceOnUse">
<stop stop-color="#6A0CF5"/>
<stop offset="1" stop-color="#AB66F3"/>
</linearGradient>
<clipPath id="clip0_9866_6170">
<rect width="152" height="24" fill="white"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 12 KiB

View File

@ -0,0 +1,24 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="Xorbits Square" clip-path="url(#clip0_9850_26870)">
<path id="Vector" d="M8.00391 12.3124C8.69334 13.0754 9.47526 13.7494 10.3316 14.3188C11.0667 14.8105 11.8509 15.2245 12.6716 15.5541C14.1617 14.1465 15.3959 12.4907 16.3192 10.6606L21.7051 0L12.3133 7.38353C10.5832 8.74456 9.12178 10.416 8.00391 12.3124Z" fill="url(#paint0_linear_9850_26870)"/>
<path id="Vector_2" d="M7.23504 18.9512C6.56092 18.5012 5.92386 18.0265 5.3221 17.5394L2.06445 24L7.91975 19.3959C7.69034 19.2494 7.46092 19.103 7.23504 18.9512Z" fill="url(#paint1_linear_9850_26870)"/>
<path id="Vector_3" d="M19.3161 8.57474C21.0808 10.9147 21.5961 13.5159 20.3996 15.3053C18.6526 17.9189 13.9161 17.8183 9.82024 15.0812C5.72435 12.3441 3.82024 8.0065 5.56729 5.39297C6.76377 3.60356 9.36318 3.0865 12.2008 3.81886C7.29318 1.73474 2.62376 1.94121 0.813177 4.64474C-1.45976 8.04709 1.64435 14.1177 7.74494 18.1889C13.8455 22.26 20.6361 22.8124 22.9091 19.4118C24.7179 16.703 23.1173 12.3106 19.3161 8.57474Z" fill="url(#paint2_linear_9850_26870)"/>
</g>
<defs>
<linearGradient id="paint0_linear_9850_26870" x1="2.15214" y1="24.3018" x2="21.2921" y2="0.0988218" gradientUnits="userSpaceOnUse">
<stop stop-color="#E9A85E"/>
<stop offset="1" stop-color="#F52B76"/>
</linearGradient>
<linearGradient id="paint1_linear_9850_26870" x1="2.06269" y1="24.2294" x2="21.2027" y2="0.028252" gradientUnits="userSpaceOnUse">
<stop stop-color="#E9A85E"/>
<stop offset="1" stop-color="#F52B76"/>
</linearGradient>
<linearGradient id="paint2_linear_9850_26870" x1="-0.613606" y1="3.843" x2="21.4449" y2="18.7258" gradientUnits="userSpaceOnUse">
<stop stop-color="#6A0CF5"/>
<stop offset="1" stop-color="#AB66F3"/>
</linearGradient>
<clipPath id="clip0_9850_26870">
<rect width="24" height="24" fill="white"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 1.8 KiB

View File

View File

@ -0,0 +1,734 @@
from collections.abc import Generator, Iterator
from typing import cast
from openai import (
APIConnectionError,
APITimeoutError,
AuthenticationError,
ConflictError,
InternalServerError,
NotFoundError,
OpenAI,
PermissionDeniedError,
RateLimitError,
UnprocessableEntityError,
)
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.completion import Completion
from xinference_client.client.restful.restful_client import (
Client,
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.xinference.xinference_helper import (
XinferenceHelper,
XinferenceModelExtraParameter,
)
from core.model_runtime.utils import helper
class XinferenceAILargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
"""
invoke LLM
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
"""
if 'temperature' in model_parameters:
if model_parameters['temperature'] < 0.01:
model_parameters['temperature'] = 0.01
elif model_parameters['temperature'] > 1.0:
model_parameters['temperature'] = 0.99
return self._generate(
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
tools=tools, stop=stop, stream=stream, user=user,
extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
)
)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
validate credentials
credentials should be like:
{
'model_type': 'text-generation',
'server_url': 'server url',
'model_uid': 'model uid',
}
"""
try:
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
)
if 'completion_type' not in credentials:
if 'chat' in extra_param.model_ability:
credentials['completion_type'] = 'chat'
elif 'generate' in extra_param.model_ability:
credentials['completion_type'] = 'completion'
else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
if extra_param.support_function_call:
credentials['support_function_call'] = True
if extra_param.context_length:
credentials['context_length'] = extra_param.context_length
except RuntimeError as e:
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
except KeyError as e:
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
except Exception as e:
raise e
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
"""
get number of tokens
cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use
so we just take the GPT2 tokenizer as default
"""
return self._num_tokens_from_messages(prompt_messages, tools)
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
is_completion_model: bool = False) -> int:
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
if is_completion_model:
return sum([tokens(str(message.content)) for message in messages])
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ''
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item.text
value = text
if key == "tool_calls":
for tool_call in value:
for t_key, t_value in tool_call.items():
num_tokens += tokens(t_key)
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += tokens(f_key)
num_tokens += tokens(f_value)
else:
num_tokens += tokens(t_key)
num_tokens += tokens(t_value)
if key == "function_call":
for t_key, t_value in value.items():
num_tokens += tokens(t_key)
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += tokens(f_key)
num_tokens += tokens(f_value)
else:
num_tokens += tokens(t_key)
num_tokens += tokens(t_value)
else:
num_tokens += tokens(str(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling
:param encoding: encoding
:param tools: tools for tool calling
:return: number of tokens
"""
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
num_tokens = 0
for tool in tools:
# calculate num tokens for function object
num_tokens += tokens('name')
num_tokens += tokens(tool.name)
num_tokens += tokens('description')
num_tokens += tokens(tool.description)
parameters = tool.parameters
num_tokens += tokens('parameters')
num_tokens += tokens('type')
num_tokens += tokens(parameters.get("type"))
if 'properties' in parameters:
num_tokens += tokens('properties')
for key, value in parameters.get('properties').items():
num_tokens += tokens(key)
for field_key, field_value in value.items():
num_tokens += tokens(field_key)
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += tokens(enum_field)
else:
num_tokens += tokens(field_key)
num_tokens += tokens(str(field_value))
if 'required' in parameters:
num_tokens += tokens('required')
for required_field in parameters['required']:
num_tokens += 3
num_tokens += tokens(required_field)
return num_tokens
def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
"""
convert prompt message to text
"""
text = ''
for item in message:
if isinstance(item, UserPromptMessage):
text += item.content
elif isinstance(item, SystemPromptMessage):
text += item.content
elif isinstance(item, AssistantPromptMessage):
text += item.content
else:
raise NotImplementedError(f'PromptMessage type {type(item)} is not supported')
return text
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI Compatibility API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature',
type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度',
en_US='Temperature'
),
),
ParameterRule(
name='top_p',
type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P',
en_US='Top P'
)
),
ParameterRule(
name='max_tokens',
type=ParameterType.INT,
use_template='max_tokens',
min=1,
max=credentials.get('context_length', 2048),
default=512,
label=I18nObject(
zh_Hans='最大生成长度',
en_US='Max Tokens'
)
)
]
completion_type = None
if 'completion_type' in credentials:
if credentials['completion_type'] == 'chat':
completion_type = LLMMode.CHAT.value
elif credentials['completion_type'] == 'completion':
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
else:
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
)
if 'chat' in extra_args.model_ability:
completion_type = LLMMode.CHAT.value
elif 'generate' in extra_args.model_ability:
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
support_function_call = credentials.get('support_function_call', False)
context_length = credentials.get('context_length', 2048)
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
features=[
ModelFeature.TOOL_CALL
] if support_function_call else [],
model_properties={
ModelPropertyKey.MODE: completion_type,
ModelPropertyKey.CONTEXT_SIZE: context_length
},
parameter_rules=rules
)
return entity
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
"""
generate text from LLM
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate`
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('server_url is required in credentials')
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
client = OpenAI(
base_url=f'{credentials["server_url"]}/v1',
api_key='abc',
max_retries=3,
timeout=60,
)
xinference_client = Client(
base_url=credentials['server_url'],
)
xinference_model = xinference_client.get_model(credentials['model_uid'])
generate_config = {
'temperature': model_parameters.get('temperature', 1.0),
'top_p': model_parameters.get('top_p', 0.7),
'max_tokens': model_parameters.get('max_tokens', 512),
}
if stop:
generate_config['stop'] = stop
if tools and len(tools) > 0:
generate_config['tools'] = [
{
'type': 'function',
'function': helper.dump_model(tool)
} for tool in tools
]
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
resp = client.chat.completions.create(
model=credentials['model_uid'],
messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
stream=stream,
user=user,
**generate_config,
)
if stream:
if tools and len(tools) > 0:
raise InvokeBadRequestError('xinference tool calls does not support stream mode')
return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
tools=tools, resp=resp)
return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
tools=tools, resp=resp)
elif isinstance(xinference_model, RESTfulGenerateModelHandle):
resp = client.completions.create(
model=credentials['model_uid'],
prompt=self._convert_prompt_message_to_text(prompt_messages),
stream=stream,
user=user,
**generate_config,
)
if stream:
return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
tools=tools, resp=resp)
return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
tools=tools, resp=resp)
else:
raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported')
def _extract_response_tool_calls(self,
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
-> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
:param response_tool_calls: response tool calls
:return: list of tool calls
"""
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name,
arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id,
type=response_tool_call.type,
function=function
)
tool_calls.append(tool_call)
return tool_calls
def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
-> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
:param response_function_call: response function call
:return: tool call
"""
tool_call = None
if response_function_call:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call.name,
arguments=response_function_call.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call.name,
type="function",
function=function
)
return tool_call
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: ChatCompletion) -> LLMResult:
"""
handle normal chat generate response
"""
if len(resp.choices) == 0:
raise InvokeServerUnavailableError("Empty response")
assistant_message = resp.choices[0].message
# convert tool call to assistant message tool call
tool_calls = assistant_message.tool_calls
assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else [])
function_call = assistant_message.function_call
if function_call:
assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)]
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_message.content,
tool_calls=assistant_prompt_message_tool_calls
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=resp.system_fingerprint,
usage=usage,
message=assistant_prompt_message,
)
return response
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: Iterator[ChatCompletionChunk]) -> Generator:
"""
handle stream chat generate response
"""
full_response = ''
for chunk in resp:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
continue
# check if there is a tool call in the response
function_call = None
tool_calls = []
if delta.delta.tool_calls:
tool_calls += delta.delta.tool_calls
if delta.delta.function_call:
function_call = delta.delta.function_call
assistant_message_tool_calls = self._extract_response_tool_calls(tool_calls)
if function_call:
assistant_message_tool_calls += [self._extract_response_function_call(function_call)]
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else '',
tool_calls=assistant_message_tool_calls
)
if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response,
tool_calls=assistant_message_tool_calls
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message,
),
)
full_response += delta.delta.content
def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: Completion) -> LLMResult:
"""
handle normal completion generate response
"""
if len(resp.choices) == 0:
raise InvokeServerUnavailableError("Empty response")
assistant_message = resp.choices[0].text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_message,
tool_calls=[]
)
prompt_tokens = self._get_num_tokens_by_gpt2(
self._convert_prompt_message_to_text(prompt_messages)
)
completion_tokens = self._num_tokens_from_messages(
messages=[assistant_prompt_message], tools=[], is_completion_model=True
)
usage = self._calc_response_usage(
model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=resp.system_fingerprint,
usage=usage,
message=assistant_prompt_message,
)
return response
def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: Iterator[Completion]) -> Generator:
"""
handle stream completion generate response
"""
full_response = ''
for chunk in resp:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.text if delta.text else '',
tool_calls=[]
)
if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response,
tool_calls=[]
)
prompt_tokens = self._get_num_tokens_by_gpt2(
self._convert_prompt_message_to_text(prompt_messages)
)
completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True
)
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage
),
)
else:
if delta.text is None or delta.text == '':
continue
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message,
),
)
full_response += delta.text
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
APIConnectionError,
APITimeoutError,
],
InvokeServerUnavailableError: [
InternalServerError,
ConflictError,
NotFoundError,
UnprocessableEntityError,
PermissionDeniedError
],
InvokeRateLimitError: [
RateLimitError
],
InvokeAuthorizationError: [
AuthenticationError
],
InvokeBadRequestError: [
ValueError
]
}

View File

View File

@ -0,0 +1,160 @@
from typing import Optional
from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class XinferenceRerankModel(RerankModel):
"""
Model class for Xinference rerank model.
"""
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(
model=model,
docs=[]
)
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
# initialize client
client = Client(
base_url=credentials['server_url']
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])
if not isinstance(xinference_client, RESTfulRerankModelHandle):
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model')
response = xinference_client.rerank(
documents=docs,
query=query,
top_n=top_n,
)
rerank_documents = []
for idx, result in enumerate(response['results']):
# format document
index = result['index']
page_content = result['document']
rerank_document = RerankDocument(
index=index,
text=page_content,
score=result['relevance_score'],
)
# score threshold check
if score_threshold is not None:
if result['relevance_score'] >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)
return RerankResult(
model=model,
docs=rerank_documents
)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
self.invoke(
model=model,
credentials=credentials,
query="Whose kasumi",
docs=[
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
"and she leads a team named PopiParty."
],
score_threshold=0.8
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={ },
parameter_rules=[]
)
return entity

View File

@ -0,0 +1,201 @@
import time
from typing import Optional
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Xinference text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
credentials should be like:
{
'server_url': 'server url',
'model_uid': 'model uid',
}
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
server_url = credentials['server_url']
model_uid = credentials['model_uid']
if server_url.endswith('/'):
server_url = server_url[:-1]
client = Client(base_url=server_url)
try:
handle = client.get_model(model_uid=model_uid)
except RuntimeError as e:
raise InvokeAuthorizationError(e)
if not isinstance(handle, RESTfulEmbeddingModelHandle):
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
try:
embeddings = handle.create_embedding(input=texts)
except RuntimeError as e:
raise InvokeServerUnavailableError(e)
"""
for convenience, the response json is like:
class Embedding(TypedDict):
object: Literal["list"]
model: str
data: List[EmbeddingData]
usage: EmbeddingUsage
class EmbeddingUsage(TypedDict):
prompt_tokens: int
total_tokens: int
class EmbeddingData(TypedDict):
index: int
object: str
embedding: List[float]
"""
usage = embeddings['usage']
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
result = TextEmbeddingResult(
model=model,
embeddings=[embedding['embedding'] for embedding in embeddings['data']],
usage=usage
)
return result
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
num_tokens = 0
for text in texts:
# use GPT2Tokenizer to get num tokens
num_tokens += self._get_num_tokens_by_gpt2(text)
return num_tokens
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
server_url = credentials['server_url']
model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens
self._invoke(model=model, credentials=credentials, texts=['ping'])
except InvokeAuthorizationError as e:
raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}')
except RuntimeError as e:
raise CredentialsValidateFailedError(e)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
]
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.MAX_CHUNKS: 1,
ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
},
parameter_rules=[]
)
return entity

View File

@ -0,0 +1,10 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class XinferenceAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -0,0 +1,47 @@
provider: xinference
label:
en_US: Xorbits Inference
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#FAF5FF"
help:
title:
en_US: How to deploy Xinference
zh_Hans: 如何部署 Xinference
url:
en_US: https://github.com/xorbitsai/inference
supported_model_types:
- llm
- text-embedding
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: server_url
label:
zh_Hans: 服务器URL
en_US: Server url
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 http://192.168.1.100:9997
en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997
- variable: model_uid
label:
zh_Hans: 模型UID
en_US: Model uid
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的Model UID
en_US: Enter the model uid

View File

@ -0,0 +1,103 @@
from threading import Lock
from time import time
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout
from requests.sessions import Session
from yarl import URL
class XinferenceModelExtraParameter:
model_format: str
model_handle_type: str
model_ability: list[str]
max_tokens: int = 512
context_length: int = 2048
support_function_call: bool = False
def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
support_function_call: bool, max_tokens: int, context_length: int) -> None:
self.model_format = model_format
self.model_handle_type = model_handle_type
self.model_ability = model_ability
self.support_function_call = support_function_call
self.max_tokens = max_tokens
self.context_length = context_length
cache = {}
cache_lock = Lock()
class XinferenceHelper:
@staticmethod
def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
XinferenceHelper._clean_cache()
with cache_lock:
if model_uid not in cache:
cache[model_uid] = {
'expires': time() + 300,
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid)
}
return cache[model_uid]['value']
@staticmethod
def _clean_cache() -> None:
try:
with cache_lock:
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
for model_uid in expired_keys:
del cache[model_uid]
except RuntimeError as e:
pass
@staticmethod
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
"""
get xinference model extra parameter like model_format and model_handle_type
"""
if not model_uid or not model_uid.strip() or not server_url or not server_url.strip():
raise RuntimeError('model_uid is empty')
url = str(URL(server_url) / 'v1' / 'models' / model_uid)
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session()
session.mount('http://', HTTPAdapter(max_retries=3))
session.mount('https://', HTTPAdapter(max_retries=3))
try:
response = session.get(url, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200:
raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
response_json = response.json()
model_format = response_json.get('model_format', 'ggmlv3')
model_ability = response_json.get('model_ability', [])
if response_json.get('model_type') == 'embedding':
model_handle_type = 'embedding'
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
model_handle_type = 'chatglm'
elif 'generate' in model_ability:
model_handle_type = 'generate'
elif 'chat' in model_ability:
model_handle_type = 'chat'
else:
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
support_function_call = 'tools' in model_ability
max_tokens = response_json.get('max_tokens', 512)
context_length = response_json.get('context_length', 2048)
return XinferenceModelExtraParameter(
model_format=model_format,
model_handle_type=model_handle_type,
model_ability=model_ability,
support_function_call=support_function_call,
max_tokens=max_tokens,
context_length=context_length
)