mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-19 21:37:31 +08:00
Compare commits
332 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79cb3cd01d | ||
| 257219604a | |||
|
|
40cd845ef0 | ||
| a289f57a4e | |||
| 448a4cf021 | |||
| 9ec5d86678 | |||
| 95cf1d304d | |||
| 3752d85fde | |||
| 9a5303e157 | |||
| c1b2f27f55 | |||
| 4e5f9f57c2 | |||
| 896d38e79f | |||
| 1ec8df8cec | |||
|
|
aaaf0f4242 | ||
| 5ef89072c2 | |||
| 52e01359cc | |||
| cb76fb5bdc | |||
| 3fa7ffe951 | |||
| 5bb3d63adc | |||
| f45d3262d5 | |||
| 6c6ffb3a69 | |||
| 8cec43fe4d | |||
| f65b9befc4 | |||
| bb5c7505f1 | |||
| 0ba26d35f1 | |||
| 2b343c9a0d | |||
| e39d7c31ee | |||
| e88ede52fd | |||
| 0023e3f96c | |||
| 784a711554 | |||
| 565fc013fc | |||
| bc71d74aea | |||
| fa2b90d4bf | |||
| 7257abdcc8 | |||
| b7596b27f1 | |||
|
|
c865ee3608 | ||
| 0df01f90eb | |||
|
|
5f36242fa1 | ||
| d09df7de09 | |||
|
|
840d5570ee | ||
| 80caeca815 | |||
| 1493c9a771 | |||
|
|
7b57546db3 | ||
| fbff39fbe8 | |||
|
|
89584dbce7 | ||
| 45958614e4 | |||
| 68c32bdbfb | |||
|
|
e1651497c6 | ||
| 5edb12c96d | |||
|
|
4f231315cc | ||
| aaa03db4c7 | |||
| 749da60a0e | |||
| 41e56356c1 | |||
| 7648746287 | |||
| 00c53bfd71 | |||
|
|
afe9fb9fe3 | ||
| b05ed09bc7 | |||
|
|
e18baeff79 | ||
| 730ea3d12b | |||
| f77d704a57 | |||
| 5aa717aff7 | |||
| 19a8c7e7ae | |||
| e40aa841ac | |||
| a47cb3ade4 | |||
| 8342b33dae | |||
| 568c9c52c2 | |||
| d323026f75 | |||
| 44ffe18f18 | |||
| 1f0a52f3ae | |||
| ad8ab8ee8c | |||
| 2075536a66 | |||
|
|
4224ec2218 | ||
| 81a9579627 | |||
| 92b5c2e692 | |||
| 57802e0f41 | |||
| 9b3f4f8bfd | |||
|
|
06ed361098 | ||
| 70bba19dd5 | |||
|
|
a49935b4d1 | ||
| 63cbd0b9bb | |||
| 48695c964a | |||
| a1cb3732ba | |||
| 9c89b63d67 | |||
| 02bdc94f33 | |||
| 0d27929321 | |||
|
|
190210f0df | ||
| 3bd2ecbafc | |||
| 0a79654728 | |||
| 0bf8cfae0a | |||
| 5e64e37111 | |||
| f02bddb2eb | |||
| 047c01b435 | |||
| c90b5af6c1 | |||
| de100b02a3 | |||
| ac8570dd1f | |||
| 3f00677d86 | |||
| bc5c88796f | |||
| eecd552966 | |||
| 9a1bc2f89d | |||
|
|
c291601262 | ||
| 282266174e | |||
| d5e284be3e | |||
| 9189738e59 | |||
| bb27bfcfee | |||
| 13d5ab4098 | |||
| 98a3d3ebd1 | |||
| 783a01c861 | |||
| e5ac1c4333 | |||
| 00bb8b64fc | |||
| 773b8a9c56 | |||
|
|
d412fd764d | ||
| d43d716bef | |||
|
|
0bf7b47f6f | ||
| 28791d0bca | |||
| d175a9e37d | |||
| 6250c66f59 | |||
|
|
f140aec0c0 | ||
| 764c10b07e | |||
| ba81565371 | |||
| f83cc4d4ea | |||
| f1cc825845 | |||
| ab01c1c2cc | |||
| b3470e0eb9 | |||
| 1680455a1a | |||
| e793bb5f44 | |||
|
|
ab9c7b28cb | ||
| 3bbf711668 | |||
| 142b2c8f88 | |||
| 7c9c22a6fa | |||
| ac5803cc5e | |||
|
|
810ea947dd | ||
| a15600c881 | |||
|
|
8aee2d1127 | ||
| 9879fcab71 | |||
|
|
91cab4333c | ||
| 77b68404da | |||
|
|
ccb5fd3d4e | ||
| 8dd7d844ce | |||
|
|
3f7f5a035e | ||
| e6ed7ef5dd | |||
| e9d1cfe6a2 | |||
|
|
be9a142e12 | ||
| 30dd0de6de | |||
|
|
ea9080f114 | ||
| 3f5b911c1e | |||
|
|
3258d26782 | ||
| a41449630c | |||
|
|
ad33d8b3bf | ||
| 41dc9c583f | |||
|
|
3ddbf883de | ||
| 83ee012951 | |||
|
|
ee5e1eaf9c | ||
| 41ba584c18 | |||
|
|
9da5b000df | ||
| d9d3fa208c | |||
|
|
6a9fcfb129 | ||
| 2ec8315ab9 | |||
|
|
609f251e08 | ||
| d22efc4d6b | |||
| 77bd3cb539 | |||
| a7ebf4d573 | |||
|
|
20a7d68afe | ||
| d1dcdb798d | |||
|
|
103762dd33 | ||
| 2e988b22d0 | |||
|
|
6132b59cb0 | ||
| 9676659608 | |||
|
|
014c95ac35 | ||
| e5ca1e3248 | |||
|
|
c8bf621c87 | ||
| a486b8d7b0 | |||
|
|
30bd906ce6 | ||
| 11934d7e6f | |||
|
|
388594f287 | ||
| b7f4a69d7e | |||
| 1dd7afc4de | |||
|
|
18d8494195 | ||
| 488489b765 | |||
| 2c2a51d5d0 | |||
|
|
9ac52c3523 | ||
| 66655f6f1c | |||
|
|
4b6a0cb933 | ||
| 6c4f6f99c0 | |||
|
|
4dd6f84116 | ||
| 3ee5ad91e0 | |||
| 4f464bad01 | |||
|
|
ff95acb81c | ||
| 75cc74f2fa | |||
| b9b6766229 | |||
| 6a70b31973 | |||
| 9d33cb538f | |||
| 131fe23571 | |||
| b01234170f | |||
| c5e3c884be | |||
| a8bad1341e | |||
| 7dc41ec69b | |||
| c8e8088422 | |||
| c7848cc8ec | |||
| 41e49d9b48 | |||
| 39b52fafac | |||
| 2a50cc697d | |||
| d2b85fb428 | |||
| 1b7f9cebdc | |||
| b98f4ba772 | |||
| 212e2a5f95 | |||
| d2f17fe523 | |||
| 774713b66c | |||
| 5a8ecc6a07 | |||
| 6117af13e9 | |||
| 6771764e6a | |||
|
|
a1977ce654 | ||
|
|
5aa8584129 | ||
|
|
8976b21a7f | ||
| 4f1e842b7c | |||
| 4c503415a6 | |||
| 81e1ddc1c3 | |||
| e880d71960 | |||
| faca3d23bb | |||
| 28152eba1a | |||
| 524f8f9c51 | |||
| 036a01f5ef | |||
| 08a45a9c76 | |||
| fa33a13fa2 | |||
| e32e52df27 | |||
| 05130c8826 | |||
| 186cb0274d | |||
| dc2970da44 | |||
| 5f03c12c0c | |||
| efece059f0 | |||
| efea8d7028 | |||
| 151b97a001 | |||
| 304f2e2394 | |||
|
|
7b403762ea | ||
| f797fe2db9 | |||
| 75a3881275 | |||
| 02e298ccc5 | |||
| 51656653b1 | |||
| 5bf5fa3612 | |||
| fdd602400b | |||
|
|
1a3165945b | ||
| 4b5c7dd61b | |||
| 29cee941f1 | |||
| 2506690ad1 | |||
| eed89dced3 | |||
|
|
0d7e8456c0 | ||
| 41eb2c3944 | |||
| 835db2eb76 | |||
| 4810485a01 | |||
| afb90160a3 | |||
| c55b02deb7 | |||
| e1290aab01 | |||
| fdfa65bd53 | |||
| f9885e635c | |||
| e09287af4c | |||
| c1662325b5 | |||
|
|
2e9d31aa35 | ||
| fc01293f2c | |||
| 9253aee204 | |||
| 94f716d8d2 | |||
| 93bd4754b2 | |||
| ff9095c50a | |||
| b187684f5e | |||
| 943b884fb3 | |||
| be3e486011 | |||
| d7f221b847 | |||
| ebe0a014e8 | |||
| a194582bc7 | |||
| 4a1224eb48 | |||
| 8419361e6f | |||
| cdff0df5f5 | |||
| 1942039f3b | |||
| dfda34e95b | |||
|
|
3692c22241 | ||
| f92d2ebf9a | |||
| f8e177d525 | |||
|
|
b217ae79c9 | ||
| 1df99f3767 | |||
|
|
366f6850a9 | ||
|
|
465c24546d | ||
| 6471626497 | |||
| afbb34079a | |||
| 86a2c1a055 | |||
| b6ca6ac677 | |||
| fcb2f9e4ea | |||
| 486a79a2cc | |||
| 8b68c60249 | |||
|
|
e8a8fbf826 | ||
| a8acbabe83 | |||
| 5b611653f9 | |||
| 24e21b9093 | |||
|
|
956c041ae1 | ||
| 0b2c03b6e9 | |||
|
|
2948c81208 | ||
| f1ccaffafe | |||
| 1e242de51d | |||
| 8a8337cc5c | |||
|
|
29d82e7cef | ||
| 223c6a3741 | |||
| 84ebb64394 | |||
| 947f722de9 | |||
|
|
5f8cfd95bd | ||
| 475811a2e0 | |||
| 501fbdb5cd | |||
| af720e45cc | |||
| fdcb437c7a | |||
| c76b7d0284 | |||
| 8d9ec8dc05 | |||
| 45c0e2e6ca | |||
| 418bea217f | |||
| 05fb135ece | |||
| 42d489d4f6 | |||
| bcaed27891 | |||
| 137ff307f4 | |||
|
|
eebdef847c | ||
| 6ef767d58a | |||
| c5f31da81a | |||
|
|
e6303e35df | ||
| cd0ea4b2eb | |||
|
|
1e59f65d29 | ||
| b5cb5efc78 | |||
|
|
93b53a9cd3 | ||
| 4c9c6f3dfb | |||
|
|
7f401ac721 | ||
| e5cb3b73c7 | |||
|
|
19e4a151d7 | ||
| d681f86fea | |||
| 49082a238c | |||
| 320bae36c7 | |||
|
|
dab06086a0 | ||
| 3a4d59c0e3 | |||
| 7399b56fa1 | |||
| 47654c821a |
128
.github/workflows/build.yml
vendored
Normal file
128
.github/workflows/build.yml
vendored
Normal file
@ -0,0 +1,128 @@
|
||||
name: Unittest and Build + CI
|
||||
|
||||
# Triggers: Equivalent to ADO trigger block
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
|
||||
# Concurrency control: Ensures only one run per branch at a time, Equivalent to batch: true
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run_backend_unittests:
|
||||
name: Run Backend unit tests
|
||||
permissions:
|
||||
checks: write
|
||||
secrets: inherit # Inherit secrets from the parent workflow
|
||||
# Call the reusable workflow for unit tests
|
||||
uses: ./.github/workflows/template_unit_pytest.yml
|
||||
# Pass parameters as inputs to the reusable workflow
|
||||
with:
|
||||
projectName: Backend # Value defined in original variables
|
||||
workingDir: app/backend
|
||||
testsFolderName: tests
|
||||
# secrets: inherit # Inherit secrets from the parent workflow
|
||||
|
||||
# This job defines the matrix and calls the reusable workflow for each image build
|
||||
build:
|
||||
needs: run_backend_unittests
|
||||
name: Build ${{ matrix.image_config.IMAGE_NAME }}
|
||||
# Define necessary permissions if needed (e.g., for GitHub Packages)
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write # If pushing to GitHub Packages registry
|
||||
|
||||
# Use secrets defined in the repository/organization settings
|
||||
# 'inherit' makes all secrets available to the called workflow
|
||||
secrets: inherit
|
||||
|
||||
# Define the matrix strategy based on the 'images' object from the original ADO build.yml
|
||||
strategy:
|
||||
fail-fast: false # Don't cancel other matrix jobs if one fails
|
||||
matrix:
|
||||
# We wrap the image configuration in a single 'image_config' key
|
||||
# to pass it more easily if needed, but primarily access sub-keys directly.
|
||||
image_config:
|
||||
- IMAGE_NAME: backend-aimingmedai
|
||||
BUILD_CONTEXT: ./app/backend
|
||||
DOCKERFILE: ./app/backend/Dockerfile
|
||||
- IMAGE_NAME: frontend-aimingmedai
|
||||
BUILD_CONTEXT: ./app/frontend
|
||||
DOCKERFILE: ./app/frontend/Dockerfile.test
|
||||
- IMAGE_NAME: tests-aimingmedai
|
||||
BUILD_CONTEXT: ./app/tests
|
||||
DOCKERFILE: ./app/tests/Dockerfile
|
||||
|
||||
# Call the reusable workflow
|
||||
uses: ./.github/workflows/template_build.yml # Path to the reusable workflow file
|
||||
# Pass inputs required by the reusable workflow
|
||||
with:
|
||||
# Pass values from the matrix context and global env
|
||||
project_name: aimingmed-ai
|
||||
image_repo: "ghcr.io/$(echo $GITHUB_REPOSITORY | tr '[A-Z]' '[a-z]')"
|
||||
image_name: ${{ matrix.image_config.IMAGE_NAME }}
|
||||
build_context: ${{ matrix.image_config.BUILD_CONTEXT }}
|
||||
dockerfile: ${{ matrix.image_config.DOCKERFILE }}
|
||||
build_id: ${{ github.run_id }}
|
||||
commit_sha: ${{ github.sha }}
|
||||
|
||||
# TEST Stage equivalent
|
||||
test:
|
||||
name: Run Integration Tests
|
||||
needs: build # Ensure this job runs after the build job
|
||||
# Define necessary permissions if needed (e.g., for GitHub Packages)
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write # If pushing to GitHub Packages registry
|
||||
checks: write # If you want to update checks
|
||||
# Call the reusable workflow for testing
|
||||
uses: ./.github/workflows/template_test.yml # Path to the reusable workflow file
|
||||
with:
|
||||
projectName: aimingmed-ai
|
||||
image_repo: ghcr.io/$(echo $GITHUB_REPOSITORY | tr '[A-Z]' '[a-z]')
|
||||
testContainerName: tests-aimingmedai
|
||||
# Todos: This part is not working the testEnvs is not being taken up corrrectly by Run Tests
|
||||
# Pass test environment variables as JSON string
|
||||
testEnvs: >
|
||||
'[
|
||||
"FRONTEND_URL=http://frontend:80",
|
||||
"BACKEND_URL=http://backend:80",
|
||||
"ENVIRONMENT=dev",
|
||||
"TESTING=1",
|
||||
]'
|
||||
# Todos: This part is not working the testEnvs is not being taken up corrrectly by Run Tests
|
||||
# Pass test directories as JSON string
|
||||
tests: >
|
||||
'[
|
||||
"tests/integration/backend",
|
||||
]'
|
||||
# Pass image definitions for compose setup as JSON string
|
||||
# Sensitive values should be passed via secrets and referenced within the template
|
||||
images: >
|
||||
[
|
||||
{
|
||||
"name": "backend-aimingmedai",
|
||||
"ports" : ["8004:80"],
|
||||
"env": {
|
||||
"ENVIRONMENT": "dev",
|
||||
"TESTING": "1",
|
||||
"DEEPSEEK_API_KEY": "sk-XXXXXXXXXX",
|
||||
"TAVILY_API_KEY": "tvly-dev-wXXXXXX"
|
||||
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "frontend-aimingmedai",
|
||||
"ports" : ["3004:80"],
|
||||
"depends_on": ["backend-aimingmedai"],
|
||||
"env": {
|
||||
"ENVIRONMENT": "dev",
|
||||
"TESTING": "1",
|
||||
"LOG_LEVEL": "DEBUG"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
36
.github/workflows/obsolete/app-testing.yml
vendored
Normal file
36
.github/workflows/obsolete/app-testing.yml
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
name: App testing
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "develop" ]
|
||||
pull_request:
|
||||
branches: [ "develop" ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
streamlit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd app/streamlit
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
- uses: streamlit/streamlit-app-action@v0.0.3
|
||||
with:
|
||||
app-path: app/streamlit/Chatbot.py
|
||||
ruff: true
|
||||
skip-smoke: true
|
||||
pytest-args: -v --junit-xml=test-results.xml
|
||||
- if: always()
|
||||
uses: pmeier/pytest-results-action@v0.6.0
|
||||
with:
|
||||
path: test-results.xml
|
||||
summary: true
|
||||
display-options: fEX
|
||||
142
.github/workflows/obsolete/build.yml
vendored
Normal file
142
.github/workflows/obsolete/build.yml
vendored
Normal file
@ -0,0 +1,142 @@
|
||||
name: CI - build
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- main
|
||||
|
||||
env:
|
||||
IMAGE: ghcr.io/$(echo $GITHUB_REPOSITORY | tr '[A-Z]' '[a-z]')/aimingmed-ai-backend
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Check disk space
|
||||
run: df -h
|
||||
- name: Cleanup Docker resources
|
||||
if: always()
|
||||
run: |
|
||||
docker system prune -a -f --volumes
|
||||
- name: Remove unnecessary files
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf "/usr/local/share/boost"
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
- name: Check disk space
|
||||
run: df -h
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
ref: develop
|
||||
- name: Log in to GitHub Packages
|
||||
run: echo ${GITHUB_TOKEN} | docker login -u ${GITHUB_ACTOR} --password-stdin ghcr.io
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Pull image
|
||||
run: |
|
||||
docker pull ${{ env.IMAGE }}:latest || true
|
||||
- name: Check disk space
|
||||
if: always()
|
||||
run: df -h
|
||||
- name: Build image
|
||||
run: |
|
||||
docker build \
|
||||
--cache-from ${{ env.IMAGE }}:latest \
|
||||
--tag ${{ env.IMAGE }}:latest \
|
||||
--file ./app/backend/Dockerfile.prod \
|
||||
"./app/backend"
|
||||
- name: Push image
|
||||
run: |
|
||||
docker push ${{ env.IMAGE }}:latest
|
||||
- name: Check disk space
|
||||
if: always()
|
||||
run: df -h
|
||||
- name: Cleanup Docker resources
|
||||
if: always()
|
||||
run: docker system prune -a -f --volumes
|
||||
- name: Check disk space
|
||||
if: always()
|
||||
run: df -h
|
||||
|
||||
test:
|
||||
name: Test Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Check disk space
|
||||
run: df -h
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
ref: develop
|
||||
- name: Log in to GitHub Packages
|
||||
run: echo ${GITHUB_TOKEN} | docker login -u ${GITHUB_ACTOR} --password-stdin ghcr.io
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Cleanup Docker resources
|
||||
if: always()
|
||||
run: docker system prune -a -f --volumes
|
||||
- name: Remove unnecessary files
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf "/usr/local/share/boost"
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
- name: Pull image
|
||||
run: |
|
||||
docker pull ${{ env.IMAGE }}:latest || true
|
||||
- name: Check disk space
|
||||
if: always()
|
||||
run: df -h
|
||||
- name: Build image
|
||||
run: |
|
||||
docker build \
|
||||
--cache-from ${{ env.IMAGE }}:latest \
|
||||
--tag ${{ env.IMAGE }}:latest \
|
||||
--file ./app/backend/Dockerfile.prod \
|
||||
"./app/backend"
|
||||
- name: Check disk space
|
||||
if: always()
|
||||
run: df -h
|
||||
- name: Validate Docker image
|
||||
run: docker inspect ${{ env.IMAGE }}:latest
|
||||
- name: Run container
|
||||
run: |
|
||||
docker run \
|
||||
-d \
|
||||
-e DEEPSEEK_API_KEY=${{ secrets.DEEPSEEK_API_KEY }} \
|
||||
-e TAVILY_API_KEY=${{ secrets.TAVILY_API_KEY }} \
|
||||
-e ENVIRONMENT=dev \
|
||||
-e TESTING=0 \
|
||||
-e PORT=8765 \
|
||||
-e LOG_LEVEL=DEBUG \
|
||||
--name backend-backend \
|
||||
-p 8004:8765 \
|
||||
${{ env.IMAGE }}:latest
|
||||
- name: Monitor memory usage
|
||||
run: free -h
|
||||
- name: Get container logs
|
||||
if: failure()
|
||||
run: docker logs backend-backend
|
||||
- name: Pytest
|
||||
run: docker exec backend-backend pipenv run python -m pytest .
|
||||
# - name: Flake8
|
||||
# run: docker exec backend-backend pipenv run python -m flake8 .
|
||||
# - name: Black
|
||||
# run: docker exec backend-backend pipenv run python -m black . --check
|
||||
- name: isort
|
||||
if: always()
|
||||
run: docker exec backend-backend pipenv run python -m isort . --check-only
|
||||
- name: Cleanup container at end of job
|
||||
if: always()
|
||||
run: docker stop backend-backend || true && docker rm backend-backend || true
|
||||
102
.github/workflows/template_build.yml
vendored
Normal file
102
.github/workflows/template_build.yml
vendored
Normal file
@ -0,0 +1,102 @@
|
||||
name: Reusable Docker Build Template
|
||||
|
||||
# Define inputs expected from the calling workflow
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
project_name:
|
||||
required: true
|
||||
type: string
|
||||
image_repo:
|
||||
required: true
|
||||
type: string
|
||||
image_name:
|
||||
required: true
|
||||
type: string
|
||||
build_context:
|
||||
required: true
|
||||
type: string
|
||||
dockerfile:
|
||||
required: true
|
||||
type: string
|
||||
build_id:
|
||||
required: true
|
||||
type: string # Pass run_id as string
|
||||
commit_sha:
|
||||
required: true
|
||||
type: string
|
||||
|
||||
|
||||
jobs:
|
||||
build-single-image:
|
||||
# This job executes the build steps for the specific image configuration passed via inputs
|
||||
name: Build ${{ inputs.image_name }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 120 # From original ADO template
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
# Checks out the repository code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
ref: develop # Use the branch specified in the calling workflow
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
# Recommended for improved build features and caching
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to GitHub Packages
|
||||
run: echo ${GITHUB_TOKEN} | docker login -u ${GITHUB_ACTOR} --password-stdin ghcr.io
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Nuke Docker Cache
|
||||
# Equivalent to CmdLine@2 Nuke Cache
|
||||
run: |
|
||||
echo "Pruning Docker system..."
|
||||
docker system prune -a -f --volumes
|
||||
|
||||
- name: Remove unnecessary files
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf "/usr/local/share/boost"
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
|
||||
- name: Check disk space
|
||||
run: df -h
|
||||
|
||||
- name: Define Image Tags
|
||||
# Define tags consistently using inputs
|
||||
id: tags
|
||||
run: |
|
||||
echo "image_repo_path=${{ inputs.image_repo }}/${{ inputs.image_name }}" >> $GITHUB_OUTPUT
|
||||
echo "tag_build_id=${{ inputs.build_id }}" >> $GITHUB_OUTPUT
|
||||
echo "tag_commit_sha=${{ inputs.commit_sha }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Pull Latest Image for Cache
|
||||
# Pulls the latest tag if it exists
|
||||
continue-on-error: true # Mimics '|| true'
|
||||
run: |
|
||||
echo "Attempting to pull latest image for cache: ${{ steps.tags.outputs.image_repo_path }}:latest"
|
||||
docker pull ${{ steps.tags.outputs.image_repo_path }}:latest || true
|
||||
|
||||
- name: Build Final Image
|
||||
run: |
|
||||
echo "Building final image without intermediate cache..."
|
||||
docker build \
|
||||
-f ${{ inputs.dockerfile }} \
|
||||
--pull \
|
||||
--cache-from type=registry,ref=${{ steps.tags.outputs.image_repo_path }}:latest \
|
||||
-t ${{ steps.tags.outputs.image_repo_path }}:${{ steps.tags.outputs.tag_build_id }} \
|
||||
-t ${{ steps.tags.outputs.image_repo_path }}:${{ steps.tags.outputs.tag_commit_sha }} \
|
||||
-t ${{ steps.tags.outputs.image_repo_path }}:latest \
|
||||
${{ inputs.build_context }}
|
||||
|
||||
- name: Push Final Image Tags
|
||||
# Pushes the final tags (build id, commit sha, latest)
|
||||
run: |
|
||||
echo "Pushing final image tags..."
|
||||
docker push ${{ steps.tags.outputs.image_repo_path }}:${{ steps.tags.outputs.tag_build_id }}
|
||||
docker push ${{ steps.tags.outputs.image_repo_path }}:${{ steps.tags.outputs.tag_commit_sha }}
|
||||
docker push ${{ steps.tags.outputs.image_repo_path }}:latest
|
||||
245
.github/workflows/template_test.yml
vendored
Normal file
245
.github/workflows/template_test.yml
vendored
Normal file
@ -0,0 +1,245 @@
|
||||
name: Reusable Integration Test Template
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
projectName:
|
||||
required: true
|
||||
type: string
|
||||
image_repo:
|
||||
required: true
|
||||
type: string
|
||||
images: # JSON string defining services for compose
|
||||
required: true
|
||||
type: string
|
||||
tests: # JSON string array of test directories/commands
|
||||
required: true
|
||||
type: string
|
||||
testEnvs: # JSON string array of env vars for the test runner container
|
||||
required: false
|
||||
type: string
|
||||
default: '[]'
|
||||
testComposeFilePath: # Path where the generated compose file will be saved
|
||||
required: false
|
||||
type: string
|
||||
default: ./test_compose.yml # Use .yml extension for docker compose v2
|
||||
testContainerName:
|
||||
required: false
|
||||
type: string
|
||||
default: tests # Name of the image containing the tests
|
||||
testResultsPath: # Path inside the test container where results are stored
|
||||
required: false
|
||||
type: string
|
||||
default: /usr/src/app/results
|
||||
testResultsFilename:
|
||||
required: false
|
||||
type: string
|
||||
default: results.xml
|
||||
|
||||
|
||||
jobs:
|
||||
compose_and_test:
|
||||
name: Compose Services and Run Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
# Env vars needed for compose file generation/execution
|
||||
IMAGE_REPO: ${{ inputs.image_repo }}
|
||||
PROJECT_NAME: ${{ inputs.projectName }}
|
||||
TAG: ${{ github.run_id }} # Use run_id as the build tag
|
||||
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: develop
|
||||
|
||||
- name: Log in to GitHub Packages
|
||||
run: echo ${GITHUB_TOKEN} | docker login -u ${GITHUB_ACTOR} --password-stdin ghcr.io
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Nuke Docker Cache
|
||||
# Equivalent to CmdLine@2 Nuke Cache
|
||||
run: |
|
||||
echo "Pruning Docker system..."
|
||||
docker system prune -a -f --volumes
|
||||
|
||||
- name: Remove unnecessary files
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf "/usr/local/share/boost"
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
|
||||
- name: Define Image Repo and other tags
|
||||
# Define tags consistently using inputs
|
||||
id: tags
|
||||
run: |
|
||||
echo "image_repo_path=${{ inputs.image_repo }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Create Docker-Compose File from Inputs
|
||||
id: create_compose
|
||||
shell: pwsh
|
||||
run: |
|
||||
# Load Inputs (parse JSON strings)
|
||||
$imagesJson = '${{ inputs.images }}'
|
||||
|
||||
Write-Host "Substituted Images JSON: $imagesJson" # Debugging - remove sensitive info if public
|
||||
$images = $imagesJson | ConvertFrom-Json
|
||||
|
||||
$testComposePath = "${{ inputs.testComposeFilePath }}"
|
||||
|
||||
# create compose structure (using YAML structure for Docker Compose v2+)
|
||||
$compose = @{ services = @{}; networks = @{} }
|
||||
$compose.networks.test = @{ external = $false; name = "test-network-${{ github.run_id }}" } # Use unique network name per run
|
||||
|
||||
# Generate services Section Based on Images inputs
|
||||
foreach ($img in $images) {
|
||||
$serviceName = $img.name
|
||||
$svc = @{}
|
||||
$svc.container_name = $serviceName
|
||||
$svc.image = "${{ steps.tags.outputs.image_repo_path }}/$($serviceName):${{ env.TAG }}" # Use run_id tag
|
||||
|
||||
if ($img.depends_on) {
|
||||
$svc.depends_on = $img.depends_on
|
||||
}
|
||||
if ($img.env) {
|
||||
$svc.environment = $img.env
|
||||
} else {
|
||||
$svc.environment = @{}
|
||||
}
|
||||
$svc.networks = @("test") # Assign service to the custom network
|
||||
if ($img.ports) {
|
||||
$svc.ports = $img.ports
|
||||
}
|
||||
|
||||
$compose.services.$serviceName = $svc
|
||||
}
|
||||
|
||||
# Convert PS object to YAML and write file
|
||||
# Installing powershell-yaml module might be needed on some runners
|
||||
# Install-Module -Name powershell-yaml -Force -Scope CurrentUser # Uncomment if needed
|
||||
# Import-Module powershell-yaml # Uncomment if needed
|
||||
# $compose | ConvertTo-Yaml | Out-File -Encoding utf8 $testComposePath
|
||||
|
||||
# Alternative: Convert to JSON then use a tool to convert JSON to YAML, or just use JSON with compose
|
||||
# Using JSON with docker compose -f is often simpler
|
||||
$compose | ConvertTo-Json -Depth 10 | Out-File -Encoding utf8 $testComposePath.replace('.yml','.json')
|
||||
echo "COMPOSE_FILE_PATH=$($testComposePath.replace('.yml','.json'))" >> $env:GITHUB_OUTPUT
|
||||
|
||||
# Removed 'docker network create test' - using isolated compose network now
|
||||
|
||||
- name: Clean Docker Services (if any previous)
|
||||
run: |
|
||||
docker compose -f ${{ steps.create_compose.outputs.COMPOSE_FILE_PATH }} down -v --remove-orphans || true
|
||||
continue-on-error: true
|
||||
|
||||
- name: Start Docker Compose Services
|
||||
run: |
|
||||
echo "Using compose file: ${{ steps.create_compose.outputs.COMPOSE_FILE_PATH }}"
|
||||
cat "${{ steps.create_compose.outputs.COMPOSE_FILE_PATH }}" # Print generated compose file (check secrets aren't exposed if public)
|
||||
docker compose -f "${{ steps.create_compose.outputs.COMPOSE_FILE_PATH }}" up -d
|
||||
|
||||
- name: Print Service Logs on Failure or Success
|
||||
if: always() # Always run this step
|
||||
run: |
|
||||
echo "Printing final logs from Docker Compose services..."
|
||||
docker compose -f "${{ steps.create_compose.outputs.COMPOSE_FILE_PATH }}" logs
|
||||
|
||||
- name: Wait for Services
|
||||
run: |
|
||||
echo "Waiting 60 seconds for services to initialize..."
|
||||
sleep 60
|
||||
echo "Compose logs after wait:"
|
||||
docker compose -f "${{ steps.create_compose.outputs.COMPOSE_FILE_PATH }}" logs
|
||||
|
||||
- name: Check Docker Services Health
|
||||
run: |
|
||||
echo "Checking health of Docker services..."
|
||||
# Check if all services are healthy
|
||||
docker compose -f "${{ steps.create_compose.outputs.COMPOSE_FILE_PATH }}" ps
|
||||
|
||||
- name: Debug Network Connections
|
||||
if: always() # Run even if previous steps failed
|
||||
run: |
|
||||
echo "--- Inspecting network: test-network-${{ github.run_id }} ---"
|
||||
docker network inspect test-network-${{ github.run_id }}
|
||||
echo "--- Listing running containers (docker ps) ---"
|
||||
docker ps -a --format "table {{.ID}}\t{{.Names}}\t{{.Image}}\t{{.Status}}\t{{.Ports}}\t{{.Networks}}"
|
||||
echo "--- Backend Service Logs ---"
|
||||
docker logs backend-aimingmedai || echo "Could not get logs for backend-aimingmedai" # Replace with actual service name
|
||||
|
||||
- name: Run Tests
|
||||
shell: bash
|
||||
run: |
|
||||
TEST_DIRS='["tests/integration/backend"]'
|
||||
TEST_ENVS_JSON='["ENVIRONMENT=dev","TESTING=1", "DEEPSEEK_API_KEY=sk-XXXXXXXXXX","TAVILY_API_KEY=tvly-dev-wXXXXXX"]'
|
||||
RESULTS_PATH="${{ inputs.testResultsPath }}"
|
||||
STAGING_DIR="${{ runner.temp }}/test-results" # Use runner temp dir for results
|
||||
mkdir -p "$STAGING_DIR"
|
||||
|
||||
# Prepare environment variables for docker run
|
||||
ENV_ARGS=""
|
||||
if [[ "$TEST_ENVS_JSON" != "[]" ]]; then
|
||||
# Convert JSON array string to individual env vars
|
||||
while IFS= read -r line; do
|
||||
ENV_ARGS+=" -e \"$line\""
|
||||
done <<< $(echo "$TEST_ENVS_JSON" | jq -r '.[]')
|
||||
else
|
||||
# Add a dummy env var if none are provided, as required by original script logic
|
||||
ENV_ARGS+=" -e DUMMY_ENV_TEST_RUN_ID=${{ github.run_id }}"
|
||||
fi
|
||||
echo "Env args: $ENV_ARGS" # Debugging
|
||||
|
||||
# Get the dynamically generated network name
|
||||
COMPOSE_NETWORK_NAME=$(docker network ls --filter name=test-network-${{ github.run_id }} --format "{{.Name}}")
|
||||
echo "Using Network: $COMPOSE_NETWORK_NAME"
|
||||
|
||||
# Loop through test directories and execute tests
|
||||
echo "$TEST_DIRS" | jq -r '.[]' | while read test_dir; do
|
||||
test_dir=$(echo $test_dir | sed 's/"//g') # Remove quotes
|
||||
echo "Running test: $test_dir"
|
||||
docker run \
|
||||
--network "$COMPOSE_NETWORK_NAME" \
|
||||
$ENV_ARGS \
|
||||
-v "$STAGING_DIR:$RESULTS_PATH" \
|
||||
--rm \
|
||||
"${{ steps.tags.outputs.image_repo_path }}/${{ inputs.testContainerName }}:${{ github.run_id }}" \
|
||||
"$test_dir"
|
||||
# Add error handling if needed (e.g., exit script if a test run fails)
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Test failed: $test_dir"
|
||||
# exit 1 # Uncomment to stop on first failure
|
||||
fi
|
||||
done
|
||||
|
||||
# Copy results file to expected location for upload artifact (adjust filename if needed)
|
||||
# Assuming all test runs output to the *same* results file, overwriting previous ones.
|
||||
# If they output to different files, adjust this copy/rename logic.
|
||||
if [[ -f "$STAGING_DIR/${{ inputs.testResultsFilename }}" ]]; then
|
||||
cp "$STAGING_DIR/${{ inputs.testResultsFilename }}" "${{ runner.temp }}/${{ inputs.testResultsFilename }}"
|
||||
else
|
||||
echo "Warning: Test results file ${{ inputs.testResultsFilename }} not found in $STAGING_DIR"
|
||||
fi
|
||||
|
||||
- name: Upload Test Results Artifact
|
||||
if: always() # Run even if tests fail
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-results-${{ github.run_id }}
|
||||
path: ${{ runner.temp }}/${{ inputs.testResultsFilename }} # Path to the results file on the runner
|
||||
retention-days: 7
|
||||
|
||||
# Optional: Publish Test Results for UI display
|
||||
- name: Publish Test Results
|
||||
if: success() || failure() # always run even if the previous step fails
|
||||
uses: mikepenz/action-junit-report@v5
|
||||
with:
|
||||
report_paths: ${{ runner.temp }}/${{ inputs.testResultsFilename }}
|
||||
include_passed: true
|
||||
|
||||
- name: Docker Compose Down
|
||||
if: always() # Always run cleanup
|
||||
run: |
|
||||
echo "Bringing down Docker Compose services..."
|
||||
docker compose -f "${{ steps.create_compose.outputs.COMPOSE_FILE_PATH }}" down -v --remove-orphans
|
||||
95
.github/workflows/template_unit_pytest.yml
vendored
Normal file
95
.github/workflows/template_unit_pytest.yml
vendored
Normal file
@ -0,0 +1,95 @@
|
||||
name: Reusable Unit Test with Pytest Template
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
projectName:
|
||||
description: 'Name of the project'
|
||||
required: true
|
||||
type: string
|
||||
workingDir:
|
||||
description: 'Working directory for the component'
|
||||
required: true
|
||||
type: string
|
||||
testsFolderName:
|
||||
description: 'Tests folder name'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
build_and_test:
|
||||
name: Build and Test ${{ inputs.projectName }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 120
|
||||
|
||||
# Define environment variables based on inputs, similar to Azure variables
|
||||
env:
|
||||
SRC_PATH: ${{ github.workspace }}/${{ inputs.workingDir }}
|
||||
TESTS_PATH: ${{ github.workspace }}/${{ inputs.workingDir }}/${{ inputs.testsFolderName }}
|
||||
TESTS_RESULTS_PATH: ${{ github.workspace }}/${{ inputs.workingDir }}/results.xml
|
||||
TESTS_COVERAGE_REPORT_PATH: ${{ github.workspace }}/${{ inputs.workingDir }}/coverage.xml
|
||||
# Use the working directory input for commands that need it
|
||||
WORKING_DIR: ${{ inputs.workingDir }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v5 # Use latest stable version
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pipenv
|
||||
|
||||
- name: Install environment including dev dependencies
|
||||
working-directory: ${{ env.WORKING_DIR }}
|
||||
run: |
|
||||
echo "Current directory:"
|
||||
pwd
|
||||
echo "Listing files:"
|
||||
ls -al
|
||||
echo "Pipfile content:"
|
||||
cat Pipfile
|
||||
pipenv install --dev --skip-lock
|
||||
echo "Listing installed packages:"
|
||||
pipenv graph
|
||||
|
||||
- name: Run tests with pytest
|
||||
working-directory: ${{ env.WORKING_DIR }}
|
||||
run: |
|
||||
pipenv run pytest --version
|
||||
# Use the environment variables defined above for paths
|
||||
pipenv run pytest -v -s -o log_cli=true --junitxml=results.xml --cov=${{ env.SRC_PATH }} --cov-report=xml:${{ env.TESTS_COVERAGE_REPORT_PATH }} ${{ env.TESTS_PATH }}
|
||||
echo "Listing results in working directory:"
|
||||
ls -al ${{ github.workspace }}/${{ env.WORKING_DIR }}
|
||||
|
||||
|
||||
# Use a popular action for publishing test results for better GitHub integration
|
||||
- name: Publish Test Report
|
||||
uses: dorny/test-reporter@v1
|
||||
if: success() || failure() # always run even if tests fail
|
||||
with:
|
||||
name: ${{ inputs.projectName }} Test Results
|
||||
path: ${{ env.TESTS_RESULTS_PATH }}
|
||||
reporter: java-junit # Specify JUnit format
|
||||
|
||||
# Upload coverage report as an artifact
|
||||
- name: Upload coverage report artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
if: success() || failure() # always run
|
||||
with:
|
||||
name: ${{ inputs.projectName }}-coverage-report
|
||||
path: ${{ env.TESTS_COVERAGE_REPORT_PATH }}
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ${{ env.TESTS_COVERAGE_REPORT_PATH }}
|
||||
fail_ci_if_error: true
|
||||
|
||||
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@ -202,7 +202,13 @@ data/*
|
||||
**/.config.py
|
||||
**/chroma_db/*
|
||||
**/*.pdf
|
||||
**/documents/**/*.json
|
||||
**/documents/**/*.xlsx
|
||||
**/.env
|
||||
**/llm-template2/*
|
||||
**/llmops/outputs/*
|
||||
**/*.zip
|
||||
**/*.zip
|
||||
**/llm-examples/*
|
||||
**/*.ipynb_checkpoints
|
||||
**/*.ipynb
|
||||
**/transformer_model/*
|
||||
7
.vscode/settings.json
vendored
Normal file
7
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"python.testing.pytestArgs": [
|
||||
"app"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
||||
11
Pipfile
Normal file
11
Pipfile
Normal file
@ -0,0 +1,11 @@
|
||||
[[source]]
|
||||
url = "https://pypi.org/simple"
|
||||
verify_ssl = true
|
||||
name = "pypi"
|
||||
|
||||
[packages]
|
||||
|
||||
[dev-packages]
|
||||
|
||||
[requires]
|
||||
python_version = "3.8"
|
||||
@ -1,4 +1,7 @@
|
||||
## Important note:
|
||||
[](https://github.com/aimingmed/aimingmed-ai/actions/workflows/build.yml)
|
||||
|
||||
## Important note:
|
||||
|
||||
No data or output should be uploaded to this repo. Please make use of .gitignore template in the root directory if you have folder/directory containing dataset. The content in folder/directory currently being ignored from git push are data/ and output/, recursively.
|
||||
|
||||
## Configure Hooks
|
||||
@ -12,4 +15,3 @@ To set up the hooks for only this Repo run `git config core.hooksPath ./.hooks/`
|
||||
## Please enter your general Project description here
|
||||
|
||||
## If you don't need all folder feel free to delete them
|
||||
|
||||
|
||||
11
app/Pipfile
Normal file
11
app/Pipfile
Normal file
@ -0,0 +1,11 @@
|
||||
[[source]]
|
||||
url = "https://pypi.org/simple"
|
||||
verify_ssl = true
|
||||
name = "pypi"
|
||||
|
||||
[packages]
|
||||
|
||||
[dev-packages]
|
||||
|
||||
[requires]
|
||||
python_version = "3.11"
|
||||
20
app/Pipfile.lock
generated
Normal file
20
app/Pipfile.lock
generated
Normal file
@ -0,0 +1,20 @@
|
||||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "ed6d5d614626ae28e274e453164affb26694755170ccab3aa5866f093d51d3e4"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
"python_version": "3.11"
|
||||
},
|
||||
"sources": [
|
||||
{
|
||||
"name": "pypi",
|
||||
"url": "https://pypi.org/simple",
|
||||
"verify_ssl": true
|
||||
}
|
||||
]
|
||||
},
|
||||
"default": {},
|
||||
"develop": {}
|
||||
}
|
||||
50
app/README.md
Normal file
50
app/README.md
Normal file
@ -0,0 +1,50 @@
|
||||
# How to work with this app repository
|
||||
|
||||
Build the images:
|
||||
|
||||
```bash
|
||||
docker compose up --build -d
|
||||
```
|
||||
|
||||
I
|
||||
|
||||
# Run the tests for backend:
|
||||
|
||||
```bash
|
||||
docker compose exec backend pipenv run python -m pytest --disable-warnings --cov="."
|
||||
```
|
||||
|
||||
Lint:
|
||||
|
||||
```bash
|
||||
docker compose exec backend pipenv run flake8 .
|
||||
```
|
||||
|
||||
Run Black and isort with check options:
|
||||
|
||||
```bash
|
||||
docker compose exec backend pipenv run black . --check
|
||||
docker compose exec backend pipenv run isort . --check-only
|
||||
```
|
||||
|
||||
Make code changes with Black and isort:
|
||||
|
||||
```bash
|
||||
docker compose exec backend pipenv run black .
|
||||
docker compose exec backend pipenv run isort .
|
||||
```
|
||||
|
||||
# Postgres
|
||||
|
||||
Want to access the database via psql?
|
||||
|
||||
```bash
|
||||
docker compose exec -it database psql -U postgres
|
||||
```
|
||||
|
||||
Then, you can connect to the database and run SQL queries. For example:
|
||||
|
||||
```sql
|
||||
# \c web_dev
|
||||
# \dt
|
||||
```
|
||||
56
app/backend/Dockerfile
Normal file
56
app/backend/Dockerfile
Normal file
@ -0,0 +1,56 @@
|
||||
# pull official base image
|
||||
FROM python:3.11-slim-bookworm
|
||||
|
||||
# create directory for the app user
|
||||
RUN mkdir -p /home/app
|
||||
|
||||
# create the app user
|
||||
RUN addgroup --system app && adduser --system --group app
|
||||
|
||||
# create the appropriate directories
|
||||
ENV HOME=/home/app
|
||||
ENV APP_HOME=/home/app/backend
|
||||
RUN mkdir $APP_HOME
|
||||
WORKDIR $APP_HOME
|
||||
|
||||
# set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV ENVIRONMENT=dev
|
||||
ENV TESTING=1
|
||||
ENV CUDA_VISIBLE_DEVICES=""
|
||||
|
||||
COPY Pipfile $APP_HOME/
|
||||
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pipenv && rm -rf ~/.cache/pip
|
||||
RUN pipenv install --deploy --dev --no-cache-dir
|
||||
RUN pipenv run pip install torch --force-reinstall --no-cache-dir
|
||||
|
||||
# remove all cached files not needed to save space
|
||||
RUN pip cache purge
|
||||
RUN rm -rf /root/.cache
|
||||
|
||||
# add app
|
||||
COPY . $APP_HOME
|
||||
|
||||
# Create cache directory and set permissions
|
||||
RUN mkdir -p /home/app/.cache/huggingface
|
||||
RUN chown -R app:app /home/app/.cache/huggingface
|
||||
|
||||
RUN chown -R app:app $APP_HOME
|
||||
|
||||
# change to the app user
|
||||
USER app
|
||||
|
||||
# Run python to initialize download of SentenceTransformer model
|
||||
RUN pipenv run python utils/initialize_sentence_transformer.py
|
||||
|
||||
# pytest
|
||||
RUN export DEEPSEEK_API_KEY=sk-XXXXXXXXXX; export TAVILY_API_KEY=tvly-dev-wXXXXXX;\
|
||||
pipenv run pytest tests --disable-warnings
|
||||
|
||||
# expose the port the app runs on
|
||||
EXPOSE 80
|
||||
|
||||
# run uvicorn
|
||||
CMD ["pipenv", "run", "uvicorn", "main:app", "--reload", "--workers", "1", "--host", "0.0.0.0", "--port", "80"]
|
||||
|
||||
87
app/backend/Dockerfile.prod
Normal file
87
app/backend/Dockerfile.prod
Normal file
@ -0,0 +1,87 @@
|
||||
###########
|
||||
# BUILDER #
|
||||
###########
|
||||
|
||||
# pull official base image
|
||||
FROM python:3.11-slim-bookworm AS builder
|
||||
|
||||
# set working directory
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV ENVIRONMENT=dev
|
||||
ENV TESTING=1
|
||||
ENV CUDA_VISIBLE_DEVICES=""
|
||||
|
||||
# install python dependencies
|
||||
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pipenv && rm -rf ~/.cache/pip
|
||||
COPY ./Pipfile .
|
||||
RUN pipenv install --deploy --dev --no-cache-dir
|
||||
RUN pipenv run pip install torch --force-reinstall --no-cache-dir
|
||||
|
||||
# remove all cached files not needed to save space
|
||||
RUN pip cache purge
|
||||
RUN rm -rf /root/.cache
|
||||
|
||||
# Create cache directory and set permissions
|
||||
RUN mkdir -p /home/app/.cache/huggingface
|
||||
RUN chown -R app:app /home/app/.cache/huggingface
|
||||
RUN chown -R app:app $APP_HOME
|
||||
#
|
||||
# add app
|
||||
COPY . /usr/src/app
|
||||
RUN export DEEPSEEK_API_KEY=sk-XXXXXXXXXX; export TAVILY_API_KEY=tvly-dev-wXXXXXX;\
|
||||
pipenv run pytest tests --disable-warnings
|
||||
RUN pipenv run flake8 .
|
||||
RUN pipenv run black --exclude=migrations . --check
|
||||
RUN pipenv run isort . --check-only
|
||||
|
||||
#########
|
||||
# FINAL #
|
||||
#########
|
||||
|
||||
# pull official base image
|
||||
FROM python:3.11-slim-bookworm
|
||||
|
||||
# create directory for the app user
|
||||
RUN mkdir -p /home/app
|
||||
|
||||
# create the app user
|
||||
RUN addgroup --system app && adduser --system --group app
|
||||
|
||||
|
||||
# create the appropriate directories
|
||||
ENV HOME=/home/app
|
||||
ENV APP_HOME=/home/app/backend
|
||||
RUN mkdir $APP_HOME
|
||||
WORKDIR $APP_HOME
|
||||
|
||||
# set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV ENVIRONMENT=prod
|
||||
ENV TESTING=0
|
||||
|
||||
|
||||
# install python dependencies
|
||||
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pipenv && rm -rf ~/.cache/pip
|
||||
COPY --from=builder /usr/src/app/Pipfile .
|
||||
RUN pipenv install --deploy
|
||||
RUN pipenv run pip install "uvicorn[standard]==0.26.0"
|
||||
|
||||
# add app
|
||||
COPY . $APP_HOME
|
||||
|
||||
# chown all the files to the app user
|
||||
RUN chown -R app:app $APP_HOME
|
||||
|
||||
# change to the app user
|
||||
USER app
|
||||
|
||||
# expose the port the app runs on
|
||||
EXPOSE 80
|
||||
|
||||
# run uvicorn
|
||||
CMD ["pipenv", "run", "uvicorn", "main:app", "--reload", "--workers", "1", "--host", "0.0.0.0", "--port", "80"]
|
||||
35
app/backend/Pipfile
Normal file
35
app/backend/Pipfile
Normal file
@ -0,0 +1,35 @@
|
||||
[[source]]
|
||||
url = "https://pypi.org/simple"
|
||||
verify_ssl = true
|
||||
name = "pypi"
|
||||
|
||||
[packages]
|
||||
fastapi = "==0.115.9"
|
||||
starlette = "==0.45.3"
|
||||
uvicorn = {version = "==0.26.0", extras = ["standard"]}
|
||||
pydantic-settings = "*"
|
||||
gunicorn = "==21.0.1"
|
||||
python-decouple = "==3.8"
|
||||
pyyaml = "==6.0.1"
|
||||
docker = "==6.1.3"
|
||||
chromadb = "==0.6.3"
|
||||
langchain = "==0.3.20"
|
||||
langgraph = "==0.3.5"
|
||||
langchain-community = "==0.3.19"
|
||||
tavily-python = "==0.5.1"
|
||||
langchain_huggingface = "==0.1.2"
|
||||
langchain-deepseek = "==0.1.2"
|
||||
torch = "*"
|
||||
sentence-transformers = "*"
|
||||
|
||||
[dev-packages]
|
||||
httpx = "==0.26.0"
|
||||
pytest = "==7.4.4"
|
||||
pytest-cov = "==4.1.0"
|
||||
pytest-mock = "==3.10.0"
|
||||
flake8 = "==7.0.0"
|
||||
black = "==23.12.1"
|
||||
isort = "==5.13.2"
|
||||
|
||||
[requires]
|
||||
python_version = "3.11"
|
||||
3445
app/backend/Pipfile.lock
generated
Normal file
3445
app/backend/Pipfile.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
509
app/backend/api/chatbot.py
Normal file
509
app/backend/api/chatbot.py
Normal file
@ -0,0 +1,509 @@
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
from decouple import config
|
||||
from typing import List
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from langchain_community.vectorstores.chroma import Chroma
|
||||
from langchain_community.tools.tavily_search import TavilySearchResults
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain.prompts import PromptTemplate, HumanMessagePromptTemplate
|
||||
|
||||
from langchain.schema import Document
|
||||
from pprint import pprint
|
||||
from langgraph.graph import END, StateGraph, START
|
||||
|
||||
from models.adaptive_rag.routing import RouteQuery
|
||||
from models.adaptive_rag.grading import (
|
||||
GradeDocuments,
|
||||
GradeHallucinations,
|
||||
GradeAnswer,
|
||||
)
|
||||
from models.adaptive_rag.query import (
|
||||
QueryRequest,
|
||||
QueryResponse,
|
||||
)
|
||||
|
||||
from models.adaptive_rag.prompts_library import (
|
||||
system_router,
|
||||
system_retriever_grader,
|
||||
system_hallucination_grader,
|
||||
system_answer_grader,
|
||||
system_question_rewriter,
|
||||
qa_prompt_template
|
||||
)
|
||||
|
||||
from .utils import ConnectionManager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Load environment variables
|
||||
os.environ["DEEPSEEK_API_KEY"] = config(
|
||||
"DEEPSEEK_API_KEY", cast=str, default="sk-XXXXXXXXXX"
|
||||
)
|
||||
os.environ["TAVILY_API_KEY"] = config(
|
||||
"TAVILY_API_KEY", cast=str, default="tvly-dev-wXXXXXX"
|
||||
)
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
embedding_model = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-mpnet-base-v2")
|
||||
|
||||
# Initialize the DeepSeek chat model
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# Load data from ChromaDB
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
collection_name = "rag-chroma"
|
||||
vectorstore = Chroma(persist_directory=db_path, collection_name=collection_name, embedding_function=embedding_model)
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
|
||||
############################ LLM functions ############################
|
||||
# Routing to vectorstore or web search
|
||||
structured_llm_router = llm.with_structured_output(RouteQuery)
|
||||
# Prompt
|
||||
route_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_router),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
question_router = route_prompt | structured_llm_router
|
||||
|
||||
### Retrieval Grader
|
||||
structured_llm_grader = llm.with_structured_output(GradeDocuments)
|
||||
# Prompt
|
||||
grade_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_retriever_grader),
|
||||
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
|
||||
]
|
||||
)
|
||||
retrieval_grader = grade_prompt | structured_llm_grader
|
||||
|
||||
### Generate
|
||||
# Create a PromptTemplate with the given prompt
|
||||
new_prompt_template = PromptTemplate(
|
||||
input_variables=["context", "question"],
|
||||
template=qa_prompt_template,
|
||||
)
|
||||
|
||||
# Create a new HumanMessagePromptTemplate with the new PromptTemplate
|
||||
new_human_message_prompt_template = HumanMessagePromptTemplate(
|
||||
prompt=new_prompt_template
|
||||
)
|
||||
prompt_qa = ChatPromptTemplate.from_messages([new_human_message_prompt_template])
|
||||
|
||||
# Chain
|
||||
rag_chain = prompt_qa | llm | StrOutputParser()
|
||||
|
||||
|
||||
### Hallucination Grader
|
||||
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
|
||||
|
||||
# Prompt
|
||||
hallucination_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_hallucination_grader),
|
||||
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
|
||||
]
|
||||
)
|
||||
|
||||
hallucination_grader = hallucination_prompt | structured_llm_grader
|
||||
|
||||
### Answer Grader
|
||||
structured_llm_grader = llm.with_structured_output(GradeAnswer)
|
||||
|
||||
# Prompt
|
||||
answer_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_answer_grader),
|
||||
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
|
||||
]
|
||||
)
|
||||
answer_grader = answer_prompt | structured_llm_grader
|
||||
|
||||
### Question Re-writer
|
||||
# Prompt
|
||||
re_write_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_question_rewriter),
|
||||
(
|
||||
"human",
|
||||
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
|
||||
),
|
||||
]
|
||||
)
|
||||
question_rewriter = re_write_prompt | llm | StrOutputParser()
|
||||
|
||||
### Search
|
||||
web_search_tool = TavilySearchResults(k=3)
|
||||
|
||||
############### Graph functions ################
|
||||
|
||||
def retrieve(state):
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE---")
|
||||
question = state["question"]
|
||||
|
||||
# Retrieval
|
||||
documents = retriever.invoke(question)
|
||||
|
||||
print(documents)
|
||||
return {"documents": documents, "question": question}
|
||||
|
||||
|
||||
def generate(state):
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, generation, that contains LLM generation
|
||||
"""
|
||||
print("---GENERATE---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
|
||||
# RAG generation
|
||||
generation = rag_chain.invoke({"context": documents, "question": question})
|
||||
return {"documents": documents, "question": question, "generation": generation}
|
||||
|
||||
|
||||
def grade_documents(state):
|
||||
"""
|
||||
Determines whether the retrieved documents are relevant to the question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): Updates documents key with only filtered relevant documents
|
||||
"""
|
||||
|
||||
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
|
||||
# Score each doc
|
||||
filtered_docs = []
|
||||
for d in documents:
|
||||
score = retrieval_grader.invoke(
|
||||
{"question": question, "document": d.page_content}
|
||||
)
|
||||
grade = score.binary_score
|
||||
if grade == "yes":
|
||||
print("---GRADE: DOCUMENT RELEVANT---")
|
||||
filtered_docs.append(d)
|
||||
else:
|
||||
print("---GRADE: DOCUMENT NOT RELEVANT---")
|
||||
continue
|
||||
return {"documents": filtered_docs, "question": question}
|
||||
|
||||
|
||||
def transform_query(state):
|
||||
"""
|
||||
Transform the query to produce a better question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): Updates question key with a re-phrased question
|
||||
"""
|
||||
|
||||
print("---TRANSFORM QUERY---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
|
||||
# Re-write question
|
||||
better_question = question_rewriter.invoke({"question": question})
|
||||
return {"documents": documents, "question": better_question}
|
||||
|
||||
|
||||
def web_search(state):
|
||||
"""
|
||||
Web search based on the re-phrased question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): Updates documents key with appended web results
|
||||
"""
|
||||
|
||||
print("---WEB SEARCH---")
|
||||
question = state["question"]
|
||||
|
||||
# Web search
|
||||
docs = web_search_tool.invoke({"query": question})
|
||||
web_results = "\n".join([d["content"] for d in docs])
|
||||
web_results = Document(page_content=web_results)
|
||||
|
||||
return {"documents": web_results, "question": question}
|
||||
|
||||
|
||||
### Edges ###
|
||||
def route_question(state):
|
||||
"""
|
||||
Route question to web search or RAG.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
str: Next node to call
|
||||
"""
|
||||
|
||||
print("---ROUTE QUESTION---")
|
||||
question = state["question"]
|
||||
source = question_router.invoke({"question": question})
|
||||
if source.datasource == "web_search":
|
||||
print("---ROUTE QUESTION TO WEB SEARCH---")
|
||||
return "web_search"
|
||||
elif source.datasource == "vectorstore":
|
||||
print("---ROUTE QUESTION TO RAG---")
|
||||
return "vectorstore"
|
||||
|
||||
|
||||
def decide_to_generate(state):
|
||||
"""
|
||||
Determines whether to generate an answer, or re-generate a question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
str: Binary decision for next node to call
|
||||
"""
|
||||
|
||||
print("---ASSESS GRADED DOCUMENTS---")
|
||||
state["question"]
|
||||
filtered_documents = state["documents"]
|
||||
|
||||
if not filtered_documents:
|
||||
# All documents have been filtered check_relevance
|
||||
# We will re-generate a new query
|
||||
print(
|
||||
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
|
||||
)
|
||||
return "transform_query"
|
||||
else:
|
||||
# We have relevant documents, so generate answer
|
||||
print("---DECISION: GENERATE---")
|
||||
return "generate"
|
||||
|
||||
|
||||
def grade_generation_v_documents_and_question(state):
|
||||
"""
|
||||
Determines whether the generation is grounded in the document and answers question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
str: Decision for next node to call
|
||||
"""
|
||||
|
||||
print("---CHECK HALLUCINATIONS---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
generation = state["generation"]
|
||||
|
||||
score = hallucination_grader.invoke(
|
||||
{"documents": documents, "generation": generation}
|
||||
)
|
||||
grade = score.binary_score
|
||||
|
||||
# Check hallucination
|
||||
if grade == "yes":
|
||||
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
||||
# Check question-answering
|
||||
print("---GRADE GENERATION vs QUESTION---")
|
||||
score = answer_grader.invoke({"question": question, "generation": generation})
|
||||
grade = score.binary_score
|
||||
if grade == "yes":
|
||||
print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
||||
return "useful"
|
||||
else:
|
||||
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
||||
return "not useful"
|
||||
else:
|
||||
pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
||||
return "not supported"
|
||||
|
||||
|
||||
class GraphState(TypedDict):
|
||||
"""
|
||||
Represents the state of our graph.
|
||||
|
||||
Attributes:
|
||||
question: question
|
||||
generation: LLM generation
|
||||
documents: list of documents
|
||||
"""
|
||||
|
||||
question: str
|
||||
generation: str
|
||||
documents: List[str]
|
||||
|
||||
workflow = StateGraph(GraphState)
|
||||
|
||||
# Define the nodes
|
||||
workflow.add_node("web_search", web_search) # web search
|
||||
workflow.add_node("retrieve", retrieve) # retrieve
|
||||
workflow.add_node("grade_documents", grade_documents) # grade documents
|
||||
workflow.add_node("generate", generate) # generatae
|
||||
workflow.add_node("transform_query", transform_query) # transform_query
|
||||
|
||||
# Build graph
|
||||
workflow.add_conditional_edges(
|
||||
START,
|
||||
route_question,
|
||||
{
|
||||
"web_search": "web_search",
|
||||
"vectorstore": "retrieve",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("web_search", "generate")
|
||||
workflow.add_edge("retrieve", "grade_documents")
|
||||
workflow.add_conditional_edges(
|
||||
"grade_documents",
|
||||
decide_to_generate,
|
||||
{
|
||||
"transform_query": "transform_query",
|
||||
"generate": "generate",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("transform_query", "retrieve")
|
||||
workflow.add_conditional_edges(
|
||||
"generate",
|
||||
grade_generation_v_documents_and_question,
|
||||
{
|
||||
"not supported": "generate",
|
||||
"useful": END,
|
||||
"not useful": "transform_query",
|
||||
},
|
||||
)
|
||||
|
||||
# Compile
|
||||
app = workflow.compile()
|
||||
|
||||
# Initialize the connection manager
|
||||
manager = ConnectionManager()
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
|
||||
try:
|
||||
data_json = json.loads(data)
|
||||
if (
|
||||
isinstance(data_json, list)
|
||||
and len(data_json) > 0
|
||||
and "content" in data_json[0]
|
||||
):
|
||||
inputs = {
|
||||
"question": data_json[0]["content"]
|
||||
}
|
||||
async for chunk in app.astream(inputs):
|
||||
# Determine if chunk is intermediate or final
|
||||
if isinstance(chunk, dict):
|
||||
if len(chunk) == 1:
|
||||
step_name = list(chunk.keys())[0]
|
||||
step_value = chunk[step_name]
|
||||
# Check if this step contains the final answer
|
||||
if isinstance(step_value, dict) and 'generation' in step_value:
|
||||
await manager.send_personal_message(
|
||||
json.dumps({
|
||||
"type": "final",
|
||||
"title": "Answer",
|
||||
"payload": step_value['generation']
|
||||
}),
|
||||
websocket,
|
||||
)
|
||||
else:
|
||||
await manager.send_personal_message(
|
||||
json.dumps({
|
||||
"type": "intermediate",
|
||||
"title": step_name.replace('_', ' ').title(),
|
||||
"payload": str(step_value)
|
||||
}),
|
||||
websocket,
|
||||
)
|
||||
elif 'generation' in chunk:
|
||||
await manager.send_personal_message(
|
||||
json.dumps({
|
||||
"type": "final",
|
||||
"title": "Answer",
|
||||
"payload": chunk['generation']
|
||||
}),
|
||||
websocket,
|
||||
)
|
||||
else:
|
||||
await manager.send_personal_message(
|
||||
json.dumps({
|
||||
"type": "intermediate",
|
||||
"title": "Step",
|
||||
"payload": str(chunk)
|
||||
}),
|
||||
websocket,
|
||||
)
|
||||
else:
|
||||
# Fallback for non-dict chunks
|
||||
await manager.send_personal_message(
|
||||
json.dumps({
|
||||
"type": "intermediate",
|
||||
"title": "Step",
|
||||
"payload": str(chunk)
|
||||
}),
|
||||
websocket,
|
||||
)
|
||||
# Send a final 'done' message to signal completion
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "done"}),
|
||||
websocket,
|
||||
)
|
||||
else:
|
||||
await manager.send_personal_message(
|
||||
"Invalid message format", websocket
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await manager.broadcast("Invalid JSON message")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
await manager.broadcast("Client disconnected")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
await manager.broadcast("Client disconnected")
|
||||
14
app/backend/api/ping.py
Normal file
14
app/backend/api/ping.py
Normal file
@ -0,0 +1,14 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from config import Settings, get_settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/ping")
|
||||
async def pong(settings: Settings = Depends(get_settings)):
|
||||
return {
|
||||
"ping": "pong!",
|
||||
"environment": settings.environment,
|
||||
"testing": settings.testing,
|
||||
}
|
||||
24
app/backend/api/utils.py
Normal file
24
app/backend/api/utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
async def send_personal_message(self, message: str, websocket: WebSocket):
|
||||
await websocket.send_text(message)
|
||||
|
||||
async def broadcast(self, message: str):
|
||||
json_message = {"type": "message", "payload": message}
|
||||
for connection in self.active_connections:
|
||||
await connection.send_text(json.dumps(json_message))
|
||||
17
app/backend/config.py
Normal file
17
app/backend/config.py
Normal file
@ -0,0 +1,17 @@
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
log = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
environment: str = "dev"
|
||||
testing: bool = 0
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> BaseSettings:
|
||||
log.info("Loading config settings from the environment...")
|
||||
return Settings()
|
||||
38
app/backend/main.py
Normal file
38
app/backend/main.py
Normal file
@ -0,0 +1,38 @@
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from api import chatbot, ping
|
||||
|
||||
log = logging.getLogger("uvicorn")
|
||||
|
||||
origins = ["http://localhost:8004"]
|
||||
|
||||
|
||||
def create_application() -> FastAPI:
|
||||
application = FastAPI()
|
||||
application.include_router(ping.router, tags=["ping"])
|
||||
application.include_router(chatbot.router, tags=["chatbot"])
|
||||
return application
|
||||
|
||||
|
||||
app = create_application()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8004,
|
||||
reload=True
|
||||
)
|
||||
25
app/backend/models/adaptive_rag/grading.py
Normal file
25
app/backend/models/adaptive_rag/grading.py
Normal file
@ -0,0 +1,25 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GradeDocuments(BaseModel):
|
||||
"""Binary score for relevance check on retrieved documents."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Documents are relevant to the question, 'yes' or 'no'"
|
||||
)
|
||||
|
||||
|
||||
class GradeHallucinations(BaseModel):
|
||||
"""Binary score for hallucination present in generation answer."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Answer is grounded in the facts, 'yes' or 'no'"
|
||||
)
|
||||
|
||||
|
||||
class GradeAnswer(BaseModel):
|
||||
"""Binary score to assess answer addresses question."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Answer addresses the question, 'yes' or 'no'"
|
||||
)
|
||||
39
app/backend/models/adaptive_rag/prompts_library.py
Normal file
39
app/backend/models/adaptive_rag/prompts_library.py
Normal file
@ -0,0 +1,39 @@
|
||||
system_router = """You are an expert at routing a user question to a vectorstore or web search.
|
||||
The vectorstore contains documents related to any cancer/tumor disease. The question may be
|
||||
asked in a variety of languages, and may be phrased in a variety of ways.
|
||||
Use the vectorstore for questions on these topics. Otherwise, use web-search.
|
||||
"""
|
||||
|
||||
system_retriever_grader = """You are a grader assessing relevance of a retrieved document to a user question. \n
|
||||
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
|
||||
You must make sure to read carefully that the document contains a sentence or chunk of sentences that is exactly related but not closely related to the question subject (e.g. must be the exact disease or subject in question). \n
|
||||
The goal is to filter out erroneous retrievals. \n
|
||||
Must return a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
|
||||
|
||||
system_hallucination_grader = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
|
||||
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
|
||||
|
||||
system_answer_grader = """You are a grader assessing whether an answer addresses / resolves a question \n
|
||||
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
|
||||
|
||||
system_question_rewriter = """You a question re-writer that converts an input question to a better version that is optimized \n
|
||||
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
|
||||
|
||||
# prompt for question answering based on retrieved documents
|
||||
qa_prompt_template = """You are an expert at answering questions based on the following retrieved context.\n
|
||||
Before answering the question, you must have your own thought process what are the general scopes to cover when answering this question, step-by-step. Do not include this thought process in the answer.\n
|
||||
Then, given your thought process, you must read the provided context carefully and extract the relevant information.\n
|
||||
|
||||
If the question is about medical question, you must answer the question in a medical way and assume that the audience is a junior doctor or a medical student: \n
|
||||
1. For cancer diseases, you must include comprehensive treatment advices that encompasses multidisciplinary treatment options that included but not limited to surgery, chemotherapy, radiology, internal medicine (drugs), nutritional ratio (protein), etc. You must layout out the treatment options like what are the first-line, second-line treatment etc.\n
|
||||
2. For cancer diseases, don't consider context that is not primary tumor/cancer related, unless the question specifically mention it is secondary tumor/cancer related.\n
|
||||
3. If the question didn't state the stage of the cancer disease, you must reply with treatment options for each stage of the cancer disease, if they are availalbe in the provided context. If they are not available in the provided context, give a general one.\n
|
||||
|
||||
You must not use any information that is not present in the provided context to answer the question. Make sure to remove those information not present in the provided context.\n
|
||||
If you don't know the answer, just say that you don't know.\n
|
||||
Provide the answer in a concise and organized manner. \n
|
||||
|
||||
Question: {question} \n
|
||||
Context: {context} \n
|
||||
Answer:
|
||||
"""
|
||||
9
app/backend/models/adaptive_rag/query.py
Normal file
9
app/backend/models/adaptive_rag/query.py
Normal file
@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str = Field(..., description="The question to ask the model")
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
response: str = Field(..., description="The model's response")
|
||||
12
app/backend/models/adaptive_rag/routing.py
Normal file
12
app/backend/models/adaptive_rag/routing.py
Normal file
@ -0,0 +1,12 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RouteQuery(BaseModel):
|
||||
"""Route a user query to the most relevant datasource."""
|
||||
|
||||
datasource: Literal["vectorstore", "web_search"] = Field(
|
||||
...,
|
||||
description="Given a user question choose to route it to web search or a vectorstore.",
|
||||
)
|
||||
2
app/backend/setup.cfg
Normal file
2
app/backend/setup.cfg
Normal file
@ -0,0 +1,2 @@
|
||||
[flake8]
|
||||
max-line-length = 119
|
||||
0
app/backend/tests/__init__.py
Normal file
0
app/backend/tests/__init__.py
Normal file
51
app/backend/tests/api/test_chatbot.py
Normal file
51
app/backend/tests/api/test_chatbot.py
Normal file
@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch, MagicMock
|
||||
from fastapi import WebSocket
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Patch langchain and other heavy dependencies for import
|
||||
sys.modules['langchain_deepseek'] = MagicMock()
|
||||
sys.modules['langchain_huggingface'] = MagicMock()
|
||||
sys.modules['langchain_community.vectorstores.chroma'] = MagicMock()
|
||||
sys.modules['langchain_community.tools.tavily_search'] = MagicMock()
|
||||
sys.modules['langchain_core.prompts'] = MagicMock()
|
||||
sys.modules['langchain_core.output_parsers'] = MagicMock()
|
||||
sys.modules['langchain.prompts'] = MagicMock()
|
||||
sys.modules['langchain.schema'] = MagicMock()
|
||||
sys.modules['langgraph.graph'] = MagicMock()
|
||||
|
||||
from api import chatbot
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(chatbot.router)
|
||||
return TestClient(app)
|
||||
|
||||
def test_router_exists():
|
||||
assert hasattr(chatbot, 'router')
|
||||
|
||||
def test_env_vars_loaded(monkeypatch):
|
||||
monkeypatch.setenv('DEEPSEEK_API_KEY', 'dummy')
|
||||
monkeypatch.setenv('TAVILY_API_KEY', 'dummy')
|
||||
# Re-import to trigger env loading
|
||||
import importlib
|
||||
importlib.reload(chatbot)
|
||||
assert True
|
||||
|
||||
def test_websocket_endpoint_accepts(monkeypatch):
|
||||
# Patch ConnectionManager
|
||||
mock_manager = MagicMock()
|
||||
monkeypatch.setattr(chatbot, 'manager', mock_manager)
|
||||
ws = MagicMock(spec=WebSocket)
|
||||
ws.receive_text = MagicMock(side_effect=[pytest.raises(StopIteration)])
|
||||
ws.accept = MagicMock()
|
||||
# Should not raise
|
||||
try:
|
||||
coro = chatbot.websocket_endpoint(ws)
|
||||
assert hasattr(coro, '__await__')
|
||||
except Exception as e:
|
||||
pytest.fail(f"websocket_endpoint raised: {e}")
|
||||
51
app/backend/tests/api/test_utils.py
Normal file
51
app/backend/tests/api/test_utils.py
Normal file
@ -0,0 +1,51 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from api.utils import ConnectionManager
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
||||
|
||||
|
||||
# Test for ConnectionManager class
|
||||
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.manager = ConnectionManager()
|
||||
|
||||
async def test_connect(self):
|
||||
mock_websocket = AsyncMock(spec=WebSocket)
|
||||
await self.manager.connect(mock_websocket)
|
||||
self.assertIn(mock_websocket, self.manager.active_connections)
|
||||
mock_websocket.accept.assert_awaited_once()
|
||||
|
||||
async def test_disconnect(self):
|
||||
mock_websocket = MagicMock(spec=WebSocket)
|
||||
self.manager.active_connections.append(mock_websocket)
|
||||
self.manager.disconnect(mock_websocket)
|
||||
self.assertNotIn(mock_websocket, self.manager.active_connections)
|
||||
|
||||
async def test_send_personal_message(self):
|
||||
mock_websocket = AsyncMock(spec=WebSocket)
|
||||
message = "Test message"
|
||||
await self.manager.send_personal_message(message, mock_websocket)
|
||||
mock_websocket.send_text.assert_awaited_once_with(message)
|
||||
|
||||
async def test_broadcast(self):
|
||||
mock_websocket1 = AsyncMock(spec=WebSocket)
|
||||
mock_websocket2 = AsyncMock(spec=WebSocket)
|
||||
self.manager.active_connections = [mock_websocket1, mock_websocket2]
|
||||
message = "Broadcast message"
|
||||
await self.manager.broadcast(message)
|
||||
mock_websocket1.send_text.assert_awaited_once_with(
|
||||
'{"type": "message", "payload": "Broadcast message"}'
|
||||
)
|
||||
mock_websocket2.send_text.assert_awaited_once_with(
|
||||
'{"type": "message", "payload": "Broadcast message"}'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
21
app/backend/tests/conftest.py
Normal file
21
app/backend/tests/conftest.py
Normal file
@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from config import Settings, get_settings
|
||||
from main import create_application
|
||||
|
||||
|
||||
def get_settings_override():
|
||||
return Settings(testing=1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_app():
|
||||
# set up
|
||||
app = create_application()
|
||||
app.dependency_overrides[get_settings] = get_settings_override
|
||||
with TestClient(app) as test_client:
|
||||
# testing
|
||||
yield test_client
|
||||
|
||||
# tear down
|
||||
14
app/backend/tests/models/adaptive_rag/test_grading.py
Normal file
14
app/backend/tests/models/adaptive_rag/test_grading.py
Normal file
@ -0,0 +1,14 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import grading
|
||||
|
||||
def test_grade_documents_class():
|
||||
doc = grading.GradeDocuments(binary_score='yes')
|
||||
assert doc.binary_score == 'yes'
|
||||
|
||||
def test_grade_hallucinations_class():
|
||||
doc = grading.GradeHallucinations(binary_score='no')
|
||||
assert doc.binary_score == 'no'
|
||||
|
||||
def test_grade_answer_class():
|
||||
doc = grading.GradeAnswer(binary_score='yes')
|
||||
assert doc.binary_score == 'yes'
|
||||
@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import prompts_library
|
||||
|
||||
def test_prompts_are_strings():
|
||||
assert isinstance(prompts_library.system_router, str)
|
||||
assert isinstance(prompts_library.system_retriever_grader, str)
|
||||
assert isinstance(prompts_library.system_hallucination_grader, str)
|
||||
assert isinstance(prompts_library.system_answer_grader, str)
|
||||
assert isinstance(prompts_library.system_question_rewriter, str)
|
||||
assert isinstance(prompts_library.qa_prompt_template, str)
|
||||
8
app/backend/tests/models/adaptive_rag/test_query.py
Normal file
8
app/backend/tests/models/adaptive_rag/test_query.py
Normal file
@ -0,0 +1,8 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import query
|
||||
|
||||
def test_query_request_and_response():
|
||||
req = query.QueryRequest(query="What is AI?")
|
||||
assert req.query == "What is AI?"
|
||||
resp = query.QueryResponse(response="Artificial Intelligence")
|
||||
assert resp.response == "Artificial Intelligence"
|
||||
6
app/backend/tests/models/adaptive_rag/test_routing.py
Normal file
6
app/backend/tests/models/adaptive_rag/test_routing.py
Normal file
@ -0,0 +1,6 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import routing
|
||||
|
||||
def test_route_query_class():
|
||||
route = routing.RouteQuery(datasource="vectorstore")
|
||||
assert route.datasource == "vectorstore"
|
||||
10
app/backend/tests/test_config_and_main.py
Normal file
10
app/backend/tests/test_config_and_main.py
Normal file
@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
from importlib import import_module
|
||||
|
||||
def test_config_import():
|
||||
mod = import_module('config')
|
||||
assert mod is not None
|
||||
|
||||
def test_main_import():
|
||||
mod = import_module('main')
|
||||
assert mod is not None
|
||||
4
app/backend/tests/test_ping.py
Normal file
4
app/backend/tests/test_ping.py
Normal file
@ -0,0 +1,4 @@
|
||||
def test_ping(test_app):
|
||||
response = test_app.get("/ping")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"environment": "dev", "ping": "pong!", "testing": True}
|
||||
@ -0,0 +1,6 @@
|
||||
import pytest
|
||||
from importlib import import_module
|
||||
|
||||
def test_initialize_sentence_transformer_import():
|
||||
mod = import_module('utils.initialize_sentence_transformer')
|
||||
assert mod is not None
|
||||
15
app/backend/utils/initialize_sentence_transformer.py
Normal file
15
app/backend/utils/initialize_sentence_transformer.py
Normal file
@ -0,0 +1,15 @@
|
||||
from decouple import config
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import os
|
||||
|
||||
EMBEDDING_MODEL = config("EMBEDDING_MODEL", cast=str, default="paraphrase-multilingual-mpnet-base-v2")
|
||||
|
||||
# Initialize embedding model
|
||||
model = SentenceTransformer(EMBEDDING_MODEL, device="cpu")
|
||||
|
||||
# create directory if not exists
|
||||
if not os.path.exists("./transformer_model"):
|
||||
os.makedirs("./transformer_model")
|
||||
|
||||
# save the model
|
||||
model.save("./transformer_model/paraphrase-multilingual-mpnet-base-v2")
|
||||
@ -1,11 +1,49 @@
|
||||
version: "3.9"
|
||||
services:
|
||||
chroma:
|
||||
image: ghcr.io/chroma-core/chroma:latest
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- chroma_data:/chroma
|
||||
# streamlit:
|
||||
# build: ./streamlit
|
||||
# platform: linux/amd64
|
||||
# ports:
|
||||
# - "8501:8501"
|
||||
# volumes:
|
||||
# - ./llmops/src/rag_cot_evaluation/chroma_db:/app/llmops/src/rag_cot_evaluation/chroma_db
|
||||
|
||||
volumes:
|
||||
chroma_data:
|
||||
backend:
|
||||
build:
|
||||
context: ./backend
|
||||
dockerfile: Dockerfile
|
||||
container_name: backend-aimingmedai
|
||||
platform: linux/amd64
|
||||
# command: pipenv run uvicorn main:app --reload --workers 1 --host 0.0.0.0 --port 8765
|
||||
volumes:
|
||||
- ./backend:/home/app/backend
|
||||
ports:
|
||||
- "8004:80"
|
||||
environment:
|
||||
- ENVIRONMENT=dev
|
||||
- TESTING=0
|
||||
|
||||
frontend:
|
||||
build:
|
||||
context: ./frontend
|
||||
dockerfile: Dockerfile.test
|
||||
container_name: frontend-aimingmedai
|
||||
volumes:
|
||||
- ./frontend:/usr/src/app
|
||||
- /usr/src/app/node_modules
|
||||
ports:
|
||||
- "3004:80"
|
||||
depends_on:
|
||||
- backend
|
||||
environment:
|
||||
LOG_LEVEL: "DEBUG"
|
||||
|
||||
# tests:
|
||||
# build:
|
||||
# context: ./tests
|
||||
# container_name: tests-aimingmedai
|
||||
# # depends_on:
|
||||
# # - backend
|
||||
# # - frontend
|
||||
# environment:
|
||||
# FRONTEND_URL: http://frontend:80
|
||||
# BACKEND_URL: http://backend:80
|
||||
|
||||
1
app/frontend/.dockerignore
Normal file
1
app/frontend/.dockerignore
Normal file
@ -0,0 +1 @@
|
||||
node_modules
|
||||
1
app/frontend/.env.production
Normal file
1
app/frontend/.env.production
Normal file
@ -0,0 +1 @@
|
||||
REACT_APP_BASE_URL=https://backend.aimingmed.com/
|
||||
1
app/frontend/.env.test
Normal file
1
app/frontend/.env.test
Normal file
@ -0,0 +1 @@
|
||||
REACT_APP_BASE_DOMAIN_NAME_PORT=localhost:8004
|
||||
24
app/frontend/.gitignore
vendored
Normal file
24
app/frontend/.gitignore
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
18
app/frontend/Dockerfile.test
Normal file
18
app/frontend/Dockerfile.test
Normal file
@ -0,0 +1,18 @@
|
||||
####### BUILDER IMAGE #######
|
||||
# Build stage
|
||||
FROM node:alpine
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Copy everything else, test and build
|
||||
COPY . /usr/src/app
|
||||
|
||||
# Build the app with a specific .env file
|
||||
ARG ENV_FILE=.env.test
|
||||
COPY ${ENV_FILE} /usr/src/app/.env
|
||||
|
||||
# Copy dependency files and install dependencies
|
||||
RUN npm install && npm install --save-dev @types/jest
|
||||
|
||||
EXPOSE 80
|
||||
CMD [ "npm", "run", "dev", "--", "--host", "0.0.0.0", "--port", "80" ]
|
||||
54
app/frontend/README.md
Normal file
54
app/frontend/README.md
Normal file
@ -0,0 +1,54 @@
|
||||
# React + TypeScript + Vite
|
||||
|
||||
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
|
||||
|
||||
Currently, two official plugins are available:
|
||||
|
||||
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh
|
||||
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
|
||||
|
||||
## Expanding the ESLint configuration
|
||||
|
||||
If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
|
||||
|
||||
```js
|
||||
export default tseslint.config({
|
||||
extends: [
|
||||
// Remove ...tseslint.configs.recommended and replace with this
|
||||
...tseslint.configs.recommendedTypeChecked,
|
||||
// Alternatively, use this for stricter rules
|
||||
...tseslint.configs.strictTypeChecked,
|
||||
// Optionally, add this for stylistic rules
|
||||
...tseslint.configs.stylisticTypeChecked,
|
||||
],
|
||||
languageOptions: {
|
||||
// other options...
|
||||
parserOptions: {
|
||||
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
||||
tsconfigRootDir: import.meta.dirname,
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
|
||||
|
||||
```js
|
||||
// eslint.config.js
|
||||
import reactX from 'eslint-plugin-react-x'
|
||||
import reactDom from 'eslint-plugin-react-dom'
|
||||
|
||||
export default tseslint.config({
|
||||
plugins: {
|
||||
// Add the react-x and react-dom plugins
|
||||
'react-x': reactX,
|
||||
'react-dom': reactDom,
|
||||
},
|
||||
rules: {
|
||||
// other rules...
|
||||
// Enable its recommended typescript rules
|
||||
...reactX.configs['recommended-typescript'].rules,
|
||||
...reactDom.configs.recommended.rules,
|
||||
},
|
||||
})
|
||||
```
|
||||
28
app/frontend/eslint.config.js
Normal file
28
app/frontend/eslint.config.js
Normal file
@ -0,0 +1,28 @@
|
||||
import js from '@eslint/js'
|
||||
import globals from 'globals'
|
||||
import reactHooks from 'eslint-plugin-react-hooks'
|
||||
import reactRefresh from 'eslint-plugin-react-refresh'
|
||||
import tseslint from 'typescript-eslint'
|
||||
|
||||
export default tseslint.config(
|
||||
{ ignores: ['dist'] },
|
||||
{
|
||||
extends: [js.configs.recommended, ...tseslint.configs.recommended],
|
||||
files: ['**/*.{ts,tsx}'],
|
||||
languageOptions: {
|
||||
ecmaVersion: 2020,
|
||||
globals: globals.browser,
|
||||
},
|
||||
plugins: {
|
||||
'react-hooks': reactHooks,
|
||||
'react-refresh': reactRefresh,
|
||||
},
|
||||
rules: {
|
||||
...reactHooks.configs.recommended.rules,
|
||||
'react-refresh/only-export-components': [
|
||||
'warn',
|
||||
{ allowConstantExport: true },
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
13
app/frontend/index.html
Normal file
13
app/frontend/index.html
Normal file
@ -0,0 +1,13 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Vite + React + TS</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
7085
app/frontend/package-lock.json
generated
Normal file
7085
app/frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
43
app/frontend/package.json
Normal file
43
app/frontend/package.json
Normal file
@ -0,0 +1,43 @@
|
||||
{
|
||||
"name": "frontend",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "tsc -b && vite build",
|
||||
"lint": "eslint .",
|
||||
"preview": "vite preview",
|
||||
"test": "vitest",
|
||||
"test:run": "vitest run"
|
||||
},
|
||||
"dependencies": {
|
||||
"@tailwindcss/typography": "^0.5.16",
|
||||
"daisyui": "^5.0.17",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"react-markdown": "^10.1.0",
|
||||
"remark-gfm": "^4.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.21.0",
|
||||
"@testing-library/jest-dom": "^6.6.3",
|
||||
"@testing-library/react": "^16.3.0",
|
||||
"@types/node": "^22.14.0",
|
||||
"@types/react": "^19.0.10",
|
||||
"@types/react-dom": "^19.0.4",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"autoprefixer": "^10.4.21",
|
||||
"eslint": "^9.21.0",
|
||||
"eslint-plugin-react-hooks": "^5.1.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.19",
|
||||
"globals": "^15.15.0",
|
||||
"jsdom": "^26.0.0",
|
||||
"postcss": "^8.5.3",
|
||||
"tailwindcss": "^3.4.17",
|
||||
"typescript": "~5.7.2",
|
||||
"typescript-eslint": "^8.24.1",
|
||||
"vite": "^6.2.0",
|
||||
"vitest": "^3.1.1"
|
||||
}
|
||||
}
|
||||
6
app/frontend/postcss.config.js
Normal file
6
app/frontend/postcss.config.js
Normal file
@ -0,0 +1,6 @@
|
||||
export default {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
||||
22
app/frontend/src/App.test.tsx
Normal file
22
app/frontend/src/App.test.tsx
Normal file
@ -0,0 +1,22 @@
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react';
|
||||
import App from './App';
|
||||
import { vi } from 'vitest';
|
||||
|
||||
it('renders initial state', () => {
|
||||
render(<App />);
|
||||
expect(screen.getByText('Simple Chatbot')).toBeInTheDocument();
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument();
|
||||
expect(screen.getByRole('button', { name: /send/i })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('sends a message', () => {
|
||||
const mockSend = vi.fn();
|
||||
vi.spyOn(WebSocket.prototype, 'send').mockImplementation(mockSend);
|
||||
render(<App />);
|
||||
const inputElement = screen.getByRole('textbox');
|
||||
fireEvent.change(inputElement, { target: { value: 'Hello' } });
|
||||
const buttonElement = screen.getByRole('button', { name: /send/i });
|
||||
fireEvent.click(buttonElement);
|
||||
expect(mockSend).toHaveBeenCalledWith(JSON.stringify([{ role: 'user', content: 'Hello' }]));
|
||||
expect(screen.getByText('Hello')).toBeInTheDocument();
|
||||
});
|
||||
201
app/frontend/src/App.tsx
Normal file
201
app/frontend/src/App.tsx
Normal file
@ -0,0 +1,201 @@
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
|
||||
const BASE_DOMAIN_NAME_PORT = import.meta.env.REACT_APP_DOMAIN_NAME_PORT || 'localhost:8004';
|
||||
|
||||
interface Message {
|
||||
sender: 'user' | 'bot';
|
||||
text: string;
|
||||
}
|
||||
|
||||
interface ChatTurn {
|
||||
question: string;
|
||||
intermediateMessages: { title: string; payload: string }[];
|
||||
finalAnswer: string | null;
|
||||
isLoading: boolean;
|
||||
showIntermediate: boolean;
|
||||
}
|
||||
|
||||
const App: React.FC = () => {
|
||||
const [chatTurns, setChatTurns] = useState<ChatTurn[]>([]);
|
||||
const [newMessage, setNewMessage] = useState('');
|
||||
const [socket, setSocket] = useState<WebSocket | null>(null);
|
||||
const mounted = useRef(false);
|
||||
|
||||
// Disable input/button if any job is running
|
||||
const isJobRunning = chatTurns.some(turn => turn.isLoading);
|
||||
|
||||
useEffect(() => {
|
||||
mounted.current = true;
|
||||
const ws = new WebSocket(`ws://${BASE_DOMAIN_NAME_PORT}/ws`);
|
||||
setSocket(ws);
|
||||
ws.onopen = () => {
|
||||
console.log('WebSocket connection opened');
|
||||
};
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
setChatTurns((prevTurns) => {
|
||||
if (prevTurns.length === 0) return prevTurns;
|
||||
const lastTurn = prevTurns[prevTurns.length - 1];
|
||||
if (data.type === 'intermediate') {
|
||||
// Add intermediate message to the last turn
|
||||
const updatedTurn = {
|
||||
...lastTurn,
|
||||
intermediateMessages: [...lastTurn.intermediateMessages, { title: data.title, payload: data.payload }],
|
||||
};
|
||||
return [...prevTurns.slice(0, -1), updatedTurn];
|
||||
} else if (data.type === 'final') {
|
||||
// Set final answer for the last turn
|
||||
const updatedTurn = {
|
||||
...lastTurn,
|
||||
finalAnswer: data.payload,
|
||||
};
|
||||
return [...prevTurns.slice(0, -1), updatedTurn];
|
||||
} else if (data.type === 'done') {
|
||||
// Mark last turn as not loading
|
||||
const updatedTurn = {
|
||||
...lastTurn,
|
||||
isLoading: false,
|
||||
};
|
||||
return [...prevTurns.slice(0, -1), updatedTurn];
|
||||
} else if (data.type === 'message' && data.payload && mounted.current) {
|
||||
// legacy support, treat as final
|
||||
const updatedTurn = {
|
||||
...lastTurn,
|
||||
finalAnswer: (lastTurn.finalAnswer || '') + data.payload,
|
||||
};
|
||||
return [...prevTurns.slice(0, -1), updatedTurn];
|
||||
}
|
||||
return prevTurns;
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error parsing message:', error);
|
||||
}
|
||||
};
|
||||
ws.onclose = () => {
|
||||
console.log('WebSocket connection closed');
|
||||
};
|
||||
ws.onerror = (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
};
|
||||
return () => {
|
||||
mounted.current = false;
|
||||
ws.close();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const sendMessage = () => {
|
||||
if (newMessage.trim() !== '') {
|
||||
setChatTurns((prev) => [
|
||||
...prev,
|
||||
{
|
||||
question: newMessage,
|
||||
intermediateMessages: [],
|
||||
finalAnswer: null,
|
||||
isLoading: true,
|
||||
showIntermediate: false,
|
||||
},
|
||||
]);
|
||||
const message = [{ role: 'user', content: newMessage }];
|
||||
socket?.send(JSON.stringify(message));
|
||||
setNewMessage('');
|
||||
}
|
||||
};
|
||||
|
||||
const toggleShowIntermediate = (idx: number) => {
|
||||
setChatTurns((prev) => prev.map((turn, i) => i === idx ? { ...turn, showIntermediate: !turn.showIntermediate } : turn));
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-screen bg-gray-100">
|
||||
<div className="p-4">
|
||||
<h1 className="text-3xl font-bold text-center text-gray-800">Simple Chatbot</h1>
|
||||
</div>
|
||||
<div className="flex-grow overflow-y-auto p-4">
|
||||
{chatTurns.map((turn, idx) => (
|
||||
<React.Fragment key={idx}>
|
||||
{/* User question */}
|
||||
<div className="p-4 rounded-lg mb-2 bg-blue-100 text-blue-800">{turn.question}</div>
|
||||
{/* Status box for this question */}
|
||||
{turn.intermediateMessages.length > 0 && (
|
||||
<div className="mb-4">
|
||||
<div className="bg-blue-50 border border-blue-300 rounded-lg p-3 shadow-sm flex items-center">
|
||||
{/* Spinner icon */}
|
||||
{turn.isLoading && (
|
||||
<svg className="animate-spin h-5 w-5 text-blue-500 mr-2" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
||||
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
|
||||
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8v8z"></path>
|
||||
</svg>
|
||||
)}
|
||||
<span className="font-semibold text-blue-700 mr-2">Working on:</span>
|
||||
{/* Key steps summary */}
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{turn.intermediateMessages.map((msg, i) => (
|
||||
<span key={i} className="bg-blue-100 text-blue-700 px-2 py-1 rounded text-xs font-medium border border-blue-200">
|
||||
{msg.title}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
<button
|
||||
className="ml-auto text-xs text-blue-600 flex items-center gap-1 px-2 py-1 rounded hover:bg-blue-100 focus:outline-none border border-transparent focus:border-blue-300 transition"
|
||||
onClick={() => toggleShowIntermediate(idx)}
|
||||
aria-expanded={turn.showIntermediate}
|
||||
title={turn.showIntermediate ? 'Hide details' : 'Show details'}
|
||||
>
|
||||
<svg
|
||||
className={`w-4 h-4 transition-transform duration-200 ${turn.showIntermediate ? 'rotate-180' : ''}`}
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth="2" d="M19 9l-7 7-7-7" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/* Expanded details */}
|
||||
{turn.showIntermediate && (
|
||||
<div className="bg-white border border-blue-200 rounded-b-lg p-3 mt-1 text-xs max-h-64 overflow-y-auto">
|
||||
{turn.intermediateMessages.map((msg, i) => (
|
||||
<div key={i} className="mb-3">
|
||||
<div className="font-bold text-blue-700 mb-1">{msg.title}</div>
|
||||
<pre className="whitespace-pre-wrap break-words text-gray-800">{msg.payload}</pre>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{/* Final answer for this question */}
|
||||
{turn.finalAnswer && (
|
||||
<div className="prose p-4 rounded-lg mb-2 bg-gray-200 text-gray-800">
|
||||
<ReactMarkdown remarkPlugins={[remarkGfm]}>{turn.finalAnswer}</ReactMarkdown> </div>
|
||||
)}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>
|
||||
<div className="p-4 border-t border-gray-300">
|
||||
<div className="flex">
|
||||
<input
|
||||
type="text"
|
||||
value={newMessage}
|
||||
onChange={(e) => setNewMessage(e.target.value)}
|
||||
className="flex-grow p-2 border border-gray-300 rounded-lg mr-2"
|
||||
disabled={isJobRunning}
|
||||
/>
|
||||
<button
|
||||
onClick={sendMessage}
|
||||
className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded-lg"
|
||||
disabled={isJobRunning}
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default App;
|
||||
1
app/frontend/src/assets/react.svg
Normal file
1
app/frontend/src/assets/react.svg
Normal file
@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="35.93" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 228"><path fill="#00D8FF" d="M210.483 73.824a171.49 171.49 0 0 0-8.24-2.597c.465-1.9.893-3.777 1.273-5.621c6.238-30.281 2.16-54.676-11.769-62.708c-13.355-7.7-35.196.329-57.254 19.526a171.23 171.23 0 0 0-6.375 5.848a155.866 155.866 0 0 0-4.241-3.917C100.759 3.829 77.587-4.822 63.673 3.233C50.33 10.957 46.379 33.89 51.995 62.588a170.974 170.974 0 0 0 1.892 8.48c-3.28.932-6.445 1.924-9.474 2.98C17.309 83.498 0 98.307 0 113.668c0 15.865 18.582 31.778 46.812 41.427a145.52 145.52 0 0 0 6.921 2.165a167.467 167.467 0 0 0-2.01 9.138c-5.354 28.2-1.173 50.591 12.134 58.266c13.744 7.926 36.812-.22 59.273-19.855a145.567 145.567 0 0 0 5.342-4.923a168.064 168.064 0 0 0 6.92 6.314c21.758 18.722 43.246 26.282 56.54 18.586c13.731-7.949 18.194-32.003 12.4-61.268a145.016 145.016 0 0 0-1.535-6.842c1.62-.48 3.21-.974 4.76-1.488c29.348-9.723 48.443-25.443 48.443-41.52c0-15.417-17.868-30.326-45.517-39.844Zm-6.365 70.984c-1.4.463-2.836.91-4.3 1.345c-3.24-10.257-7.612-21.163-12.963-32.432c5.106-11 9.31-21.767 12.459-31.957c2.619.758 5.16 1.557 7.61 2.4c23.69 8.156 38.14 20.213 38.14 29.504c0 9.896-15.606 22.743-40.946 31.14Zm-10.514 20.834c2.562 12.94 2.927 24.64 1.23 33.787c-1.524 8.219-4.59 13.698-8.382 15.893c-8.067 4.67-25.32-1.4-43.927-17.412a156.726 156.726 0 0 1-6.437-5.87c7.214-7.889 14.423-17.06 21.459-27.246c12.376-1.098 24.068-2.894 34.671-5.345a134.17 134.17 0 0 1 1.386 6.193ZM87.276 214.515c-7.882 2.783-14.16 2.863-17.955.675c-8.075-4.657-11.432-22.636-6.853-46.752a156.923 156.923 0 0 1 1.869-8.499c10.486 2.32 22.093 3.988 34.498 4.994c7.084 9.967 14.501 19.128 21.976 27.15a134.668 134.668 0 0 1-4.877 4.492c-9.933 8.682-19.886 14.842-28.658 17.94ZM50.35 144.747c-12.483-4.267-22.792-9.812-29.858-15.863c-6.35-5.437-9.555-10.836-9.555-15.216c0-9.322 13.897-21.212 37.076-29.293c2.813-.98 5.757-1.905 8.812-2.773c3.204 10.42 7.406 21.315 12.477 32.332c-5.137 11.18-9.399 22.249-12.634 32.792a134.718 134.718 0 0 1-6.318-1.979Zm12.378-84.26c-4.811-24.587-1.616-43.134 6.425-47.789c8.564-4.958 27.502 2.111 47.463 19.835a144.318 144.318 0 0 1 3.841 3.545c-7.438 7.987-14.787 17.08-21.808 26.988c-12.04 1.116-23.565 2.908-34.161 5.309a160.342 160.342 0 0 1-1.76-7.887Zm110.427 27.268a347.8 347.8 0 0 0-7.785-12.803c8.168 1.033 15.994 2.404 23.343 4.08c-2.206 7.072-4.956 14.465-8.193 22.045a381.151 381.151 0 0 0-7.365-13.322Zm-45.032-43.861c5.044 5.465 10.096 11.566 15.065 18.186a322.04 322.04 0 0 0-30.257-.006c4.974-6.559 10.069-12.652 15.192-18.18ZM82.802 87.83a323.167 323.167 0 0 0-7.227 13.238c-3.184-7.553-5.909-14.98-8.134-22.152c7.304-1.634 15.093-2.97 23.209-3.984a321.524 321.524 0 0 0-7.848 12.897Zm8.081 65.352c-8.385-.936-16.291-2.203-23.593-3.793c2.26-7.3 5.045-14.885 8.298-22.6a321.187 321.187 0 0 0 7.257 13.246c2.594 4.48 5.28 8.868 8.038 13.147Zm37.542 31.03c-5.184-5.592-10.354-11.779-15.403-18.433c4.902.192 9.899.29 14.978.29c5.218 0 10.376-.117 15.453-.343c-4.985 6.774-10.018 12.97-15.028 18.486Zm52.198-57.817c3.422 7.8 6.306 15.345 8.596 22.52c-7.422 1.694-15.436 3.058-23.88 4.071a382.417 382.417 0 0 0 7.859-13.026a347.403 347.403 0 0 0 7.425-13.565Zm-16.898 8.101a358.557 358.557 0 0 1-12.281 19.815a329.4 329.4 0 0 1-23.444.823c-7.967 0-15.716-.248-23.178-.732a310.202 310.202 0 0 1-12.513-19.846h.001a307.41 307.41 0 0 1-10.923-20.627a310.278 310.278 0 0 1 10.89-20.637l-.001.001a307.318 307.318 0 0 1 12.413-19.761c7.613-.576 15.42-.876 23.31-.876H128c7.926 0 15.743.303 23.354.883a329.357 329.357 0 0 1 12.335 19.695a358.489 358.489 0 0 1 11.036 20.54a329.472 329.472 0 0 1-11 20.722Zm22.56-122.124c8.572 4.944 11.906 24.881 6.52 51.026c-.344 1.668-.73 3.367-1.15 5.09c-10.622-2.452-22.155-4.275-34.23-5.408c-7.034-10.017-14.323-19.124-21.64-27.008a160.789 160.789 0 0 1 5.888-5.4c18.9-16.447 36.564-22.941 44.612-18.3ZM128 90.808c12.625 0 22.86 10.235 22.86 22.86s-10.235 22.86-22.86 22.86s-22.86-10.235-22.86-22.86s10.235-22.86 22.86-22.86Z"></path></svg>
|
||||
|
After Width: | Height: | Size: 4.0 KiB |
3
app/frontend/src/index.css
Normal file
3
app/frontend/src/index.css
Normal file
@ -0,0 +1,3 @@
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
10
app/frontend/src/main.tsx
Normal file
10
app/frontend/src/main.tsx
Normal file
@ -0,0 +1,10 @@
|
||||
import { StrictMode } from 'react'
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import './index.css'
|
||||
import App from './App.tsx'
|
||||
|
||||
createRoot(document.getElementById('root')!).render(
|
||||
<StrictMode>
|
||||
<App />
|
||||
</StrictMode>,
|
||||
)
|
||||
9
app/frontend/src/vite-env.d.ts
vendored
Normal file
9
app/frontend/src/vite-env.d.ts
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
/// <reference types="vite/client" />
|
||||
import type { TestingLibraryMatchers } from "@testing-library/jest-dom/matchers";
|
||||
|
||||
declare global {
|
||||
namespace jest {
|
||||
interface Matchers<R = void>
|
||||
extends TestingLibraryMatchers<typeof expect.stringContaining, R> {}
|
||||
}
|
||||
}
|
||||
14
app/frontend/tailwind.config.js
Normal file
14
app/frontend/tailwind.config.js
Normal file
@ -0,0 +1,14 @@
|
||||
/** @type {import('tailwindcss').Config} */
|
||||
export default {
|
||||
content: [
|
||||
"./src/**/*.{js,jsx,ts,tsx}",
|
||||
],
|
||||
theme: {
|
||||
extend: {},
|
||||
},
|
||||
plugins: [
|
||||
require('@tailwindcss/typography'),
|
||||
require("daisyui"),
|
||||
],
|
||||
}
|
||||
|
||||
9
app/frontend/tests/setup.ts
Normal file
9
app/frontend/tests/setup.ts
Normal file
@ -0,0 +1,9 @@
|
||||
import { expect, afterEach } from "vitest";
|
||||
import { cleanup } from "@testing-library/react";
|
||||
import * as matchers from "@testing-library/jest-dom/matchers";
|
||||
|
||||
expect.extend(matchers);
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
});
|
||||
26
app/frontend/tsconfig.app.json
Normal file
26
app/frontend/tsconfig.app.json
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
|
||||
"target": "ES2020",
|
||||
"useDefineForClassFields": true,
|
||||
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||
"module": "ESNext",
|
||||
"skipLibCheck": true,
|
||||
|
||||
/* Bundler mode */
|
||||
"moduleResolution": "bundler",
|
||||
"allowImportingTsExtensions": true,
|
||||
"isolatedModules": true,
|
||||
"moduleDetection": "force",
|
||||
"noEmit": true,
|
||||
"jsx": "react-jsx",
|
||||
|
||||
/* Linting */
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"noFallthroughCasesInSwitch": true,
|
||||
"noUncheckedSideEffectImports": true
|
||||
},
|
||||
"include": ["src"]
|
||||
}
|
||||
7
app/frontend/tsconfig.json
Normal file
7
app/frontend/tsconfig.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"files": [],
|
||||
"references": [
|
||||
{ "path": "./tsconfig.app.json" },
|
||||
{ "path": "./tsconfig.node.json" }
|
||||
]
|
||||
}
|
||||
24
app/frontend/tsconfig.node.json
Normal file
24
app/frontend/tsconfig.node.json
Normal file
@ -0,0 +1,24 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
|
||||
"target": "ES2022",
|
||||
"lib": ["ES2023"],
|
||||
"module": "ESNext",
|
||||
"skipLibCheck": true,
|
||||
|
||||
/* Bundler mode */
|
||||
"moduleResolution": "bundler",
|
||||
"allowImportingTsExtensions": true,
|
||||
"isolatedModules": true,
|
||||
"moduleDetection": "force",
|
||||
"noEmit": true,
|
||||
|
||||
/* Linting */
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"noFallthroughCasesInSwitch": true,
|
||||
"noUncheckedSideEffectImports": true
|
||||
},
|
||||
"include": ["vite.config.ts"]
|
||||
}
|
||||
17
app/frontend/vite.config.ts
Normal file
17
app/frontend/vite.config.ts
Normal file
@ -0,0 +1,17 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
host: true,
|
||||
strictPort: true,
|
||||
port: 8004
|
||||
},
|
||||
test: {
|
||||
globals: true,
|
||||
environment: "jsdom",
|
||||
setupFiles: "./tests/setup.ts",
|
||||
},
|
||||
});
|
||||
@ -16,6 +16,7 @@ docker = "*"
|
||||
ipywidgets = "*"
|
||||
ipykernel = "*"
|
||||
jupyter = "*"
|
||||
chromadb = "*"
|
||||
|
||||
[dev-packages]
|
||||
pytest = "==8.0.0"
|
||||
|
||||
960
app/llmops/Pipfile.lock
generated
960
app/llmops/Pipfile.lock
generated
File diff suppressed because it is too large
Load Diff
@ -10,5 +10,4 @@ build_dependencies:
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
- wandb==0.16.0
|
||||
- git+https://github.com/udacity/nd0821-c2-build-model-workflow-starter.git#egg=wandb-utils&subdirectory=components
|
||||
|
||||
@ -5,33 +5,33 @@ This script download a URL to a local destination
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
import wandb
|
||||
|
||||
from wandb_utils.log_artifact import log_artifact
|
||||
import mlflow
|
||||
import shutil
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def go(args):
|
||||
|
||||
zip_path = os.path.join(args.path_document_folder, f"{args.document_folder}.zip")
|
||||
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', args.path_document_folder, args.document_folder)
|
||||
|
||||
run = wandb.init(job_type="get_documents", entity='aimingmed')
|
||||
run.config.update(args)
|
||||
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id):
|
||||
|
||||
logger.info(f"Uploading {args.artifact_name} to Weights & Biases")
|
||||
log_artifact(
|
||||
args.artifact_name,
|
||||
args.artifact_type,
|
||||
args.artifact_description,
|
||||
zip_path,
|
||||
run,
|
||||
)
|
||||
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||
if 'artifact_description' not in existing_params:
|
||||
mlflow.log_param('artifact_description', args.artifact_description)
|
||||
if 'artifact_types' not in existing_params:
|
||||
mlflow.log_param('artifact_types', args.artifact_type)
|
||||
|
||||
|
||||
# Log parameters to MLflow
|
||||
mlflow.log_params({
|
||||
"input_artifact": args.artifact_name,
|
||||
})
|
||||
|
||||
logger.info(f"Uploading {args.artifact_name} to MLFlow")
|
||||
mlflow.log_artifact(zip_path, args.artifact_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
29
app/llmops/components/test_rag_cot/MLproject
Normal file
29
app/llmops/components/test_rag_cot/MLproject
Normal file
@ -0,0 +1,29 @@
|
||||
name: test_rag_cot
|
||||
python_env: python_env.yml
|
||||
|
||||
entry_points:
|
||||
main:
|
||||
parameters:
|
||||
|
||||
query:
|
||||
description: Query to run
|
||||
type: string
|
||||
|
||||
input_chromadb_local:
|
||||
description: path to input chromadb local
|
||||
type: string
|
||||
|
||||
embedding_model:
|
||||
description: Fully-qualified name for the embedding model
|
||||
type: string
|
||||
|
||||
chat_model_provider:
|
||||
description: Fully-qualified name for the chat model provider
|
||||
type: string
|
||||
|
||||
|
||||
command: >-
|
||||
python run.py --query {query} \
|
||||
--input_chromadb_local {input_chromadb_local} \
|
||||
--embedding_model {embedding_model} \
|
||||
--chat_model_provider {chat_model_provider}
|
||||
@ -14,5 +14,4 @@ build_dependencies:
|
||||
- langchain-community
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
- wandb==0.16.0
|
||||
- mlflow==2.8.1
|
||||
157
app/llmops/components/test_rag_cot/run.py
Normal file
157
app/llmops/components/test_rag_cot/run.py
Normal file
@ -0,0 +1,157 @@
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import mlflow
|
||||
import chromadb
|
||||
from decouple import config
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_community.llms.moonshot import Moonshot
|
||||
import sys
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
||||
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
|
||||
|
||||
def stream_output(text):
|
||||
for char in text:
|
||||
print(char, end="")
|
||||
sys.stdout.flush()
|
||||
|
||||
def go(args):
|
||||
|
||||
# start a new MLflow run
|
||||
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromadb_pdf"):
|
||||
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||
if 'query' not in existing_params:
|
||||
mlflow.log_param('query', args.query)
|
||||
|
||||
# Log parameters to MLflow
|
||||
mlflow.log_params({
|
||||
"input_chromadb_local": args.input_chromadb_local,
|
||||
"embedding_model": args.embedding_model,
|
||||
"chat_model_provider": args.chat_model_provider
|
||||
})
|
||||
|
||||
|
||||
# Load data from ChromaDB
|
||||
db_path = args.input_chromadb_local
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
collection = chroma_client.get_collection(name=collection_name)
|
||||
|
||||
# Formulate a question
|
||||
question = args.query
|
||||
|
||||
if args.chat_model_provider == "deepseek":
|
||||
# Initialize DeepSeek model
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=DEEKSEEK_API_KEY
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "gemini":
|
||||
# Initialize Gemini model
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
google_api_key=GEMINI_API_KEY,
|
||||
temperature=0,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "moonshot":
|
||||
# Initialize Moonshot model
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=MOONSHOT_API_KEY
|
||||
)
|
||||
|
||||
|
||||
# Chain of Thought Prompt
|
||||
cot_template = """Let's think step by step.
|
||||
Given the following document in text: {documents_text}
|
||||
Question: {question}
|
||||
Reply with language that is similar to the language used with asked question.
|
||||
"""
|
||||
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||
cot_chain = cot_prompt | llm
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model = SentenceTransformer(args.embedding_model)
|
||||
|
||||
# Query (prompt)
|
||||
query_embedding = model.encode(question) # Embed the query using the SAME model
|
||||
|
||||
# Search ChromaDB
|
||||
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
|
||||
|
||||
# Generate chain of thought
|
||||
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": question})
|
||||
print("Chain of Thought: ", end="")
|
||||
stream_output(cot_output.content)
|
||||
print()
|
||||
|
||||
# Answer Prompt
|
||||
answer_template = """Given the chain of thought: {cot}
|
||||
Provide a concise answer to the question: {question}
|
||||
Provide the answer with language that is similar to the question asked.
|
||||
"""
|
||||
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
||||
answer_chain = answer_prompt | llm
|
||||
|
||||
# Generate answer
|
||||
answer_output = answer_chain.invoke({"cot": cot_output, "question": question})
|
||||
print("Answer: ", end="")
|
||||
stream_output(answer_output.content)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
|
||||
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
help="Question to ask the model",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_chromadb_local",
|
||||
type=str,
|
||||
help="Path to input chromadb local directory",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chat_model_provider",
|
||||
type=str,
|
||||
default="gemini",
|
||||
help="Chat model provider"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
@ -1,6 +1,4 @@
|
||||
import wandb
|
||||
import mlflow
|
||||
|
||||
|
||||
def log_artifact(artifact_name, artifact_type, artifact_description, filename, wandb_run):
|
||||
"""
|
||||
|
||||
@ -7,8 +7,18 @@ etl:
|
||||
input_artifact_name: documents
|
||||
document_folder: documents
|
||||
path_document_folder: "../../../../data"
|
||||
run_id_documents: None
|
||||
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
||||
prompt_engineering:
|
||||
chat_model_provider: kimi
|
||||
query: "怎么治疗有kras的肺癌?"
|
||||
|
||||
rag:
|
||||
run_id_chromadb: None
|
||||
chat_model_provider: deepseek
|
||||
testing:
|
||||
query: "如何治疗乳腺癌?"
|
||||
evaluation:
|
||||
evaluation_dataset_csv_path: "../../../../data/qa_dataset_20250409_onlyBreast.csv"
|
||||
evaluation_dataset_column_question: question
|
||||
evaluation_dataset_column_answer: answer
|
||||
ls_chat_model_provider:
|
||||
- gemini
|
||||
- deepseek
|
||||
- moonshot
|
||||
|
||||
@ -1,17 +1,16 @@
|
||||
import json
|
||||
|
||||
import mlflow
|
||||
import tempfile
|
||||
import os
|
||||
import hydra
|
||||
from omegaconf import DictConfig
|
||||
from decouple import config
|
||||
|
||||
_steps = [
|
||||
"get_documents",
|
||||
"etl_chromdb_pdf",
|
||||
"etl_chromdb_scanned_pdf", # the performance for scanned pdf may not be good
|
||||
"chain_of_thought"
|
||||
"etl_chromadb_pdf",
|
||||
"etl_chromadb_scanned_pdf", # the performance for scanned pdf may not be good
|
||||
"rag_cot_evaluation",
|
||||
"rag_adaptive_evaluation",
|
||||
"test_rag_cot"
|
||||
]
|
||||
|
||||
|
||||
@ -19,16 +18,15 @@ _steps = [
|
||||
@hydra.main(config_name='config')
|
||||
def go(config: DictConfig):
|
||||
|
||||
# Setup the wandb experiment. All runs will be grouped under this name
|
||||
os.environ["WANDB_PROJECT"] = config["main"]["project_name"]
|
||||
os.environ["WANDB_RUN_GROUP"] = config["main"]["experiment_name"]
|
||||
# Setup the MLflow experiment. All runs will be grouped under this name
|
||||
mlflow.set_experiment(config["main"]["experiment_name"])
|
||||
|
||||
# Steps to execute
|
||||
steps_par = config['main']['steps']
|
||||
active_steps = steps_par.split(",") if steps_par != "all" else _steps
|
||||
|
||||
# Move to a temporary directory
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with tempfile.TemporaryDirectory():
|
||||
|
||||
if "get_documents" in active_steps:
|
||||
# Download file and load in W&B
|
||||
@ -43,41 +41,144 @@ def go(config: DictConfig):
|
||||
"artifact_description": "Raw file as downloaded"
|
||||
},
|
||||
)
|
||||
if "etl_chromdb_pdf" in active_steps:
|
||||
if "etl_chromadb_pdf" in active_steps:
|
||||
if config["etl"]["run_id_documents"] == "None":
|
||||
# Look for run_id that has artifact logged as documents
|
||||
run_id = None
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
for run in client.search_runs(experiment_ids=[client.get_experiment_by_name(config["main"]["experiment_name"]).experiment_id]):
|
||||
for artifact in client.list_artifacts(run.info.run_id):
|
||||
if artifact.path == "documents":
|
||||
run_id = run.info.run_id
|
||||
break
|
||||
if run_id:
|
||||
break
|
||||
|
||||
if run_id is None:
|
||||
raise ValueError("No run_id found with artifact logged as documents")
|
||||
else:
|
||||
run_id = config["etl"]["run_id_documents"]
|
||||
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromdb_pdf"),
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromadb_pdf"),
|
||||
"main",
|
||||
parameters={
|
||||
"input_artifact": f'{config["etl"]["input_artifact_name"]}:latest',
|
||||
"output_artifact": "chromdb.zip",
|
||||
"output_type": "chromdb",
|
||||
"input_artifact": f'runs:/{run_id}/documents/documents.zip',
|
||||
"output_artifact": "chromadb",
|
||||
"output_type": "chromadb",
|
||||
"output_description": "Documents in pdf to be read and stored in chromdb",
|
||||
"embedding_model": config["etl"]["embedding_model"]
|
||||
},
|
||||
)
|
||||
if "etl_chromdb_scanned_pdf" in active_steps:
|
||||
|
||||
if "etl_chromadb_scanned_pdf" in active_steps:
|
||||
|
||||
if config["etl"]["run_id_documents"] == "None":
|
||||
# Look for run_id that has artifact logged as documents
|
||||
run_id = None
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
for run in client.search_runs(experiment_ids=[client.get_experiment_by_name(config["main"]["experiment_name"]).experiment_id]):
|
||||
for artifact in client.list_artifacts(run.info.run_id):
|
||||
if artifact.path == "documents":
|
||||
run_id = run.info.run_id
|
||||
break
|
||||
if run_id:
|
||||
break
|
||||
|
||||
if run_id is None:
|
||||
raise ValueError("No run_id found with artifact logged as documents")
|
||||
else:
|
||||
run_id = config["etl"]["run_id_documents"]
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromdb_scanned_pdf"),
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromadb_scanned_pdf"),
|
||||
"main",
|
||||
parameters={
|
||||
"input_artifact": f'{config["etl"]["input_artifact_name"]}:latest',
|
||||
"output_artifact": "chromdb.zip",
|
||||
"output_type": "chromdb",
|
||||
"input_artifact": f'runs:/{run_id}/documents/documents.zip',
|
||||
"output_artifact": "chromadb",
|
||||
"output_type": "chromadb",
|
||||
"output_description": "Scanned Documents in pdf to be read and stored in chromdb",
|
||||
"embedding_model": config["etl"]["embedding_model"]
|
||||
},
|
||||
)
|
||||
if "chain_of_thought" in active_steps:
|
||||
if "rag_cot_evaluation" in active_steps:
|
||||
|
||||
if config["rag"]["run_id_chromadb"] == "None":
|
||||
# Look for run_id that has artifact logged as documents
|
||||
run_id = None
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
for run in client.search_runs(experiment_ids=[client.get_experiment_by_name(config["main"]["experiment_name"]).experiment_id]):
|
||||
for artifact in client.list_artifacts(run.info.run_id):
|
||||
if artifact.path == "chromadb":
|
||||
run_id = run.info.run_id
|
||||
break
|
||||
if run_id:
|
||||
break
|
||||
|
||||
if run_id is None:
|
||||
raise ValueError("No run_id found with artifact logged as documents")
|
||||
else:
|
||||
run_id = config["rag"]["run_id_chromadb"]
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "chain_of_thought"),
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation"),
|
||||
"main",
|
||||
parameters={
|
||||
"query": config["prompt_engineering"]["query"],
|
||||
"input_chromadb_artifact": "chromdb.zip:latest",
|
||||
"query": config["testing"]["query"],
|
||||
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
|
||||
"embedding_model": config["etl"]["embedding_model"],
|
||||
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
||||
"chat_model_provider": config["rag"]["chat_model_provider"]
|
||||
},
|
||||
)
|
||||
|
||||
if "rag_adaptive_evaluation" in active_steps:
|
||||
|
||||
if config["rag"]["run_id_chromadb"] == "None":
|
||||
# Look for run_id that has artifact logged as documents
|
||||
run_id = None
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
for run in client.search_runs(experiment_ids=[client.get_experiment_by_name(config["main"]["experiment_name"]).experiment_id]):
|
||||
for artifact in client.list_artifacts(run.info.run_id):
|
||||
if artifact.path == "chromadb":
|
||||
run_id = run.info.run_id
|
||||
break
|
||||
if run_id:
|
||||
break
|
||||
|
||||
if run_id is None:
|
||||
raise ValueError("No run_id found with artifact logged as documents")
|
||||
else:
|
||||
run_id = config["rag"]["run_id_chromadb"]
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_adaptive_evaluation"),
|
||||
"main",
|
||||
parameters={
|
||||
"query": config["testing"]["query"],
|
||||
"evaluation_dataset_csv_path": config["evaluation"]["evaluation_dataset_csv_path"],
|
||||
"evaluation_dataset_column_question": config["evaluation"]["evaluation_dataset_column_question"],
|
||||
"evaluation_dataset_column_answer": config["evaluation"]["evaluation_dataset_column_answer"],
|
||||
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
|
||||
"embedding_model": config["etl"]["embedding_model"],
|
||||
"chat_model_provider": config["rag"]["chat_model_provider"],
|
||||
"ls_chat_model_evaluator": ','.join(config["evaluation"]["ls_chat_model_provider"]) if config["evaluation"]["ls_chat_model_provider"] is not None else 'None',
|
||||
},
|
||||
)
|
||||
|
||||
if "test_rag_cot" in active_steps:
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "components", "test_rag_cot"),
|
||||
"main",
|
||||
parameters={
|
||||
"query": config["testing"]["query"],
|
||||
"input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation", "chroma_db"),
|
||||
"embedding_model": config["etl"]["embedding_model"],
|
||||
"chat_model_provider": config["rag"]["chat_model_provider"]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
go()
|
||||
|
||||
@ -1,144 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import wandb
|
||||
import chromadb
|
||||
import shutil
|
||||
from decouple import config
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_community.llms.moonshot import Moonshot
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
||||
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
|
||||
|
||||
def go(args):
|
||||
run = wandb.init(job_type="chain_of_thought", entity='aimingmed')
|
||||
run.config.update(args)
|
||||
|
||||
logger.info("Downloading chromadb artifact")
|
||||
artifact_chromadb_local_path = run.use_artifact(args.input_chromadb_artifact).file()
|
||||
|
||||
# unzip the artifact
|
||||
logger.info("Unzipping the artifact")
|
||||
shutil.unpack_archive(artifact_chromadb_local_path, "chroma_db")
|
||||
|
||||
# Load data from ChromaDB
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
collection = chroma_client.get_collection(name=collection_name)
|
||||
|
||||
# Formulate a question
|
||||
question = args.query
|
||||
|
||||
if args.chat_model_provider == "deepseek":
|
||||
# Initialize DeepSeek model
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=DEEKSEEK_API_KEY
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "gemini":
|
||||
# Initialize Gemini model
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
google_api_key=GEMINI_API_KEY,
|
||||
temperature=0,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "moonshot":
|
||||
# Initialize Moonshot model
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=MOONSHOT_API_KEY
|
||||
)
|
||||
|
||||
|
||||
# Chain of Thought Prompt
|
||||
cot_template = """Let's think step by step.
|
||||
Given the following document in text: {documents_text}
|
||||
Question: {question}
|
||||
Reply with language that is similar to the language used with asked question.
|
||||
"""
|
||||
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||
cot_chain = cot_prompt | llm
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model = SentenceTransformer(args.embedding_model)
|
||||
|
||||
# Query (prompt)
|
||||
query_embedding = model.encode(question) # Embed the query using the SAME model
|
||||
|
||||
# Search ChromaDB
|
||||
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
|
||||
|
||||
# Generate chain of thought
|
||||
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": question})
|
||||
print("Chain of Thought: ", cot_output)
|
||||
|
||||
# Answer Prompt
|
||||
answer_template = """Given the chain of thought: {cot}
|
||||
Provide a concise answer to the question: {question}
|
||||
Provide the answer with language that is similar to the question asked.
|
||||
"""
|
||||
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
||||
answer_chain = answer_prompt | llm
|
||||
|
||||
# Generate answer
|
||||
answer_output = answer_chain.invoke({"cot": cot_output, "question": question})
|
||||
print("Answer: ", answer_output)
|
||||
|
||||
run.finish()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
|
||||
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
help="Question to ask the model",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_chromadb_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the chromadb artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chat_model_provider",
|
||||
type=str,
|
||||
default="gemini",
|
||||
help="Chat model provider"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
@ -1,4 +1,4 @@
|
||||
name: etl_chromdb_pdf
|
||||
name: etl_chromadb_pdf
|
||||
python_env: python_env.yml
|
||||
|
||||
entry_points:
|
||||
@ -10,7 +10,11 @@ build_dependencies:
|
||||
- pdfminer.six
|
||||
- langchain
|
||||
- sentence_transformers
|
||||
- langchain-text-splitters
|
||||
- langchain_huggingface
|
||||
- langchain-community
|
||||
- tiktoken
|
||||
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
- wandb==0.16.0
|
||||
- mlflow==2.8.1
|
||||
202
app/llmops/src/etl_chromadb_pdf/run.py
Normal file
202
app/llmops/src/etl_chromadb_pdf/run.py
Normal file
@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
|
||||
"""
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import mlflow
|
||||
import shutil
|
||||
|
||||
import io
|
||||
from pdfminer.converter import TextConverter
|
||||
from pdfminer.pdfinterp import PDFPageInterpreter
|
||||
from pdfminer.pdfinterp import PDFResourceManager
|
||||
from pdfminer.pdfpage import PDFPage
|
||||
from langchain.schema import Document
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from langchain_community.vectorstores.chroma import Chroma
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def extract_chinese_text_from_pdf(pdf_path):
|
||||
"""
|
||||
Extracts Chinese text from a PDF file.
|
||||
|
||||
Args:
|
||||
pdf_path (str): The path to the PDF file.
|
||||
|
||||
Returns:
|
||||
str: The extracted Chinese text, or None if an error occurs.
|
||||
"""
|
||||
resource_manager = PDFResourceManager()
|
||||
fake_file_handle = io.StringIO()
|
||||
converter = TextConverter(resource_manager, fake_file_handle)
|
||||
page_interpreter = PDFPageInterpreter(resource_manager, converter)
|
||||
|
||||
try:
|
||||
with open(pdf_path, 'rb') as fh:
|
||||
for page in PDFPage.get_pages(fh, caching=True, check_extractable=True):
|
||||
page_interpreter.process_page(page)
|
||||
|
||||
text = fake_file_handle.getvalue()
|
||||
|
||||
return text
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: PDF file not found at {pdf_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
finally:
|
||||
converter.close()
|
||||
fake_file_handle.close()
|
||||
|
||||
|
||||
def go(args):
|
||||
"""
|
||||
Run the etl for chromdb with scanned pdf
|
||||
"""
|
||||
|
||||
# Start an MLflow run
|
||||
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
|
||||
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||
if 'output_description' not in existing_params:
|
||||
mlflow.log_param('output_description', args.output_description)
|
||||
|
||||
# Log parameters to MLflow
|
||||
mlflow.log_params({
|
||||
"input_artifact": args.input_artifact,
|
||||
"output_artifact": args.output_artifact,
|
||||
"output_type": args.output_type,
|
||||
"embedding_model": args.embedding_model
|
||||
})
|
||||
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model_embedding = HuggingFaceEmbeddings(model_name=args.embedding_model) # Or a multilingual model
|
||||
|
||||
|
||||
# Create database, delete the database directory if it exists
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
if os.path.exists(db_path):
|
||||
shutil.rmtree(db_path)
|
||||
os.makedirs(db_path)
|
||||
|
||||
|
||||
logger.info("Downloading artifact")
|
||||
artifact_local_path = mlflow.artifacts.download_artifacts(artifact_uri=args.input_artifact)
|
||||
|
||||
logger.info("Reading data")
|
||||
|
||||
# unzip the downloaded artifact
|
||||
import zipfile
|
||||
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(".")
|
||||
|
||||
# show the unzipped folder
|
||||
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=15000, chunk_overlap=7500
|
||||
)
|
||||
|
||||
# read the dictionary json for word replacement in the read text
|
||||
with open(f'./{documents_folder}/2023CACA/CACA英文缩写.json', 'r', encoding='utf-8') as f:
|
||||
df_dict_json = json.load(f)
|
||||
|
||||
ls_docs = []
|
||||
pdf_files = glob.glob(f"./{documents_folder}/**/*.pdf", recursive=True)
|
||||
|
||||
for pdf_file in pdf_files:
|
||||
read_text = extract_chinese_text_from_pdf(pdf_file)
|
||||
relative_path = os.path.relpath(pdf_file, start=f"./{documents_folder}")
|
||||
|
||||
# if the parent directory of the pdf file is 2023CACA, then replace the shortform text with the dictionary value
|
||||
if '2023CACA' in relative_path:
|
||||
# get the pdf filename without the extension
|
||||
pdf_filename = os.path.splitext(os.path.basename(pdf_file))[0]
|
||||
# replace the text with the dictionary
|
||||
dict_file = df_dict_json.get(pdf_filename)
|
||||
if dict_file:
|
||||
for key, value in dict_file.items():
|
||||
read_text = read_text.replace(key, value)
|
||||
|
||||
|
||||
document = Document(metadata={"file": relative_path}, page_content=read_text)
|
||||
ls_docs.append(document)
|
||||
|
||||
|
||||
|
||||
doc_splits = text_splitter.split_documents(ls_docs)
|
||||
|
||||
# Add to vectorDB
|
||||
_vectorstore = Chroma.from_documents(
|
||||
documents=doc_splits,
|
||||
collection_name="rag-chroma",
|
||||
embedding=model_embedding,
|
||||
persist_directory=db_path
|
||||
)
|
||||
|
||||
logger.info("Logging artifact with mlflow")
|
||||
shutil.make_archive(db_path, 'zip', db_path)
|
||||
mlflow.log_artifact(db_path + '.zip', args.output_artifact)
|
||||
|
||||
# clean up
|
||||
os.remove(db_path + '.zip')
|
||||
shutil.rmtree(db_path)
|
||||
shutil.rmtree(documents_folder)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="ETL for ChromaDB with readable PDF")
|
||||
|
||||
parser.add_argument(
|
||||
"--input_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the input artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_artifact",
|
||||
type=str,
|
||||
help="Name for the output artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_type",
|
||||
type=str,
|
||||
help="Type for the artifact output",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_description",
|
||||
type=str,
|
||||
help="Description for the artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
@ -1,4 +1,4 @@
|
||||
name: etl_chromdb_scanned_pdf
|
||||
name: etl_chromadb_scanned_pdf
|
||||
python_env: python_env.yml
|
||||
|
||||
entry_points:
|
||||
@ -14,4 +14,3 @@ build_dependencies:
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
- wandb==0.16.0
|
||||
160
app/llmops/src/etl_chromadb_scanned_pdf/run.py
Normal file
160
app/llmops/src/etl_chromadb_scanned_pdf/run.py
Normal file
@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import mlflow
|
||||
import shutil
|
||||
|
||||
import chromadb
|
||||
# from openai import OpenAI
|
||||
import pytesseract as pt
|
||||
from pdf2image import convert_from_path
|
||||
from langchain.schema import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def extract_text_from_pdf_ocr(pdf_path):
|
||||
try:
|
||||
images = convert_from_path(pdf_path) # Convert PDF pages to images
|
||||
extracted_text = ""
|
||||
for image in images:
|
||||
text = pt.image_to_string(image, lang="chi_sim+eng") # chi_sim for Simplified Chinese, chi_tra for Traditional
|
||||
|
||||
extracted_text += text + "\n"
|
||||
return extracted_text
|
||||
|
||||
except ImportError:
|
||||
print("Error: pdf2image or pytesseract not installed. Please install them: pip install pdf2image pytesseract")
|
||||
return ""
|
||||
except Exception as e:
|
||||
print(f"OCR failed: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
def go(args):
|
||||
"""
|
||||
Run the etl for chromdb with scanned pdf
|
||||
"""
|
||||
|
||||
# Start an MLflow run
|
||||
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
|
||||
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||
if 'output_description' not in existing_params:
|
||||
mlflow.log_param('output_description', args.output_description)
|
||||
|
||||
# Log parameters to MLflow
|
||||
mlflow.log_params({
|
||||
"input_artifact": args.input_artifact,
|
||||
"output_artifact": args.output_artifact,
|
||||
"output_type": args.output_type,
|
||||
"embedding_model": args.embedding_model
|
||||
})
|
||||
|
||||
|
||||
# Initialize embedding model
|
||||
model_embedding = SentenceTransformer(args.embedding_model) # Or a multilingual model
|
||||
|
||||
|
||||
# Create database, delete the database directory if it exists
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
if os.path.exists(db_path):
|
||||
shutil.rmtree(db_path)
|
||||
os.makedirs(db_path)
|
||||
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
db = chroma_client.create_collection(name=collection_name)
|
||||
|
||||
|
||||
logger.info("Downloading artifact")
|
||||
artifact_local_path = mlflow.artifacts.download_artifacts(artifact_uri=args.input_artifact)
|
||||
|
||||
logger.info("Reading data")
|
||||
|
||||
# unzip the downloaded artifact
|
||||
import zipfile
|
||||
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(".")
|
||||
|
||||
# show the unzipped folder
|
||||
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||
|
||||
for root, _dir, files in os.walk(f"./{documents_folder}"):
|
||||
for file in files:
|
||||
if file.endswith(".pdf"):
|
||||
read_text = extract_text_from_pdf_ocr(os.path.join(root, file))
|
||||
document = Document(page_content=read_text)
|
||||
all_splits = text_splitter.split_documents([document])
|
||||
|
||||
for i, split in enumerate(all_splits):
|
||||
db.add(documents=[split.page_content],
|
||||
metadatas=[{"filename": file}],
|
||||
ids=[f'{file[:-4]}-{str(i)}'],
|
||||
embeddings=[model_embedding.encode(split.page_content)]
|
||||
)
|
||||
|
||||
logger.info("Uploading artifact to MLFlow")
|
||||
shutil.make_archive(db_path, 'zip', db_path)
|
||||
mlflow.log_artifact(db_path + '.zip', args.output_artifact)
|
||||
|
||||
# clean up
|
||||
os.remove(db_path + '.zip')
|
||||
shutil.rmtree(db_path)
|
||||
shutil.rmtree(documents_folder)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
||||
|
||||
parser.add_argument(
|
||||
"--input_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the input artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_artifact",
|
||||
type=str,
|
||||
help="Name for the output artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_type",
|
||||
type=str,
|
||||
help="Type for the artifact output",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_description",
|
||||
type=str,
|
||||
help="Description for the artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
@ -1,184 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import wandb
|
||||
import shutil
|
||||
|
||||
import chromadb
|
||||
# from openai import OpenAI
|
||||
import io
|
||||
from pdfminer.converter import TextConverter
|
||||
from pdfminer.pdfinterp import PDFPageInterpreter
|
||||
from pdfminer.pdfinterp import PDFResourceManager
|
||||
from pdfminer.pdfpage import PDFPage
|
||||
from langchain.schema import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def extract_chinese_text_from_pdf(pdf_path):
|
||||
"""
|
||||
Extracts Chinese text from a PDF file.
|
||||
|
||||
Args:
|
||||
pdf_path (str): The path to the PDF file.
|
||||
|
||||
Returns:
|
||||
str: The extracted Chinese text, or None if an error occurs.
|
||||
"""
|
||||
resource_manager = PDFResourceManager()
|
||||
fake_file_handle = io.StringIO()
|
||||
converter = TextConverter(resource_manager, fake_file_handle)
|
||||
page_interpreter = PDFPageInterpreter(resource_manager, converter)
|
||||
|
||||
try:
|
||||
with open(pdf_path, 'rb') as fh:
|
||||
for page in PDFPage.get_pages(fh, caching=True, check_extractable=True):
|
||||
page_interpreter.process_page(page)
|
||||
|
||||
text = fake_file_handle.getvalue()
|
||||
|
||||
return text
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: PDF file not found at {pdf_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
finally:
|
||||
converter.close()
|
||||
fake_file_handle.close()
|
||||
|
||||
|
||||
def go(args):
|
||||
"""
|
||||
Run the etl for chromdb with scanned pdf
|
||||
"""
|
||||
|
||||
run = wandb.init(job_type="etl_chromdb_scanned_pdf", entity='aimingmed')
|
||||
run.config.update(args)
|
||||
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model_embedding = SentenceTransformer(args.embedding_model) # Or a multilingual model
|
||||
|
||||
|
||||
# Create database, delete the database directory if it exists
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
if os.path.exists(db_path):
|
||||
shutil.rmtree(db_path)
|
||||
os.makedirs(db_path)
|
||||
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
db = chroma_client.create_collection(name=collection_name)
|
||||
|
||||
|
||||
logger.info("Downloading artifact")
|
||||
artifact_local_path = run.use_artifact(args.input_artifact).file()
|
||||
|
||||
logger.info("Reading data")
|
||||
|
||||
# unzip the downloaded artifact
|
||||
import zipfile
|
||||
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(".")
|
||||
os.remove(artifact_local_path)
|
||||
|
||||
# show the unzipped folder
|
||||
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||
|
||||
for root, _dir, files in os.walk(f"./{documents_folder}"):
|
||||
for file in files:
|
||||
if file.endswith(".pdf"):
|
||||
read_text = extract_chinese_text_from_pdf(os.path.join(root, file))
|
||||
document = Document(page_content=read_text)
|
||||
all_splits = text_splitter.split_documents([document])
|
||||
|
||||
for i, split in enumerate(all_splits):
|
||||
db.add(documents=[split.page_content],
|
||||
metadatas=[{"filename": file}],
|
||||
ids=[f'{file[:-4]}-{str(i)}'],
|
||||
embeddings=[model_embedding.encode(split.page_content)]
|
||||
)
|
||||
|
||||
# Create a new artifact
|
||||
artifact = wandb.Artifact(
|
||||
args.output_artifact,
|
||||
type=args.output_type,
|
||||
description=args.output_description
|
||||
)
|
||||
|
||||
# zip the database folder first
|
||||
shutil.make_archive(db_path, 'zip', db_path)
|
||||
|
||||
# Add the database to the artifact
|
||||
artifact.add_file(db_path + '.zip')
|
||||
|
||||
# Log the artifact
|
||||
run.log_artifact(artifact)
|
||||
|
||||
# Finish the run
|
||||
run.finish()
|
||||
|
||||
# clean up
|
||||
os.remove(db_path + '.zip')
|
||||
os.remove(db_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
||||
|
||||
parser.add_argument(
|
||||
"--input_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the input artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_artifact",
|
||||
type=str,
|
||||
help="Name for the output artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_type",
|
||||
type=str,
|
||||
help="Type for the artifact output",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_description",
|
||||
type=str,
|
||||
help="Description for the artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
@ -1,173 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import wandb
|
||||
import shutil
|
||||
|
||||
import chromadb
|
||||
# from openai import OpenAI
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import pytesseract as pt
|
||||
from pdf2image import convert_from_path
|
||||
from langchain.schema import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def extract_text_from_pdf_ocr(pdf_path):
|
||||
try:
|
||||
images = convert_from_path(pdf_path) # Convert PDF pages to images
|
||||
extracted_text = ""
|
||||
for image in images:
|
||||
text = pt.image_to_string(image, lang="chi_sim+eng") # chi_sim for Simplified Chinese, chi_tra for Traditional
|
||||
|
||||
extracted_text += text + "\n"
|
||||
return extracted_text
|
||||
|
||||
except ImportError:
|
||||
print("Error: pdf2image or pytesseract not installed. Please install them: pip install pdf2image pytesseract")
|
||||
return ""
|
||||
except Exception as e:
|
||||
print(f"OCR failed: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
def go(args):
|
||||
"""
|
||||
Run the etl for chromdb with scanned pdf
|
||||
"""
|
||||
|
||||
run = wandb.init(job_type="etl_chromdb_scanned_pdf", entity='aimingmed')
|
||||
run.config.update(args)
|
||||
|
||||
# Setup the Gemini client
|
||||
# client = OpenAI(
|
||||
# api_key=args.gemini_api_key,
|
||||
# base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
# )
|
||||
|
||||
|
||||
# def get_google_embedding(text: str) -> List[float]:
|
||||
# response = client.embeddings.create(
|
||||
# model="text-embedding-004",
|
||||
# input=text
|
||||
# )
|
||||
# return response.data[0].embedding
|
||||
|
||||
# class GeminiEmbeddingFunction(object):
|
||||
# def __init__(self, api_key: str, base_url: str, model_name: str):
|
||||
# self.client = OpenAI(
|
||||
# api_key=args.gemini_api_key,
|
||||
# base_url=base_url
|
||||
# )
|
||||
# self.model_name = model_name
|
||||
|
||||
# def __call__(self, input: List[str]) -> List[List[float]]:
|
||||
# all_embeddings = []
|
||||
# for text in input:
|
||||
# response = self.client.embeddings.create(input=text, model=self.model_name)
|
||||
# embeddings = [record.embedding for record in response.data]
|
||||
# all_embeddings.append(np.array(embeddings[0]))
|
||||
# return all_embeddings
|
||||
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model_embedding = SentenceTransformer('all-mpnet-base-v2') # Or a multilingual model
|
||||
|
||||
|
||||
# Create database, delete the database directory if it exists
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
if os.path.exists(db_path):
|
||||
shutil.rmtree(db_path)
|
||||
os.makedirs(db_path)
|
||||
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
db = chroma_client.create_collection(name=collection_name)
|
||||
|
||||
|
||||
logger.info("Downloading artifact")
|
||||
artifact_local_path = run.use_artifact(args.input_artifact).file()
|
||||
|
||||
logger.info("Reading data")
|
||||
|
||||
# unzip the downloaded artifact
|
||||
import zipfile
|
||||
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(".")
|
||||
os.remove(artifact_local_path)
|
||||
|
||||
# show the unzipped folder
|
||||
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||
|
||||
for root, _dir, files in os.walk(f"./{documents_folder}"):
|
||||
for file in files:
|
||||
if file.endswith(".pdf"):
|
||||
read_text = extract_text_from_pdf_ocr(os.path.join(root, file))
|
||||
document = Document(page_content=read_text)
|
||||
all_splits = text_splitter.split_documents([document])
|
||||
|
||||
for i, split in enumerate(all_splits):
|
||||
db.add(documents=[split.page_content],
|
||||
metadatas=[{"filename": file}],
|
||||
ids=[f'{file[:-4]}-{str(i)}'],
|
||||
embeddings=[model_embedding.encode(split.page_content)]
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
||||
|
||||
parser.add_argument(
|
||||
"--input_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the input artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_artifact",
|
||||
type=str,
|
||||
help="Name for the output artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_type",
|
||||
type=str,
|
||||
help="Type for the artifact output",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_description",
|
||||
type=str,
|
||||
help="Description for the artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
49
app/llmops/src/rag_adaptive_evaluation/MLproject
Normal file
49
app/llmops/src/rag_adaptive_evaluation/MLproject
Normal file
@ -0,0 +1,49 @@
|
||||
name: rag_adaptive_evaluation
|
||||
python_env: python_env.yml
|
||||
|
||||
entry_points:
|
||||
main:
|
||||
parameters:
|
||||
|
||||
query:
|
||||
description: Query to run
|
||||
type: string
|
||||
|
||||
evaluation_dataset_csv_path:
|
||||
description: query evaluation dataset csv path
|
||||
type: string
|
||||
|
||||
evaluation_dataset_column_question:
|
||||
description: query evaluation dataset column question
|
||||
type: string
|
||||
|
||||
evaluation_dataset_column_answer:
|
||||
description: query evaluation dataset column groundtruth
|
||||
type: string
|
||||
|
||||
input_chromadb_artifact:
|
||||
description: Fully-qualified name for the input artifact
|
||||
type: string
|
||||
|
||||
embedding_model:
|
||||
description: Fully-qualified name for the embedding model
|
||||
type: string
|
||||
|
||||
chat_model_provider:
|
||||
description: Fully-qualified name for the chat model provider
|
||||
type: string
|
||||
|
||||
ls_chat_model_evaluator:
|
||||
description: list of chat model providers for evaluation
|
||||
type: string
|
||||
|
||||
|
||||
command: >-
|
||||
python run.py --query {query} \
|
||||
--evaluation_dataset_csv_path {evaluation_dataset_csv_path} \
|
||||
--evaluation_dataset_column_question {evaluation_dataset_column_question} \
|
||||
--evaluation_dataset_column_answer {evaluation_dataset_column_answer} \
|
||||
--input_chromadb_artifact {input_chromadb_artifact} \
|
||||
--embedding_model {embedding_model} \
|
||||
--chat_model_provider {chat_model_provider} \
|
||||
--ls_chat_model_evaluator {ls_chat_model_evaluator}
|
||||
32
app/llmops/src/rag_adaptive_evaluation/data_models.py
Normal file
32
app/llmops/src/rag_adaptive_evaluation/data_models.py
Normal file
@ -0,0 +1,32 @@
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RouteQuery(BaseModel):
|
||||
"""Route a user query to the most relevant datasource."""
|
||||
|
||||
datasource: Literal["vectorstore", "web_search"] = Field(
|
||||
...,
|
||||
description="Given a user question choose to route it to web search or a vectorstore.",
|
||||
)
|
||||
|
||||
class GradeDocuments(BaseModel):
|
||||
"""Binary score for relevance check on retrieved documents."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Documents are relevant to the question, 'yes' or 'no'"
|
||||
)
|
||||
|
||||
class GradeHallucinations(BaseModel):
|
||||
"""Binary score for hallucination present in generation answer."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Answer is grounded in the facts, 'yes' or 'no'"
|
||||
)
|
||||
|
||||
class GradeAnswer(BaseModel):
|
||||
"""Binary score to assess answer addresses question."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Answer addresses the question, 'yes' or 'no'"
|
||||
)
|
||||
141
app/llmops/src/rag_adaptive_evaluation/evaluators.py
Normal file
141
app/llmops/src/rag_adaptive_evaluation/evaluators.py
Normal file
@ -0,0 +1,141 @@
|
||||
import os
|
||||
from decouple import config
|
||||
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_community.llms.moonshot import Moonshot
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from prompts_library import CORRECTNESS_PROMPT, FAITHFULNESS_PROMPT
|
||||
|
||||
os.environ["GOOGLE_API_KEY"] = config("GOOGLE_API_KEY", cast=str)
|
||||
os.environ["DEEPSEEK_API_KEY"] = config("DEEPSEEK_API_KEY", cast=str)
|
||||
os.environ["MOONSHOT_API_KEY"] = config("MOONSHOT_API_KEY", cast=str)
|
||||
|
||||
|
||||
# Define output schema for the evaluation
|
||||
class CorrectnessGrade(BaseModel):
|
||||
score: int = Field(description="Numerical score (1-5) indicating the correctness of the response.")
|
||||
|
||||
class FaithfulnessGrade(BaseModel):
|
||||
score: int = Field(description="Numerical score (1-5) indicating the faithfulness of the response.")
|
||||
|
||||
|
||||
|
||||
# Evaluators
|
||||
def gemini_evaluator_correctness(outputs: dict, reference_outputs: dict) -> CorrectnessGrade:
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": CORRECTNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Ground Truth answer: {reference_outputs["answer"]};
|
||||
Student's Answer: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
return CorrectnessGrade(score=int(response.content)).score
|
||||
|
||||
|
||||
def deepseek_evaluator_correctness(outputs: dict, reference_outputs: dict) -> CorrectnessGrade:
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": CORRECTNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Ground Truth answer: {reference_outputs["answer"]};
|
||||
Student's Answer: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
return CorrectnessGrade(score=int(response.content)).score
|
||||
|
||||
|
||||
def moonshot_evaluator_correctness(outputs: dict, reference_outputs: dict) -> CorrectnessGrade:
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": CORRECTNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Ground Truth answer: {reference_outputs["answer"]};
|
||||
Student's Answer: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
try:
|
||||
return CorrectnessGrade(score=int(response)).score
|
||||
except ValueError:
|
||||
score_str = response.split(":")[1].strip()
|
||||
return CorrectnessGrade(score=int(score_str)).score
|
||||
|
||||
|
||||
def gemini_evaluator_faithfulness(outputs: dict, reference_outputs: dict) -> FaithfulnessGrade:
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-pro",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": FAITHFULNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Context: {reference_outputs["answer"]};
|
||||
Output: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
return FaithfulnessGrade(score=int(response.content)).score
|
||||
|
||||
|
||||
def deepseek_evaluator_faithfulness(outputs: dict, reference_outputs: dict) -> FaithfulnessGrade:
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": FAITHFULNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Context: {reference_outputs["answer"]};
|
||||
Output: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
return FaithfulnessGrade(score=int(response.content)).score
|
||||
|
||||
|
||||
def moonshot_evaluator_faithfulness(outputs: dict, reference_outputs: dict) -> FaithfulnessGrade:
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": FAITHFULNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Context: {reference_outputs["answer"]};
|
||||
Output: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
try:
|
||||
return FaithfulnessGrade(score=int(response)).score
|
||||
except ValueError:
|
||||
score_str = response.split(":")[1].strip()
|
||||
return FaithfulnessGrade(score=int(score_str)).score
|
||||
|
||||
98
app/llmops/src/rag_adaptive_evaluation/prompts_library.py
Normal file
98
app/llmops/src/rag_adaptive_evaluation/prompts_library.py
Normal file
@ -0,0 +1,98 @@
|
||||
system_router = """You are an expert at routing a user question to a vectorstore or web search.
|
||||
The vectorstore contains documents related to any cancer/tumor disease. The question may be
|
||||
asked in a variety of languages, and may be phrased in a variety of ways.
|
||||
Use the vectorstore for questions on these topics. Otherwise, use web-search.
|
||||
"""
|
||||
|
||||
system_retriever_grader = """You are a grader assessing relevance of a retrieved document to a user question. \n
|
||||
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
|
||||
You must make sure to read carefully that the document contains a sentence or chunk of sentences that is exactly related but not closely related to the question subject (e.g. must be the exact disease or subject in question). \n
|
||||
The goal is to filter out erroneous retrievals. \n
|
||||
Must return a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
|
||||
|
||||
system_hallucination_grader = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
|
||||
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
|
||||
|
||||
system_answer_grader = """You are a grader assessing whether an answer addresses / resolves a question \n
|
||||
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
|
||||
|
||||
system_question_rewriter = """You a question re-writer that converts an input question to a better version that is optimized \n
|
||||
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
|
||||
|
||||
# prompt for question answering based on retrieved documents
|
||||
qa_prompt_template = """You are an expert at answering questions based on the following retrieved context.\n
|
||||
Before answering the question, you must have your own thought process what are the general scopes to cover when answering this question, step-by-step. Do not include this thought process in the answer.\n
|
||||
Then, given your thought process, you must read the provided context carefully and extract the relevant information.\n
|
||||
|
||||
If the question is about medical question, you must answer the question in a medical way and assume that the audience is a junior doctor or a medical student: \n
|
||||
1. For cancer diseases, you must include comprehensive treatment advices that encompasses multidisciplinary treatment options that included but not limited to surgery, chemotherapy, radiology, internal medicine (drugs), nutritional ratio (protein), etc. You must layout out the treatment options like what are the first-line, second-line treatment etc.\n
|
||||
2. For cancer diseases, don't consider context that is not primary tumor/cancer related, unless the question specifically mention it is secondary tumor/cancer related.\n
|
||||
3. If the question didn't state the stage of the cancer disease, you must reply with treatment options for each stage of the cancer disease, if they are availalbe in the provided context. If they are not available in the provided context, give a general one.\n
|
||||
|
||||
You must not use any information that is not present in the provided context to answer the question. Make sure to remove those information not present in the provided context.\n
|
||||
If you don't know the answer, just say that you don't know.\n
|
||||
Provide the answer in a concise and organized manner. \n
|
||||
|
||||
Question: {question} \n
|
||||
Context: {context} \n
|
||||
Answer:
|
||||
"""
|
||||
|
||||
|
||||
# Evaluation
|
||||
CORRECTNESS_PROMPT = """You are an impartial judge. Evaluate Student Answer against Ground Truth for conceptual similarity and correctness.
|
||||
You may also be given additional information that was used by the model to generate the output.
|
||||
|
||||
Your task is to determine a numerical score called correctness based on the Student Answer and Ground Truth.
|
||||
A definition of correctness and a grading rubric are provided below.
|
||||
You must use the grading rubric to determine your score.
|
||||
|
||||
Metric definition:
|
||||
Correctness assesses the degree to which a provided Student Answer aligns with factual accuracy, completeness, logical
|
||||
consistency, and precise terminology of the Ground Truth. It evaluates the intrinsic validity of the Student Answer , independent of any
|
||||
external context. A higher score indicates a higher adherence to factual accuracy, completeness, logical consistency,
|
||||
and precise terminology of the Ground Truth.
|
||||
|
||||
Grading rubric:
|
||||
Correctness: Below are the details for different scores:
|
||||
- 1: Major factual errors, highly incomplete, illogical, and uses incorrect terminology.
|
||||
- 2: Significant factual errors, incomplete, noticeable logical flaws, and frequent terminology errors.
|
||||
- 3: Minor factual errors, somewhat incomplete, minor logical inconsistencies, and occasional terminology errors.
|
||||
- 4: Few to no factual errors, mostly complete, strong logical consistency, and accurate terminology.
|
||||
- 5: Accurate, complete, logically consistent, and uses precise terminology.
|
||||
|
||||
Reminder:
|
||||
- Carefully read the Student Answer and Ground Truth
|
||||
- Check for factual accuracy and completeness of Student Answer compared to the Ground Truth
|
||||
- Focus on correctness of information rather than style or verbosity
|
||||
- The goal is to evaluate factual correctness and completeness of the Student Answer.
|
||||
- Please provide your answer score only with the numerical number between 1 and 5. No score: or other text is allowed.
|
||||
|
||||
"""
|
||||
|
||||
FAITHFULNESS_PROMPT = """You are an impartial judge. Evaluate output against context for faithfulness.
|
||||
You may also be given additional information that was used by the model to generate the Output.
|
||||
|
||||
Your task is to determine a numerical score called faithfulness based on the output and context.
|
||||
A definition of faithfulness and a grading rubric are provided below.
|
||||
You must use the grading rubric to determine your score.
|
||||
|
||||
Metric definition:
|
||||
Faithfulness is only evaluated with the provided output and context. Faithfulness assesses how much of the
|
||||
provided output is factually consistent with the provided context. A higher score indicates that a higher proportion of
|
||||
claims present in the output can be derived from the provided context. Faithfulness does not consider how much extra
|
||||
information from the context is not present in the output.
|
||||
|
||||
Grading rubric:
|
||||
Faithfulness: Below are the details for different scores:
|
||||
- Score 1: None of the claims in the output can be inferred from the provided context.
|
||||
- Score 2: Some of the claims in the output can be inferred from the provided context, but the majority of the output is missing from, inconsistent with, or contradictory to the provided context.
|
||||
- Score 3: Half or more of the claims in the output can be inferred from the provided context.
|
||||
- Score 4: Most of the claims in the output can be inferred from the provided context, with very little information that is not directly supported by the provided context.
|
||||
- Score 5: All of the claims in the output are directly supported by the provided context, demonstrating high faithfulness to the provided context.
|
||||
|
||||
Reminder:
|
||||
- Carefully read the output and context
|
||||
- Focus on the information instead of the writing style or verbosity.
|
||||
- Please provide your answer score only with the numerical number between 1 and 5, according to the grading rubric above. No score: or other text is allowed.
|
||||
"""
|
||||
29
app/llmops/src/rag_adaptive_evaluation/python_env.yml
Normal file
29
app/llmops/src/rag_adaptive_evaluation/python_env.yml
Normal file
@ -0,0 +1,29 @@
|
||||
# Python version required to run the project.
|
||||
python: "3.11.11"
|
||||
# Dependencies required to build packages. This field is optional.
|
||||
build_dependencies:
|
||||
- pip==23.3.1
|
||||
- setuptools
|
||||
- wheel==0.37.1
|
||||
- chromadb
|
||||
- langchain
|
||||
- sentence_transformers
|
||||
- python-decouple
|
||||
- langchain_google_genai
|
||||
- langchain-deepseek
|
||||
- langchain-openai
|
||||
- langchain-community
|
||||
- mlflow[genai]
|
||||
- langsmith
|
||||
- openai
|
||||
- tiktoken
|
||||
- langchainhub
|
||||
- langgraph
|
||||
- langchain-text-splitters
|
||||
- langchain-cohere
|
||||
- tavily-python
|
||||
- langchain_huggingface
|
||||
- pydantic
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
608
app/llmops/src/rag_adaptive_evaluation/run.py
Normal file
608
app/llmops/src/rag_adaptive_evaluation/run.py
Normal file
@ -0,0 +1,608 @@
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import mlflow
|
||||
import shutil
|
||||
import langsmith
|
||||
|
||||
from decouple import config
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_community.llms.moonshot import Moonshot
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from langchain_community.vectorstores.chroma import Chroma
|
||||
|
||||
|
||||
from typing import List
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_community.tools.tavily_search import TavilySearchResults
|
||||
from langchain.prompts import PromptTemplate, HumanMessagePromptTemplate
|
||||
|
||||
from langchain.schema import Document
|
||||
from pprint import pprint
|
||||
from langgraph.graph import END, StateGraph, START
|
||||
from langsmith import Client
|
||||
|
||||
|
||||
from data_models import (
|
||||
RouteQuery,
|
||||
GradeDocuments,
|
||||
GradeHallucinations,
|
||||
GradeAnswer
|
||||
)
|
||||
from prompts_library import (
|
||||
system_router,
|
||||
system_retriever_grader,
|
||||
system_hallucination_grader,
|
||||
system_answer_grader,
|
||||
system_question_rewriter,
|
||||
qa_prompt_template
|
||||
)
|
||||
|
||||
from evaluators import (
|
||||
gemini_evaluator_correctness,
|
||||
deepseek_evaluator_correctness,
|
||||
moonshot_evaluator_correctness,
|
||||
gemini_evaluator_faithfulness,
|
||||
deepseek_evaluator_faithfulness,
|
||||
moonshot_evaluator_faithfulness
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["GOOGLE_API_KEY"] = config("GOOGLE_API_KEY", cast=str)
|
||||
os.environ["DEEPSEEK_API_KEY"] = config("DEEPSEEK_API_KEY", cast=str)
|
||||
os.environ["MOONSHOT_API_KEY"] = config("MOONSHOT_API_KEY", cast=str)
|
||||
os.environ["TAVILY_API_KEY"] = config("TAVILY_API_KEY", cast=str)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["LANGSMITH_API_KEY"] = config("LANGSMITH_API_KEY", cast=str)
|
||||
os.environ["LANGSMITH_TRACING"] = config("LANGSMITH_TRACING", cast=str)
|
||||
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
|
||||
os.environ["LANGSMITH_PROJECT"] = config("LANGSMITH_PROJECT", cast=str)
|
||||
|
||||
def go(args):
|
||||
|
||||
# start a new MLflow run
|
||||
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
|
||||
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||
if 'query' not in existing_params:
|
||||
mlflow.log_param('query', args.query)
|
||||
|
||||
# Log parameters to MLflow
|
||||
mlflow.log_params({
|
||||
"input_chromadb_artifact": args.input_chromadb_artifact,
|
||||
"embedding_model": args.embedding_model,
|
||||
"chat_model_provider": args.chat_model_provider
|
||||
})
|
||||
|
||||
|
||||
logger.info("Downloading chromadb artifact")
|
||||
artifact_chromadb_local_path = mlflow.artifacts.download_artifacts(artifact_uri=args.input_chromadb_artifact)
|
||||
|
||||
# unzip the artifact
|
||||
logger.info("Unzipping the artifact")
|
||||
shutil.unpack_archive(artifact_chromadb_local_path, "chroma_db")
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
embedding_model = HuggingFaceEmbeddings(model_name=args.embedding_model)
|
||||
if args.chat_model_provider == 'deepseek':
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
)
|
||||
elif args.chat_model_provider == 'gemini':
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
temperature=0,
|
||||
max_retries=3,
|
||||
streaming=True
|
||||
)
|
||||
elif args.chat_model_provider == 'moonshot':
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# Load data from ChromaDB
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
collection_name = "rag-chroma"
|
||||
vectorstore = Chroma(persist_directory=db_path, collection_name=collection_name, embedding_function=embedding_model)
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
##########################################
|
||||
# Routing to vectorstore or web search
|
||||
structured_llm_router = llm.with_structured_output(RouteQuery)
|
||||
# Prompt
|
||||
route_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_router),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
question_router = route_prompt | structured_llm_router
|
||||
|
||||
##########################################
|
||||
### Retrieval Grader
|
||||
structured_llm_grader = llm.with_structured_output(GradeDocuments)
|
||||
# Prompt
|
||||
grade_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_retriever_grader),
|
||||
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
|
||||
]
|
||||
)
|
||||
retrieval_grader = grade_prompt | structured_llm_grader
|
||||
|
||||
##########################################
|
||||
### Generate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
# Create a PromptTemplate with the given prompt
|
||||
new_prompt_template = PromptTemplate(
|
||||
input_variables=["context", "question"],
|
||||
template=qa_prompt_template,
|
||||
)
|
||||
|
||||
# Create a new HumanMessagePromptTemplate with the new PromptTemplate
|
||||
new_human_message_prompt_template = HumanMessagePromptTemplate(
|
||||
prompt=new_prompt_template
|
||||
)
|
||||
prompt_qa = ChatPromptTemplate.from_messages([new_human_message_prompt_template])
|
||||
|
||||
# Chain
|
||||
rag_chain = prompt_qa | llm | StrOutputParser()
|
||||
|
||||
|
||||
##########################################
|
||||
### Hallucination Grader
|
||||
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
|
||||
|
||||
# Prompt
|
||||
hallucination_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_hallucination_grader),
|
||||
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
|
||||
]
|
||||
)
|
||||
|
||||
hallucination_grader = hallucination_prompt | structured_llm_grader
|
||||
|
||||
##########################################
|
||||
### Answer Grader
|
||||
structured_llm_grader = llm.with_structured_output(GradeAnswer)
|
||||
|
||||
# Prompt
|
||||
answer_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_answer_grader),
|
||||
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
|
||||
]
|
||||
)
|
||||
answer_grader = answer_prompt | structured_llm_grader
|
||||
|
||||
##########################################
|
||||
### Question Re-writer
|
||||
# Prompt
|
||||
re_write_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_question_rewriter),
|
||||
(
|
||||
"human",
|
||||
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
|
||||
),
|
||||
]
|
||||
)
|
||||
question_rewriter = re_write_prompt | llm | StrOutputParser()
|
||||
|
||||
|
||||
### Search
|
||||
web_search_tool = TavilySearchResults(k=3)
|
||||
|
||||
class GraphState(TypedDict):
|
||||
"""
|
||||
Represents the state of our graph.
|
||||
|
||||
Attributes:
|
||||
question: question
|
||||
generation: LLM generation
|
||||
documents: list of documents
|
||||
"""
|
||||
|
||||
question: str
|
||||
generation: str
|
||||
documents: List[str]
|
||||
|
||||
|
||||
|
||||
def retrieve(state):
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE---")
|
||||
question = state["question"]
|
||||
|
||||
# Retrieval
|
||||
documents = retriever.invoke(question)
|
||||
|
||||
print(documents)
|
||||
return {"documents": documents, "question": question}
|
||||
|
||||
|
||||
def generate(state):
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, generation, that contains LLM generation
|
||||
"""
|
||||
print("---GENERATE---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
|
||||
# RAG generation
|
||||
generation = rag_chain.invoke({"context": documents, "question": question})
|
||||
return {"documents": documents, "question": question, "generation": generation}
|
||||
|
||||
|
||||
def grade_documents(state):
|
||||
"""
|
||||
Determines whether the retrieved documents are relevant to the question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): Updates documents key with only filtered relevant documents
|
||||
"""
|
||||
|
||||
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
|
||||
# Score each doc
|
||||
filtered_docs = []
|
||||
for d in documents:
|
||||
score = retrieval_grader.invoke(
|
||||
{"question": question, "document": d.page_content}
|
||||
)
|
||||
grade = score.binary_score
|
||||
if grade == "yes":
|
||||
print("---GRADE: DOCUMENT RELEVANT---")
|
||||
filtered_docs.append(d)
|
||||
else:
|
||||
print("---GRADE: DOCUMENT NOT RELEVANT---")
|
||||
continue
|
||||
return {"documents": filtered_docs, "question": question}
|
||||
|
||||
|
||||
def transform_query(state):
|
||||
"""
|
||||
Transform the query to produce a better question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): Updates question key with a re-phrased question
|
||||
"""
|
||||
|
||||
print("---TRANSFORM QUERY---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
|
||||
# Re-write question
|
||||
better_question = question_rewriter.invoke({"question": question})
|
||||
return {"documents": documents, "question": better_question}
|
||||
|
||||
|
||||
def web_search(state):
|
||||
"""
|
||||
Web search based on the re-phrased question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): Updates documents key with appended web results
|
||||
"""
|
||||
|
||||
print("---WEB SEARCH---")
|
||||
question = state["question"]
|
||||
|
||||
# Web search
|
||||
docs = web_search_tool.invoke({"query": question})
|
||||
web_results = "\n".join([d["content"] for d in docs])
|
||||
web_results = Document(page_content=web_results)
|
||||
|
||||
return {"documents": web_results, "question": question}
|
||||
|
||||
|
||||
### Edges ###
|
||||
def route_question(state):
|
||||
"""
|
||||
Route question to web search or RAG.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
str: Next node to call
|
||||
"""
|
||||
|
||||
print("---ROUTE QUESTION---")
|
||||
question = state["question"]
|
||||
source = question_router.invoke({"question": question})
|
||||
if source.datasource == "web_search":
|
||||
print("---ROUTE QUESTION TO WEB SEARCH---")
|
||||
return "web_search"
|
||||
elif source.datasource == "vectorstore":
|
||||
print("---ROUTE QUESTION TO RAG---")
|
||||
return "vectorstore"
|
||||
|
||||
|
||||
def decide_to_generate(state):
|
||||
"""
|
||||
Determines whether to generate an answer, or re-generate a question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
str: Binary decision for next node to call
|
||||
"""
|
||||
|
||||
print("---ASSESS GRADED DOCUMENTS---")
|
||||
state["question"]
|
||||
filtered_documents = state["documents"]
|
||||
|
||||
if not filtered_documents:
|
||||
# All documents have been filtered check_relevance
|
||||
# We will re-generate a new query
|
||||
print(
|
||||
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
|
||||
)
|
||||
return "transform_query"
|
||||
else:
|
||||
# We have relevant documents, so generate answer
|
||||
print("---DECISION: GENERATE---")
|
||||
return "generate"
|
||||
|
||||
|
||||
def grade_generation_v_documents_and_question(state):
|
||||
"""
|
||||
Determines whether the generation is grounded in the document and answers question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
str: Decision for next node to call
|
||||
"""
|
||||
|
||||
print("---CHECK HALLUCINATIONS---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
generation = state["generation"]
|
||||
|
||||
score = hallucination_grader.invoke(
|
||||
{"documents": documents, "generation": generation}
|
||||
)
|
||||
grade = score.binary_score
|
||||
|
||||
# Check hallucination
|
||||
if grade == "yes":
|
||||
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
||||
# Check question-answering
|
||||
print("---GRADE GENERATION vs QUESTION---")
|
||||
score = answer_grader.invoke({"question": question, "generation": generation})
|
||||
grade = score.binary_score
|
||||
if grade == "yes":
|
||||
print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
||||
return "useful"
|
||||
else:
|
||||
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
||||
return "not useful"
|
||||
else:
|
||||
pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
||||
return "not supported"
|
||||
|
||||
workflow = StateGraph(GraphState)
|
||||
|
||||
# Define the nodes
|
||||
workflow.add_node("web_search", web_search) # web search
|
||||
workflow.add_node("retrieve", retrieve) # retrieve
|
||||
workflow.add_node("grade_documents", grade_documents) # grade documents
|
||||
workflow.add_node("generate", generate) # generatae
|
||||
workflow.add_node("transform_query", transform_query) # transform_query
|
||||
|
||||
# Build graph
|
||||
workflow.add_conditional_edges(
|
||||
START,
|
||||
route_question,
|
||||
{
|
||||
"web_search": "web_search",
|
||||
"vectorstore": "retrieve",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("web_search", "generate")
|
||||
workflow.add_edge("retrieve", "grade_documents")
|
||||
workflow.add_conditional_edges(
|
||||
"grade_documents",
|
||||
decide_to_generate,
|
||||
{
|
||||
"transform_query": "transform_query",
|
||||
"generate": "generate",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("transform_query", "retrieve")
|
||||
workflow.add_conditional_edges(
|
||||
"generate",
|
||||
grade_generation_v_documents_and_question,
|
||||
{
|
||||
"not supported": "generate",
|
||||
"useful": END,
|
||||
"not useful": "transform_query",
|
||||
},
|
||||
)
|
||||
|
||||
# Compile
|
||||
app = workflow.compile()
|
||||
|
||||
# Run
|
||||
inputs = {
|
||||
"question": args.query
|
||||
}
|
||||
for output in app.stream(inputs):
|
||||
for key, value in output.items():
|
||||
# Node
|
||||
pprint(f"Node '{key}':")
|
||||
# Optional: print full state at each node
|
||||
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
|
||||
pprint("\n---\n")
|
||||
|
||||
# Final generation
|
||||
print(value["generation"])
|
||||
|
||||
return {"response": value["generation"]}
|
||||
|
||||
def go_evaluation(args):
|
||||
if args.evaluation_dataset_csv_path:
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_csv(args.evaluation_dataset_csv_path)
|
||||
dataset_name = os.path.basename(args.evaluation_dataset_csv_path).split('.')[0]
|
||||
|
||||
# df contains columns of question and answer
|
||||
examples = df[[args.evaluation_dataset_column_question, args.evaluation_dataset_column_answer]].values.tolist()
|
||||
inputs = [{"question": input_prompt} for input_prompt, _ in examples]
|
||||
outputs = [{"answer": output_answer} for _, output_answer in examples]
|
||||
|
||||
# Programmatically create a dataset in LangSmith
|
||||
client = Client()
|
||||
|
||||
try:
|
||||
# Create a dataset
|
||||
dataset = client.create_dataset(
|
||||
dataset_name = dataset_name,
|
||||
description = "An evaluation dataset in LangSmith."
|
||||
)
|
||||
# Add examples to the dataset
|
||||
client.create_examples(inputs=inputs, outputs=outputs, dataset_id=dataset.id)
|
||||
except langsmith.utils.LangSmithConflictError:
|
||||
pass
|
||||
|
||||
|
||||
args.ls_chat_model_evaluator = None if args.ls_chat_model_evaluator == 'None' else args.ls_chat_model_evaluator.split(',')
|
||||
|
||||
def target(inputs: dict) -> dict:
|
||||
new_args = argparse.Namespace(**vars(args))
|
||||
new_args.query = inputs["question"]
|
||||
return go(new_args)
|
||||
|
||||
ls_evaluators = []
|
||||
if args.ls_chat_model_evaluator:
|
||||
for evaluator in args.ls_chat_model_evaluator:
|
||||
if evaluator == 'moonshot':
|
||||
ls_evaluators.append(moonshot_evaluator_correctness)
|
||||
ls_evaluators.append(moonshot_evaluator_faithfulness)
|
||||
elif evaluator == 'deepseek':
|
||||
ls_evaluators.append(deepseek_evaluator_correctness)
|
||||
ls_evaluators.append(deepseek_evaluator_faithfulness)
|
||||
elif evaluator == 'gemini':
|
||||
ls_evaluators.append(gemini_evaluator_correctness)
|
||||
ls_evaluators.append(gemini_evaluator_faithfulness)
|
||||
|
||||
# After running the evaluation, a link will be provided to view the results in langsmith
|
||||
_ = client.evaluate(
|
||||
target,
|
||||
data = dataset_name,
|
||||
evaluators = ls_evaluators,
|
||||
experiment_prefix = "first-eval-in-langsmith",
|
||||
max_concurrency = 1,
|
||||
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Adaptive AG")
|
||||
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
help="Question to ask the model",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--evaluation_dataset_csv_path",
|
||||
type=str,
|
||||
help="Path to the query evaluation dataset",
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--evaluation_dataset_column_question",
|
||||
type=str,
|
||||
help="Column name for the questions in the evaluation dataset",
|
||||
default="question",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--evaluation_dataset_column_answer",
|
||||
type=str,
|
||||
help="Column name for the groundtruth answers in the evaluation dataset",
|
||||
default="groundtruth",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_chromadb_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the chromadb artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chat_model_provider",
|
||||
type=str,
|
||||
default="gemini",
|
||||
help="Chat model provider"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ls_chat_model_evaluator",
|
||||
type=str,
|
||||
help="list of Chat model providers for evaluation",
|
||||
required=False,
|
||||
default="None"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go_evaluation(args)
|
||||
@ -1,4 +1,4 @@
|
||||
name: chain_of_thought
|
||||
name: rag_cot
|
||||
python_env: python_env.yml
|
||||
|
||||
entry_points:
|
||||
18
app/llmops/src/rag_cot_evaluation/python_env.yml
Normal file
18
app/llmops/src/rag_cot_evaluation/python_env.yml
Normal file
@ -0,0 +1,18 @@
|
||||
# Python version required to run the project.
|
||||
python: "3.11.11"
|
||||
# Dependencies required to build packages. This field is optional.
|
||||
build_dependencies:
|
||||
- pip==23.3.1
|
||||
- setuptools
|
||||
- wheel==0.37.1
|
||||
- chromadb
|
||||
- langchain
|
||||
- sentence_transformers
|
||||
- python-decouple
|
||||
- langchain_google_genai
|
||||
- langchain-deepseek
|
||||
- langchain-community
|
||||
- mlflow[genai]
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
156
app/llmops/src/rag_cot_evaluation/run.py
Normal file
156
app/llmops/src/rag_cot_evaluation/run.py
Normal file
@ -0,0 +1,156 @@
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import mlflow
|
||||
import chromadb
|
||||
import shutil
|
||||
from decouple import config
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_community.llms.moonshot import Moonshot
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["GOOGLE_API_KEY"] = config("GOOGLE_API_KEY", cast=str)
|
||||
os.environ["DEEPSEEK_API_KEY"] = config("DEEPSEEK_API_KEY", cast=str)
|
||||
os.environ["MOONSHOT_API_KEY"] = config("MOONSHOT_API_KEY", cast=str)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["LANGSMITH_API_KEY"] = config("LANGSMITH_API_KEY", cast=str)
|
||||
os.environ["LANGSMITH_TRACING"] = config("LANGSMITH_TRACING", cast=str)
|
||||
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
|
||||
os.environ["LANGSMITH_PROJECT"] = config("LANGSMITH_PROJECT", cast=str)
|
||||
|
||||
def go(args):
|
||||
|
||||
# start a new MLflow run
|
||||
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
|
||||
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
||||
if 'query' not in existing_params:
|
||||
mlflow.log_param('query', args.query)
|
||||
|
||||
# Log parameters to MLflow
|
||||
mlflow.log_params({
|
||||
"input_chromadb_artifact": args.input_chromadb_artifact,
|
||||
"embedding_model": args.embedding_model,
|
||||
"chat_model_provider": args.chat_model_provider
|
||||
})
|
||||
|
||||
|
||||
logger.info("Downloading chromadb artifact")
|
||||
artifact_chromadb_local_path = mlflow.artifacts.download_artifacts(artifact_uri=args.input_chromadb_artifact)
|
||||
|
||||
# unzip the artifact
|
||||
logger.info("Unzipping the artifact")
|
||||
shutil.unpack_archive(artifact_chromadb_local_path, "chroma_db")
|
||||
|
||||
# Load data from ChromaDB
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag-chroma"
|
||||
collection = chroma_client.get_collection(name=collection_name)
|
||||
|
||||
# Formulate a question
|
||||
question = args.query
|
||||
|
||||
if args.chat_model_provider == "deepseek":
|
||||
# Initialize DeepSeek model
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "gemini":
|
||||
# Initialize Gemini model
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
temperature=0,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "moonshot":
|
||||
# Initialize Moonshot model
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
|
||||
# Chain of Thought Prompt
|
||||
cot_template = """Let's think step by step.
|
||||
Given the following document in text: {documents_text}
|
||||
Question: {question}
|
||||
Reply with language that is similar to the language used with asked question.
|
||||
"""
|
||||
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||
cot_chain = cot_prompt | llm
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model = SentenceTransformer(args.embedding_model)
|
||||
|
||||
# Query (prompt)
|
||||
query_embedding = model.encode(question) # Embed the query using the SAME model
|
||||
|
||||
# Search ChromaDB
|
||||
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
|
||||
|
||||
# Generate chain of thought
|
||||
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": question})
|
||||
print("Chain of Thought: ", cot_output)
|
||||
|
||||
# Answer Prompt
|
||||
answer_template = """Given the chain of thought: {cot}
|
||||
Provide a concise answer to the question: {question}
|
||||
Provide the answer with language that is similar to the question asked.
|
||||
"""
|
||||
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
||||
answer_chain = answer_prompt | llm
|
||||
|
||||
# Generate answer
|
||||
answer_output = answer_chain.invoke({"cot": cot_output, "question": question})
|
||||
print("Answer: ", answer_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
|
||||
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
help="Question to ask the model",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_chromadb_artifact",
|
||||
type=str,
|
||||
help="Fully-qualified name for the chromadb artifact",
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chat_model_provider",
|
||||
type=str,
|
||||
default="gemini",
|
||||
help="Chat model provider"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
112
app/streamlit/Chatbot.py
Normal file
112
app/streamlit/Chatbot.py
Normal file
@ -0,0 +1,112 @@
|
||||
import os
|
||||
import streamlit as st
|
||||
import chromadb
|
||||
from decouple import config
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_community.llms.moonshot import Moonshot
|
||||
|
||||
import torch
|
||||
torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)]
|
||||
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str, default="123456")
|
||||
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str, default="123456")
|
||||
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str, default="123456")
|
||||
CHAT_MODEL_PROVIDER = config("CHAT_MODEL_PROVIDER", cast=str, default="gemini")
|
||||
INPUT_CHROMADB_LOCAL = config("INPUT_CHROMADB_LOCAL", cast=str, default="../llmops/src/rag_cot_evaluation/chroma_db")
|
||||
EMBEDDING_MODEL = config("EMBEDDING_MODEL", cast=str, default="paraphrase-multilingual-mpnet-base-v2")
|
||||
COLLECTION_NAME = config("COLLECTION_NAME", cast=str, default="rag-chroma")
|
||||
|
||||
st.title("💬 RAG AI for Medical Guideline")
|
||||
st.caption(f"🚀 A RAG AI for Medical Guideline powered by {CHAT_MODEL_PROVIDER}")
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
||||
for msg in st.session_state.messages:
|
||||
st.chat_message(msg["role"]).write(msg["content"])
|
||||
|
||||
# Load data from ChromaDB
|
||||
chroma_client = chromadb.PersistentClient(path=INPUT_CHROMADB_LOCAL)
|
||||
collection = chroma_client.get_collection(name=COLLECTION_NAME)
|
||||
|
||||
# Initialize embedding model
|
||||
model = SentenceTransformer(EMBEDDING_MODEL)
|
||||
|
||||
if CHAT_MODEL_PROVIDER == "deepseek":
|
||||
# Initialize DeepSeek model
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=DEEKSEEK_API_KEY
|
||||
)
|
||||
|
||||
elif CHAT_MODEL_PROVIDER == "gemini":
|
||||
# Initialize Gemini model
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
google_api_key=GEMINI_API_KEY,
|
||||
temperature=0,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
elif CHAT_MODEL_PROVIDER == "moonshot":
|
||||
# Initialize Moonshot model
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=MOONSHOT_API_KEY
|
||||
)
|
||||
|
||||
# Chain of Thought Prompt
|
||||
cot_template = """Let's think step by step.
|
||||
Given the following document in text: {documents_text}
|
||||
Question: {question}
|
||||
Reply with language that is similar to the language used with asked question.
|
||||
"""
|
||||
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
|
||||
cot_chain = cot_prompt | llm
|
||||
|
||||
# Answer Prompt
|
||||
answer_template = """Given the chain of thought: {cot}
|
||||
Provide a concise answer to the question: {question}
|
||||
Provide the answer with language that is similar to the question asked.
|
||||
"""
|
||||
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
|
||||
answer_chain = answer_prompt | llm
|
||||
|
||||
if prompt := st.chat_input():
|
||||
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
st.chat_message("user").write(prompt)
|
||||
|
||||
# Query (prompt)
|
||||
query_embedding = model.encode(prompt) # Embed the query using the SAME model
|
||||
|
||||
# Search ChromaDB
|
||||
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
|
||||
|
||||
# Generate chain of thought
|
||||
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": prompt})
|
||||
|
||||
# response = client.chat.completions.create(model="gpt-3.5-turbo", messages=st.session_state.messages)
|
||||
msg = cot_output.content
|
||||
st.session_state.messages.append({"role": "assistant", "content": msg})
|
||||
st.chat_message("assistant").write(msg)
|
||||
|
||||
# Generate answer
|
||||
answer_output = answer_chain.invoke({"cot": cot_output, "question": prompt})
|
||||
msg = answer_output.content
|
||||
st.session_state.messages.append({"role": "assistant", "content": msg})
|
||||
st.chat_message("assistant").write(msg)
|
||||
|
||||
|
||||
|
||||
24
app/streamlit/Dockerfile
Normal file
24
app/streamlit/Dockerfile
Normal file
@ -0,0 +1,24 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app/streamlit
|
||||
|
||||
COPY Pipfile ./
|
||||
|
||||
# RUN pip install --no-cache-dir -r requirements.txt
|
||||
# RUN pip install -r requirements.txt
|
||||
RUN pip install --upgrade pip setuptools wheel -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
RUN pip install pipenv -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
RUN pipenv install --deploy
|
||||
|
||||
COPY Chatbot.py .
|
||||
COPY .env .
|
||||
|
||||
# Run python to initialize download of SentenceTransformer model
|
||||
COPY initialize_sentence_transformer.py .
|
||||
RUN pipenv run python initialize_sentence_transformer.py
|
||||
|
||||
COPY pages ./pages
|
||||
|
||||
EXPOSE 8501
|
||||
|
||||
ENTRYPOINT ["pipenv", "run", "streamlit", "run", "Chatbot.py"]
|
||||
29
app/streamlit/Pipfile
Normal file
29
app/streamlit/Pipfile
Normal file
@ -0,0 +1,29 @@
|
||||
[[source]]
|
||||
url = "https://pypi.org/simple"
|
||||
verify_ssl = true
|
||||
name = "pypi"
|
||||
|
||||
[packages]
|
||||
streamlit = "*"
|
||||
langchain = "*"
|
||||
duckduckgo-search = "*"
|
||||
anthropic = "*"
|
||||
trubrics = "*"
|
||||
streamlit-feedback = "*"
|
||||
langchain-community = "*"
|
||||
watchdog = "*"
|
||||
mlflow = "*"
|
||||
python-decouple = "*"
|
||||
langchain_google_genai = "*"
|
||||
langchain-deepseek = "*"
|
||||
sentence_transformers = "*"
|
||||
chromadb = "*"
|
||||
|
||||
[dev-packages]
|
||||
pytest = "==8.0.0"
|
||||
pytest-cov = "==4.1.0"
|
||||
pytest-mock = "==3.10.0"
|
||||
pytest-asyncio = "*"
|
||||
|
||||
[requires]
|
||||
python_version = "3.11"
|
||||
4165
app/streamlit/Pipfile.lock
generated
Normal file
4165
app/streamlit/Pipfile.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
17
app/streamlit/app_test.py
Normal file
17
app/streamlit/app_test.py
Normal file
@ -0,0 +1,17 @@
|
||||
from unittest.mock import patch
|
||||
from streamlit.testing.v1 import AppTest
|
||||
|
||||
|
||||
|
||||
@patch("langchain.llms.OpenAI.__call__")
|
||||
def test_Langchain_Quickstart(langchain_llm):
|
||||
at = AppTest.from_file("pages/3_Langchain_Quickstart.py").run()
|
||||
assert at.info[0].value == "Please add your OpenAI API key to continue."
|
||||
|
||||
RESPONSE = "1. The best way to learn how to code is by practicing..."
|
||||
langchain_llm.return_value = RESPONSE
|
||||
at.sidebar.text_input[0].set_value("sk-...")
|
||||
at.button[0].set_value(True).run()
|
||||
print(at)
|
||||
assert at.info[0].value == RESPONSE
|
||||
|
||||
9
app/streamlit/initialize_sentence_transformer.py
Normal file
9
app/streamlit/initialize_sentence_transformer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from decouple import config
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
EMBEDDING_MODEL = config("EMBEDDING_MODEL", cast=str, default="paraphrase-multilingual-mpnet-base-v2")
|
||||
|
||||
# Initialize embedding model
|
||||
model = SentenceTransformer(EMBEDDING_MODEL)
|
||||
|
||||
model.save("./transformer_model/paraphrase-multilingual-mpnet-base-v2")
|
||||
22
app/streamlit/pages/3_Langchain_Quickstart.py
Normal file
22
app/streamlit/pages/3_Langchain_Quickstart.py
Normal file
@ -0,0 +1,22 @@
|
||||
import streamlit as st
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
st.title("🦜🔗 Langchain Quickstart App")
|
||||
|
||||
with st.sidebar:
|
||||
openai_api_key = st.text_input("OpenAI API Key", type="password")
|
||||
"[Get an OpenAI API key](https://platform.openai.com/account/api-keys)"
|
||||
|
||||
|
||||
def generate_response(input_text):
|
||||
llm = OpenAI(temperature=0.7, openai_api_key=openai_api_key)
|
||||
st.info(llm(input_text))
|
||||
|
||||
|
||||
with st.form("my_form"):
|
||||
text = st.text_area("Enter text:", "What are 3 key advice for learning how to code?")
|
||||
submitted = st.form_submit_button("Submit")
|
||||
if not openai_api_key:
|
||||
st.info("Please add your OpenAI API key to continue.")
|
||||
elif submitted:
|
||||
generate_response(text)
|
||||
5
app/streamlit/requirements-dev.txt
Normal file
5
app/streamlit/requirements-dev.txt
Normal file
@ -0,0 +1,5 @@
|
||||
black==23.3.0
|
||||
mypy==1.4.1
|
||||
pre-commit==3.3.3
|
||||
watchdog
|
||||
pytest
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user