mirror of
https://github.com/RYDE-WORK/visual-med-alpaca.git
synced 2026-01-19 14:28:49 +08:00
gradio-example
This commit is contained in:
parent
2d49088f49
commit
ebcf783330
217
code/gradio-example/app.py
Normal file
217
code/gradio-example/app.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
import gradio as gr
|
||||||
|
import transformers
|
||||||
|
import numpy as np
|
||||||
|
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
|
||||||
|
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||||
|
|
||||||
|
auth_username = os.environ["AUTH_USERNAME"]
|
||||||
|
auth_password = os.environ["AUTH_PASSWORD"]
|
||||||
|
cambridgeltl_access_token = os.environ['CAMBRIDGELTL_ACCESS_TOKEN']
|
||||||
|
|
||||||
|
## med-alpaca
|
||||||
|
|
||||||
|
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
|
||||||
|
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained("cambridgeltl/med-alpaca-fp16", use_auth_token=cambridgeltl_access_token)
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
"cambridgeltl/med-alpaca-fp16",
|
||||||
|
use_auth_token=cambridgeltl_access_token,
|
||||||
|
load_in_8bit=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
# model.half()
|
||||||
|
else:
|
||||||
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
"cambridgeltl/med-alpaca-fp16", use_auth_token=cambridgeltl_access_token, device_map={"": device}, low_cpu_mem_usage=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
if torch.__version__ >= "2":
|
||||||
|
model = torch.compile(model)
|
||||||
|
|
||||||
|
## OpenAI models
|
||||||
|
openai.api_key = os.environ.get("OPENAI_TOKEN", None)
|
||||||
|
def set_openai_api_key(api_key):
|
||||||
|
if api_key and api_key.startswith("sk-") and len(api_key) > 50:
|
||||||
|
openai.api_key = api_key
|
||||||
|
|
||||||
|
def get_response_from_openai(prompt, model="gpt-3.5-turbo", max_output_tokens=512):
|
||||||
|
messages = [{"role": "assistant", "content": prompt}]
|
||||||
|
response = openai.ChatCompletion.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=max_output_tokens,
|
||||||
|
top_p=1,
|
||||||
|
frequency_penalty=0,
|
||||||
|
presence_penalty=0,
|
||||||
|
)
|
||||||
|
ret = response.choices[0].message['content']
|
||||||
|
return ret
|
||||||
|
|
||||||
|
torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
||||||
|
|
||||||
|
## deplot models
|
||||||
|
model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch_dtype).to(device)
|
||||||
|
processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
|
||||||
|
## med-git models
|
||||||
|
model_med_git = AutoModelForCausalLM.from_pretrained('cambridgeltl/med-git-base', use_auth_token=cambridgeltl_access_token, torch_dtype=torch_dtype).to(device)
|
||||||
|
processor_med_git = AutoProcessor.from_pretrained('cambridgeltl/med-git-base', use_auth_token=cambridgeltl_access_token)
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
table,
|
||||||
|
question,
|
||||||
|
llm="med-alpaca",
|
||||||
|
input=None,
|
||||||
|
temperature=0.1,
|
||||||
|
top_p=0.75,
|
||||||
|
top_k=40,
|
||||||
|
num_beams=4,
|
||||||
|
max_new_tokens=512,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
prompt_input = f"Below is an instruction that describes a task, paired with an input that provides further context of an uploaded image. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Input:\n{table}\n\n### Response:\n"
|
||||||
|
prompt_no_input = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Response:\n"
|
||||||
|
|
||||||
|
prompt = prompt_input if len(table) > 0 else prompt_no_input
|
||||||
|
|
||||||
|
output = "UNKNOWN ERROR"
|
||||||
|
if llm == "med-alpaca":
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
input_ids = inputs["input_ids"].to(device)
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
num_beams=num_beams,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
generation_output = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
generation_config=generation_config,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
)
|
||||||
|
s = generation_output.sequences[0]
|
||||||
|
output = tokenizer.decode(s)
|
||||||
|
output = output.split("### Response:")[1].strip()
|
||||||
|
elif llm == "gpt-3.5-turbo":
|
||||||
|
try:
|
||||||
|
output = get_response_from_openai(prompt)
|
||||||
|
except:
|
||||||
|
output = "<Remember to input your OpenAI API key ☺>"
|
||||||
|
else:
|
||||||
|
RuntimeError(f"No such LLM: {llm}")
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def deplot(image, question, llm):
|
||||||
|
# image = Image.open(image)
|
||||||
|
inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(device, torch_dtype)
|
||||||
|
predictions = model_deplot.generate(**inputs, max_new_tokens=512)
|
||||||
|
table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def med_git(image, question, llm):
|
||||||
|
# image = Image.open(image)
|
||||||
|
inputs = processor_med_git(images=image, return_tensors="pt").to(device)
|
||||||
|
pixel_values = inputs.pixel_values.to(torch_dtype)
|
||||||
|
generated_ids = model_med_git.generate(pixel_values=pixel_values, max_length=512)
|
||||||
|
captions = processor_med_git.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
|
return captions
|
||||||
|
|
||||||
|
|
||||||
|
def process_document(image, question, llm):
|
||||||
|
# image = Image.open(image)
|
||||||
|
if image:
|
||||||
|
if np.mean(image) >= 128:
|
||||||
|
table = deplot(image, question, llm)
|
||||||
|
else:
|
||||||
|
table = med_git(image, question, llm)
|
||||||
|
else:
|
||||||
|
table = ""
|
||||||
|
|
||||||
|
# send prompt+table to LLM
|
||||||
|
res = evaluate(table, question, llm=llm)
|
||||||
|
return [table, res]
|
||||||
|
|
||||||
|
|
||||||
|
theme = gr.themes.Monochrome(
|
||||||
|
primary_hue="indigo",
|
||||||
|
secondary_hue="blue",
|
||||||
|
neutral_hue="slate",
|
||||||
|
radius_size=gr.themes.sizes.radius_sm,
|
||||||
|
font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Blocks(theme=theme) as demo:
|
||||||
|
with gr.Column():
|
||||||
|
gr.Markdown(
|
||||||
|
"""<h1><center>Visual Med-Alpaca: Bridging Modalities in Biomedical Language Models</center></h1>
|
||||||
|
<p>
|
||||||
|
This is a demo of Visual Med-Alpaca for multi-modal medical foundation model. To use it, simply upload your image and type a question or instruction and click 'submit'.
|
||||||
|
</p>
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
|
||||||
|
#input_image.style(height=512, width=512)
|
||||||
|
instruction = gr.Textbox(placeholder="Enter your instruction/question...", label="Question/Instruction")
|
||||||
|
llm = gr.Dropdown(["med-alpaca", "gpt-3.5-turbo"], label="LLM")
|
||||||
|
openai_api_key_textbox = gr.Textbox(value='',
|
||||||
|
placeholder="Paste your OpenAI API key (sk-...) and hit Enter (if using OpenAI models, otherwise leave empty)",
|
||||||
|
show_label=False, lines=1, type='password')
|
||||||
|
submit = gr.Button("Submit", variant="primary")
|
||||||
|
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
with gr.Accordion("Show intermediate table", open=False):
|
||||||
|
output_table = gr.Textbox(lines=8, label="Intermediate Table")
|
||||||
|
output_text = gr.Textbox(lines=8, label="Output")
|
||||||
|
|
||||||
|
gr.Examples(
|
||||||
|
examples=[
|
||||||
|
[None, "what are the chemicals that treat hair loss?", "med-alpaca"],
|
||||||
|
["case_study_1.jpg", "what is seen in the x-ray and what should be done?", "med-alpaca"],
|
||||||
|
["case_study_2.jpg", "how effective is this treatment on papule?", "med-alpaca"],
|
||||||
|
["case_study_3.png", "is absorbance related to number of cells?", "med-alpaca"],
|
||||||
|
],
|
||||||
|
cache_examples=False,
|
||||||
|
inputs=[input_image, instruction, llm],
|
||||||
|
outputs=[output_table, output_text],
|
||||||
|
fn=process_document
|
||||||
|
)
|
||||||
|
|
||||||
|
# gr.Markdown(
|
||||||
|
# """<p style='text-align: center'><a href='https://arxiv.org/abs/2212.10505' target='_blank'>DePlot: One-shot visual language reasoning by plot-to-table translation</a></p>"""
|
||||||
|
# )
|
||||||
|
openai.api_key = ""
|
||||||
|
openai_api_key_textbox.change(set_openai_api_key,
|
||||||
|
inputs=[openai_api_key_textbox],
|
||||||
|
outputs=[])
|
||||||
|
openai_api_key_textbox.submit(set_openai_api_key,
|
||||||
|
inputs=[openai_api_key_textbox],
|
||||||
|
outputs=[])
|
||||||
|
submit.click(process_document, inputs=[input_image, instruction, llm], outputs=[output_table, output_text])
|
||||||
|
instruction.submit(
|
||||||
|
process_document, inputs=[input_image, instruction, llm], outputs=[output_table, output_text]
|
||||||
|
)
|
||||||
|
|
||||||
|
demo.queue(concurrency_count=1).launch(auth=(auth_username, auth_password))
|
||||||
BIN
code/gradio-example/case_study_1.jpg
Normal file
BIN
code/gradio-example/case_study_1.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 46 KiB |
BIN
code/gradio-example/case_study_2.jpg
Normal file
BIN
code/gradio-example/case_study_2.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 35 KiB |
BIN
code/gradio-example/case_study_3.png
Normal file
BIN
code/gradio-example/case_study_3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 575 KiB |
10
code/gradio-example/requirements.txt
Normal file
10
code/gradio-example/requirements.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
torch
|
||||||
|
git+https://github.com/huggingface/transformers
|
||||||
|
datasets
|
||||||
|
loralib
|
||||||
|
sentencepiece
|
||||||
|
accelerate
|
||||||
|
bitsandbytes
|
||||||
|
git+https://github.com/huggingface/peft.git
|
||||||
|
gradio
|
||||||
|
openai
|
||||||
Loading…
x
Reference in New Issue
Block a user