mirror of
https://github.com/RYDE-WORK/visual-med-alpaca.git
synced 2026-02-08 17:56:26 +08:00
training code and data
This commit is contained in:
parent
478e4f7f55
commit
21b0c7417d
201
code/med-alpaca-lora/LICENSE
Normal file
201
code/med-alpaca-lora/LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
6
code/med-alpaca-lora/finetune-med.sh
Executable file
6
code/med-alpaca-lora/finetune-med.sh
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
python finetune.py \
|
||||||
|
--base_model 'decapoda-research/llama-7b-hf' \
|
||||||
|
--data_path '/path/to/med_alpaca_data_clean.json' \
|
||||||
|
--micro_batch_size 32 \
|
||||||
|
--output_dir './med-alpaca-lora'
|
||||||
|
|
||||||
217
code/med-alpaca-lora/finetune.py
Normal file
217
code/med-alpaca-lora/finetune.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
from datasets import load_dataset
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
||||||
|
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
||||||
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||||
|
from peft import (
|
||||||
|
prepare_model_for_int8_training,
|
||||||
|
LoraConfig,
|
||||||
|
get_peft_model,
|
||||||
|
get_peft_model_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
# model/data params
|
||||||
|
base_model: str = "", # the only required argument
|
||||||
|
data_path: str = "./alpaca_data_cleaned.json",
|
||||||
|
output_dir: str = "./lora-alpaca",
|
||||||
|
# training hyperparams
|
||||||
|
batch_size: int = 128,
|
||||||
|
micro_batch_size: int = 4,
|
||||||
|
num_epochs: int = 3,
|
||||||
|
learning_rate: float = 3e-4,
|
||||||
|
cutoff_len: int = 512,
|
||||||
|
val_set_size: int = 2000,
|
||||||
|
# lora hyperparams
|
||||||
|
lora_r: int = 8,
|
||||||
|
lora_alpha: int = 16,
|
||||||
|
lora_dropout: float = 0.05,
|
||||||
|
lora_target_modules: List[str] = [
|
||||||
|
"q_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
# llm hyperparams
|
||||||
|
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
||||||
|
group_by_length: bool = True, # faster, but produces an odd training loss curve
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"Training Alpaca-LoRA model with params:\n"
|
||||||
|
f"base_model: {base_model}\n"
|
||||||
|
f"data_path: {data_path}\n"
|
||||||
|
f"output_dir: {output_dir}\n"
|
||||||
|
f"batch_size: {batch_size}\n"
|
||||||
|
f"micro_batch_size: {micro_batch_size}\n"
|
||||||
|
f"num_epochs: {num_epochs}\n"
|
||||||
|
f"learning_rate: {learning_rate}\n"
|
||||||
|
f"cutoff_len: {cutoff_len}\n"
|
||||||
|
f"val_set_size: {val_set_size}\n"
|
||||||
|
f"lora_r: {lora_r}\n"
|
||||||
|
f"lora_alpha: {lora_alpha}\n"
|
||||||
|
f"lora_dropout: {lora_dropout}\n"
|
||||||
|
f"lora_target_modules: {lora_target_modules}\n"
|
||||||
|
f"train_on_inputs: {train_on_inputs}\n"
|
||||||
|
f"group_by_length: {group_by_length}\n"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
base_model
|
||||||
|
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
||||||
|
gradient_accumulation_steps = batch_size // micro_batch_size
|
||||||
|
|
||||||
|
device_map = "auto"
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
ddp = world_size != 1
|
||||||
|
if ddp:
|
||||||
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
||||||
|
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
||||||
|
|
||||||
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
load_in_8bit=True,
|
||||||
|
device_map=device_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||||
|
|
||||||
|
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
|
||||||
|
tokenizer.padding_side = "left" # Allow batched inference
|
||||||
|
|
||||||
|
def tokenize(prompt, add_eos_token=True):
|
||||||
|
# there's probably a way to do this with the tokenizer settings
|
||||||
|
# but again, gotta move fast
|
||||||
|
result = tokenizer(
|
||||||
|
prompt,
|
||||||
|
truncation=True,
|
||||||
|
max_length=cutoff_len,
|
||||||
|
padding=False,
|
||||||
|
return_tensors=None,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
result["input_ids"][-1] != tokenizer.eos_token_id
|
||||||
|
and len(result["input_ids"]) < cutoff_len
|
||||||
|
and add_eos_token
|
||||||
|
):
|
||||||
|
result["input_ids"].append(tokenizer.eos_token_id)
|
||||||
|
result["attention_mask"].append(1)
|
||||||
|
|
||||||
|
result["labels"] = result["input_ids"].copy()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def generate_and_tokenize_prompt(data_point):
|
||||||
|
full_prompt = generate_prompt(data_point)
|
||||||
|
tokenized_full_prompt = tokenize(full_prompt)
|
||||||
|
if not train_on_inputs:
|
||||||
|
user_prompt = generate_prompt({**data_point, "output": ""})
|
||||||
|
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
|
||||||
|
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||||
|
|
||||||
|
tokenized_full_prompt["labels"] = [
|
||||||
|
-100
|
||||||
|
] * user_prompt_len + tokenized_full_prompt["labels"][
|
||||||
|
user_prompt_len:
|
||||||
|
] # could be sped up, probably
|
||||||
|
return tokenized_full_prompt
|
||||||
|
|
||||||
|
model = prepare_model_for_int8_training(model)
|
||||||
|
|
||||||
|
config = LoraConfig(
|
||||||
|
r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
target_modules=lora_target_modules,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
model = get_peft_model(model, config)
|
||||||
|
|
||||||
|
data = load_dataset("json", data_files=data_path)
|
||||||
|
|
||||||
|
if val_set_size > 0:
|
||||||
|
train_val = data["train"].train_test_split(
|
||||||
|
test_size=val_set_size, shuffle=True, seed=42
|
||||||
|
)
|
||||||
|
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||||
|
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
||||||
|
else:
|
||||||
|
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||||
|
val_data = None
|
||||||
|
|
||||||
|
trainer = transformers.Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataset=train_data,
|
||||||
|
eval_dataset=val_data,
|
||||||
|
args=transformers.TrainingArguments(
|
||||||
|
per_device_train_batch_size=micro_batch_size,
|
||||||
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
|
warmup_steps=100,
|
||||||
|
num_train_epochs=num_epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
fp16=True,
|
||||||
|
logging_steps=10,
|
||||||
|
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_steps=200 if val_set_size > 0 else None,
|
||||||
|
save_steps=200,
|
||||||
|
output_dir=output_dir,
|
||||||
|
save_total_limit=3,
|
||||||
|
load_best_model_at_end=True if val_set_size > 0 else False,
|
||||||
|
ddp_find_unused_parameters=False if ddp else None,
|
||||||
|
group_by_length=group_by_length,
|
||||||
|
),
|
||||||
|
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
|
old_state_dict = model.state_dict
|
||||||
|
model.state_dict = (
|
||||||
|
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
||||||
|
).__get__(model, type(model))
|
||||||
|
|
||||||
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
|
model = torch.compile(model)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
model.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
print("\n If there's a warning about missing keys above, please disregard :)")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_prompt(data_point):
|
||||||
|
# sorry about the formatting disaster gotta move fast
|
||||||
|
if data_point["input"]:
|
||||||
|
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||||
|
|
||||||
|
### Instruction:
|
||||||
|
{data_point["instruction"]}
|
||||||
|
|
||||||
|
### Input:
|
||||||
|
{data_point["input"]}
|
||||||
|
|
||||||
|
### Response:
|
||||||
|
{data_point["output"]}"""
|
||||||
|
else:
|
||||||
|
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||||
|
|
||||||
|
### Instruction:
|
||||||
|
{data_point["instruction"]}
|
||||||
|
|
||||||
|
### Response:
|
||||||
|
{data_point["output"]}"""
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(train)
|
||||||
201
code/med-alpaca/LICENSE
Normal file
201
code/med-alpaca/LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
231
code/med-alpaca/train.py
Normal file
231
code/med-alpaca/train.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, Dict, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
|
IGNORE_INDEX = -100
|
||||||
|
DEFAULT_PAD_TOKEN = "[PAD]"
|
||||||
|
DEFAULT_EOS_TOKEN = "</s>"
|
||||||
|
DEFAULT_BOS_TOKEN = "</s>"
|
||||||
|
DEFAULT_UNK_TOKEN = "</s>"
|
||||||
|
PROMPT_DICT = {
|
||||||
|
"prompt_input": (
|
||||||
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
||||||
|
),
|
||||||
|
"prompt_no_input": (
|
||||||
|
"Below is an instruction that describes a task. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\n{instruction}\n\n### Response:"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments:
|
||||||
|
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataArguments:
|
||||||
|
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments(transformers.TrainingArguments):
|
||||||
|
cache_dir: Optional[str] = field(default=None)
|
||||||
|
optim: str = field(default="adamw_torch")
|
||||||
|
model_max_length: int = field(
|
||||||
|
default=512,
|
||||||
|
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
||||||
|
"""Collects the state dict and dump to disk."""
|
||||||
|
state_dict = trainer.model.state_dict()
|
||||||
|
if trainer.args.should_save:
|
||||||
|
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
||||||
|
del state_dict
|
||||||
|
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
||||||
|
|
||||||
|
|
||||||
|
def smart_tokenizer_and_embedding_resize(
|
||||||
|
special_tokens_dict: Dict,
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
model: transformers.PreTrainedModel,
|
||||||
|
):
|
||||||
|
"""Resize tokenizer and embedding.
|
||||||
|
|
||||||
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
||||||
|
"""
|
||||||
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
if num_new_tokens > 0:
|
||||||
|
input_embeddings = model.get_input_embeddings().weight.data
|
||||||
|
output_embeddings = model.get_output_embeddings().weight.data
|
||||||
|
|
||||||
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||||
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||||
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||||
|
|
||||||
|
|
||||||
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
||||||
|
"""Tokenize a list of strings."""
|
||||||
|
tokenized_list = [
|
||||||
|
tokenizer(
|
||||||
|
text,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
max_length=tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
for text in strings
|
||||||
|
]
|
||||||
|
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||||
|
input_ids_lens = labels_lens = [
|
||||||
|
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
||||||
|
]
|
||||||
|
return dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
labels=labels,
|
||||||
|
input_ids_lens=input_ids_lens,
|
||||||
|
labels_lens=labels_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
sources: Sequence[str],
|
||||||
|
targets: Sequence[str],
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
) -> Dict:
|
||||||
|
"""Preprocess the data by tokenizing."""
|
||||||
|
examples = [s + t for s, t in zip(sources, targets)]
|
||||||
|
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
||||||
|
input_ids = examples_tokenized["input_ids"]
|
||||||
|
labels = copy.deepcopy(input_ids)
|
||||||
|
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||||
|
label[:source_len] = IGNORE_INDEX
|
||||||
|
return dict(input_ids=input_ids, labels=labels)
|
||||||
|
|
||||||
|
|
||||||
|
class SupervisedDataset(Dataset):
|
||||||
|
"""Dataset for supervised fine-tuning."""
|
||||||
|
|
||||||
|
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
|
||||||
|
super(SupervisedDataset, self).__init__()
|
||||||
|
logging.warning("Loading data...")
|
||||||
|
list_data_dict = utils.jload(data_path)
|
||||||
|
|
||||||
|
logging.warning("Formatting inputs...")
|
||||||
|
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||||
|
sources = [
|
||||||
|
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
||||||
|
for example in list_data_dict
|
||||||
|
]
|
||||||
|
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
||||||
|
|
||||||
|
logging.warning("Tokenizing inputs... This may take some time...")
|
||||||
|
data_dict = preprocess(sources, targets, tokenizer)
|
||||||
|
|
||||||
|
self.input_ids = data_dict["input_ids"]
|
||||||
|
self.labels = data_dict["labels"]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.input_ids)
|
||||||
|
|
||||||
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||||
|
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorForSupervisedDataset(object):
|
||||||
|
"""Collate examples for supervised fine-tuning."""
|
||||||
|
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer
|
||||||
|
|
||||||
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||||
|
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
||||||
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
||||||
|
return dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
labels=labels,
|
||||||
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
||||||
|
"""Make dataset and collator for supervised fine-tuning."""
|
||||||
|
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
|
||||||
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||||
|
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||||
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
cache_dir=training_args.cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
cache_dir=training_args.cache_dir,
|
||||||
|
model_max_length=training_args.model_max_length,
|
||||||
|
padding_side="right",
|
||||||
|
use_fast=False,
|
||||||
|
)
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
smart_tokenizer_and_embedding_resize(
|
||||||
|
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
if "llama" in model_args.model_name_or_path:
|
||||||
|
tokenizer.add_special_tokens(
|
||||||
|
{
|
||||||
|
"eos_token": DEFAULT_EOS_TOKEN,
|
||||||
|
"bos_token": DEFAULT_BOS_TOKEN,
|
||||||
|
"unk_token": DEFAULT_UNK_TOKEN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
||||||
|
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
||||||
|
trainer.train()
|
||||||
|
trainer.save_state()
|
||||||
|
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train()
|
||||||
21
code/med-alpaca/train.sh
Executable file
21
code/med-alpaca/train.sh
Executable file
@ -0,0 +1,21 @@
|
|||||||
|
torchrun --nproc_per_node=1 train.py \
|
||||||
|
--model_name_or_path decapoda-research/llama-7b-hf \
|
||||||
|
--data_path /path/to/med_alpaca_data_clean.json \
|
||||||
|
--bf16 True \
|
||||||
|
--output_dir ./med-alpaca \
|
||||||
|
--num_train_epochs 3 \
|
||||||
|
--per_device_train_batch_size 4 \
|
||||||
|
--per_device_eval_batch_size 4 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--evaluation_strategy "no" \
|
||||||
|
--save_strategy "steps" \
|
||||||
|
--save_steps 2000 \
|
||||||
|
--save_total_limit 1 \
|
||||||
|
--learning_rate 2e-5 \
|
||||||
|
--weight_decay 0. \
|
||||||
|
--warmup_ratio 0.03 \
|
||||||
|
--lr_scheduler_type "cosine" \
|
||||||
|
--logging_steps 1 \
|
||||||
|
--fsdp "full_shard auto_wrap" \
|
||||||
|
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
|
||||||
|
--tf32 True
|
||||||
173
code/med-alpaca/utils.py
Normal file
173
code/med-alpaca/utils.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import tqdm
|
||||||
|
from openai import openai_object
|
||||||
|
import copy
|
||||||
|
|
||||||
|
StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]
|
||||||
|
|
||||||
|
openai_org = os.getenv("OPENAI_ORG")
|
||||||
|
if openai_org is not None:
|
||||||
|
openai.organization = openai_org
|
||||||
|
logging.warning(f"Switching to organization: {openai_org} for OAI API key.")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class OpenAIDecodingArguments(object):
|
||||||
|
max_tokens: int = 1800
|
||||||
|
temperature: float = 0.2
|
||||||
|
top_p: float = 1.0
|
||||||
|
n: int = 1
|
||||||
|
stream: bool = False
|
||||||
|
stop: Optional[Sequence[str]] = None
|
||||||
|
presence_penalty: float = 0.0
|
||||||
|
frequency_penalty: float = 0.0
|
||||||
|
suffix: Optional[str] = None
|
||||||
|
logprobs: Optional[int] = None
|
||||||
|
echo: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def openai_completion(
|
||||||
|
prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
|
||||||
|
decoding_args: OpenAIDecodingArguments,
|
||||||
|
model_name="text-davinci-003",
|
||||||
|
sleep_time=2,
|
||||||
|
batch_size=1,
|
||||||
|
max_instances=sys.maxsize,
|
||||||
|
max_batches=sys.maxsize,
|
||||||
|
return_text=False,
|
||||||
|
**decoding_kwargs,
|
||||||
|
) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
|
||||||
|
"""Decode with OpenAI API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
|
||||||
|
as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
|
||||||
|
it can also be a dictionary (or list thereof) as explained here:
|
||||||
|
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||||
|
decoding_args: Decoding arguments.
|
||||||
|
model_name: Model name. Can be either in the format of "org/model" or just "model".
|
||||||
|
sleep_time: Time to sleep once the rate-limit is hit.
|
||||||
|
batch_size: Number of prompts to send in a single request. Only for non chat model.
|
||||||
|
max_instances: Maximum number of prompts to decode.
|
||||||
|
max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
|
||||||
|
return_text: If True, return text instead of full completion object (which contains things like logprob).
|
||||||
|
decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A completion or a list of completions.
|
||||||
|
Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
|
||||||
|
- a string (if return_text is True)
|
||||||
|
- an openai_object.OpenAIObject object (if return_text is False)
|
||||||
|
- a list of objects of the above types (if decoding_args.n > 1)
|
||||||
|
"""
|
||||||
|
is_single_prompt = isinstance(prompts, (str, dict))
|
||||||
|
if is_single_prompt:
|
||||||
|
prompts = [prompts]
|
||||||
|
|
||||||
|
if max_batches < sys.maxsize:
|
||||||
|
logging.warning(
|
||||||
|
"`max_batches` will be deprecated in the future, please use `max_instances` instead."
|
||||||
|
"Setting `max_instances` to `max_batches * batch_size` for now."
|
||||||
|
)
|
||||||
|
max_instances = max_batches * batch_size
|
||||||
|
|
||||||
|
prompts = prompts[:max_instances]
|
||||||
|
num_prompts = len(prompts)
|
||||||
|
prompt_batches = [
|
||||||
|
prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
|
||||||
|
for batch_id in range(int(math.ceil(num_prompts / batch_size)))
|
||||||
|
]
|
||||||
|
|
||||||
|
completions = []
|
||||||
|
for batch_id, prompt_batch in tqdm.tqdm(
|
||||||
|
enumerate(prompt_batches),
|
||||||
|
desc="prompt_batches",
|
||||||
|
total=len(prompt_batches),
|
||||||
|
):
|
||||||
|
batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
shared_kwargs = dict(
|
||||||
|
model=model_name,
|
||||||
|
**batch_decoding_args.__dict__,
|
||||||
|
**decoding_kwargs,
|
||||||
|
)
|
||||||
|
completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
|
||||||
|
choices = completion_batch.choices
|
||||||
|
|
||||||
|
for choice in choices:
|
||||||
|
choice["total_tokens"] = completion_batch.usage.total_tokens
|
||||||
|
completions.extend(choices)
|
||||||
|
break
|
||||||
|
except openai.error.OpenAIError as e:
|
||||||
|
logging.warning(f"OpenAIError: {e}.")
|
||||||
|
if "Please reduce your prompt" in str(e):
|
||||||
|
batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
|
||||||
|
logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
|
||||||
|
else:
|
||||||
|
logging.warning("Hit request rate limit; retrying...")
|
||||||
|
time.sleep(sleep_time) # Annoying rate limit on requests.
|
||||||
|
|
||||||
|
if return_text:
|
||||||
|
completions = [completion.text for completion in completions]
|
||||||
|
if decoding_args.n > 1:
|
||||||
|
# make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
|
||||||
|
completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
|
||||||
|
if is_single_prompt:
|
||||||
|
# Return non-tuple if only 1 input and 1 generation.
|
||||||
|
(completions,) = completions
|
||||||
|
return completions
|
||||||
|
|
||||||
|
|
||||||
|
def _make_w_io_base(f, mode: str):
|
||||||
|
if not isinstance(f, io.IOBase):
|
||||||
|
f_dirname = os.path.dirname(f)
|
||||||
|
if f_dirname != "":
|
||||||
|
os.makedirs(f_dirname, exist_ok=True)
|
||||||
|
f = open(f, mode=mode)
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def _make_r_io_base(f, mode: str):
|
||||||
|
if not isinstance(f, io.IOBase):
|
||||||
|
f = open(f, mode=mode)
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def jdump(obj, f, mode="w", indent=4, default=str):
|
||||||
|
"""Dump a str or dictionary to a file in json format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: An object to be written.
|
||||||
|
f: A string path to the location on disk.
|
||||||
|
mode: Mode for opening the file.
|
||||||
|
indent: Indent for storing json dictionaries.
|
||||||
|
default: A function to handle non-serializable entries; defaults to `str`.
|
||||||
|
"""
|
||||||
|
f = _make_w_io_base(f, mode)
|
||||||
|
if isinstance(obj, (dict, list)):
|
||||||
|
json.dump(obj, f, indent=indent, default=default)
|
||||||
|
elif isinstance(obj, str):
|
||||||
|
f.write(obj)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected type: {type(obj)}")
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
|
def jload(f, mode="r"):
|
||||||
|
"""Load a .json file into a dictionary."""
|
||||||
|
f = _make_r_io_base(f, mode)
|
||||||
|
jdict = json.load(f)
|
||||||
|
f.close()
|
||||||
|
return jdict
|
||||||
134
code/med-git/Fine_tune_GIT_on_an_image_captioning_dataset.py
Normal file
134
code/med-git/Fine_tune_GIT_on_an_image_captioning_dataset.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
# Fine-tune GIT on a custom dataset for image captioning
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import json
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# path to the csv containing training data directories
|
||||||
|
train_data_csv = ""
|
||||||
|
# path to the folder containing the training data images
|
||||||
|
train_data_folder = ""
|
||||||
|
# path to the csv containing training data directories
|
||||||
|
validation_data_csv = ""
|
||||||
|
# path to the folder containing the training data images
|
||||||
|
validation_data_folder = ""
|
||||||
|
# save pretrained model to
|
||||||
|
output_dir = ""
|
||||||
|
|
||||||
|
|
||||||
|
df = pd.read_csv(train_data_csv)
|
||||||
|
captions = [{"file_name": df.iloc[i]["name"],
|
||||||
|
"text": df.iloc[i]["caption"].strip()} for i in range(len(df))]
|
||||||
|
|
||||||
|
# add metadata.jsonl file to this folder
|
||||||
|
with open(train_data_folder + "metadata.jsonl", 'w') as f:
|
||||||
|
for item in captions:
|
||||||
|
f.write(json.dumps(item) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
df_val = pd.read_csv(validation_data_csv)
|
||||||
|
captions = [{"file_name": df_val.iloc[i]["name"],
|
||||||
|
"text": df_val.iloc[i]["caption"].strip()} for i in range(len(df_val))]
|
||||||
|
|
||||||
|
# add metadata.jsonl file to this folder
|
||||||
|
with open(validation_data_folder + "metadata.jsonl", 'w') as f:
|
||||||
|
for item in captions:
|
||||||
|
f.write(json.dumps(item) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
dataset = load_dataset("imagefolder", data_dir=train_data_folder, split="train")
|
||||||
|
val_dataset = load_dataset("imagefolder", data_dir=validation_data_folder, split="train")
|
||||||
|
|
||||||
|
|
||||||
|
# We use `GitProcessor` to turn each (image, text) pair into the expected inputs. Basically, the text gets turned into `input_ids` and `attention_mask`, and the image gets turned into `pixel_values`.
|
||||||
|
class ImageCaptioningDataset(Dataset):
|
||||||
|
def __init__(self, dataset, processor):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
item = self.dataset[idx]
|
||||||
|
|
||||||
|
encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt")
|
||||||
|
|
||||||
|
# remove batch dimension
|
||||||
|
encoding = {k:v.squeeze() for k,v in encoding.items()}
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
||||||
|
train_dataset = ImageCaptioningDataset(dataset, processor)
|
||||||
|
validation_dataset = ImageCaptioningDataset(val_dataset, processor)
|
||||||
|
|
||||||
|
# Next, we create a corresponding [PyTorch DataLoader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), which allows us to get batches of data from the dataset.
|
||||||
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=4)
|
||||||
|
validation_dataloader = DataLoader(validation_dataset, shuffle=False, batch_size=4)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
|
||||||
|
|
||||||
|
|
||||||
|
# Dummy forward pass
|
||||||
|
batch = next(iter(train_dataloader))
|
||||||
|
outputs = model(input_ids=batch["input_ids"],
|
||||||
|
attention_mask=batch["attention_mask"],
|
||||||
|
pixel_values=batch["pixel_values"],
|
||||||
|
labels=batch["input_ids"])
|
||||||
|
print(outputs.loss)
|
||||||
|
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(device)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
num_epochs = 30
|
||||||
|
train_loss_history = []
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print("Epoch:", epoch)
|
||||||
|
avg_loss = 0
|
||||||
|
with tqdm(total=len(train_dataloader)) as pbar:
|
||||||
|
model.train()
|
||||||
|
for batch_idx, batch in enumerate(train_dataloader):
|
||||||
|
input_ids = batch.pop("input_ids").to(device)
|
||||||
|
pixel_values = batch.pop("pixel_values").to(device)
|
||||||
|
outputs = model(input_ids=input_ids,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
labels=input_ids)
|
||||||
|
loss = outputs.loss
|
||||||
|
train_loss_history.append(loss.item())
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
avg_loss = (avg_loss * batch_idx + loss.item()) / (batch_idx + 1)
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_description(f"Epoch {epoch}, Loss {loss:.4f}, Avg Loss {avg_loss:.4f}")
|
||||||
|
with torch.no_grad():
|
||||||
|
model.eval()
|
||||||
|
validation_loss = 0
|
||||||
|
for batch_idx, batch in enumerate(validation_dataloader):
|
||||||
|
input_ids = batch.pop("input_ids").to(device)
|
||||||
|
pixel_values = batch.pop("pixel_values").to(device)
|
||||||
|
outputs = model(input_ids=input_ids,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
labels=input_ids)
|
||||||
|
loss = outputs.loss
|
||||||
|
validation_loss += loss.item()
|
||||||
|
validation_loss /= len(validation_dataloader)
|
||||||
|
print(f"Epoch {epoch}, Validation Loss {validation_loss:.4f}")
|
||||||
|
|
||||||
|
model.save_pretrained(output_dir)
|
||||||
|
processor.save_pretrained(output_dir)
|
||||||
|
|
||||||
21
code/med-git/LICENSE
Normal file
21
code/med-git/LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2021 NielsRogge
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
17
code/requirements.txt
Normal file
17
code/requirements.txt
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
datasets
|
||||||
|
loralib
|
||||||
|
sentencepiece
|
||||||
|
git+https://github.com/huggingface/transformers.git
|
||||||
|
accelerate
|
||||||
|
bitsandbytes
|
||||||
|
git+https://github.com/huggingface/peft.git
|
||||||
|
gradio
|
||||||
|
appdirs
|
||||||
|
fire
|
||||||
|
numpy
|
||||||
|
rouge_score
|
||||||
|
openai
|
||||||
|
torch
|
||||||
|
sentencepiece
|
||||||
|
tokenizers==0.12.1
|
||||||
|
wandb
|
||||||
1
data/med_alpaca_data_clean.json
Normal file
1
data/med_alpaca_data_clean.json
Normal file
File diff suppressed because one or more lines are too long
130889
data/radiologytraindata_cleaned.csv
Normal file
130889
data/radiologytraindata_cleaned.csv
Normal file
File diff suppressed because it is too large
Load Diff
16359
data/radiologyvaldata_cleaned.csv
Normal file
16359
data/radiologyvaldata_cleaned.csv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user