mirror of
https://github.com/RYDE-WORK/visual-med-alpaca.git
synced 2026-01-19 14:28:49 +08:00
training code and data
This commit is contained in:
parent
478e4f7f55
commit
21b0c7417d
201
code/med-alpaca-lora/LICENSE
Normal file
201
code/med-alpaca-lora/LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
6
code/med-alpaca-lora/finetune-med.sh
Executable file
6
code/med-alpaca-lora/finetune-med.sh
Executable file
@ -0,0 +1,6 @@
|
||||
python finetune.py \
|
||||
--base_model 'decapoda-research/llama-7b-hf' \
|
||||
--data_path '/path/to/med_alpaca_data_clean.json' \
|
||||
--micro_batch_size 32 \
|
||||
--output_dir './med-alpaca-lora'
|
||||
|
||||
217
code/med-alpaca-lora/finetune.py
Normal file
217
code/med-alpaca-lora/finetune.py
Normal file
@ -0,0 +1,217 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import bitsandbytes as bnb
|
||||
from datasets import load_dataset
|
||||
import transformers
|
||||
|
||||
assert (
|
||||
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
||||
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
from peft import (
|
||||
prepare_model_for_int8_training,
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
)
|
||||
|
||||
|
||||
def train(
|
||||
# model/data params
|
||||
base_model: str = "", # the only required argument
|
||||
data_path: str = "./alpaca_data_cleaned.json",
|
||||
output_dir: str = "./lora-alpaca",
|
||||
# training hyperparams
|
||||
batch_size: int = 128,
|
||||
micro_batch_size: int = 4,
|
||||
num_epochs: int = 3,
|
||||
learning_rate: float = 3e-4,
|
||||
cutoff_len: int = 512,
|
||||
val_set_size: int = 2000,
|
||||
# lora hyperparams
|
||||
lora_r: int = 8,
|
||||
lora_alpha: int = 16,
|
||||
lora_dropout: float = 0.05,
|
||||
lora_target_modules: List[str] = [
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
],
|
||||
# llm hyperparams
|
||||
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
||||
group_by_length: bool = True, # faster, but produces an odd training loss curve
|
||||
):
|
||||
print(
|
||||
f"Training Alpaca-LoRA model with params:\n"
|
||||
f"base_model: {base_model}\n"
|
||||
f"data_path: {data_path}\n"
|
||||
f"output_dir: {output_dir}\n"
|
||||
f"batch_size: {batch_size}\n"
|
||||
f"micro_batch_size: {micro_batch_size}\n"
|
||||
f"num_epochs: {num_epochs}\n"
|
||||
f"learning_rate: {learning_rate}\n"
|
||||
f"cutoff_len: {cutoff_len}\n"
|
||||
f"val_set_size: {val_set_size}\n"
|
||||
f"lora_r: {lora_r}\n"
|
||||
f"lora_alpha: {lora_alpha}\n"
|
||||
f"lora_dropout: {lora_dropout}\n"
|
||||
f"lora_target_modules: {lora_target_modules}\n"
|
||||
f"train_on_inputs: {train_on_inputs}\n"
|
||||
f"group_by_length: {group_by_length}\n"
|
||||
)
|
||||
assert (
|
||||
base_model
|
||||
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
||||
gradient_accumulation_steps = batch_size // micro_batch_size
|
||||
|
||||
device_map = "auto"
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
ddp = world_size != 1
|
||||
if ddp:
|
||||
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
||||
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=True,
|
||||
device_map=device_map,
|
||||
)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||
|
||||
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
|
||||
tokenizer.padding_side = "left" # Allow batched inference
|
||||
|
||||
def tokenize(prompt, add_eos_token=True):
|
||||
# there's probably a way to do this with the tokenizer settings
|
||||
# but again, gotta move fast
|
||||
result = tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=cutoff_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
result["input_ids"][-1] != tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < cutoff_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
|
||||
return result
|
||||
|
||||
def generate_and_tokenize_prompt(data_point):
|
||||
full_prompt = generate_prompt(data_point)
|
||||
tokenized_full_prompt = tokenize(full_prompt)
|
||||
if not train_on_inputs:
|
||||
user_prompt = generate_prompt({**data_point, "output": ""})
|
||||
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
|
||||
tokenized_full_prompt["labels"] = [
|
||||
-100
|
||||
] * user_prompt_len + tokenized_full_prompt["labels"][
|
||||
user_prompt_len:
|
||||
] # could be sped up, probably
|
||||
return tokenized_full_prompt
|
||||
|
||||
model = prepare_model_for_int8_training(model)
|
||||
|
||||
config = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
data = load_dataset("json", data_files=data_path)
|
||||
|
||||
if val_set_size > 0:
|
||||
train_val = data["train"].train_test_split(
|
||||
test_size=val_set_size, shuffle=True, seed=42
|
||||
)
|
||||
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
||||
else:
|
||||
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
val_data = None
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
args=transformers.TrainingArguments(
|
||||
per_device_train_batch_size=micro_batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
warmup_steps=100,
|
||||
num_train_epochs=num_epochs,
|
||||
learning_rate=learning_rate,
|
||||
fp16=True,
|
||||
logging_steps=10,
|
||||
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
||||
save_strategy="steps",
|
||||
eval_steps=200 if val_set_size > 0 else None,
|
||||
save_steps=200,
|
||||
output_dir=output_dir,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=True if val_set_size > 0 else False,
|
||||
ddp_find_unused_parameters=False if ddp else None,
|
||||
group_by_length=group_by_length,
|
||||
),
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
),
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
old_state_dict = model.state_dict
|
||||
model.state_dict = (
|
||||
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
||||
).__get__(model, type(model))
|
||||
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
model = torch.compile(model)
|
||||
|
||||
trainer.train()
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
print("\n If there's a warning about missing keys above, please disregard :)")
|
||||
|
||||
|
||||
def generate_prompt(data_point):
|
||||
# sorry about the formatting disaster gotta move fast
|
||||
if data_point["input"]:
|
||||
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{data_point["instruction"]}
|
||||
|
||||
### Input:
|
||||
{data_point["input"]}
|
||||
|
||||
### Response:
|
||||
{data_point["output"]}"""
|
||||
else:
|
||||
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{data_point["instruction"]}
|
||||
|
||||
### Response:
|
||||
{data_point["output"]}"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(train)
|
||||
201
code/med-alpaca/LICENSE
Normal file
201
code/med-alpaca/LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
231
code/med-alpaca/train.py
Normal file
231
code/med-alpaca/train.py
Normal file
@ -0,0 +1,231 @@
|
||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Sequence
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import Trainer
|
||||
|
||||
import utils
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
DEFAULT_EOS_TOKEN = "</s>"
|
||||
DEFAULT_BOS_TOKEN = "</s>"
|
||||
DEFAULT_UNK_TOKEN = "</s>"
|
||||
PROMPT_DICT = {
|
||||
"prompt_input": (
|
||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
||||
),
|
||||
"prompt_no_input": (
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Response:"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(transformers.TrainingArguments):
|
||||
cache_dir: Optional[str] = field(default=None)
|
||||
optim: str = field(default="adamw_torch")
|
||||
model_max_length: int = field(
|
||||
default=512,
|
||||
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
||||
)
|
||||
|
||||
|
||||
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
||||
"""Collects the state dict and dump to disk."""
|
||||
state_dict = trainer.model.state_dict()
|
||||
if trainer.args.should_save:
|
||||
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
||||
del state_dict
|
||||
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
||||
|
||||
|
||||
def smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict: Dict,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
model: transformers.PreTrainedModel,
|
||||
):
|
||||
"""Resize tokenizer and embedding.
|
||||
|
||||
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
||||
"""
|
||||
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings = model.get_input_embeddings().weight.data
|
||||
output_embeddings = model.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
|
||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
)
|
||||
for text in strings
|
||||
]
|
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
||||
]
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(
|
||||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
label[:source_len] = IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
logging.warning("Loading data...")
|
||||
list_data_dict = utils.jload(data_path)
|
||||
|
||||
logging.warning("Formatting inputs...")
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
sources = [
|
||||
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
||||
for example in list_data_dict
|
||||
]
|
||||
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
||||
|
||||
logging.warning("Tokenizing inputs... This may take some time...")
|
||||
data_dict = preprocess(sources, targets, tokenizer)
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSupervisedDataset(object):
|
||||
"""Collate examples for supervised fine-tuning."""
|
||||
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
||||
)
|
||||
|
||||
|
||||
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
||||
"""Make dataset and collator for supervised fine-tuning."""
|
||||
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
||||
|
||||
|
||||
def train():
|
||||
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
cache_dir=training_args.cache_dir,
|
||||
)
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
cache_dir=training_args.cache_dir,
|
||||
model_max_length=training_args.model_max_length,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
)
|
||||
if "llama" in model_args.model_name_or_path:
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": DEFAULT_EOS_TOKEN,
|
||||
"bos_token": DEFAULT_BOS_TOKEN,
|
||||
"unk_token": DEFAULT_UNK_TOKEN,
|
||||
}
|
||||
)
|
||||
|
||||
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
||||
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
||||
trainer.train()
|
||||
trainer.save_state()
|
||||
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
21
code/med-alpaca/train.sh
Executable file
21
code/med-alpaca/train.sh
Executable file
@ -0,0 +1,21 @@
|
||||
torchrun --nproc_per_node=1 train.py \
|
||||
--model_name_or_path decapoda-research/llama-7b-hf \
|
||||
--data_path /path/to/med_alpaca_data_clean.json \
|
||||
--bf16 True \
|
||||
--output_dir ./med-alpaca \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--evaluation_strategy "no" \
|
||||
--save_strategy "steps" \
|
||||
--save_steps 2000 \
|
||||
--save_total_limit 1 \
|
||||
--learning_rate 2e-5 \
|
||||
--weight_decay 0. \
|
||||
--warmup_ratio 0.03 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--logging_steps 1 \
|
||||
--fsdp "full_shard auto_wrap" \
|
||||
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
|
||||
--tf32 True
|
||||
173
code/med-alpaca/utils.py
Normal file
173
code/med-alpaca/utils.py
Normal file
@ -0,0 +1,173 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import io
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import openai
|
||||
import tqdm
|
||||
from openai import openai_object
|
||||
import copy
|
||||
|
||||
StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]
|
||||
|
||||
openai_org = os.getenv("OPENAI_ORG")
|
||||
if openai_org is not None:
|
||||
openai.organization = openai_org
|
||||
logging.warning(f"Switching to organization: {openai_org} for OAI API key.")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class OpenAIDecodingArguments(object):
|
||||
max_tokens: int = 1800
|
||||
temperature: float = 0.2
|
||||
top_p: float = 1.0
|
||||
n: int = 1
|
||||
stream: bool = False
|
||||
stop: Optional[Sequence[str]] = None
|
||||
presence_penalty: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
suffix: Optional[str] = None
|
||||
logprobs: Optional[int] = None
|
||||
echo: bool = False
|
||||
|
||||
|
||||
def openai_completion(
|
||||
prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
|
||||
decoding_args: OpenAIDecodingArguments,
|
||||
model_name="text-davinci-003",
|
||||
sleep_time=2,
|
||||
batch_size=1,
|
||||
max_instances=sys.maxsize,
|
||||
max_batches=sys.maxsize,
|
||||
return_text=False,
|
||||
**decoding_kwargs,
|
||||
) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
|
||||
"""Decode with OpenAI API.
|
||||
|
||||
Args:
|
||||
prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
|
||||
as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
|
||||
it can also be a dictionary (or list thereof) as explained here:
|
||||
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||
decoding_args: Decoding arguments.
|
||||
model_name: Model name. Can be either in the format of "org/model" or just "model".
|
||||
sleep_time: Time to sleep once the rate-limit is hit.
|
||||
batch_size: Number of prompts to send in a single request. Only for non chat model.
|
||||
max_instances: Maximum number of prompts to decode.
|
||||
max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
|
||||
return_text: If True, return text instead of full completion object (which contains things like logprob).
|
||||
decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
|
||||
|
||||
Returns:
|
||||
A completion or a list of completions.
|
||||
Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
|
||||
- a string (if return_text is True)
|
||||
- an openai_object.OpenAIObject object (if return_text is False)
|
||||
- a list of objects of the above types (if decoding_args.n > 1)
|
||||
"""
|
||||
is_single_prompt = isinstance(prompts, (str, dict))
|
||||
if is_single_prompt:
|
||||
prompts = [prompts]
|
||||
|
||||
if max_batches < sys.maxsize:
|
||||
logging.warning(
|
||||
"`max_batches` will be deprecated in the future, please use `max_instances` instead."
|
||||
"Setting `max_instances` to `max_batches * batch_size` for now."
|
||||
)
|
||||
max_instances = max_batches * batch_size
|
||||
|
||||
prompts = prompts[:max_instances]
|
||||
num_prompts = len(prompts)
|
||||
prompt_batches = [
|
||||
prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
|
||||
for batch_id in range(int(math.ceil(num_prompts / batch_size)))
|
||||
]
|
||||
|
||||
completions = []
|
||||
for batch_id, prompt_batch in tqdm.tqdm(
|
||||
enumerate(prompt_batches),
|
||||
desc="prompt_batches",
|
||||
total=len(prompt_batches),
|
||||
):
|
||||
batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args
|
||||
|
||||
while True:
|
||||
try:
|
||||
shared_kwargs = dict(
|
||||
model=model_name,
|
||||
**batch_decoding_args.__dict__,
|
||||
**decoding_kwargs,
|
||||
)
|
||||
completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
|
||||
choices = completion_batch.choices
|
||||
|
||||
for choice in choices:
|
||||
choice["total_tokens"] = completion_batch.usage.total_tokens
|
||||
completions.extend(choices)
|
||||
break
|
||||
except openai.error.OpenAIError as e:
|
||||
logging.warning(f"OpenAIError: {e}.")
|
||||
if "Please reduce your prompt" in str(e):
|
||||
batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
|
||||
logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
|
||||
else:
|
||||
logging.warning("Hit request rate limit; retrying...")
|
||||
time.sleep(sleep_time) # Annoying rate limit on requests.
|
||||
|
||||
if return_text:
|
||||
completions = [completion.text for completion in completions]
|
||||
if decoding_args.n > 1:
|
||||
# make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
|
||||
completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
|
||||
if is_single_prompt:
|
||||
# Return non-tuple if only 1 input and 1 generation.
|
||||
(completions,) = completions
|
||||
return completions
|
||||
|
||||
|
||||
def _make_w_io_base(f, mode: str):
|
||||
if not isinstance(f, io.IOBase):
|
||||
f_dirname = os.path.dirname(f)
|
||||
if f_dirname != "":
|
||||
os.makedirs(f_dirname, exist_ok=True)
|
||||
f = open(f, mode=mode)
|
||||
return f
|
||||
|
||||
|
||||
def _make_r_io_base(f, mode: str):
|
||||
if not isinstance(f, io.IOBase):
|
||||
f = open(f, mode=mode)
|
||||
return f
|
||||
|
||||
|
||||
def jdump(obj, f, mode="w", indent=4, default=str):
|
||||
"""Dump a str or dictionary to a file in json format.
|
||||
|
||||
Args:
|
||||
obj: An object to be written.
|
||||
f: A string path to the location on disk.
|
||||
mode: Mode for opening the file.
|
||||
indent: Indent for storing json dictionaries.
|
||||
default: A function to handle non-serializable entries; defaults to `str`.
|
||||
"""
|
||||
f = _make_w_io_base(f, mode)
|
||||
if isinstance(obj, (dict, list)):
|
||||
json.dump(obj, f, indent=indent, default=default)
|
||||
elif isinstance(obj, str):
|
||||
f.write(obj)
|
||||
else:
|
||||
raise ValueError(f"Unexpected type: {type(obj)}")
|
||||
f.close()
|
||||
|
||||
|
||||
def jload(f, mode="r"):
|
||||
"""Load a .json file into a dictionary."""
|
||||
f = _make_r_io_base(f, mode)
|
||||
jdict = json.load(f)
|
||||
f.close()
|
||||
return jdict
|
||||
134
code/med-git/Fine_tune_GIT_on_an_image_captioning_dataset.py
Normal file
134
code/med-git/Fine_tune_GIT_on_an_image_captioning_dataset.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Fine-tune GIT on a custom dataset for image captioning
|
||||
|
||||
import pandas as pd
|
||||
import json
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoProcessor
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
# path to the csv containing training data directories
|
||||
train_data_csv = ""
|
||||
# path to the folder containing the training data images
|
||||
train_data_folder = ""
|
||||
# path to the csv containing training data directories
|
||||
validation_data_csv = ""
|
||||
# path to the folder containing the training data images
|
||||
validation_data_folder = ""
|
||||
# save pretrained model to
|
||||
output_dir = ""
|
||||
|
||||
|
||||
df = pd.read_csv(train_data_csv)
|
||||
captions = [{"file_name": df.iloc[i]["name"],
|
||||
"text": df.iloc[i]["caption"].strip()} for i in range(len(df))]
|
||||
|
||||
# add metadata.jsonl file to this folder
|
||||
with open(train_data_folder + "metadata.jsonl", 'w') as f:
|
||||
for item in captions:
|
||||
f.write(json.dumps(item) + "\n")
|
||||
|
||||
|
||||
df_val = pd.read_csv(validation_data_csv)
|
||||
captions = [{"file_name": df_val.iloc[i]["name"],
|
||||
"text": df_val.iloc[i]["caption"].strip()} for i in range(len(df_val))]
|
||||
|
||||
# add metadata.jsonl file to this folder
|
||||
with open(validation_data_folder + "metadata.jsonl", 'w') as f:
|
||||
for item in captions:
|
||||
f.write(json.dumps(item) + "\n")
|
||||
|
||||
|
||||
dataset = load_dataset("imagefolder", data_dir=train_data_folder, split="train")
|
||||
val_dataset = load_dataset("imagefolder", data_dir=validation_data_folder, split="train")
|
||||
|
||||
|
||||
# We use `GitProcessor` to turn each (image, text) pair into the expected inputs. Basically, the text gets turned into `input_ids` and `attention_mask`, and the image gets turned into `pixel_values`.
|
||||
class ImageCaptioningDataset(Dataset):
|
||||
def __init__(self, dataset, processor):
|
||||
self.dataset = dataset
|
||||
self.processor = processor
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.dataset[idx]
|
||||
|
||||
encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt")
|
||||
|
||||
# remove batch dimension
|
||||
encoding = {k:v.squeeze() for k,v in encoding.items()}
|
||||
return encoding
|
||||
|
||||
|
||||
processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
||||
train_dataset = ImageCaptioningDataset(dataset, processor)
|
||||
validation_dataset = ImageCaptioningDataset(val_dataset, processor)
|
||||
|
||||
# Next, we create a corresponding [PyTorch DataLoader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), which allows us to get batches of data from the dataset.
|
||||
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=4)
|
||||
validation_dataloader = DataLoader(validation_dataset, shuffle=False, batch_size=4)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
|
||||
|
||||
|
||||
# Dummy forward pass
|
||||
batch = next(iter(train_dataloader))
|
||||
outputs = model(input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
pixel_values=batch["pixel_values"],
|
||||
labels=batch["input_ids"])
|
||||
print(outputs.loss)
|
||||
|
||||
|
||||
# Train the model
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(device)
|
||||
model.to(device)
|
||||
|
||||
|
||||
num_epochs = 30
|
||||
train_loss_history = []
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print("Epoch:", epoch)
|
||||
avg_loss = 0
|
||||
with tqdm(total=len(train_dataloader)) as pbar:
|
||||
model.train()
|
||||
for batch_idx, batch in enumerate(train_dataloader):
|
||||
input_ids = batch.pop("input_ids").to(device)
|
||||
pixel_values = batch.pop("pixel_values").to(device)
|
||||
outputs = model(input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
labels=input_ids)
|
||||
loss = outputs.loss
|
||||
train_loss_history.append(loss.item())
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
avg_loss = (avg_loss * batch_idx + loss.item()) / (batch_idx + 1)
|
||||
pbar.update(1)
|
||||
pbar.set_description(f"Epoch {epoch}, Loss {loss:.4f}, Avg Loss {avg_loss:.4f}")
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
validation_loss = 0
|
||||
for batch_idx, batch in enumerate(validation_dataloader):
|
||||
input_ids = batch.pop("input_ids").to(device)
|
||||
pixel_values = batch.pop("pixel_values").to(device)
|
||||
outputs = model(input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
labels=input_ids)
|
||||
loss = outputs.loss
|
||||
validation_loss += loss.item()
|
||||
validation_loss /= len(validation_dataloader)
|
||||
print(f"Epoch {epoch}, Validation Loss {validation_loss:.4f}")
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
processor.save_pretrained(output_dir)
|
||||
|
||||
21
code/med-git/LICENSE
Normal file
21
code/med-git/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 NielsRogge
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
17
code/requirements.txt
Normal file
17
code/requirements.txt
Normal file
@ -0,0 +1,17 @@
|
||||
datasets
|
||||
loralib
|
||||
sentencepiece
|
||||
git+https://github.com/huggingface/transformers.git
|
||||
accelerate
|
||||
bitsandbytes
|
||||
git+https://github.com/huggingface/peft.git
|
||||
gradio
|
||||
appdirs
|
||||
fire
|
||||
numpy
|
||||
rouge_score
|
||||
openai
|
||||
torch
|
||||
sentencepiece
|
||||
tokenizers==0.12.1
|
||||
wandb
|
||||
1
data/med_alpaca_data_clean.json
Normal file
1
data/med_alpaca_data_clean.json
Normal file
File diff suppressed because one or more lines are too long
130889
data/radiologytraindata_cleaned.csv
Normal file
130889
data/radiologytraindata_cleaned.csv
Normal file
File diff suppressed because it is too large
Load Diff
16359
data/radiologyvaldata_cleaned.csv
Normal file
16359
data/radiologyvaldata_cleaned.csv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user