mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-01-19 11:53:13 +08:00
Initial commit
This commit is contained in:
commit
4c2fdc395f
191
.gitignore
vendored
Normal file
191
.gitignore
vendored
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
# Data
|
||||||
|
/data/
|
||||||
|
|
||||||
|
# Mac OS-specific storage files
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# vim
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
## https://github.com/github/gitignore/blob/e8554d85bf62e38d6db966a50d2064ac025fd82a/Python.gitignore
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# MkDocs documentation
|
||||||
|
docs/site/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
#uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# pixi
|
||||||
|
# pixi.lock should be committed to version control for reproducibility
|
||||||
|
# .pixi/ contains the environments and should not be committed
|
||||||
|
.pixi/
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
114
ARCHITECTURE.md
Normal file
114
ARCHITECTURE.md
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
## Model Architecture
|
||||||
|
|
||||||
|
1. Input(8 tokens):
|
||||||
|
```
|
||||||
|
# chem token([B, 600])
|
||||||
|
SMILES(string) + mpnn_encoder -> chem token
|
||||||
|
|
||||||
|
# morgan token([B, 1024]), maccs token([B, 167]), rdkit token([B, 210])
|
||||||
|
SMILES(string) + rdkit_encoder-> chem token, morgan token, maccs token, rdkit token
|
||||||
|
|
||||||
|
# comp token([B, 5])
|
||||||
|
Cationic_Lipid_to_mRNA_weight_ratio(float)
|
||||||
|
Cationic_Lipid_Mol_Ratio(float)
|
||||||
|
Phospholipid_Mol_Ratio(float)
|
||||||
|
Cholesterol_Mol_Ratio(float)
|
||||||
|
PEG_Lipid_Mol_Ratio(float)
|
||||||
|
|
||||||
|
# phys token([B, 12])
|
||||||
|
Purity_Pure(one-hot for Purity)
|
||||||
|
Purity_Crude(one-hot for Purity)
|
||||||
|
Mix_type_Microfluidic(one-hot for Mix_type)
|
||||||
|
Mix_type_Microfluidic(one-hot for Mix_type)
|
||||||
|
Cargo_type_mRNA(one-hot for Cargo_type)
|
||||||
|
Cargo_type_pDNA(one-hot for Cargo_type)
|
||||||
|
Cargo_type_siRNA(one-hot for Cargo_type)
|
||||||
|
Target_or_delivered_gene_FFL(one-hot for Target_or_delivered_gene)
|
||||||
|
Target_or_delivered_gene_Peptide_barcode(one-hot for Target_or_delivered_gene)
|
||||||
|
Target_or_delivered_gene_hEPO(one-hot for Target_or_delivered_gene)
|
||||||
|
Target_or_delivered_gene_FVII(one-hot for Target_or_delivered_gene)
|
||||||
|
Target_or_delivered_gene_GFP(one-hot for Target_or_delivered_gene)
|
||||||
|
|
||||||
|
# help token([B, 4])
|
||||||
|
Helper_lipid_ID_DOPE(one-hot for Helper_lipid_ID)
|
||||||
|
Helper_lipid_ID_DOTAP(one-hot for Helper_lipid_ID)
|
||||||
|
Helper_lipid_ID_DSPC(one-hot for Helper_lipid_ID)
|
||||||
|
Helper_lipid_ID_MDOA(one-hot for Helper_lipid_ID)
|
||||||
|
|
||||||
|
# exp token([B, 32])
|
||||||
|
Model_type_A549(one-hot for Model_type)
|
||||||
|
Model_type_BDMC(one-hot for Model_type)
|
||||||
|
Model_type_BMDM(one-hot for Model_type)
|
||||||
|
Model_type_HBEC_ALI(one-hot for Model_type)
|
||||||
|
Model_type_HEK293T(one-hot for Model_type)
|
||||||
|
Model_type_HeLa(one-hot for Model_type)
|
||||||
|
Model_type_IGROV1(one-hot for Model_type)
|
||||||
|
Model_type_Mouse(one-hot for Model_type)
|
||||||
|
Model_type_RAW264p7(one-hot for Model_type)
|
||||||
|
Delivery_target_dendritic_cell(one-hot for Delivery_target)
|
||||||
|
Delivery_target_generic_cell(one-hot for Delivery_target)
|
||||||
|
Delivery_target_liver(one-hot for Delivery_target)
|
||||||
|
Delivery_target_lung(one-hot for Delivery_target)
|
||||||
|
Delivery_target_lung_epithelium(one-hot for Delivery_target)
|
||||||
|
Delivery_target_macrophage(one-hot for Delivery_target)
|
||||||
|
Delivery_target_muscle(one-hot for Delivery_target)
|
||||||
|
Delivery_target_spleen(one-hot for Delivery_target)
|
||||||
|
Delivery_target_body(one-hot for Delivery_target)
|
||||||
|
Route_of_administration_in_vitro(one-hot for Route_of_administration)
|
||||||
|
Route_of_administration_intravenous(one-hot for Route_of_administration)
|
||||||
|
Route_of_administration_intramuscular(one-hot for Route_of_administration)
|
||||||
|
Route_of_administration_intratracheal(one-hot for Route_of_administration)
|
||||||
|
Sample_organization_type_individual(one-hot for Sample_organization_type)
|
||||||
|
Sample_organization_type_barcoded(one-hot for Sample_organization_type)
|
||||||
|
Value_name_log_luminescence(one-hot for Value_name)
|
||||||
|
Value_name_luminescence(one-hot for Value_name)
|
||||||
|
Value_name_FFL_silencing(one-hot for Value_name)
|
||||||
|
Value_name_Peptide_abundance(one-hot for Value_name)
|
||||||
|
Value_name_hEPO(one-hot for Value_name)
|
||||||
|
Value_name_FVII_silencing(one-hot for Value_name)
|
||||||
|
Value_name_GFP_delivery(one-hot for Value_name)
|
||||||
|
Value_name_Discretized_luminescence(one-hot for Value_name)
|
||||||
|
```
|
||||||
|
|
||||||
|
2. token projector
|
||||||
|
|
||||||
|
3. Channel split: to A[chem, Morgan, Maccs, rdkit] and B[comp, pays, help, exp]
|
||||||
|
|
||||||
|
4. Bi-directional cross attention
|
||||||
|
|
||||||
|
5. Token re-compose: back to [chem, Morgan, maccs, rdkit, comp, physical, help, exp]
|
||||||
|
|
||||||
|
6. Fusion layer(depends on user choice)
|
||||||
|
|
||||||
|
7. heads:
|
||||||
|
```
|
||||||
|
# regression head
|
||||||
|
size(float, training data already logged)
|
||||||
|
|
||||||
|
# classification head
|
||||||
|
toxic(boolean, 0/1)
|
||||||
|
|
||||||
|
# regression head
|
||||||
|
quantified_delivery(float, training data already z-scored)
|
||||||
|
|
||||||
|
# classification head
|
||||||
|
PDI_0_0to0_2(one-hot classification for PDI)
|
||||||
|
PDI_0_2to0_3(one-hot classification for PDI)
|
||||||
|
PDI_0_3to0_4(one-hot classification for PDI)
|
||||||
|
PDI_0_4to0_5(one-hot classification for PDI)
|
||||||
|
|
||||||
|
# classification head
|
||||||
|
Encapsulation_Efficiency_EE<50(one-hot classification for Encapsulation_Efficiency)
|
||||||
|
Encapsulation_Efficiency_50<=EE<80(one-hot classification for Encapsulation_Efficiency)
|
||||||
|
Encapsulation_Efficiency_80<EE<=100(one-hot classification for Encapsulation_Efficiency)
|
||||||
|
|
||||||
|
# distribution head
|
||||||
|
Biodistribution_lymph_nodes(float, sum of Biodistribution is 1)
|
||||||
|
Biodistribution_heart(float, sum of Biodistribution is 1)
|
||||||
|
Biodistribution_liver(float, sum of Biodistribution is 1)
|
||||||
|
Biodistribution_spleen(float, sum of Biodistribution is 1)
|
||||||
|
Biodistribution_lung(float, sum of Biodistribution is 1)
|
||||||
|
Biodistribution_kidney(float, sum of Biodistribution is 1)
|
||||||
|
Biodistribution_muscle(float, sum of Biodistribution is 1)
|
||||||
|
```
|
||||||
|
|
||||||
10
LICENSE
Normal file
10
LICENSE
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
The MIT License (MIT)
|
||||||
|
Copyright (c) 2026, Your name (or your organization/company/team)
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
108
Makefile
Normal file
108
Makefile
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
#################################################################################
|
||||||
|
# GLOBALS #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
PROJECT_NAME = lnp-ml
|
||||||
|
PYTHON_VERSION = 3.8
|
||||||
|
PYTHON_INTERPRETER = python
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# COMMANDS #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
## Install Python dependencies
|
||||||
|
.PHONY: requirements
|
||||||
|
requirements:
|
||||||
|
pixi install
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Delete all compiled Python files
|
||||||
|
.PHONY: clean
|
||||||
|
clean:
|
||||||
|
find . -type f -name "*.py[co]" -delete
|
||||||
|
find . -type d -name "__pycache__" -delete
|
||||||
|
|
||||||
|
|
||||||
|
## Lint using ruff (use `make format` to do formatting)
|
||||||
|
.PHONY: lint
|
||||||
|
lint:
|
||||||
|
ruff format --check
|
||||||
|
ruff check
|
||||||
|
|
||||||
|
## Format source code with ruff
|
||||||
|
.PHONY: format
|
||||||
|
format:
|
||||||
|
ruff check --fix
|
||||||
|
ruff format
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Set up Python interpreter environment
|
||||||
|
.PHONY: create_environment
|
||||||
|
create_environment:
|
||||||
|
|
||||||
|
@echo ">>> Pixi environment will be created when running 'make requirements'"
|
||||||
|
|
||||||
|
@echo ">>> Activate with:\npixi shell"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# PROJECT RULES #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
## Clean raw data (raw -> interim)
|
||||||
|
.PHONY: clean_data
|
||||||
|
clean_data: requirements
|
||||||
|
$(PYTHON_INTERPRETER) scripts/data_cleaning.py
|
||||||
|
|
||||||
|
## Process dataset (interim -> processed)
|
||||||
|
.PHONY: data
|
||||||
|
data: requirements
|
||||||
|
$(PYTHON_INTERPRETER) scripts/process_data.py
|
||||||
|
|
||||||
|
## Train model
|
||||||
|
.PHONY: train
|
||||||
|
train: requirements
|
||||||
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train
|
||||||
|
|
||||||
|
## Train with hyperparameter tuning
|
||||||
|
.PHONY: tune
|
||||||
|
tune: requirements
|
||||||
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune
|
||||||
|
|
||||||
|
## Run predictions
|
||||||
|
.PHONY: predict
|
||||||
|
predict: requirements
|
||||||
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict
|
||||||
|
|
||||||
|
## Test model on test set (with detailed metrics)
|
||||||
|
.PHONY: test
|
||||||
|
test: requirements
|
||||||
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict test
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Self Documenting Commands #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
.DEFAULT_GOAL := help
|
||||||
|
|
||||||
|
define PRINT_HELP_PYSCRIPT
|
||||||
|
import re, sys; \
|
||||||
|
lines = '\n'.join([line for line in sys.stdin]); \
|
||||||
|
matches = re.findall(r'\n## (.*)\n[\s\S]+?\n([a-zA-Z_-]+):', lines); \
|
||||||
|
print('Available rules:\n'); \
|
||||||
|
print('\n'.join(['{:25}{}'.format(*reversed(match)) for match in matches]))
|
||||||
|
endef
|
||||||
|
export PRINT_HELP_PYSCRIPT
|
||||||
|
|
||||||
|
help:
|
||||||
|
@$(PYTHON_INTERPRETER) -c "${PRINT_HELP_PYSCRIPT}" < $(MAKEFILE_LIST)
|
||||||
61
README.md
Normal file
61
README.md
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
# lnp-ml
|
||||||
|
|
||||||
|
<a target="_blank" href="https://cookiecutter-data-science.drivendata.org/">
|
||||||
|
<img src="https://img.shields.io/badge/CCDS-Project%20template-328F97?logo=cookiecutter" />
|
||||||
|
</a>
|
||||||
|
|
||||||
|
A short description of the project.
|
||||||
|
|
||||||
|
## Project Organization
|
||||||
|
|
||||||
|
```
|
||||||
|
├── LICENSE <- Open-source license if one is chosen
|
||||||
|
├── Makefile <- Makefile with convenience commands like `make data` or `make train`
|
||||||
|
├── README.md <- The top-level README for developers using this project.
|
||||||
|
├── data
|
||||||
|
│ ├── external <- Data from third party sources.
|
||||||
|
│ ├── interim <- Intermediate data that has been transformed.
|
||||||
|
│ ├── processed <- The final, canonical data sets for modeling.
|
||||||
|
│ └── raw <- The original, immutable data dump.
|
||||||
|
│
|
||||||
|
├── docs <- A default mkdocs project; see www.mkdocs.org for details
|
||||||
|
│
|
||||||
|
├── models <- Trained and serialized models, model predictions, or model summaries
|
||||||
|
│
|
||||||
|
├── notebooks <- Jupyter notebooks. Naming convention is a number (for ordering),
|
||||||
|
│ the creator's initials, and a short `-` delimited description, e.g.
|
||||||
|
│ `1.0-jqp-initial-data-exploration`.
|
||||||
|
│
|
||||||
|
├── pyproject.toml <- Project configuration file with package metadata for
|
||||||
|
│ lnp_ml and configuration for tools like black
|
||||||
|
│
|
||||||
|
├── references <- Data dictionaries, manuals, and all other explanatory materials.
|
||||||
|
│
|
||||||
|
├── reports <- Generated analysis as HTML, PDF, LaTeX, etc.
|
||||||
|
│ └── figures <- Generated graphics and figures to be used in reporting
|
||||||
|
│
|
||||||
|
├── requirements.txt <- The requirements file for reproducing the analysis environment, e.g.
|
||||||
|
│ generated with `pip freeze > requirements.txt`
|
||||||
|
│
|
||||||
|
├── setup.cfg <- Configuration file for flake8
|
||||||
|
│
|
||||||
|
└── lnp_ml <- Source code for use in this project.
|
||||||
|
│
|
||||||
|
├── __init__.py <- Makes lnp_ml a Python module
|
||||||
|
│
|
||||||
|
├── config.py <- Store useful variables and configuration
|
||||||
|
│
|
||||||
|
├── dataset.py <- Scripts to download or generate data
|
||||||
|
│
|
||||||
|
├── features.py <- Code to create features for modeling
|
||||||
|
│
|
||||||
|
├── modeling
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── predict.py <- Code to run model inference with trained models
|
||||||
|
│ └── train.py <- Code to train models
|
||||||
|
│
|
||||||
|
└── plots.py <- Code to create visualizations
|
||||||
|
```
|
||||||
|
|
||||||
|
--------
|
||||||
|
|
||||||
35
cal_features.py
Normal file
35
cal_features.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import multiprocessing as mp
|
||||||
|
from typing import List
|
||||||
|
from tqdm import tqdm
|
||||||
|
from rdkit import Chem
|
||||||
|
from rdkit.Chem import (
|
||||||
|
Mol,
|
||||||
|
AllChem,
|
||||||
|
MACCSkeys,
|
||||||
|
Descriptors
|
||||||
|
)
|
||||||
|
|
||||||
|
mp.set_start_method('fork') # Screw MacOS
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def get_morgan(mol: Mol, radius: int = 2, nBits: int = 1024) -> List[int]:
|
||||||
|
return AllChem.GetMorganFingerprintAsBitVect(
|
||||||
|
mol,
|
||||||
|
radius=radius,
|
||||||
|
nBits=nBits,
|
||||||
|
useChirality=False
|
||||||
|
).ToList()
|
||||||
|
|
||||||
|
|
||||||
|
def get_maccs(mol: Mol) -> List[int]:
|
||||||
|
return MACCSkeys.GenMACCSKeys(mol).ToList()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rdkit_descriptors(mol: Mol) -> List[float]:
|
||||||
|
desc_dict = Descriptors.CalcMolDescriptors(mol)
|
||||||
|
return list(desc_dict.values())
|
||||||
0
docs/.gitkeep
Normal file
0
docs/.gitkeep
Normal file
4
lnp_ml/__init__.py
Normal file
4
lnp_ml/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Lazy imports to avoid loading heavy dependencies (chemprop, rdkit) on every import
|
||||||
|
# Use explicit imports when needed:
|
||||||
|
# from lnp_ml.featurization import RDKitFeaturizer
|
||||||
|
# from lnp_ml.modeling import LNPModel
|
||||||
32
lnp_ml/config.py
Normal file
32
lnp_ml/config.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
# Load environment variables from .env file if it exists
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
PROJ_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
logger.info(f"PROJ_ROOT path is: {PROJ_ROOT}")
|
||||||
|
|
||||||
|
DATA_DIR = PROJ_ROOT / "data"
|
||||||
|
RAW_DATA_DIR = DATA_DIR / "raw"
|
||||||
|
INTERIM_DATA_DIR = DATA_DIR / "interim"
|
||||||
|
PROCESSED_DATA_DIR = DATA_DIR / "processed"
|
||||||
|
EXTERNAL_DATA_DIR = DATA_DIR / "external"
|
||||||
|
|
||||||
|
MODELS_DIR = PROJ_ROOT / "models"
|
||||||
|
|
||||||
|
REPORTS_DIR = PROJ_ROOT / "reports"
|
||||||
|
FIGURES_DIR = REPORTS_DIR / "figures"
|
||||||
|
|
||||||
|
# If tqdm is installed, configure loguru with tqdm.write
|
||||||
|
# https://github.com/Delgan/loguru/issues/135
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
logger.remove(0)
|
||||||
|
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
348
lnp_ml/dataset.py
Normal file
348
lnp_ml/dataset.py
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
"""数据集处理模块"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 列名配置 ============
|
||||||
|
|
||||||
|
# SMILES 列
|
||||||
|
SMILES_COL = "smiles"
|
||||||
|
|
||||||
|
# comp token: 配方比例 [5]
|
||||||
|
COMP_COLS = [
|
||||||
|
"Cationic_Lipid_to_mRNA_weight_ratio",
|
||||||
|
"Cationic_Lipid_Mol_Ratio",
|
||||||
|
"Phospholipid_Mol_Ratio",
|
||||||
|
"Cholesterol_Mol_Ratio",
|
||||||
|
"PEG_Lipid_Mol_Ratio",
|
||||||
|
]
|
||||||
|
|
||||||
|
# phys token: 物理/实验参数 one-hot [12]
|
||||||
|
# 需要从原始列生成 one-hot
|
||||||
|
PHYS_ONEHOT_SPECS = {
|
||||||
|
"Purity": ["Pure", "Crude"],
|
||||||
|
"Mix_type": ["Microfluidic", "Pipetting"],
|
||||||
|
"Cargo_type": ["mRNA", "pDNA", "siRNA"],
|
||||||
|
"Target_or_delivered_gene": ["FFL", "Peptide_barcode", "hEPO", "FVII", "GFP"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# help token: Helper lipid one-hot [4]
|
||||||
|
HELP_COLS = [
|
||||||
|
"Helper_lipid_ID_DOPE",
|
||||||
|
"Helper_lipid_ID_DOTAP",
|
||||||
|
"Helper_lipid_ID_DSPC",
|
||||||
|
"Helper_lipid_ID_MDOA",
|
||||||
|
]
|
||||||
|
|
||||||
|
# exp token: 实验条件 one-hot [32]
|
||||||
|
EXP_ONEHOT_SPECS = {
|
||||||
|
"Model_type": ["A549", "BDMC", "BMDM", "HBEC_ALI", "HEK293T", "HeLa", "IGROV1", "Mouse", "RAW264p7"],
|
||||||
|
"Delivery_target": ["body", "dendritic_cell", "generic_cell", "liver", "lung", "lung_epithelium", "macrophage", "muscle", "spleen"],
|
||||||
|
"Route_of_administration": ["in_vitro", "intramuscular", "intratracheal", "intravenous"],
|
||||||
|
"Batch_or_individual_or_barcoded": ["Barcoded", "Individual"],
|
||||||
|
"Value_name": ["log_luminescence", "luminescence", "FFL_silencing", "Peptide_abundance", "hEPO", "FVII_silencing", "GFP_delivery", "Discretized_luminescence"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Target 列
|
||||||
|
TARGET_REGRESSION = ["size", "quantified_delivery"]
|
||||||
|
TARGET_CLASSIFICATION_PDI = ["PDI_0_0to0_2", "PDI_0_2to0_3", "PDI_0_3to0_4", "PDI_0_4to0_5"]
|
||||||
|
TARGET_CLASSIFICATION_EE = ["Encapsulation_Efficiency_EE<50", "Encapsulation_Efficiency_50<=EE<80", "Encapsulation_Efficiency_80<EE<=100"]
|
||||||
|
TARGET_TOXIC = "toxic"
|
||||||
|
TARGET_BIODIST = [
|
||||||
|
"Biodistribution_lymph_nodes",
|
||||||
|
"Biodistribution_heart",
|
||||||
|
"Biodistribution_liver",
|
||||||
|
"Biodistribution_spleen",
|
||||||
|
"Biodistribution_lung",
|
||||||
|
"Biodistribution_kidney",
|
||||||
|
"Biodistribution_muscle",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_onehot_cols(prefix: str, values: List[str]) -> List[str]:
|
||||||
|
"""生成 one-hot 列名"""
|
||||||
|
return [f"{prefix}_{v}" for v in values]
|
||||||
|
|
||||||
|
|
||||||
|
def process_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
处理原始 DataFrame,生成模型所需的所有列。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 原始 DataFrame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的 DataFrame,包含所有需要的列
|
||||||
|
"""
|
||||||
|
df = df.copy()
|
||||||
|
|
||||||
|
# 1. 处理 phys token 的 one-hot 列(如果不存在则生成)
|
||||||
|
for col, values in PHYS_ONEHOT_SPECS.items():
|
||||||
|
for v in values:
|
||||||
|
onehot_col = f"{col}_{v}"
|
||||||
|
if onehot_col not in df.columns:
|
||||||
|
if col in df.columns:
|
||||||
|
df[onehot_col] = (df[col] == v).astype(float)
|
||||||
|
else:
|
||||||
|
df[onehot_col] = 0.0
|
||||||
|
|
||||||
|
# 2. 处理 exp token 的 one-hot 列(如果不存在则生成)
|
||||||
|
for col, values in EXP_ONEHOT_SPECS.items():
|
||||||
|
for v in values:
|
||||||
|
onehot_col = f"{col}_{v}"
|
||||||
|
if onehot_col not in df.columns:
|
||||||
|
if col in df.columns:
|
||||||
|
df[onehot_col] = (df[col] == v).astype(float)
|
||||||
|
else:
|
||||||
|
df[onehot_col] = 0.0
|
||||||
|
|
||||||
|
# 3. 确保 comp 列存在且为 float
|
||||||
|
for col in COMP_COLS:
|
||||||
|
if col in df.columns:
|
||||||
|
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
|
||||||
|
else:
|
||||||
|
df[col] = 0.0
|
||||||
|
|
||||||
|
# 4. 确保 help 列存在
|
||||||
|
for col in HELP_COLS:
|
||||||
|
if col not in df.columns:
|
||||||
|
df[col] = 0.0
|
||||||
|
else:
|
||||||
|
df[col] = df[col].fillna(0.0).astype(float)
|
||||||
|
|
||||||
|
# 5. 处理 target 列
|
||||||
|
# size: 已经 log 过,填充缺失值
|
||||||
|
if "size" in df.columns:
|
||||||
|
df["size"] = pd.to_numeric(df["size"], errors="coerce")
|
||||||
|
|
||||||
|
# quantified_delivery: 已经 z-score 过
|
||||||
|
if "quantified_delivery" in df.columns:
|
||||||
|
df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce")
|
||||||
|
|
||||||
|
# toxic: 0/1
|
||||||
|
if TARGET_TOXIC in df.columns:
|
||||||
|
df[TARGET_TOXIC] = pd.to_numeric(df[TARGET_TOXIC], errors="coerce").fillna(-1).astype(int)
|
||||||
|
|
||||||
|
# PDI 和 EE 的 one-hot 分类
|
||||||
|
for col in TARGET_CLASSIFICATION_PDI + TARGET_CLASSIFICATION_EE:
|
||||||
|
if col in df.columns:
|
||||||
|
df[col] = df[col].fillna(0).astype(float)
|
||||||
|
|
||||||
|
# Biodistribution
|
||||||
|
for col in TARGET_BIODIST:
|
||||||
|
if col in df.columns:
|
||||||
|
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def get_phys_cols() -> List[str]:
|
||||||
|
"""获取 phys token 的所有列名"""
|
||||||
|
cols = []
|
||||||
|
for col, values in PHYS_ONEHOT_SPECS.items():
|
||||||
|
cols.extend(get_onehot_cols(col, values))
|
||||||
|
return cols
|
||||||
|
|
||||||
|
|
||||||
|
def get_exp_cols() -> List[str]:
|
||||||
|
"""获取 exp token 的所有列名"""
|
||||||
|
cols = []
|
||||||
|
for col, values in EXP_ONEHOT_SPECS.items():
|
||||||
|
cols.extend(get_onehot_cols(col, values))
|
||||||
|
return cols
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LNPDatasetConfig:
|
||||||
|
"""数据集配置"""
|
||||||
|
comp_cols: List[str] = None
|
||||||
|
phys_cols: List[str] = None
|
||||||
|
help_cols: List[str] = None
|
||||||
|
exp_cols: List[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.comp_cols = self.comp_cols or COMP_COLS
|
||||||
|
self.phys_cols = self.phys_cols or get_phys_cols()
|
||||||
|
self.help_cols = self.help_cols or HELP_COLS
|
||||||
|
self.exp_cols = self.exp_cols or get_exp_cols()
|
||||||
|
|
||||||
|
|
||||||
|
class LNPDataset(Dataset):
|
||||||
|
"""
|
||||||
|
LNP 数据集,用于 PyTorch DataLoader。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- smiles: str
|
||||||
|
- tabular: Dict[str, Tensor] with keys "comp", "phys", "help", "exp"
|
||||||
|
- targets: Dict[str, Tensor] with keys "size", "pdi", "ee", "delivery", "biodist", "toxic"
|
||||||
|
- mask: Dict[str, Tensor] 标记哪些 target 有效(非缺失)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
config: Optional[LNPDatasetConfig] = None,
|
||||||
|
):
|
||||||
|
self.config = config or LNPDatasetConfig()
|
||||||
|
self.df = process_dataframe(df)
|
||||||
|
|
||||||
|
# 提取数据
|
||||||
|
self.smiles = self.df[SMILES_COL].tolist()
|
||||||
|
|
||||||
|
# Tabular features
|
||||||
|
self.comp = self.df[self.config.comp_cols].values.astype(np.float32)
|
||||||
|
self.phys = self.df[self.config.phys_cols].values.astype(np.float32)
|
||||||
|
self.help = self.df[self.config.help_cols].values.astype(np.float32)
|
||||||
|
self.exp = self.df[self.config.exp_cols].values.astype(np.float32)
|
||||||
|
|
||||||
|
# Targets
|
||||||
|
self.size = self.df["size"].values.astype(np.float32) if "size" in self.df.columns else None
|
||||||
|
self.delivery = self.df["quantified_delivery"].values.astype(np.float32) if "quantified_delivery" in self.df.columns else None
|
||||||
|
self.toxic = self.df[TARGET_TOXIC].values.astype(np.int64) if TARGET_TOXIC in self.df.columns else None
|
||||||
|
|
||||||
|
# PDI: one-hot -> class index
|
||||||
|
if all(col in self.df.columns for col in TARGET_CLASSIFICATION_PDI):
|
||||||
|
pdi_onehot = self.df[TARGET_CLASSIFICATION_PDI].values
|
||||||
|
self.pdi = np.argmax(pdi_onehot, axis=1).astype(np.int64)
|
||||||
|
self.pdi_valid = pdi_onehot.sum(axis=1) > 0
|
||||||
|
else:
|
||||||
|
self.pdi = None
|
||||||
|
self.pdi_valid = None
|
||||||
|
|
||||||
|
# EE: one-hot -> class index
|
||||||
|
if all(col in self.df.columns for col in TARGET_CLASSIFICATION_EE):
|
||||||
|
ee_onehot = self.df[TARGET_CLASSIFICATION_EE].values
|
||||||
|
self.ee = np.argmax(ee_onehot, axis=1).astype(np.int64)
|
||||||
|
self.ee_valid = ee_onehot.sum(axis=1) > 0
|
||||||
|
else:
|
||||||
|
self.ee = None
|
||||||
|
self.ee_valid = None
|
||||||
|
|
||||||
|
# Biodistribution
|
||||||
|
if all(col in self.df.columns for col in TARGET_BIODIST):
|
||||||
|
self.biodist = self.df[TARGET_BIODIST].values.astype(np.float32)
|
||||||
|
self.biodist_valid = self.biodist.sum(axis=1) > 0
|
||||||
|
else:
|
||||||
|
self.biodist = None
|
||||||
|
self.biodist_valid = None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.smiles)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Dict:
|
||||||
|
item = {
|
||||||
|
"smiles": self.smiles[idx],
|
||||||
|
"tabular": {
|
||||||
|
"comp": torch.from_numpy(self.comp[idx]),
|
||||||
|
"phys": torch.from_numpy(self.phys[idx]),
|
||||||
|
"help": torch.from_numpy(self.help[idx]),
|
||||||
|
"exp": torch.from_numpy(self.exp[idx]),
|
||||||
|
},
|
||||||
|
"targets": {},
|
||||||
|
"mask": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Targets and masks
|
||||||
|
if self.size is not None:
|
||||||
|
item["targets"]["size"] = torch.tensor(self.size[idx], dtype=torch.float32)
|
||||||
|
item["mask"]["size"] = torch.tensor(not np.isnan(self.size[idx]), dtype=torch.bool)
|
||||||
|
|
||||||
|
if self.delivery is not None:
|
||||||
|
item["targets"]["delivery"] = torch.tensor(self.delivery[idx], dtype=torch.float32)
|
||||||
|
item["mask"]["delivery"] = torch.tensor(not np.isnan(self.delivery[idx]), dtype=torch.bool)
|
||||||
|
|
||||||
|
if self.toxic is not None:
|
||||||
|
item["targets"]["toxic"] = torch.tensor(self.toxic[idx], dtype=torch.long)
|
||||||
|
item["mask"]["toxic"] = torch.tensor(self.toxic[idx] >= 0, dtype=torch.bool)
|
||||||
|
|
||||||
|
if self.pdi is not None:
|
||||||
|
item["targets"]["pdi"] = torch.tensor(self.pdi[idx], dtype=torch.long)
|
||||||
|
item["mask"]["pdi"] = torch.tensor(self.pdi_valid[idx], dtype=torch.bool)
|
||||||
|
|
||||||
|
if self.ee is not None:
|
||||||
|
item["targets"]["ee"] = torch.tensor(self.ee[idx], dtype=torch.long)
|
||||||
|
item["mask"]["ee"] = torch.tensor(self.ee_valid[idx], dtype=torch.bool)
|
||||||
|
|
||||||
|
if self.biodist is not None:
|
||||||
|
item["targets"]["biodist"] = torch.from_numpy(self.biodist[idx])
|
||||||
|
item["mask"]["biodist"] = torch.tensor(self.biodist_valid[idx], dtype=torch.bool)
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch: List[Dict]) -> Dict:
|
||||||
|
"""
|
||||||
|
自定义 collate 函数,用于 DataLoader。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- smiles: List[str]
|
||||||
|
- tabular: Dict[str, Tensor] with batched tensors
|
||||||
|
- targets: Dict[str, Tensor] with batched tensors
|
||||||
|
- mask: Dict[str, Tensor] with batched masks
|
||||||
|
"""
|
||||||
|
smiles = [item["smiles"] for item in batch]
|
||||||
|
|
||||||
|
tabular = {
|
||||||
|
key: torch.stack([item["tabular"][key] for item in batch])
|
||||||
|
for key in batch[0]["tabular"].keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
targets = {}
|
||||||
|
mask = {}
|
||||||
|
for key in batch[0]["targets"].keys():
|
||||||
|
targets[key] = torch.stack([item["targets"][key] for item in batch])
|
||||||
|
mask[key] = torch.stack([item["mask"][key] for item in batch])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"smiles": smiles,
|
||||||
|
"tabular": tabular,
|
||||||
|
"targets": targets,
|
||||||
|
"mask": mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(
|
||||||
|
path: Path,
|
||||||
|
train_ratio: float = 0.8,
|
||||||
|
val_ratio: float = 0.1,
|
||||||
|
seed: int = 42,
|
||||||
|
) -> Tuple[LNPDataset, LNPDataset, LNPDataset]:
|
||||||
|
"""
|
||||||
|
加载并划分数据集。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: CSV 文件路径
|
||||||
|
train_ratio: 训练集比例
|
||||||
|
val_ratio: 验证集比例
|
||||||
|
seed: 随机种子
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(train_dataset, val_dataset, test_dataset)
|
||||||
|
"""
|
||||||
|
df = pd.read_csv(path)
|
||||||
|
|
||||||
|
# 随机打乱
|
||||||
|
df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
|
||||||
|
|
||||||
|
n = len(df)
|
||||||
|
n_train = int(n * train_ratio)
|
||||||
|
n_val = int(n * val_ratio)
|
||||||
|
|
||||||
|
train_df = df.iloc[:n_train]
|
||||||
|
val_df = df.iloc[n_train:n_train + n_val]
|
||||||
|
test_df = df.iloc[n_train + n_val:]
|
||||||
|
|
||||||
|
config = LNPDatasetConfig()
|
||||||
|
|
||||||
|
return (
|
||||||
|
LNPDataset(train_df, config),
|
||||||
|
LNPDataset(val_df, config),
|
||||||
|
LNPDataset(test_df, config),
|
||||||
|
)
|
||||||
29
lnp_ml/features.py
Normal file
29
lnp_ml/features.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from tqdm import tqdm
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from lnp_ml.config import PROCESSED_DATA_DIR
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
# ---- REPLACE DEFAULT PATHS AS APPROPRIATE ----
|
||||||
|
input_path: Path = PROCESSED_DATA_DIR / "dataset.csv",
|
||||||
|
output_path: Path = PROCESSED_DATA_DIR / "features.csv",
|
||||||
|
# -----------------------------------------
|
||||||
|
):
|
||||||
|
# ---- REPLACE THIS WITH YOUR OWN CODE ----
|
||||||
|
logger.info("Generating features from dataset...")
|
||||||
|
for i in tqdm(range(10), total=10):
|
||||||
|
if i == 5:
|
||||||
|
logger.info("Something happened for iteration 5.")
|
||||||
|
logger.success("Features generation complete.")
|
||||||
|
# -----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
4
lnp_ml/featurization/__init__.py
Normal file
4
lnp_ml/featurization/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from lnp_ml.featurization.smiles import RDKitFeaturizer
|
||||||
|
|
||||||
|
__all__ = ["RDKitFeaturizer"]
|
||||||
|
|
||||||
156
lnp_ml/featurization/smiles.py
Normal file
156
lnp_ml/featurization/smiles.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
"""SMILES 分子特征提取器"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional, Dict
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Suppress RDKit deprecation warnings
|
||||||
|
from rdkit import RDLogger
|
||||||
|
RDLogger.DisableLog("rdApp.*")
|
||||||
|
|
||||||
|
from rdkit import Chem
|
||||||
|
from rdkit.Chem import AllChem, MACCSkeys, Descriptors
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from chemprop.utils import load_checkpoint
|
||||||
|
from chemprop.features import mol2graph
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RDKitFeaturizer:
|
||||||
|
"""
|
||||||
|
SMILES -> RDKit 特征向量,返回 Dict[str, np.ndarray]。
|
||||||
|
"""
|
||||||
|
|
||||||
|
morgan_radius: int = 2
|
||||||
|
morgan_nbits: int = 1024
|
||||||
|
|
||||||
|
def transform(self, smiles_list: List[str]) -> Dict[str, np.ndarray]:
|
||||||
|
"""SMILES 特征字典 -> value: (N, D_i) arrays"""
|
||||||
|
encoded = [self._encode_one(s) for s in smiles_list]
|
||||||
|
return {
|
||||||
|
"morgan": np.vstack([enc["morgan"] for enc in encoded]),
|
||||||
|
"maccs": np.vstack([enc["maccs"] for enc in encoded]),
|
||||||
|
"desc": np.vstack([enc["desc"] for enc in encoded])
|
||||||
|
}
|
||||||
|
|
||||||
|
def _encode_morgan(self, mol: Chem.Mol) -> np.ndarray:
|
||||||
|
return np.array(AllChem.GetMorganFingerprintAsBitVect(
|
||||||
|
mol, radius=self.morgan_radius, nBits=self.morgan_nbits
|
||||||
|
).ToList(), dtype=np.float32)
|
||||||
|
|
||||||
|
def _encode_maccs(self, mol: Chem.Mol) -> np.ndarray:
|
||||||
|
return np.array(MACCSkeys.GenMACCSKeys(mol).ToList(), dtype=np.float32)
|
||||||
|
|
||||||
|
def _encode_desc(self, mol: Chem.Mol) -> np.ndarray:
|
||||||
|
return np.array(list(Descriptors.CalcMolDescriptors(mol).values()), dtype=np.float32)
|
||||||
|
|
||||||
|
def _encode_one(self, smiles: str) -> Dict[str, np.ndarray]:
|
||||||
|
mol = Chem.MolFromSmiles(smiles)
|
||||||
|
if mol is None:
|
||||||
|
raise ValueError(f"Invalid SMILES: {smiles!r}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"morgan": self._encode_morgan(mol),
|
||||||
|
"maccs": self._encode_maccs(mol),
|
||||||
|
"desc": self._encode_desc(mol)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MPNNFeaturizer:
|
||||||
|
"""
|
||||||
|
SMILES -> D-MPNN 预训练特征向量 (N, hidden_size=600)。
|
||||||
|
|
||||||
|
从训练好的 chemprop 模型中提取 D-MPNN 编码器的输出作为分子特征。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: 模型检查点路径(.pt文件)
|
||||||
|
device: 计算设备 ("cpu" 或 "cuda")
|
||||||
|
ensemble_paths: 可选,多个模型路径列表用于集成(取平均)
|
||||||
|
"""
|
||||||
|
checkpoint_path: Optional[str] = None
|
||||||
|
device: str = "cpu"
|
||||||
|
ensemble_paths: Optional[List[str]] = None
|
||||||
|
|
||||||
|
# 内部状态(不由用户设置)
|
||||||
|
_encoders: List = field(default_factory=list, init=False, repr=False)
|
||||||
|
_hidden_size: int = field(default=0, init=False, repr=False)
|
||||||
|
_initialized: bool = field(default=False, init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""延迟初始化,在首次调用 transform 时加载模型"""
|
||||||
|
if self.checkpoint_path is None and self.ensemble_paths is None:
|
||||||
|
raise ValueError("必须提供 checkpoint_path 或 ensemble_paths")
|
||||||
|
|
||||||
|
def _lazy_init(self) -> None:
|
||||||
|
"""延迟加载模型,避免在创建对象时就加载大模型"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
device = torch.device(self.device)
|
||||||
|
paths = self.ensemble_paths or [self.checkpoint_path]
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
model = load_checkpoint(path, device=device)
|
||||||
|
model.eval()
|
||||||
|
# 提取 MPNEncoder(D-MPNN 核心部分)
|
||||||
|
encoder = model.encoder.encoder[0]
|
||||||
|
# 冻结参数
|
||||||
|
for param in encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self._encoders.append(encoder)
|
||||||
|
self._hidden_size = encoder.hidden_size
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def transform(self, smiles_list: List[str]) -> Dict[str, np.ndarray]:
|
||||||
|
"""
|
||||||
|
SMILES 列表 -> tuple of (N, hidden_size) array
|
||||||
|
|
||||||
|
Args:
|
||||||
|
smiles_list: SMILES 字符串列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple 包含一个形状为 (N, hidden_size) 的 numpy 数组
|
||||||
|
如果使用集成模型,返回所有模型输出的平均值
|
||||||
|
"""
|
||||||
|
self._lazy_init()
|
||||||
|
|
||||||
|
# 验证 SMILES 有效性
|
||||||
|
for smi in smiles_list:
|
||||||
|
mol = Chem.MolFromSmiles(smi)
|
||||||
|
if mol is None:
|
||||||
|
raise ValueError(f"Invalid SMILES: {smi!r}")
|
||||||
|
|
||||||
|
# 构建分子图(批量处理)
|
||||||
|
batch_mol_graph = mol2graph(smiles_list)
|
||||||
|
|
||||||
|
# 从所有编码器提取特征
|
||||||
|
all_features = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for encoder in self._encoders:
|
||||||
|
features = encoder(batch_mol_graph)
|
||||||
|
all_features.append(features.cpu().numpy())
|
||||||
|
|
||||||
|
# 如果是集成模型,取平均
|
||||||
|
if len(all_features) > 1:
|
||||||
|
features_array = np.mean(all_features, axis=0).astype(np.float32)
|
||||||
|
else:
|
||||||
|
features_array = all_features[0].astype(np.float32)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mpnn": features_array
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_size(self) -> int:
|
||||||
|
"""返回特征维度"""
|
||||||
|
self._lazy_init()
|
||||||
|
return self._hidden_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_models(self) -> int:
|
||||||
|
"""返回集成模型数量"""
|
||||||
|
self._lazy_init()
|
||||||
|
return len(self._encoders)
|
||||||
11
lnp_ml/modeling/__init__.py
Normal file
11
lnp_ml/modeling/__init__.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||||
|
from lnp_ml.modeling.heads import MultiTaskHead, RegressionHead, ClassificationHead, DistributionHead
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LNPModel",
|
||||||
|
"LNPModelWithoutMPNN",
|
||||||
|
"MultiTaskHead",
|
||||||
|
"RegressionHead",
|
||||||
|
"ClassificationHead",
|
||||||
|
"DistributionHead",
|
||||||
|
]
|
||||||
5
lnp_ml/modeling/encoders/__init__.py
Normal file
5
lnp_ml/modeling/encoders/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from lnp_ml.modeling.encoders.rdkit_encoder import CachedRDKitEncoder
|
||||||
|
from lnp_ml.modeling.encoders.mpnn_encoder import CachedMPNNEncoder
|
||||||
|
|
||||||
|
__all__ = ["CachedRDKitEncoder", "CachedMPNNEncoder"]
|
||||||
|
|
||||||
67
lnp_ml/modeling/encoders/mpnn_encoder.py
Normal file
67
lnp_ml/modeling/encoders/mpnn_encoder.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from typing import List, Optional, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lnp_ml.featurization.smiles import MPNNFeaturizer
|
||||||
|
|
||||||
|
|
||||||
|
class CachedMPNNEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
带内存缓存的 D-MPNN 特征提取模块。
|
||||||
|
|
||||||
|
- 使用预训练 chemprop 模型的 encoder 提取特征
|
||||||
|
- 不可训练,不参与反向传播
|
||||||
|
- 缓存已计算的 SMILES 特征,避免重复计算
|
||||||
|
- forward 返回 Dict[str, Tensor],key: "mpnn"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
checkpoint_path: Optional[str] = None,
|
||||||
|
ensemble_paths: Optional[List[str]] = None,
|
||||||
|
device: str = "cpu",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._featurizer = MPNNFeaturizer(
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
ensemble_paths=ensemble_paths,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self._cache: Dict[str, np.ndarray] = {}
|
||||||
|
|
||||||
|
def forward(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
SMILES 列表 -> Dict[str, Tensor]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"mpnn": (N, hidden_size)}
|
||||||
|
"""
|
||||||
|
# 分离:已缓存 vs 需计算
|
||||||
|
to_compute = [s for s in smiles_list if s not in self._cache]
|
||||||
|
|
||||||
|
# 批量计算未缓存的
|
||||||
|
if to_compute:
|
||||||
|
new_features = self._featurizer.transform(to_compute)
|
||||||
|
for idx, smiles in enumerate(to_compute):
|
||||||
|
self._cache[smiles] = new_features["mpnn"][idx]
|
||||||
|
|
||||||
|
# 按原顺序组装结果
|
||||||
|
return {
|
||||||
|
"mpnn": torch.from_numpy(np.stack([self._cache[s] for s in smiles_list]))
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""清空缓存"""
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_size(self) -> int:
|
||||||
|
"""当前缓存的 SMILES 数量"""
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_size(self) -> int:
|
||||||
|
"""特征维度"""
|
||||||
|
return self._featurizer.hidden_size
|
||||||
58
lnp_ml/modeling/encoders/rdkit_encoder.py
Normal file
58
lnp_ml/modeling/encoders/rdkit_encoder.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from lnp_ml.featurization.smiles import RDKitFeaturizer
|
||||||
|
|
||||||
|
|
||||||
|
class CachedRDKitEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
带内存缓存的 RDKit 特征提取模块。
|
||||||
|
|
||||||
|
- 不可训练,不参与反向传播
|
||||||
|
- 缓存已计算的 SMILES 特征,避免重复计算
|
||||||
|
- forward 返回 Dict[str, Tensor],keys: "morgan", "maccs", "desc"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, morgan_radius: int = 2, morgan_nbits: int = 1024) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._featurizer = RDKitFeaturizer(
|
||||||
|
morgan_radius=morgan_radius,
|
||||||
|
morgan_nbits=morgan_nbits,
|
||||||
|
)
|
||||||
|
self._cache: Dict[str, Dict[str, np.ndarray]] = {}
|
||||||
|
|
||||||
|
def forward(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
SMILES 列表 -> Dict[str, Tensor]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"morgan": (N, 1024), "maccs": (N, 167), "desc": (N, 210)}
|
||||||
|
"""
|
||||||
|
# 分离:已缓存 vs 需计算
|
||||||
|
to_compute = [s for s in smiles_list if s not in self._cache]
|
||||||
|
|
||||||
|
# 批量计算未缓存的
|
||||||
|
if to_compute:
|
||||||
|
new_features = self._featurizer.transform(to_compute)
|
||||||
|
for idx, smiles in enumerate(to_compute):
|
||||||
|
self._cache[smiles] = {
|
||||||
|
k: new_features[k][idx] for k in new_features
|
||||||
|
}
|
||||||
|
|
||||||
|
# 按原顺序组装结果
|
||||||
|
keys = ["morgan", "maccs", "desc"]
|
||||||
|
return {
|
||||||
|
k: torch.from_numpy(np.stack([self._cache[s][k] for s in smiles_list]))
|
||||||
|
for k in keys
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""清空缓存"""
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_size(self) -> int:
|
||||||
|
"""当前缓存的 SMILES 数量"""
|
||||||
|
return len(self._cache)
|
||||||
24
lnp_ml/modeling/encoders/tabular_encoder.py
Normal file
24
lnp_ml/modeling/encoders/tabular_encoder.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
class TabularEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, tabular_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
# Input: Dict with keys 'comp', 'phys', 'help', 'exp'
|
||||||
|
# Each value is a tensor [B, D_i] where D_i is the feature dimension
|
||||||
|
# Output: Same dict (pass-through, features already grouped by DataLoader)
|
||||||
|
|
||||||
|
# The DataLoader (trainer.py) already groups features correctly:
|
||||||
|
# - 'comp': [B, 9] - composition features
|
||||||
|
# - 'phys': [B, 9] - physical features (including processed PDI)
|
||||||
|
# - 'help': [B, 4] - helper lipid one-hot features
|
||||||
|
# - 'exp': [B, 20] - experimental condition one-hot features (including processed Purity)
|
||||||
|
|
||||||
|
# Simply return the dict as-is
|
||||||
|
# If we wanted to add learned transformations, we could add linear layers here
|
||||||
|
return tabular_data
|
||||||
118
lnp_ml/modeling/heads.py
Normal file
118
lnp_ml/modeling/heads.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
class RegressionHead(nn.Module):
|
||||||
|
"""回归任务头:输出单个 float 值"""
|
||||||
|
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int = 128, dropout: float = 0.1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(in_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""[B, in_dim] -> [B, 1]"""
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationHead(nn.Module):
|
||||||
|
"""分类任务头:输出 logits"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_dim: int, num_classes: int, hidden_dim: int = 128, dropout: float = 0.1
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(in_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, num_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""[B, in_dim] -> [B, num_classes] (logits)"""
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionHead(nn.Module):
|
||||||
|
"""分布任务头:输出和为 1 的概率分布(用于 Biodistribution)"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_dim: int, num_outputs: int, hidden_dim: int = 128, dropout: float = 0.1
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(in_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, num_outputs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""[B, in_dim] -> [B, num_outputs] (softmax, sum=1)"""
|
||||||
|
logits = self.net(x)
|
||||||
|
return F.softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTaskHead(nn.Module):
|
||||||
|
"""
|
||||||
|
多任务预测头,根据任务配置自动创建对应的子头。
|
||||||
|
|
||||||
|
输出:
|
||||||
|
- size: [B, 1] 回归
|
||||||
|
- pdi: [B, 4] 分类 logits
|
||||||
|
- ee: [B, 3] 分类 logits
|
||||||
|
- delivery: [B, 1] 回归
|
||||||
|
- biodist: [B, 7] softmax 分布
|
||||||
|
- toxic: [B, 2] 二分类 logits
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int = 128, dropout: float = 0.1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# size: 回归 (log-transformed)
|
||||||
|
self.size_head = RegressionHead(in_dim, hidden_dim, dropout)
|
||||||
|
|
||||||
|
# PDI: 4 分类
|
||||||
|
self.pdi_head = ClassificationHead(in_dim, num_classes=4, hidden_dim=hidden_dim, dropout=dropout)
|
||||||
|
|
||||||
|
# Encapsulation Efficiency: 3 分类
|
||||||
|
self.ee_head = ClassificationHead(in_dim, num_classes=3, hidden_dim=hidden_dim, dropout=dropout)
|
||||||
|
|
||||||
|
# quantified_delivery: 回归 (z-scored)
|
||||||
|
self.delivery_head = RegressionHead(in_dim, hidden_dim, dropout)
|
||||||
|
|
||||||
|
# Biodistribution: 7 输出,softmax (sum=1)
|
||||||
|
self.biodist_head = DistributionHead(in_dim, num_outputs=7, hidden_dim=hidden_dim, dropout=dropout)
|
||||||
|
|
||||||
|
# toxic: 二分类
|
||||||
|
self.toxic_head = ClassificationHead(in_dim, num_classes=2, hidden_dim=hidden_dim, dropout=dropout)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [B, in_dim] fusion 层输出
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys:
|
||||||
|
- "size": [B, 1]
|
||||||
|
- "pdi": [B, 4] logits
|
||||||
|
- "ee": [B, 3] logits
|
||||||
|
- "delivery": [B, 1]
|
||||||
|
- "biodist": [B, 7] probabilities (sum=1)
|
||||||
|
- "toxic": [B, 2] logits
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"size": self.size_head(x),
|
||||||
|
"pdi": self.pdi_head(x),
|
||||||
|
"ee": self.ee_head(x),
|
||||||
|
"delivery": self.delivery_head(x),
|
||||||
|
"biodist": self.biodist_head(x),
|
||||||
|
"toxic": self.toxic_head(x),
|
||||||
|
}
|
||||||
6
lnp_ml/modeling/layers/__init__.py
Normal file
6
lnp_ml/modeling/layers/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from lnp_ml.modeling.layers.token_projector import TokenProjector
|
||||||
|
from lnp_ml.modeling.layers.bidirectional_cross_attention import CrossModalAttention
|
||||||
|
from lnp_ml.modeling.layers.fusion import FusionLayer
|
||||||
|
|
||||||
|
__all__ = ["TokenProjector", "CrossModalAttention", "FusionLayer"]
|
||||||
|
|
||||||
124
lnp_ml/modeling/layers/bidirectional_cross_attention.py
Normal file
124
lnp_ml/modeling/layers/bidirectional_cross_attention.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionLayer(nn.Module):
|
||||||
|
"""单层双向交叉注意力"""
|
||||||
|
|
||||||
|
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||||
|
|
||||||
|
# A -> B: A as Q, B as K/V
|
||||||
|
self.cross_attn_a2b = nn.MultiheadAttention(
|
||||||
|
embed_dim=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
# B -> A: B as Q, A as K/V
|
||||||
|
self.cross_attn_b2a = nn.MultiheadAttention(
|
||||||
|
embed_dim=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# LayerNorm + FFN for channel A
|
||||||
|
self.norm_a1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm_a2 = nn.LayerNorm(d_model)
|
||||||
|
self.ffn_a = nn.Sequential(
|
||||||
|
nn.Linear(d_model, d_model * 4),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(d_model * 4, d_model),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
# LayerNorm + FFN for channel B
|
||||||
|
self.norm_b1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm_b2 = nn.LayerNorm(d_model)
|
||||||
|
self.ffn_b = nn.Sequential(
|
||||||
|
nn.Linear(d_model, d_model * 4),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(d_model * 4, d_model),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, a: torch.Tensor, b: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
a: [B, seq_len, d_model]
|
||||||
|
b: [B, seq_len, d_model]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(a_out, b_out): 更新后的两个 channel
|
||||||
|
"""
|
||||||
|
# Cross attention: A attends to B
|
||||||
|
a_attn, _ = self.cross_attn_a2b(query=a, key=b, value=b)
|
||||||
|
a = self.norm_a1(a + a_attn)
|
||||||
|
a = self.norm_a2(a + self.ffn_a(a))
|
||||||
|
|
||||||
|
# Cross attention: B attends to A
|
||||||
|
b_attn, _ = self.cross_attn_b2a(query=b, key=a, value=a)
|
||||||
|
b = self.norm_b1(b + b_attn)
|
||||||
|
b = self.norm_b2(b + self.ffn_b(b))
|
||||||
|
|
||||||
|
return a, b
|
||||||
|
|
||||||
|
|
||||||
|
class CrossModalAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
双向交叉注意力模块。
|
||||||
|
|
||||||
|
输入 stacked tokens [B, 8, d_model],split 成两个 channel 后执行
|
||||||
|
n_layers 层双向交叉注意力。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
num_heads: int,
|
||||||
|
n_layers: int,
|
||||||
|
split_idx: int = 4,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
d_model: 特征维度
|
||||||
|
num_heads: 注意力头数,d_head = d_model / num_heads
|
||||||
|
n_layers: 交叉注意力层数
|
||||||
|
split_idx: channel split 的位置,默认 4 (0:4, 4:)
|
||||||
|
dropout: dropout 比例
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.split_idx = split_idx
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
CrossAttentionLayer(d_model, num_heads, dropout)
|
||||||
|
for _ in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [B, 8, d_model] stacked tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[B, 8, d_model] 融合后的 tokens
|
||||||
|
"""
|
||||||
|
# Split: [B, 8, d_model] -> [B, 4, d_model], [B, 4, d_model]
|
||||||
|
a = x[:, : self.split_idx, :]
|
||||||
|
b = x[:, self.split_idx :, :]
|
||||||
|
|
||||||
|
# N layers of bidirectional cross attention
|
||||||
|
for layer in self.layers:
|
||||||
|
a, b = layer(a, b)
|
||||||
|
|
||||||
|
# Concat back: [B, 8, d_model]
|
||||||
|
return torch.cat([a, b], dim=1)
|
||||||
|
|
||||||
99
lnp_ml/modeling/layers/fusion.py
Normal file
99
lnp_ml/modeling/layers/fusion.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Dict, Literal, Union
|
||||||
|
|
||||||
|
|
||||||
|
PoolingStrategy = Literal["concat", "avg", "max", "attention"]
|
||||||
|
|
||||||
|
|
||||||
|
class FusionLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
将多个 token 融合成单个向量。
|
||||||
|
|
||||||
|
输入: Dict[str, Tensor] 或 [B, n_tokens, d_model]
|
||||||
|
输出: [B, fusion_dim]
|
||||||
|
|
||||||
|
策略:
|
||||||
|
- concat: [B, n_tokens, d_model] -> [B, n_tokens * d_model]
|
||||||
|
- avg: [B, n_tokens, d_model] -> [B, d_model]
|
||||||
|
- max: [B, n_tokens, d_model] -> [B, d_model]
|
||||||
|
- attention: [B, n_tokens, d_model] -> [B, d_model] (learnable attention pooling)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_tokens: int,
|
||||||
|
strategy: PoolingStrategy = "attention",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
d_model: 每个 token 的维度
|
||||||
|
n_tokens: token 数量(如 8)
|
||||||
|
strategy: 融合策略
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_tokens = n_tokens
|
||||||
|
self.strategy = strategy
|
||||||
|
|
||||||
|
if strategy == "concat":
|
||||||
|
self.fusion_dim = n_tokens * d_model
|
||||||
|
else:
|
||||||
|
self.fusion_dim = d_model
|
||||||
|
|
||||||
|
# Attention pooling: learnable query
|
||||||
|
if strategy == "attention":
|
||||||
|
self.attn_query = nn.Parameter(torch.randn(1, 1, d_model))
|
||||||
|
self.attn_proj = nn.Linear(d_model, d_model)
|
||||||
|
|
||||||
|
def forward(self, x: Union[Dict[str, torch.Tensor], torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: Dict[str, Tensor] 每个 [B, d_model],或已 stack 的 [B, n_tokens, d_model]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[B, fusion_dim]
|
||||||
|
"""
|
||||||
|
# 如果输入是 dict,先 stack
|
||||||
|
if isinstance(x, dict):
|
||||||
|
x = torch.stack(list(x.values()), dim=1) # [B, n_tokens, d_model]
|
||||||
|
|
||||||
|
if self.strategy == "concat":
|
||||||
|
return x.flatten(start_dim=1) # [B, n_tokens * d_model]
|
||||||
|
|
||||||
|
elif self.strategy == "avg":
|
||||||
|
return x.mean(dim=1) # [B, d_model]
|
||||||
|
|
||||||
|
elif self.strategy == "max":
|
||||||
|
return x.max(dim=1).values # [B, d_model]
|
||||||
|
|
||||||
|
elif self.strategy == "attention":
|
||||||
|
return self._attention_pooling(x)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown strategy: {self.strategy}")
|
||||||
|
|
||||||
|
def _attention_pooling(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Attention pooling: 用可学习 query 对 tokens 做加权求和
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: [B, n_tokens, d_model]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[B, d_model]
|
||||||
|
"""
|
||||||
|
B = x.size(0)
|
||||||
|
# query: [1, 1, d_model] -> [B, 1, d_model]
|
||||||
|
query = self.attn_query.expand(B, -1, -1)
|
||||||
|
|
||||||
|
# Attention scores: [B, 1, n_tokens]
|
||||||
|
keys = self.attn_proj(x) # [B, n_tokens, d_model]
|
||||||
|
scores = torch.bmm(query, keys.transpose(1, 2)) / (self.d_model ** 0.5)
|
||||||
|
attn_weights = F.softmax(scores, dim=-1) # [B, 1, n_tokens]
|
||||||
|
|
||||||
|
# Weighted sum: [B, 1, d_model] -> [B, d_model]
|
||||||
|
out = torch.bmm(attn_weights, x).squeeze(1)
|
||||||
|
return out
|
||||||
59
lnp_ml/modeling/layers/token_projector.py
Normal file
59
lnp_ml/modeling/layers/token_projector.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
class TokenProjector(nn.Module):
|
||||||
|
"""
|
||||||
|
将不同维度的特征投影到统一的 d_model 维度。
|
||||||
|
|
||||||
|
每个特征分支的流程:
|
||||||
|
[B, input_dim_i] -> BN -> Linear -> [B, d_model] -> ReLU -> BN -> Dropout -> * sigmoid(weight_i)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: Dict[str, int],
|
||||||
|
d_model: int,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_dims: 各特征的输入维度,如 {"morgan": 1024, "maccs": 167, "desc": 210}
|
||||||
|
d_model: 统一的输出维度
|
||||||
|
dropout: dropout 比例
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.keys = list(input_dims.keys())
|
||||||
|
|
||||||
|
# 为每个特征分支创建投影层
|
||||||
|
self.projectors = nn.ModuleDict()
|
||||||
|
for key, in_dim in input_dims.items():
|
||||||
|
self.projectors[key] = nn.Sequential(
|
||||||
|
nn.BatchNorm1d(in_dim),
|
||||||
|
nn.Linear(in_dim, d_model),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.BatchNorm1d(d_model),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 每个分支的可学习权重(初始化为 0,sigmoid 后为 0.5)
|
||||||
|
self.weights = nn.ParameterDict({
|
||||||
|
key: nn.Parameter(torch.zeros(1)) for key in self.keys
|
||||||
|
})
|
||||||
|
|
||||||
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
features: Dict[str, Tensor],每个 tensor 形状为 (B, input_dim_i)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Tensor],每个 tensor 形状为 (B, d_model)
|
||||||
|
"""
|
||||||
|
out = {}
|
||||||
|
for key in self.keys:
|
||||||
|
x = self.projectors[key](features[key])
|
||||||
|
w = torch.sigmoid(self.weights[key])
|
||||||
|
out[key] = x * w
|
||||||
|
return out
|
||||||
|
|
||||||
227
lnp_ml/modeling/models.py
Normal file
227
lnp_ml/modeling/models.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
"""LNP 多任务预测模型"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Dict, List, Optional, Literal
|
||||||
|
|
||||||
|
from lnp_ml.modeling.encoders import CachedRDKitEncoder, CachedMPNNEncoder
|
||||||
|
from lnp_ml.modeling.layers import TokenProjector, CrossModalAttention, FusionLayer
|
||||||
|
from lnp_ml.modeling.heads import MultiTaskHead
|
||||||
|
|
||||||
|
|
||||||
|
PoolingStrategy = Literal["concat", "avg", "max", "attention"]
|
||||||
|
|
||||||
|
|
||||||
|
# Token 维度配置(根据 ARCHITECTURE.md)
|
||||||
|
DEFAULT_INPUT_DIMS = {
|
||||||
|
# Channel A: 化学特征
|
||||||
|
"mpnn": 600, # D-MPNN embedding
|
||||||
|
"morgan": 1024, # Morgan fingerprint
|
||||||
|
"maccs": 167, # MACCS keys
|
||||||
|
"desc": 210, # RDKit descriptors
|
||||||
|
# Channel B: 配方/实验条件
|
||||||
|
"comp": 5, # 配方比例
|
||||||
|
"phys": 12, # 物理参数 one-hot
|
||||||
|
"help": 4, # Helper lipid one-hot
|
||||||
|
"exp": 32, # 实验条件 one-hot
|
||||||
|
}
|
||||||
|
|
||||||
|
# Token 顺序(前 4 个为 Channel A,后 4 个为 Channel B)
|
||||||
|
TOKEN_ORDER = ["mpnn", "morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
||||||
|
|
||||||
|
|
||||||
|
class LNPModel(nn.Module):
|
||||||
|
"""
|
||||||
|
LNP 药物递送性能预测模型。
|
||||||
|
|
||||||
|
架构流程:
|
||||||
|
1. Encoders: SMILES -> 化学特征; tabular -> 配方/实验特征
|
||||||
|
2. TokenProjector: 统一到 d_model
|
||||||
|
3. Stack: [B, 8, d_model]
|
||||||
|
4. CrossModalAttention: Channel A (化学) <-> Channel B (配方/实验)
|
||||||
|
5. FusionLayer: [B, 8, d_model] -> [B, fusion_dim]
|
||||||
|
6. MultiTaskHead: 多任务预测
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# 模型维度
|
||||||
|
d_model: int = 256,
|
||||||
|
# Cross attention
|
||||||
|
num_heads: int = 8,
|
||||||
|
n_attn_layers: int = 4,
|
||||||
|
# Fusion
|
||||||
|
fusion_strategy: PoolingStrategy = "attention",
|
||||||
|
# Head
|
||||||
|
head_hidden_dim: int = 128,
|
||||||
|
# Dropout
|
||||||
|
dropout: float = 0.1,
|
||||||
|
# MPNN encoder (可选,如果不用 MPNN 可以设为 None)
|
||||||
|
mpnn_checkpoint: Optional[str] = None,
|
||||||
|
mpnn_ensemble_paths: Optional[List[str]] = None,
|
||||||
|
mpnn_device: str = "cpu",
|
||||||
|
# 输入维度配置
|
||||||
|
input_dims: Optional[Dict[str, int]] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_dims = input_dims or DEFAULT_INPUT_DIMS
|
||||||
|
self.d_model = d_model
|
||||||
|
self.use_mpnn = mpnn_checkpoint is not None or mpnn_ensemble_paths is not None
|
||||||
|
|
||||||
|
# ============ Encoders ============
|
||||||
|
# RDKit encoder (always used)
|
||||||
|
self.rdkit_encoder = CachedRDKitEncoder()
|
||||||
|
|
||||||
|
# MPNN encoder (optional)
|
||||||
|
if self.use_mpnn:
|
||||||
|
self.mpnn_encoder = CachedMPNNEncoder(
|
||||||
|
checkpoint_path=mpnn_checkpoint,
|
||||||
|
ensemble_paths=mpnn_ensemble_paths,
|
||||||
|
device=mpnn_device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mpnn_encoder = None
|
||||||
|
|
||||||
|
# ============ Token Projector ============
|
||||||
|
# 根据是否使用 MPNN 调整输入维度
|
||||||
|
proj_input_dims = {k: v for k, v in self.input_dims.items()}
|
||||||
|
if not self.use_mpnn:
|
||||||
|
proj_input_dims.pop("mpnn", None)
|
||||||
|
|
||||||
|
self.token_projector = TokenProjector(
|
||||||
|
input_dims=proj_input_dims,
|
||||||
|
d_model=d_model,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============ Cross Modal Attention ============
|
||||||
|
n_tokens = 8 if self.use_mpnn else 7
|
||||||
|
split_idx = 4 if self.use_mpnn else 3 # Channel A 的 token 数量
|
||||||
|
|
||||||
|
self.cross_attention = CrossModalAttention(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_layers=n_attn_layers,
|
||||||
|
split_idx=split_idx,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============ Fusion Layer ============
|
||||||
|
self.fusion = FusionLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
n_tokens=n_tokens,
|
||||||
|
strategy=fusion_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============ Multi-Task Head ============
|
||||||
|
self.head = MultiTaskHead(
|
||||||
|
in_dim=self.fusion.fusion_dim,
|
||||||
|
hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
smiles: List[str],
|
||||||
|
tabular: Dict[str, torch.Tensor],
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
smiles: SMILES 字符串列表,长度为 B
|
||||||
|
tabular: Dict[str, Tensor],包含:
|
||||||
|
- "comp": [B, 5] 配方比例
|
||||||
|
- "phys": [B, 12] 物理参数
|
||||||
|
- "help": [B, 4] Helper lipid
|
||||||
|
- "exp": [B, 32] 实验条件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Tensor]:
|
||||||
|
- "size": [B, 1]
|
||||||
|
- "pdi": [B, 4]
|
||||||
|
- "ee": [B, 3]
|
||||||
|
- "delivery": [B, 1]
|
||||||
|
- "biodist": [B, 7]
|
||||||
|
- "toxic": [B, 2]
|
||||||
|
"""
|
||||||
|
# 1. Encode SMILES
|
||||||
|
rdkit_features = self.rdkit_encoder(smiles) # {"morgan", "maccs", "desc"}
|
||||||
|
|
||||||
|
# 2. 合并所有特征
|
||||||
|
all_features: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
# MPNN 特征(如果启用)
|
||||||
|
if self.use_mpnn:
|
||||||
|
mpnn_features = self.mpnn_encoder(smiles)
|
||||||
|
all_features["mpnn"] = mpnn_features["mpnn"]
|
||||||
|
|
||||||
|
# RDKit 特征
|
||||||
|
all_features["morgan"] = rdkit_features["morgan"]
|
||||||
|
all_features["maccs"] = rdkit_features["maccs"]
|
||||||
|
all_features["desc"] = rdkit_features["desc"]
|
||||||
|
|
||||||
|
# Tabular 特征
|
||||||
|
all_features["comp"] = tabular["comp"]
|
||||||
|
all_features["phys"] = tabular["phys"]
|
||||||
|
all_features["help"] = tabular["help"]
|
||||||
|
all_features["exp"] = tabular["exp"]
|
||||||
|
|
||||||
|
# 3. Token Projector: 统一维度
|
||||||
|
projected = self.token_projector(all_features) # Dict[str, [B, d_model]]
|
||||||
|
|
||||||
|
# 4. Stack tokens: [B, n_tokens, d_model]
|
||||||
|
# 按顺序 stack:Channel A (化学) + Channel B (配方/实验)
|
||||||
|
if self.use_mpnn:
|
||||||
|
token_order = ["mpnn", "morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
||||||
|
else:
|
||||||
|
token_order = ["morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
||||||
|
|
||||||
|
stacked = torch.stack([projected[k] for k in token_order], dim=1)
|
||||||
|
|
||||||
|
# 5. Cross Modal Attention
|
||||||
|
attended = self.cross_attention(stacked)
|
||||||
|
|
||||||
|
# 6. Fusion
|
||||||
|
fused = self.fusion(attended)
|
||||||
|
|
||||||
|
# 7. Multi-Task Head
|
||||||
|
outputs = self.head(fused)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""清空所有 encoder 的缓存"""
|
||||||
|
self.rdkit_encoder.clear_cache()
|
||||||
|
if self.mpnn_encoder is not None:
|
||||||
|
self.mpnn_encoder.clear_cache()
|
||||||
|
|
||||||
|
|
||||||
|
class LNPModelWithoutMPNN(LNPModel):
|
||||||
|
"""不使用 MPNN 的简化版本"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int = 256,
|
||||||
|
num_heads: int = 8,
|
||||||
|
n_attn_layers: int = 4,
|
||||||
|
fusion_strategy: PoolingStrategy = "attention",
|
||||||
|
head_hidden_dim: int = 128,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
input_dims: Optional[Dict[str, int]] = None,
|
||||||
|
) -> None:
|
||||||
|
# 移除 mpnn 维度
|
||||||
|
dims = input_dims or DEFAULT_INPUT_DIMS.copy()
|
||||||
|
dims.pop("mpnn", None)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_attn_layers=n_attn_layers,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
head_hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
mpnn_checkpoint=None,
|
||||||
|
mpnn_ensemble_paths=None,
|
||||||
|
input_dims=dims,
|
||||||
|
)
|
||||||
|
|
||||||
313
lnp_ml/modeling/predict.py
Normal file
313
lnp_ml/modeling/predict.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
"""预测脚本:使用训练好的模型进行推理"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from loguru import logger
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||||
|
from lnp_ml.dataset import LNPDataset, collate_fn
|
||||||
|
from lnp_ml.modeling.models import LNPModelWithoutMPNN
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path: Path, device: torch.device) -> LNPModelWithoutMPNN:
|
||||||
|
"""加载训练好的模型"""
|
||||||
|
checkpoint = torch.load(model_path, map_location=device)
|
||||||
|
config = checkpoint["config"]
|
||||||
|
|
||||||
|
model = LNPModelWithoutMPNN(
|
||||||
|
d_model=config["d_model"],
|
||||||
|
num_heads=config["num_heads"],
|
||||||
|
n_attn_layers=config["n_attn_layers"],
|
||||||
|
fusion_strategy=config["fusion_strategy"],
|
||||||
|
head_hidden_dim=config["head_hidden_dim"],
|
||||||
|
dropout=config["dropout"],
|
||||||
|
)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logger.info(f"Loaded model from {model_path}")
|
||||||
|
logger.info(f"Model config: {config}")
|
||||||
|
logger.info(f"Best val_loss: {checkpoint.get('best_val_loss', 'N/A')}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_batch(
|
||||||
|
model: LNPModelWithoutMPNN,
|
||||||
|
loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, List]:
|
||||||
|
"""对整个数据集进行预测"""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
all_preds = {
|
||||||
|
"size": [],
|
||||||
|
"pdi": [],
|
||||||
|
"ee": [],
|
||||||
|
"delivery": [],
|
||||||
|
"biodist": [],
|
||||||
|
"toxic": [],
|
||||||
|
}
|
||||||
|
all_smiles = []
|
||||||
|
|
||||||
|
for batch in loader:
|
||||||
|
smiles = batch["smiles"]
|
||||||
|
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||||
|
|
||||||
|
outputs = model(smiles, tabular)
|
||||||
|
|
||||||
|
all_smiles.extend(smiles)
|
||||||
|
|
||||||
|
# 回归任务
|
||||||
|
all_preds["size"].extend(outputs["size"].squeeze(-1).cpu().tolist())
|
||||||
|
all_preds["delivery"].extend(outputs["delivery"].squeeze(-1).cpu().tolist())
|
||||||
|
|
||||||
|
# 分类任务:取 argmax
|
||||||
|
all_preds["pdi"].extend(outputs["pdi"].argmax(dim=-1).cpu().tolist())
|
||||||
|
all_preds["ee"].extend(outputs["ee"].argmax(dim=-1).cpu().tolist())
|
||||||
|
all_preds["toxic"].extend(outputs["toxic"].argmax(dim=-1).cpu().tolist())
|
||||||
|
|
||||||
|
# 分布任务
|
||||||
|
all_preds["biodist"].extend(outputs["biodist"].cpu().tolist())
|
||||||
|
|
||||||
|
return {"smiles": all_smiles, **all_preds}
|
||||||
|
|
||||||
|
|
||||||
|
def predictions_to_dataframe(predictions: Dict) -> pd.DataFrame:
|
||||||
|
"""将预测结果转换为 DataFrame"""
|
||||||
|
# 基本列
|
||||||
|
df = pd.DataFrame({
|
||||||
|
"smiles": predictions["smiles"],
|
||||||
|
"pred_size": predictions["size"],
|
||||||
|
"pred_delivery": predictions["delivery"],
|
||||||
|
"pred_pdi_class": predictions["pdi"],
|
||||||
|
"pred_ee_class": predictions["ee"],
|
||||||
|
"pred_toxic": predictions["toxic"],
|
||||||
|
})
|
||||||
|
|
||||||
|
# PDI 类别映射
|
||||||
|
pdi_labels = ["0_0to0_2", "0_2to0_3", "0_3to0_4", "0_4to0_5"]
|
||||||
|
df["pred_pdi_label"] = df["pred_pdi_class"].map(lambda x: pdi_labels[x])
|
||||||
|
|
||||||
|
# EE 类别映射
|
||||||
|
ee_labels = ["EE<50", "50<=EE<80", "80<EE<=100"]
|
||||||
|
df["pred_ee_label"] = df["pred_ee_class"].map(lambda x: ee_labels[x])
|
||||||
|
|
||||||
|
# Biodistribution 展开为多列
|
||||||
|
biodist_cols = [
|
||||||
|
"pred_biodist_lymph_nodes",
|
||||||
|
"pred_biodist_heart",
|
||||||
|
"pred_biodist_liver",
|
||||||
|
"pred_biodist_spleen",
|
||||||
|
"pred_biodist_lung",
|
||||||
|
"pred_biodist_kidney",
|
||||||
|
"pred_biodist_muscle",
|
||||||
|
]
|
||||||
|
biodist_df = pd.DataFrame(predictions["biodist"], columns=biodist_cols)
|
||||||
|
df = pd.concat([df, biodist_df], axis=1)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
test_path: Path = PROCESSED_DATA_DIR / "test.parquet",
|
||||||
|
model_path: Path = MODELS_DIR / "model.pt",
|
||||||
|
output_path: Path = PROCESSED_DATA_DIR / "predictions.parquet",
|
||||||
|
batch_size: int = 64,
|
||||||
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
使用训练好的模型进行预测。
|
||||||
|
"""
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
device = torch.device(device)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
model = load_model(model_path, device)
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
logger.info(f"Loading test data from {test_path}")
|
||||||
|
test_df = pd.read_parquet(test_path)
|
||||||
|
test_dataset = LNPDataset(test_df)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
logger.info(f"Test samples: {len(test_dataset)}")
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
logger.info("Running predictions...")
|
||||||
|
predictions = predict_batch(model, test_loader, device)
|
||||||
|
|
||||||
|
# 转换为 DataFrame
|
||||||
|
pred_df = predictions_to_dataframe(predictions)
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
pred_df.to_parquet(output_path, index=False)
|
||||||
|
logger.success(f"Saved predictions to {output_path}")
|
||||||
|
|
||||||
|
# 打印统计
|
||||||
|
logger.info("\n=== Prediction Statistics ===")
|
||||||
|
logger.info(f"Total samples: {len(pred_df)}")
|
||||||
|
logger.info(f"\nSize (pred): mean={pred_df['pred_size'].mean():.4f}, std={pred_df['pred_size'].std():.4f}")
|
||||||
|
logger.info(f"Delivery (pred): mean={pred_df['pred_delivery'].mean():.4f}, std={pred_df['pred_delivery'].std():.4f}")
|
||||||
|
logger.info(f"\nPDI class distribution:\n{pred_df['pred_pdi_label'].value_counts()}")
|
||||||
|
logger.info(f"\nEE class distribution:\n{pred_df['pred_ee_label'].value_counts()}")
|
||||||
|
logger.info(f"\nToxic distribution:\n{pred_df['pred_toxic'].value_counts()}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def test(
|
||||||
|
test_path: Path = PROCESSED_DATA_DIR / "test.parquet",
|
||||||
|
model_path: Path = MODELS_DIR / "model.pt",
|
||||||
|
output_path: Path = PROCESSED_DATA_DIR / "test_results.json",
|
||||||
|
batch_size: int = 64,
|
||||||
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
在测试集上完整评估模型性能,输出详细指标。
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics import (
|
||||||
|
mean_squared_error,
|
||||||
|
mean_absolute_error,
|
||||||
|
r2_score,
|
||||||
|
accuracy_score,
|
||||||
|
classification_report,
|
||||||
|
)
|
||||||
|
from lnp_ml.modeling.trainer import validate
|
||||||
|
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
device_obj = torch.device(device)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
model = load_model(model_path, device_obj)
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
logger.info(f"Loading test data from {test_path}")
|
||||||
|
test_df = pd.read_parquet(test_path)
|
||||||
|
test_dataset = LNPDataset(test_df)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
logger.info(f"Test samples: {len(test_dataset)}")
|
||||||
|
|
||||||
|
# 基础损失指标
|
||||||
|
logger.info("Computing loss metrics...")
|
||||||
|
loss_metrics = validate(model, test_loader, device_obj)
|
||||||
|
|
||||||
|
# 获取预测和真实值
|
||||||
|
logger.info("Computing detailed metrics...")
|
||||||
|
predictions = predict_batch(model, test_loader, device_obj)
|
||||||
|
|
||||||
|
results = {"loss_metrics": loss_metrics, "detailed_metrics": {}}
|
||||||
|
|
||||||
|
# 回归指标:size
|
||||||
|
if "size" in test_df.columns:
|
||||||
|
mask = ~test_df["size"].isna()
|
||||||
|
if mask.any():
|
||||||
|
y_true = test_df.loc[mask, "size"].values
|
||||||
|
y_pred = np.array(predictions["size"])[mask.values]
|
||||||
|
results["detailed_metrics"]["size"] = {
|
||||||
|
"mse": float(mean_squared_error(y_true, y_pred)),
|
||||||
|
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
|
||||||
|
"mae": float(mean_absolute_error(y_true, y_pred)),
|
||||||
|
"r2": float(r2_score(y_true, y_pred)),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 回归指标:delivery
|
||||||
|
if "quantified_delivery" in test_df.columns:
|
||||||
|
mask = ~test_df["quantified_delivery"].isna()
|
||||||
|
if mask.any():
|
||||||
|
y_true = test_df.loc[mask, "quantified_delivery"].values
|
||||||
|
y_pred = np.array(predictions["delivery"])[mask.values]
|
||||||
|
results["detailed_metrics"]["delivery"] = {
|
||||||
|
"mse": float(mean_squared_error(y_true, y_pred)),
|
||||||
|
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
|
||||||
|
"mae": float(mean_absolute_error(y_true, y_pred)),
|
||||||
|
"r2": float(r2_score(y_true, y_pred)),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 分类指标:PDI
|
||||||
|
pdi_cols = ["PDI_0_0to0_2", "PDI_0_2to0_3", "PDI_0_3to0_4", "PDI_0_4to0_5"]
|
||||||
|
if all(c in test_df.columns for c in pdi_cols):
|
||||||
|
pdi_true = test_df[pdi_cols].values.argmax(axis=1)
|
||||||
|
mask = test_df[pdi_cols].sum(axis=1) > 0
|
||||||
|
if mask.any():
|
||||||
|
y_true = pdi_true[mask]
|
||||||
|
y_pred = np.array(predictions["pdi"])[mask]
|
||||||
|
results["detailed_metrics"]["pdi"] = {
|
||||||
|
"accuracy": float(accuracy_score(y_true, y_pred)),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 分类指标:EE
|
||||||
|
ee_cols = ["Encapsulation_Efficiency_EE<50", "Encapsulation_Efficiency_50<=EE<80", "Encapsulation_Efficiency_80<EE<=100"]
|
||||||
|
if all(c in test_df.columns for c in ee_cols):
|
||||||
|
ee_true = test_df[ee_cols].values.argmax(axis=1)
|
||||||
|
mask = test_df[ee_cols].sum(axis=1) > 0
|
||||||
|
if mask.any():
|
||||||
|
y_true = ee_true[mask]
|
||||||
|
y_pred = np.array(predictions["ee"])[mask]
|
||||||
|
results["detailed_metrics"]["ee"] = {
|
||||||
|
"accuracy": float(accuracy_score(y_true, y_pred)),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 分类指标:toxic
|
||||||
|
if "toxic" in test_df.columns:
|
||||||
|
mask = test_df["toxic"].notna() & (test_df["toxic"] >= 0)
|
||||||
|
if mask.any():
|
||||||
|
y_true = test_df.loc[mask, "toxic"].astype(int).values
|
||||||
|
y_pred = np.array(predictions["toxic"])[mask.values]
|
||||||
|
results["detailed_metrics"]["toxic"] = {
|
||||||
|
"accuracy": float(accuracy_score(y_true, y_pred)),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
logger.info("\n" + "=" * 50)
|
||||||
|
logger.info("TEST RESULTS")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
logger.info("\n[Loss Metrics]")
|
||||||
|
for k, v in loss_metrics.items():
|
||||||
|
logger.info(f" {k}: {v:.4f}")
|
||||||
|
|
||||||
|
logger.info("\n[Detailed Metrics]")
|
||||||
|
for task, metrics in results["detailed_metrics"].items():
|
||||||
|
logger.info(f"\n {task}:")
|
||||||
|
for k, v in metrics.items():
|
||||||
|
logger.info(f" {k}: {v:.4f}")
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
logger.success(f"\nSaved test results to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# 保留旧的 evaluate 作为 test 的别名
|
||||||
|
@app.command()
|
||||||
|
def evaluate(
|
||||||
|
test_path: Path = PROCESSED_DATA_DIR / "test.parquet",
|
||||||
|
model_path: Path = MODELS_DIR / "model.pt",
|
||||||
|
batch_size: int = 64,
|
||||||
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[已废弃] 请使用 'test' 命令。
|
||||||
|
"""
|
||||||
|
test(test_path, model_path, PROCESSED_DATA_DIR / "test_results.json", batch_size, device)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
315
lnp_ml/modeling/train.py
Normal file
315
lnp_ml/modeling/train.py
Normal file
@ -0,0 +1,315 @@
|
|||||||
|
"""训练脚本:支持超参数调优"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from loguru import logger
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||||
|
from lnp_ml.dataset import LNPDataset, collate_fn
|
||||||
|
from lnp_ml.modeling.models import LNPModelWithoutMPNN
|
||||||
|
from lnp_ml.modeling.trainer import (
|
||||||
|
train_epoch,
|
||||||
|
validate,
|
||||||
|
EarlyStopping,
|
||||||
|
LossWeights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(
|
||||||
|
d_model: int = 256,
|
||||||
|
num_heads: int = 8,
|
||||||
|
n_attn_layers: int = 4,
|
||||||
|
fusion_strategy: str = "attention",
|
||||||
|
head_hidden_dim: int = 128,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
) -> LNPModelWithoutMPNN:
|
||||||
|
"""创建模型"""
|
||||||
|
return LNPModelWithoutMPNN(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_attn_layers=n_attn_layers,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
head_hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
device: torch.device,
|
||||||
|
lr: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-5,
|
||||||
|
epochs: int = 100,
|
||||||
|
patience: int = 15,
|
||||||
|
loss_weights: Optional[LossWeights] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
训练模型。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
训练历史和最佳验证损失
|
||||||
|
"""
|
||||||
|
model = model.to(device)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer, mode="min", factor=0.5, patience=5, verbose=True
|
||||||
|
)
|
||||||
|
early_stopping = EarlyStopping(patience=patience)
|
||||||
|
|
||||||
|
history = {"train": [], "val": []}
|
||||||
|
best_val_loss = float("inf")
|
||||||
|
best_state = None
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
# Train
|
||||||
|
train_metrics = train_epoch(model, train_loader, optimizer, device, loss_weights)
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
val_metrics = validate(model, val_loader, device, loss_weights)
|
||||||
|
|
||||||
|
# Log
|
||||||
|
logger.info(
|
||||||
|
f"Epoch {epoch+1}/{epochs} | "
|
||||||
|
f"Train Loss: {train_metrics['loss']:.4f} | "
|
||||||
|
f"Val Loss: {val_metrics['loss']:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
history["train"].append(train_metrics)
|
||||||
|
history["val"].append(val_metrics)
|
||||||
|
|
||||||
|
# Learning rate scheduling
|
||||||
|
scheduler.step(val_metrics["loss"])
|
||||||
|
|
||||||
|
# Save best model
|
||||||
|
if val_metrics["loss"] < best_val_loss:
|
||||||
|
best_val_loss = val_metrics["loss"]
|
||||||
|
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||||||
|
logger.info(f" -> New best model (val_loss={best_val_loss:.4f})")
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if early_stopping(val_metrics["loss"]):
|
||||||
|
logger.info(f"Early stopping at epoch {epoch+1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Restore best model
|
||||||
|
if best_state is not None:
|
||||||
|
model.load_state_dict(best_state)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"history": history,
|
||||||
|
"best_val_loss": best_val_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_hyperparameter_tuning(
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
n_trials: int = 20,
|
||||||
|
epochs_per_trial: int = 30,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
使用 Optuna 进行超参数调优。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最佳超参数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import optuna
|
||||||
|
except ImportError:
|
||||||
|
logger.error("Optuna not installed. Run: pip install optuna")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def objective(trial: optuna.Trial) -> float:
|
||||||
|
# 采样超参数
|
||||||
|
d_model = trial.suggest_categorical("d_model", [128, 256, 512])
|
||||||
|
num_heads = trial.suggest_categorical("num_heads", [4, 8])
|
||||||
|
n_attn_layers = trial.suggest_int("n_attn_layers", 2, 6)
|
||||||
|
fusion_strategy = trial.suggest_categorical(
|
||||||
|
"fusion_strategy", ["attention", "avg", "max"]
|
||||||
|
)
|
||||||
|
head_hidden_dim = trial.suggest_categorical("head_hidden_dim", [64, 128, 256])
|
||||||
|
dropout = trial.suggest_float("dropout", 0.05, 0.3)
|
||||||
|
lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
|
||||||
|
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
model = create_model(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_attn_layers=n_attn_layers,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
head_hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 训练
|
||||||
|
result = train_model(
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
model=model,
|
||||||
|
device=device,
|
||||||
|
lr=lr,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
epochs=epochs_per_trial,
|
||||||
|
patience=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result["best_val_loss"]
|
||||||
|
|
||||||
|
# 运行优化
|
||||||
|
study = optuna.create_study(direction="minimize")
|
||||||
|
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
|
||||||
|
|
||||||
|
logger.info(f"Best trial: {study.best_trial.number}")
|
||||||
|
logger.info(f"Best val_loss: {study.best_trial.value:.4f}")
|
||||||
|
logger.info(f"Best params: {study.best_trial.params}")
|
||||||
|
|
||||||
|
return study.best_trial.params
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
train_path: Path = PROCESSED_DATA_DIR / "train.parquet",
|
||||||
|
val_path: Path = PROCESSED_DATA_DIR / "val.parquet",
|
||||||
|
output_dir: Path = MODELS_DIR,
|
||||||
|
# 模型参数
|
||||||
|
d_model: int = 256,
|
||||||
|
num_heads: int = 8,
|
||||||
|
n_attn_layers: int = 4,
|
||||||
|
fusion_strategy: str = "attention",
|
||||||
|
head_hidden_dim: int = 128,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
# 训练参数
|
||||||
|
batch_size: int = 32,
|
||||||
|
lr: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-5,
|
||||||
|
epochs: int = 100,
|
||||||
|
patience: int = 15,
|
||||||
|
# 超参数调优
|
||||||
|
tune: bool = False,
|
||||||
|
n_trials: int = 20,
|
||||||
|
epochs_per_trial: int = 30,
|
||||||
|
# 设备
|
||||||
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
训练 LNP 预测模型。
|
||||||
|
|
||||||
|
使用 --tune 启用超参数调优。
|
||||||
|
"""
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
device = torch.device(device)
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
logger.info(f"Loading train data from {train_path}")
|
||||||
|
train_df = pd.read_parquet(train_path)
|
||||||
|
train_dataset = LNPDataset(train_df)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loading val data from {val_path}")
|
||||||
|
val_df = pd.read_parquet(val_path)
|
||||||
|
val_dataset = LNPDataset(val_df)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
|
||||||
|
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 超参数调优
|
||||||
|
if tune:
|
||||||
|
logger.info(f"Starting hyperparameter tuning with {n_trials} trials...")
|
||||||
|
best_params = run_hyperparameter_tuning(
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
device=device,
|
||||||
|
n_trials=n_trials,
|
||||||
|
epochs_per_trial=epochs_per_trial,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存最佳参数
|
||||||
|
params_path = output_dir / "best_params.json"
|
||||||
|
with open(params_path, "w") as f:
|
||||||
|
json.dump(best_params, f, indent=2)
|
||||||
|
logger.success(f"Saved best params to {params_path}")
|
||||||
|
|
||||||
|
# 使用最佳参数重新训练
|
||||||
|
d_model = best_params["d_model"]
|
||||||
|
num_heads = best_params["num_heads"]
|
||||||
|
n_attn_layers = best_params["n_attn_layers"]
|
||||||
|
fusion_strategy = best_params["fusion_strategy"]
|
||||||
|
head_hidden_dim = best_params["head_hidden_dim"]
|
||||||
|
dropout = best_params["dropout"]
|
||||||
|
lr = best_params["lr"]
|
||||||
|
weight_decay = best_params["weight_decay"]
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
logger.info("Creating model...")
|
||||||
|
model = create_model(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_attn_layers=n_attn_layers,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
head_hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 打印模型信息
|
||||||
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
logger.info(f"Model parameters: {n_params:,}")
|
||||||
|
|
||||||
|
# 训练
|
||||||
|
logger.info("Starting training...")
|
||||||
|
result = train_model(
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
model=model,
|
||||||
|
device=device,
|
||||||
|
lr=lr,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
epochs=epochs,
|
||||||
|
patience=patience,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
model_path = output_dir / "model.pt"
|
||||||
|
torch.save({
|
||||||
|
"model_state_dict": model.state_dict(),
|
||||||
|
"config": {
|
||||||
|
"d_model": d_model,
|
||||||
|
"num_heads": num_heads,
|
||||||
|
"n_attn_layers": n_attn_layers,
|
||||||
|
"fusion_strategy": fusion_strategy,
|
||||||
|
"head_hidden_dim": head_hidden_dim,
|
||||||
|
"dropout": dropout,
|
||||||
|
},
|
||||||
|
"best_val_loss": result["best_val_loss"],
|
||||||
|
}, model_path)
|
||||||
|
logger.success(f"Saved model to {model_path}")
|
||||||
|
|
||||||
|
# 保存训练历史
|
||||||
|
history_path = output_dir / "history.json"
|
||||||
|
with open(history_path, "w") as f:
|
||||||
|
json.dump(result["history"], f, indent=2)
|
||||||
|
logger.success(f"Saved training history to {history_path}")
|
||||||
|
|
||||||
|
logger.success(f"Training complete! Best val_loss: {result['best_val_loss']:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
212
lnp_ml/modeling/trainer.py
Normal file
212
lnp_ml/modeling/trainer.py
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
"""训练器:封装训练、验证、损失计算逻辑"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LossWeights:
|
||||||
|
"""各任务的损失权重"""
|
||||||
|
size: float = 1.0
|
||||||
|
pdi: float = 1.0
|
||||||
|
ee: float = 1.0
|
||||||
|
delivery: float = 1.0
|
||||||
|
biodist: float = 1.0
|
||||||
|
toxic: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def compute_multitask_loss(
|
||||||
|
outputs: Dict[str, torch.Tensor],
|
||||||
|
targets: Dict[str, torch.Tensor],
|
||||||
|
mask: Dict[str, torch.Tensor],
|
||||||
|
weights: Optional[LossWeights] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
计算多任务损失。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs: 模型输出
|
||||||
|
targets: 真实标签
|
||||||
|
mask: 有效样本掩码
|
||||||
|
weights: 各任务权重
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(total_loss, loss_dict) 总损失和各任务损失
|
||||||
|
"""
|
||||||
|
weights = weights or LossWeights()
|
||||||
|
losses = {}
|
||||||
|
total_loss = torch.tensor(0.0, device=next(iter(outputs.values())).device)
|
||||||
|
|
||||||
|
# size: MSE loss
|
||||||
|
if "size" in targets and mask["size"].any():
|
||||||
|
m = mask["size"]
|
||||||
|
pred = outputs["size"][m].squeeze(-1)
|
||||||
|
tgt = targets["size"][m]
|
||||||
|
losses["size"] = F.mse_loss(pred, tgt)
|
||||||
|
total_loss = total_loss + weights.size * losses["size"]
|
||||||
|
|
||||||
|
# delivery: MSE loss
|
||||||
|
if "delivery" in targets and mask["delivery"].any():
|
||||||
|
m = mask["delivery"]
|
||||||
|
pred = outputs["delivery"][m].squeeze(-1)
|
||||||
|
tgt = targets["delivery"][m]
|
||||||
|
losses["delivery"] = F.mse_loss(pred, tgt)
|
||||||
|
total_loss = total_loss + weights.delivery * losses["delivery"]
|
||||||
|
|
||||||
|
# pdi: CrossEntropy
|
||||||
|
if "pdi" in targets and mask["pdi"].any():
|
||||||
|
m = mask["pdi"]
|
||||||
|
pred = outputs["pdi"][m]
|
||||||
|
tgt = targets["pdi"][m]
|
||||||
|
losses["pdi"] = F.cross_entropy(pred, tgt)
|
||||||
|
total_loss = total_loss + weights.pdi * losses["pdi"]
|
||||||
|
|
||||||
|
# ee: CrossEntropy
|
||||||
|
if "ee" in targets and mask["ee"].any():
|
||||||
|
m = mask["ee"]
|
||||||
|
pred = outputs["ee"][m]
|
||||||
|
tgt = targets["ee"][m]
|
||||||
|
losses["ee"] = F.cross_entropy(pred, tgt)
|
||||||
|
total_loss = total_loss + weights.ee * losses["ee"]
|
||||||
|
|
||||||
|
# toxic: CrossEntropy
|
||||||
|
if "toxic" in targets and mask["toxic"].any():
|
||||||
|
m = mask["toxic"]
|
||||||
|
pred = outputs["toxic"][m]
|
||||||
|
tgt = targets["toxic"][m]
|
||||||
|
losses["toxic"] = F.cross_entropy(pred, tgt)
|
||||||
|
total_loss = total_loss + weights.toxic * losses["toxic"]
|
||||||
|
|
||||||
|
# biodist: KL divergence
|
||||||
|
if "biodist" in targets and mask["biodist"].any():
|
||||||
|
m = mask["biodist"]
|
||||||
|
pred = outputs["biodist"][m]
|
||||||
|
tgt = targets["biodist"][m]
|
||||||
|
# KL divergence: KL(target || pred)
|
||||||
|
losses["biodist"] = F.kl_div(
|
||||||
|
pred.log().clamp(min=-100),
|
||||||
|
tgt,
|
||||||
|
reduction="batchmean",
|
||||||
|
)
|
||||||
|
total_loss = total_loss + weights.biodist * losses["biodist"]
|
||||||
|
|
||||||
|
return total_loss, losses
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(
|
||||||
|
model: nn.Module,
|
||||||
|
loader: DataLoader,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
device: torch.device,
|
||||||
|
weights: Optional[LossWeights] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""训练一个 epoch"""
|
||||||
|
model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]}
|
||||||
|
n_batches = 0
|
||||||
|
|
||||||
|
for batch in tqdm(loader, desc="Training", leave=False):
|
||||||
|
smiles = batch["smiles"]
|
||||||
|
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||||
|
targets = {k: v.to(device) for k, v in batch["targets"].items()}
|
||||||
|
mask = {k: v.to(device) for k, v in batch["mask"].items()}
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(smiles, tabular)
|
||||||
|
loss, losses = compute_multitask_loss(outputs, targets, mask, weights)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
for k, v in losses.items():
|
||||||
|
task_losses[k] += v.item()
|
||||||
|
n_batches += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss": total_loss / n_batches,
|
||||||
|
**{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate(
|
||||||
|
model: nn.Module,
|
||||||
|
loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
weights: Optional[LossWeights] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""验证"""
|
||||||
|
model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]}
|
||||||
|
n_batches = 0
|
||||||
|
|
||||||
|
# 用于计算准确率
|
||||||
|
correct = {k: 0 for k in ["pdi", "ee", "toxic"]}
|
||||||
|
total = {k: 0 for k in ["pdi", "ee", "toxic"]}
|
||||||
|
|
||||||
|
for batch in tqdm(loader, desc="Validating", leave=False):
|
||||||
|
smiles = batch["smiles"]
|
||||||
|
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||||
|
targets = {k: v.to(device) for k, v in batch["targets"].items()}
|
||||||
|
mask = {k: v.to(device) for k, v in batch["mask"].items()}
|
||||||
|
|
||||||
|
outputs = model(smiles, tabular)
|
||||||
|
loss, losses = compute_multitask_loss(outputs, targets, mask, weights)
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
for k, v in losses.items():
|
||||||
|
task_losses[k] += v.item()
|
||||||
|
n_batches += 1
|
||||||
|
|
||||||
|
# 计算分类准确率
|
||||||
|
for k in ["pdi", "ee", "toxic"]:
|
||||||
|
if k in targets and mask[k].any():
|
||||||
|
m = mask[k]
|
||||||
|
pred = outputs[k][m].argmax(dim=-1)
|
||||||
|
tgt = targets[k][m]
|
||||||
|
correct[k] += (pred == tgt).sum().item()
|
||||||
|
total[k] += m.sum().item()
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"loss": total_loss / n_batches,
|
||||||
|
**{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加准确率
|
||||||
|
for k in ["pdi", "ee", "toxic"]:
|
||||||
|
if total[k] > 0:
|
||||||
|
metrics[f"acc_{k}"] = correct[k] / total[k]
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStopping:
|
||||||
|
"""早停机制"""
|
||||||
|
|
||||||
|
def __init__(self, patience: int = 10, min_delta: float = 0.0):
|
||||||
|
self.patience = patience
|
||||||
|
self.min_delta = min_delta
|
||||||
|
self.counter = 0
|
||||||
|
self.best_loss = float("inf")
|
||||||
|
self.should_stop = False
|
||||||
|
|
||||||
|
def __call__(self, val_loss: float) -> bool:
|
||||||
|
if val_loss < self.best_loss - self.min_delta:
|
||||||
|
self.best_loss = val_loss
|
||||||
|
self.counter = 0
|
||||||
|
else:
|
||||||
|
self.counter += 1
|
||||||
|
if self.counter >= self.patience:
|
||||||
|
self.should_stop = True
|
||||||
|
return self.should_stop
|
||||||
|
|
||||||
29
lnp_ml/plots.py
Normal file
29
lnp_ml/plots.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from tqdm import tqdm
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from lnp_ml.config import FIGURES_DIR, PROCESSED_DATA_DIR
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
# ---- REPLACE DEFAULT PATHS AS APPROPRIATE ----
|
||||||
|
input_path: Path = PROCESSED_DATA_DIR / "dataset.csv",
|
||||||
|
output_path: Path = FIGURES_DIR / "plot.png",
|
||||||
|
# -----------------------------------------
|
||||||
|
):
|
||||||
|
# ---- REPLACE THIS WITH YOUR OWN CODE ----
|
||||||
|
logger.info("Generating plot from data...")
|
||||||
|
for i in tqdm(range(10), total=10):
|
||||||
|
if i == 5:
|
||||||
|
logger.info("Something happened for iteration 5.")
|
||||||
|
logger.success("Plot generation complete.")
|
||||||
|
# -----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
0
models/.gitkeep
Normal file
0
models/.gitkeep
Normal file
657
models/history.json
Normal file
657
models/history.json
Normal file
@ -0,0 +1,657 @@
|
|||||||
|
{
|
||||||
|
"train": [
|
||||||
|
{
|
||||||
|
"loss": 19.801735162734985,
|
||||||
|
"loss_size": 14.35811984539032,
|
||||||
|
"loss_pdi": 1.419530838727951,
|
||||||
|
"loss_ee": 1.0615579634904861,
|
||||||
|
"loss_delivery": 1.1513914689421654,
|
||||||
|
"loss_biodist": 1.238842561841011,
|
||||||
|
"loss_toxic": 0.5722923278808594
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 8.425787031650543,
|
||||||
|
"loss_size": 3.660134330391884,
|
||||||
|
"loss_pdi": 1.175126627087593,
|
||||||
|
"loss_ee": 0.9683727994561195,
|
||||||
|
"loss_delivery": 1.1206831969320774,
|
||||||
|
"loss_biodist": 1.0895558297634125,
|
||||||
|
"loss_toxic": 0.4119141288101673
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 4.890879034996033,
|
||||||
|
"loss_size": 0.5715616792440414,
|
||||||
|
"loss_pdi": 0.9125325158238411,
|
||||||
|
"loss_ee": 0.924556627869606,
|
||||||
|
"loss_delivery": 1.2105156518518925,
|
||||||
|
"loss_biodist": 0.9783133715391159,
|
||||||
|
"loss_toxic": 0.29339898377656937
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 4.220313906669617,
|
||||||
|
"loss_size": 0.24587923847138882,
|
||||||
|
"loss_pdi": 0.77839545160532,
|
||||||
|
"loss_ee": 0.910746157169342,
|
||||||
|
"loss_delivery": 1.1220976933836937,
|
||||||
|
"loss_biodist": 0.9082718268036842,
|
||||||
|
"loss_toxic": 0.25492355413734913
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 3.821678251028061,
|
||||||
|
"loss_size": 0.22942939028143883,
|
||||||
|
"loss_pdi": 0.6982513815164566,
|
||||||
|
"loss_ee": 0.8703903555870056,
|
||||||
|
"loss_delivery": 0.9583318457007408,
|
||||||
|
"loss_biodist": 0.8230277448892593,
|
||||||
|
"loss_toxic": 0.24224759358912706
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 3.6422846913337708,
|
||||||
|
"loss_size": 0.2943352358415723,
|
||||||
|
"loss_pdi": 0.642115443944931,
|
||||||
|
"loss_ee": 0.834287479519844,
|
||||||
|
"loss_delivery": 0.9906296711415052,
|
||||||
|
"loss_biodist": 0.7021987289190292,
|
||||||
|
"loss_toxic": 0.1787180369719863
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 3.2558620870113373,
|
||||||
|
"loss_size": 0.2535699959844351,
|
||||||
|
"loss_pdi": 0.6135993227362633,
|
||||||
|
"loss_ee": 0.7736481726169586,
|
||||||
|
"loss_delivery": 0.89691730029881,
|
||||||
|
"loss_biodist": 0.5883788056671619,
|
||||||
|
"loss_toxic": 0.12974846456199884
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 3.095754951238632,
|
||||||
|
"loss_size": 0.3008947093039751,
|
||||||
|
"loss_pdi": 0.5887892656028271,
|
||||||
|
"loss_ee": 0.730943076312542,
|
||||||
|
"loss_delivery": 0.8772530537098646,
|
||||||
|
"loss_biodist": 0.5044719725847244,
|
||||||
|
"loss_toxic": 0.0934028816409409
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.853863760828972,
|
||||||
|
"loss_size": 0.24649357236921787,
|
||||||
|
"loss_pdi": 0.5792484246194363,
|
||||||
|
"loss_ee": 0.6907523274421692,
|
||||||
|
"loss_delivery": 0.7996992748230696,
|
||||||
|
"loss_biodist": 0.4447066970169544,
|
||||||
|
"loss_toxic": 0.09296349296346307
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.7370710372924805,
|
||||||
|
"loss_size": 0.2393215736374259,
|
||||||
|
"loss_pdi": 0.5310847833752632,
|
||||||
|
"loss_ee": 0.6698194481432438,
|
||||||
|
"loss_delivery": 0.8438112968578935,
|
||||||
|
"loss_biodist": 0.37191740795969963,
|
||||||
|
"loss_toxic": 0.08111646771430969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.5620490461587906,
|
||||||
|
"loss_size": 0.2313800435513258,
|
||||||
|
"loss_pdi": 0.5440232343971729,
|
||||||
|
"loss_ee": 0.6346339136362076,
|
||||||
|
"loss_delivery": 0.7632925817742944,
|
||||||
|
"loss_biodist": 0.311477467417717,
|
||||||
|
"loss_toxic": 0.07724180119112134
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.4396377950906754,
|
||||||
|
"loss_size": 0.22545795049518347,
|
||||||
|
"loss_pdi": 0.5102365277707577,
|
||||||
|
"loss_ee": 0.5724069699645042,
|
||||||
|
"loss_delivery": 0.7819116842001677,
|
||||||
|
"loss_biodist": 0.2837779298424721,
|
||||||
|
"loss_toxic": 0.0658467230387032
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.285309001803398,
|
||||||
|
"loss_size": 0.23147230129688978,
|
||||||
|
"loss_pdi": 0.4668895788490772,
|
||||||
|
"loss_ee": 0.6054624281823635,
|
||||||
|
"loss_delivery": 0.6695475745946169,
|
||||||
|
"loss_biodist": 0.2563781104981899,
|
||||||
|
"loss_toxic": 0.05555897764861584
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.2476960122585297,
|
||||||
|
"loss_size": 0.23254966363310814,
|
||||||
|
"loss_pdi": 0.48378554731607437,
|
||||||
|
"loss_ee": 0.5625484213232994,
|
||||||
|
"loss_delivery": 0.6714967219159007,
|
||||||
|
"loss_biodist": 0.22976274229586124,
|
||||||
|
"loss_toxic": 0.06755285989493132
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.075557619333267,
|
||||||
|
"loss_size": 0.1991214593872428,
|
||||||
|
"loss_pdi": 0.4666460305452347,
|
||||||
|
"loss_ee": 0.5353769697248936,
|
||||||
|
"loss_delivery": 0.5960409259423614,
|
||||||
|
"loss_biodist": 0.21837125346064568,
|
||||||
|
"loss_toxic": 0.060001003788784146
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.5380745232105255,
|
||||||
|
"loss_size": 0.26692971400916576,
|
||||||
|
"loss_pdi": 0.46107836067676544,
|
||||||
|
"loss_ee": 0.5409572347998619,
|
||||||
|
"loss_delivery": 1.0045308656990528,
|
||||||
|
"loss_biodist": 0.20893055945634842,
|
||||||
|
"loss_toxic": 0.05564788053743541
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.187355950474739,
|
||||||
|
"loss_size": 0.19620049092918634,
|
||||||
|
"loss_pdi": 0.4322536773979664,
|
||||||
|
"loss_ee": 0.5325545407831669,
|
||||||
|
"loss_delivery": 0.791555093601346,
|
||||||
|
"loss_biodist": 0.19126823358237743,
|
||||||
|
"loss_toxic": 0.0435238650534302
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.0472205132246017,
|
||||||
|
"loss_size": 0.2132822722196579,
|
||||||
|
"loss_pdi": 0.4357164613902569,
|
||||||
|
"loss_ee": 0.4921276159584522,
|
||||||
|
"loss_delivery": 0.6838537249714136,
|
||||||
|
"loss_biodist": 0.17821292020380497,
|
||||||
|
"loss_toxic": 0.04402748285792768
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.8377329260110855,
|
||||||
|
"loss_size": 0.1911225076764822,
|
||||||
|
"loss_pdi": 0.40041540563106537,
|
||||||
|
"loss_ee": 0.48244743049144745,
|
||||||
|
"loss_delivery": 0.5407265722751617,
|
||||||
|
"loss_biodist": 0.18436198495328426,
|
||||||
|
"loss_toxic": 0.03865901718381792
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.767472356557846,
|
||||||
|
"loss_size": 0.1669206041842699,
|
||||||
|
"loss_pdi": 0.41016123816370964,
|
||||||
|
"loss_ee": 0.48856623098254204,
|
||||||
|
"loss_delivery": 0.4990109819918871,
|
||||||
|
"loss_biodist": 0.17120445892214775,
|
||||||
|
"loss_toxic": 0.03160886780824512
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.7883769124746323,
|
||||||
|
"loss_size": 0.18442536424845457,
|
||||||
|
"loss_pdi": 0.40120166912674904,
|
||||||
|
"loss_ee": 0.46751197054982185,
|
||||||
|
"loss_delivery": 0.537370229139924,
|
||||||
|
"loss_biodist": 0.16241099871695042,
|
||||||
|
"loss_toxic": 0.03545669896993786
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.7724643051624298,
|
||||||
|
"loss_size": 0.1618126342073083,
|
||||||
|
"loss_pdi": 0.40923435613512993,
|
||||||
|
"loss_ee": 0.4502934068441391,
|
||||||
|
"loss_delivery": 0.5432828362099826,
|
||||||
|
"loss_biodist": 0.17081433162093163,
|
||||||
|
"loss_toxic": 0.03702673775842413
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.8993548452854156,
|
||||||
|
"loss_size": 0.15662006475031376,
|
||||||
|
"loss_pdi": 0.37853332981467247,
|
||||||
|
"loss_ee": 0.4373016320168972,
|
||||||
|
"loss_delivery": 0.7414433392696083,
|
||||||
|
"loss_biodist": 0.15876813046634197,
|
||||||
|
"loss_toxic": 0.02668837201781571
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.7090002000331879,
|
||||||
|
"loss_size": 0.18559002690017223,
|
||||||
|
"loss_pdi": 0.3948023244738579,
|
||||||
|
"loss_ee": 0.4514715373516083,
|
||||||
|
"loss_delivery": 0.5189001243561506,
|
||||||
|
"loss_biodist": 0.13513169158250093,
|
||||||
|
"loss_toxic": 0.023104518884792924
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.5656199902296066,
|
||||||
|
"loss_size": 0.1486712908372283,
|
||||||
|
"loss_pdi": 0.37246324494481087,
|
||||||
|
"loss_ee": 0.4362662769854069,
|
||||||
|
"loss_delivery": 0.4543162193149328,
|
||||||
|
"loss_biodist": 0.13266231305897236,
|
||||||
|
"loss_toxic": 0.021240679430775344
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.6873565018177032,
|
||||||
|
"loss_size": 0.16270869225263596,
|
||||||
|
"loss_pdi": 0.3765699379146099,
|
||||||
|
"loss_ee": 0.4488464966416359,
|
||||||
|
"loss_delivery": 0.5208693165332079,
|
||||||
|
"loss_biodist": 0.15251764748245478,
|
||||||
|
"loss_toxic": 0.025844466988928616
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.611741527915001,
|
||||||
|
"loss_size": 0.1608524639159441,
|
||||||
|
"loss_pdi": 0.3773089461028576,
|
||||||
|
"loss_ee": 0.4431660957634449,
|
||||||
|
"loss_delivery": 0.47964918427169323,
|
||||||
|
"loss_biodist": 0.13358874432742596,
|
||||||
|
"loss_toxic": 0.017176096327602863
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.547684259712696,
|
||||||
|
"loss_size": 0.15734383463859558,
|
||||||
|
"loss_pdi": 0.3490111008286476,
|
||||||
|
"loss_ee": 0.4290902689099312,
|
||||||
|
"loss_delivery": 0.45983999967575073,
|
||||||
|
"loss_biodist": 0.1339349802583456,
|
||||||
|
"loss_toxic": 0.01846407217090018
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.4728069305419922,
|
||||||
|
"loss_size": 0.16153435036540031,
|
||||||
|
"loss_pdi": 0.3516402244567871,
|
||||||
|
"loss_ee": 0.4158446751534939,
|
||||||
|
"loss_delivery": 0.3929086276330054,
|
||||||
|
"loss_biodist": 0.12818495463579893,
|
||||||
|
"loss_toxic": 0.02269406005507335
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.494736596941948,
|
||||||
|
"loss_size": 0.13391354773193598,
|
||||||
|
"loss_pdi": 0.3454095683991909,
|
||||||
|
"loss_ee": 0.3995618261396885,
|
||||||
|
"loss_delivery": 0.47225130116567016,
|
||||||
|
"loss_biodist": 0.12276446260511875,
|
||||||
|
"loss_toxic": 0.02083588083041832
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 1.5983823090791702,
|
||||||
|
"loss_size": 0.15655907429754734,
|
||||||
|
"loss_pdi": 0.3302378598600626,
|
||||||
|
"loss_ee": 0.40332265198230743,
|
||||||
|
"loss_delivery": 0.558618601411581,
|
||||||
|
"loss_biodist": 0.13051889464259148,
|
||||||
|
"loss_toxic": 0.019125229271594435
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"val": [
|
||||||
|
{
|
||||||
|
"loss": 10.642529010772705,
|
||||||
|
"loss_size": 6.192671060562134,
|
||||||
|
"loss_pdi": 1.2811625599861145,
|
||||||
|
"loss_ee": 1.0301620960235596,
|
||||||
|
"loss_delivery": 0.41728606820106506,
|
||||||
|
"loss_biodist": 1.2410815358161926,
|
||||||
|
"loss_toxic": 0.4801655113697052,
|
||||||
|
"acc_pdi": 0.45,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 0.9743589743589743
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 4.37042236328125,
|
||||||
|
"loss_size": 0.6340591907501221,
|
||||||
|
"loss_pdi": 0.9936424493789673,
|
||||||
|
"loss_ee": 0.9678087830543518,
|
||||||
|
"loss_delivery": 0.41530802845954895,
|
||||||
|
"loss_biodist": 1.0821772813796997,
|
||||||
|
"loss_toxic": 0.2774266004562378,
|
||||||
|
"acc_pdi": 0.75,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 0.9743589743589743
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 3.4235814809799194,
|
||||||
|
"loss_size": 0.07745273411273956,
|
||||||
|
"loss_pdi": 0.7955547869205475,
|
||||||
|
"loss_ee": 0.9630871117115021,
|
||||||
|
"loss_delivery": 0.41978873312473297,
|
||||||
|
"loss_biodist": 1.0081366300582886,
|
||||||
|
"loss_toxic": 0.1595613993704319,
|
||||||
|
"acc_pdi": 0.75,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 0.9743589743589743
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 3.3056893348693848,
|
||||||
|
"loss_size": 0.13385668396949768,
|
||||||
|
"loss_pdi": 0.717946320772171,
|
||||||
|
"loss_ee": 0.9484646320343018,
|
||||||
|
"loss_delivery": 0.41560181975364685,
|
||||||
|
"loss_biodist": 0.9615574777126312,
|
||||||
|
"loss_toxic": 0.12826236337423325,
|
||||||
|
"acc_pdi": 0.75,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 0.9743589743589743
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 3.1493594646453857,
|
||||||
|
"loss_size": 0.07856455072760582,
|
||||||
|
"loss_pdi": 0.668080747127533,
|
||||||
|
"loss_ee": 0.9540894627571106,
|
||||||
|
"loss_delivery": 0.474899023771286,
|
||||||
|
"loss_biodist": 0.8731786608695984,
|
||||||
|
"loss_toxic": 0.10054689273238182,
|
||||||
|
"acc_pdi": 0.75,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 0.9743589743589743
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.8707590103149414,
|
||||||
|
"loss_size": 0.1222907267510891,
|
||||||
|
"loss_pdi": 0.6401174664497375,
|
||||||
|
"loss_ee": 0.9136309623718262,
|
||||||
|
"loss_delivery": 0.3875949829816818,
|
||||||
|
"loss_biodist": 0.7330293655395508,
|
||||||
|
"loss_toxic": 0.07409549504518509,
|
||||||
|
"acc_pdi": 0.7166666666666667,
|
||||||
|
"acc_ee": 0.6,
|
||||||
|
"acc_toxic": 0.9743589743589743
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.6709553003311157,
|
||||||
|
"loss_size": 0.14854157716035843,
|
||||||
|
"loss_pdi": 0.6254143714904785,
|
||||||
|
"loss_ee": 0.8874242305755615,
|
||||||
|
"loss_delivery": 0.37635043263435364,
|
||||||
|
"loss_biodist": 0.5931203663349152,
|
||||||
|
"loss_toxic": 0.04010416939854622,
|
||||||
|
"acc_pdi": 0.7166666666666667,
|
||||||
|
"acc_ee": 0.5666666666666667,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.622614026069641,
|
||||||
|
"loss_size": 0.1486668586730957,
|
||||||
|
"loss_pdi": 0.6269595921039581,
|
||||||
|
"loss_ee": 0.8829602599143982,
|
||||||
|
"loss_delivery": 0.40370240807533264,
|
||||||
|
"loss_biodist": 0.535428449511528,
|
||||||
|
"loss_toxic": 0.024896428920328617,
|
||||||
|
"acc_pdi": 0.7166666666666667,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.4939963817596436,
|
||||||
|
"loss_size": 0.11941515654325485,
|
||||||
|
"loss_pdi": 0.624951958656311,
|
||||||
|
"loss_ee": 0.882407933473587,
|
||||||
|
"loss_delivery": 0.38461272418498993,
|
||||||
|
"loss_biodist": 0.462909460067749,
|
||||||
|
"loss_toxic": 0.019699251279234886,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.6,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.4623881578445435,
|
||||||
|
"loss_size": 0.14110726118087769,
|
||||||
|
"loss_pdi": 0.5920329689979553,
|
||||||
|
"loss_ee": 0.8816524147987366,
|
||||||
|
"loss_delivery": 0.41789302229881287,
|
||||||
|
"loss_biodist": 0.40996842086315155,
|
||||||
|
"loss_toxic": 0.019733915105462074,
|
||||||
|
"acc_pdi": 0.75,
|
||||||
|
"acc_ee": 0.65,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.396706700325012,
|
||||||
|
"loss_size": 0.11815935745835304,
|
||||||
|
"loss_pdi": 0.6046018600463867,
|
||||||
|
"loss_ee": 0.8763127326965332,
|
||||||
|
"loss_delivery": 0.4112636297941208,
|
||||||
|
"loss_biodist": 0.36830006539821625,
|
||||||
|
"loss_toxic": 0.018068938050419092,
|
||||||
|
"acc_pdi": 0.7166666666666667,
|
||||||
|
"acc_ee": 0.65,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.457483172416687,
|
||||||
|
"loss_size": 0.13330847769975662,
|
||||||
|
"loss_pdi": 0.6174256205558777,
|
||||||
|
"loss_ee": 0.8725559711456299,
|
||||||
|
"loss_delivery": 0.46076954901218414,
|
||||||
|
"loss_biodist": 0.358450323343277,
|
||||||
|
"loss_toxic": 0.014973167330026627,
|
||||||
|
"acc_pdi": 0.75,
|
||||||
|
"acc_ee": 0.6333333333333333,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.384940981864929,
|
||||||
|
"loss_size": 0.11283988133072853,
|
||||||
|
"loss_pdi": 0.5969351530075073,
|
||||||
|
"loss_ee": 0.8822423815727234,
|
||||||
|
"loss_delivery": 0.4492803066968918,
|
||||||
|
"loss_biodist": 0.3306152671575546,
|
||||||
|
"loss_toxic": 0.013027888257056475,
|
||||||
|
"acc_pdi": 0.75,
|
||||||
|
"acc_ee": 0.6,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.395999550819397,
|
||||||
|
"loss_size": 0.1267366074025631,
|
||||||
|
"loss_pdi": 0.609168529510498,
|
||||||
|
"loss_ee": 0.8720199763774872,
|
||||||
|
"loss_delivery": 0.45844390988349915,
|
||||||
|
"loss_biodist": 0.31721335649490356,
|
||||||
|
"loss_toxic": 0.01241705659776926,
|
||||||
|
"acc_pdi": 0.75,
|
||||||
|
"acc_ee": 0.6333333333333333,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.432988166809082,
|
||||||
|
"loss_size": 0.13193047046661377,
|
||||||
|
"loss_pdi": 0.6282809674739838,
|
||||||
|
"loss_ee": 0.889969140291214,
|
||||||
|
"loss_delivery": 0.45955638587474823,
|
||||||
|
"loss_biodist": 0.3108634203672409,
|
||||||
|
"loss_toxic": 0.012387921568006277,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.6,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.361279010772705,
|
||||||
|
"loss_size": 0.09955761209130287,
|
||||||
|
"loss_pdi": 0.6252501904964447,
|
||||||
|
"loss_ee": 0.8960326313972473,
|
||||||
|
"loss_delivery": 0.44326628744602203,
|
||||||
|
"loss_biodist": 0.285326212644577,
|
||||||
|
"loss_toxic": 0.011845993809401989,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.6166666666666667,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.420892119407654,
|
||||||
|
"loss_size": 0.13956347852945328,
|
||||||
|
"loss_pdi": 0.6046904623508453,
|
||||||
|
"loss_ee": 0.9216890335083008,
|
||||||
|
"loss_delivery": 0.4730597734451294,
|
||||||
|
"loss_biodist": 0.27191491425037384,
|
||||||
|
"loss_toxic": 0.009974355343729258,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.5666666666666667,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.453316569328308,
|
||||||
|
"loss_size": 0.11874271184206009,
|
||||||
|
"loss_pdi": 0.6105562448501587,
|
||||||
|
"loss_ee": 0.923934280872345,
|
||||||
|
"loss_delivery": 0.5215270966291428,
|
||||||
|
"loss_biodist": 0.27011625468730927,
|
||||||
|
"loss_toxic": 0.008439893601462245,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.4440466165542603,
|
||||||
|
"loss_size": 0.13506918773055077,
|
||||||
|
"loss_pdi": 0.6187410056591034,
|
||||||
|
"loss_ee": 0.9395149946212769,
|
||||||
|
"loss_delivery": 0.48008452355861664,
|
||||||
|
"loss_biodist": 0.26326628774404526,
|
||||||
|
"loss_toxic": 0.007370669860392809,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.5666666666666667,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.4298155307769775,
|
||||||
|
"loss_size": 0.11216039955615997,
|
||||||
|
"loss_pdi": 0.6475824415683746,
|
||||||
|
"loss_ee": 0.9234196841716766,
|
||||||
|
"loss_delivery": 0.46182475984096527,
|
||||||
|
"loss_biodist": 0.2774467319250107,
|
||||||
|
"loss_toxic": 0.007381373317912221,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.6333333333333333,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.4662675857543945,
|
||||||
|
"loss_size": 0.07973097264766693,
|
||||||
|
"loss_pdi": 0.6542999148368835,
|
||||||
|
"loss_ee": 0.9407779574394226,
|
||||||
|
"loss_delivery": 0.4921555370092392,
|
||||||
|
"loss_biodist": 0.29242587089538574,
|
||||||
|
"loss_toxic": 0.006877265637740493,
|
||||||
|
"acc_pdi": 0.7166666666666667,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.394904613494873,
|
||||||
|
"loss_size": 0.1067984513938427,
|
||||||
|
"loss_pdi": 0.6353477835655212,
|
||||||
|
"loss_ee": 0.9535529017448425,
|
||||||
|
"loss_delivery": 0.4185742735862732,
|
||||||
|
"loss_biodist": 0.27519282698631287,
|
||||||
|
"loss_toxic": 0.005438495893031359,
|
||||||
|
"acc_pdi": 0.7166666666666667,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.424455165863037,
|
||||||
|
"loss_size": 0.09576653316617012,
|
||||||
|
"loss_pdi": 0.6532827019691467,
|
||||||
|
"loss_ee": 0.9595111310482025,
|
||||||
|
"loss_delivery": 0.4392053782939911,
|
||||||
|
"loss_biodist": 0.27055467665195465,
|
||||||
|
"loss_toxic": 0.006134827388450503,
|
||||||
|
"acc_pdi": 0.7166666666666667,
|
||||||
|
"acc_ee": 0.5666666666666667,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.5053725242614746,
|
||||||
|
"loss_size": 0.10888586193323135,
|
||||||
|
"loss_pdi": 0.656737744808197,
|
||||||
|
"loss_ee": 0.9638712704181671,
|
||||||
|
"loss_delivery": 0.4994397610425949,
|
||||||
|
"loss_biodist": 0.27049925923347473,
|
||||||
|
"loss_toxic": 0.005938541842624545,
|
||||||
|
"acc_pdi": 0.7166666666666667,
|
||||||
|
"acc_ee": 0.6166666666666667,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.539380669593811,
|
||||||
|
"loss_size": 0.11638890951871872,
|
||||||
|
"loss_pdi": 0.6549257636070251,
|
||||||
|
"loss_ee": 0.9761019647121429,
|
||||||
|
"loss_delivery": 0.5175990015268326,
|
||||||
|
"loss_biodist": 0.26853571832180023,
|
||||||
|
"loss_toxic": 0.0058291994500905275,
|
||||||
|
"acc_pdi": 0.7166666666666667,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.4885005950927734,
|
||||||
|
"loss_size": 0.12185845524072647,
|
||||||
|
"loss_pdi": 0.667609304189682,
|
||||||
|
"loss_ee": 0.9739555716514587,
|
||||||
|
"loss_delivery": 0.45484381914138794,
|
||||||
|
"loss_biodist": 0.26541487127542496,
|
||||||
|
"loss_toxic": 0.00481854728423059,
|
||||||
|
"acc_pdi": 0.7166666666666667,
|
||||||
|
"acc_ee": 0.6,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.514465093612671,
|
||||||
|
"loss_size": 0.11423381045460701,
|
||||||
|
"loss_pdi": 0.6905614733695984,
|
||||||
|
"loss_ee": 0.9758535623550415,
|
||||||
|
"loss_delivery": 0.47175830602645874,
|
||||||
|
"loss_biodist": 0.257549487054348,
|
||||||
|
"loss_toxic": 0.004508488811552525,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.6,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.5609203577041626,
|
||||||
|
"loss_size": 0.11414870619773865,
|
||||||
|
"loss_pdi": 0.6677174866199493,
|
||||||
|
"loss_ee": 0.9934183955192566,
|
||||||
|
"loss_delivery": 0.5167151391506195,
|
||||||
|
"loss_biodist": 0.2642747312784195,
|
||||||
|
"loss_toxic": 0.0046457564458251,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.6166666666666667,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.5480626821517944,
|
||||||
|
"loss_size": 0.10506367310881615,
|
||||||
|
"loss_pdi": 0.6713496148586273,
|
||||||
|
"loss_ee": 0.9947319328784943,
|
||||||
|
"loss_delivery": 0.5066179037094116,
|
||||||
|
"loss_biodist": 0.2658369243144989,
|
||||||
|
"loss_toxic": 0.004462693585082889,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.6,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.5397441387176514,
|
||||||
|
"loss_size": 0.10522815585136414,
|
||||||
|
"loss_pdi": 0.6821762025356293,
|
||||||
|
"loss_ee": 0.9780809879302979,
|
||||||
|
"loss_delivery": 0.5069089531898499,
|
||||||
|
"loss_biodist": 0.2627965956926346,
|
||||||
|
"loss_toxic": 0.0045532952062785625,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 2.5472241640090942,
|
||||||
|
"loss_size": 0.11468468606472015,
|
||||||
|
"loss_pdi": 0.6848073601722717,
|
||||||
|
"loss_ee": 0.9830659925937653,
|
||||||
|
"loss_delivery": 0.5032574832439423,
|
||||||
|
"loss_biodist": 0.25694920122623444,
|
||||||
|
"loss_toxic": 0.0044593592174351215,
|
||||||
|
"acc_pdi": 0.7333333333333333,
|
||||||
|
"acc_ee": 0.5833333333333334,
|
||||||
|
"acc_toxic": 1.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
BIN
models/model.pt
Normal file
BIN
models/model.pt
Normal file
Binary file not shown.
165
models/pretrained/all_amine_split_for_LiON/cv_0/args.json
Normal file
165
models/pretrained/all_amine_split_for_LiON/cv_0/args.json
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
{
|
||||||
|
"activation": "ReLU",
|
||||||
|
"adding_bond_types": true,
|
||||||
|
"adding_h": false,
|
||||||
|
"aggregation": "mean",
|
||||||
|
"aggregation_norm": 100,
|
||||||
|
"atom_constraints": [],
|
||||||
|
"atom_descriptor_scaling": true,
|
||||||
|
"atom_descriptors": null,
|
||||||
|
"atom_descriptors_path": null,
|
||||||
|
"atom_descriptors_size": 0,
|
||||||
|
"atom_features_size": 0,
|
||||||
|
"atom_messages": false,
|
||||||
|
"atom_targets": [],
|
||||||
|
"batch_size": 50,
|
||||||
|
"bias": false,
|
||||||
|
"bias_solvent": false,
|
||||||
|
"bond_constraints": [],
|
||||||
|
"bond_descriptor_scaling": true,
|
||||||
|
"bond_descriptors": null,
|
||||||
|
"bond_descriptors_path": null,
|
||||||
|
"bond_descriptors_size": 0,
|
||||||
|
"bond_features_size": 0,
|
||||||
|
"bond_targets": [],
|
||||||
|
"cache_cutoff": 10000,
|
||||||
|
"checkpoint_dir": null,
|
||||||
|
"checkpoint_frzn": null,
|
||||||
|
"checkpoint_path": null,
|
||||||
|
"checkpoint_paths": null,
|
||||||
|
"class_balance": false,
|
||||||
|
"config_path": "../data/args_files/optimized_configs.json",
|
||||||
|
"constraints_path": null,
|
||||||
|
"crossval_index_dir": null,
|
||||||
|
"crossval_index_file": null,
|
||||||
|
"crossval_index_sets": null,
|
||||||
|
"cuda": true,
|
||||||
|
"data_path": "../data/crossval_splits/all_amine_split_for_paper/cv_0/train.csv",
|
||||||
|
"data_weights_path": "../data/crossval_splits/all_amine_split_for_paper/cv_0/train_weights.csv",
|
||||||
|
"dataset_type": "regression",
|
||||||
|
"depth": 4,
|
||||||
|
"depth_solvent": 3,
|
||||||
|
"device": {
|
||||||
|
"_string": "cuda",
|
||||||
|
"_type": "python_object (type = device)",
|
||||||
|
"_value": "gASVHwAAAAAAAACMBXRvcmNolIwGZGV2aWNllJOUjARjdWRhlIWUUpQu"
|
||||||
|
},
|
||||||
|
"dropout": 0.1,
|
||||||
|
"empty_cache": false,
|
||||||
|
"ensemble_size": 1,
|
||||||
|
"epochs": 50,
|
||||||
|
"evidential_regularization": 0,
|
||||||
|
"explicit_h": false,
|
||||||
|
"extra_metrics": [],
|
||||||
|
"features_generator": null,
|
||||||
|
"features_only": false,
|
||||||
|
"features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_0/train_extra_x.csv"
|
||||||
|
],
|
||||||
|
"features_scaling": true,
|
||||||
|
"features_size": null,
|
||||||
|
"ffn_hidden_size": 600,
|
||||||
|
"ffn_num_layers": 3,
|
||||||
|
"final_lr": 0.0001,
|
||||||
|
"folds_file": null,
|
||||||
|
"freeze_first_only": false,
|
||||||
|
"frzn_ffn_layers": 0,
|
||||||
|
"gpu": null,
|
||||||
|
"grad_clip": null,
|
||||||
|
"hidden_size": 600,
|
||||||
|
"hidden_size_solvent": 300,
|
||||||
|
"ignore_columns": null,
|
||||||
|
"ignore_nan_metrics": false,
|
||||||
|
"init_lr": 0.0001,
|
||||||
|
"is_atom_bond_targets": false,
|
||||||
|
"keeping_atom_map": false,
|
||||||
|
"log_frequency": 10,
|
||||||
|
"loss_function": "mse",
|
||||||
|
"max_data_size": null,
|
||||||
|
"max_lr": 0.001,
|
||||||
|
"metric": "rmse",
|
||||||
|
"metrics": [
|
||||||
|
"rmse"
|
||||||
|
],
|
||||||
|
"minimize_score": true,
|
||||||
|
"mpn_shared": false,
|
||||||
|
"multiclass_num_classes": 3,
|
||||||
|
"no_adding_bond_types": false,
|
||||||
|
"no_atom_descriptor_scaling": false,
|
||||||
|
"no_bond_descriptor_scaling": false,
|
||||||
|
"no_cache_mol": false,
|
||||||
|
"no_cuda": false,
|
||||||
|
"no_features_scaling": false,
|
||||||
|
"no_shared_atom_bond_ffn": false,
|
||||||
|
"num_folds": 1,
|
||||||
|
"num_lrs": 1,
|
||||||
|
"num_tasks": 1,
|
||||||
|
"num_workers": 8,
|
||||||
|
"number_of_molecules": 1,
|
||||||
|
"overwrite_default_atom_features": false,
|
||||||
|
"overwrite_default_bond_features": false,
|
||||||
|
"phase_features_path": null,
|
||||||
|
"pytorch_seed": 0,
|
||||||
|
"quantile_loss_alpha": 0.1,
|
||||||
|
"quantiles": [],
|
||||||
|
"quiet": false,
|
||||||
|
"reaction": false,
|
||||||
|
"reaction_mode": "reac_diff",
|
||||||
|
"reaction_solvent": false,
|
||||||
|
"reproducibility": {
|
||||||
|
"command_line": "python main_script.py train all_amine_split_for_paper",
|
||||||
|
"git_has_uncommitted_changes": true,
|
||||||
|
"git_root": "/media/andersonxps/wd_4tb/evan/LNP_ML",
|
||||||
|
"git_url": "https://github.com/jswitten/LNP_ML/tree/167822980dc26ba65c5c14539c4ce12b81b0b8f3",
|
||||||
|
"time": "Tue Jul 30 10:15:25 2024"
|
||||||
|
},
|
||||||
|
"resume_experiment": false,
|
||||||
|
"save_dir": "../data/crossval_splits/all_amine_split_for_paper/cv_0",
|
||||||
|
"save_preds": false,
|
||||||
|
"save_smiles_splits": false,
|
||||||
|
"seed": 42,
|
||||||
|
"separate_test_atom_descriptors_path": null,
|
||||||
|
"separate_test_bond_descriptors_path": null,
|
||||||
|
"separate_test_constraints_path": null,
|
||||||
|
"separate_test_features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_0/test_extra_x.csv"
|
||||||
|
],
|
||||||
|
"separate_test_path": "../data/crossval_splits/all_amine_split_for_paper/cv_0/test.csv",
|
||||||
|
"separate_test_phase_features_path": null,
|
||||||
|
"separate_val_atom_descriptors_path": null,
|
||||||
|
"separate_val_bond_descriptors_path": null,
|
||||||
|
"separate_val_constraints_path": null,
|
||||||
|
"separate_val_features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_0/valid_extra_x.csv"
|
||||||
|
],
|
||||||
|
"separate_val_path": "../data/crossval_splits/all_amine_split_for_paper/cv_0/valid.csv",
|
||||||
|
"separate_val_phase_features_path": null,
|
||||||
|
"shared_atom_bond_ffn": true,
|
||||||
|
"show_individual_scores": false,
|
||||||
|
"smiles_columns": [
|
||||||
|
"smiles"
|
||||||
|
],
|
||||||
|
"spectra_activation": "exp",
|
||||||
|
"spectra_phase_mask_path": null,
|
||||||
|
"spectra_target_floor": 1e-08,
|
||||||
|
"split_key_molecule": 0,
|
||||||
|
"split_sizes": [
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"split_type": "random",
|
||||||
|
"target_columns": null,
|
||||||
|
"target_weights": null,
|
||||||
|
"task_names": [
|
||||||
|
"quantified_delivery"
|
||||||
|
],
|
||||||
|
"test": false,
|
||||||
|
"test_fold_index": null,
|
||||||
|
"train_data_size": null,
|
||||||
|
"undirected": false,
|
||||||
|
"use_input_features": true,
|
||||||
|
"val_fold_index": null,
|
||||||
|
"warmup_epochs": 2.0,
|
||||||
|
"weights_ffn_num_layers": 2
|
||||||
|
}
|
||||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"rmse": [
|
||||||
|
0.8880622451903801
|
||||||
|
]
|
||||||
|
}
|
||||||
165
models/pretrained/all_amine_split_for_LiON/cv_1/args.json
Normal file
165
models/pretrained/all_amine_split_for_LiON/cv_1/args.json
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
{
|
||||||
|
"activation": "ReLU",
|
||||||
|
"adding_bond_types": true,
|
||||||
|
"adding_h": false,
|
||||||
|
"aggregation": "mean",
|
||||||
|
"aggregation_norm": 100,
|
||||||
|
"atom_constraints": [],
|
||||||
|
"atom_descriptor_scaling": true,
|
||||||
|
"atom_descriptors": null,
|
||||||
|
"atom_descriptors_path": null,
|
||||||
|
"atom_descriptors_size": 0,
|
||||||
|
"atom_features_size": 0,
|
||||||
|
"atom_messages": false,
|
||||||
|
"atom_targets": [],
|
||||||
|
"batch_size": 50,
|
||||||
|
"bias": false,
|
||||||
|
"bias_solvent": false,
|
||||||
|
"bond_constraints": [],
|
||||||
|
"bond_descriptor_scaling": true,
|
||||||
|
"bond_descriptors": null,
|
||||||
|
"bond_descriptors_path": null,
|
||||||
|
"bond_descriptors_size": 0,
|
||||||
|
"bond_features_size": 0,
|
||||||
|
"bond_targets": [],
|
||||||
|
"cache_cutoff": 10000,
|
||||||
|
"checkpoint_dir": null,
|
||||||
|
"checkpoint_frzn": null,
|
||||||
|
"checkpoint_path": null,
|
||||||
|
"checkpoint_paths": null,
|
||||||
|
"class_balance": false,
|
||||||
|
"config_path": "../data/args_files/optimized_configs.json",
|
||||||
|
"constraints_path": null,
|
||||||
|
"crossval_index_dir": null,
|
||||||
|
"crossval_index_file": null,
|
||||||
|
"crossval_index_sets": null,
|
||||||
|
"cuda": true,
|
||||||
|
"data_path": "../data/crossval_splits/all_amine_split_for_paper/cv_1/train.csv",
|
||||||
|
"data_weights_path": "../data/crossval_splits/all_amine_split_for_paper/cv_1/train_weights.csv",
|
||||||
|
"dataset_type": "regression",
|
||||||
|
"depth": 4,
|
||||||
|
"depth_solvent": 3,
|
||||||
|
"device": {
|
||||||
|
"_string": "cuda",
|
||||||
|
"_type": "python_object (type = device)",
|
||||||
|
"_value": "gASVHwAAAAAAAACMBXRvcmNolIwGZGV2aWNllJOUjARjdWRhlIWUUpQu"
|
||||||
|
},
|
||||||
|
"dropout": 0.1,
|
||||||
|
"empty_cache": false,
|
||||||
|
"ensemble_size": 1,
|
||||||
|
"epochs": 50,
|
||||||
|
"evidential_regularization": 0,
|
||||||
|
"explicit_h": false,
|
||||||
|
"extra_metrics": [],
|
||||||
|
"features_generator": null,
|
||||||
|
"features_only": false,
|
||||||
|
"features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_1/train_extra_x.csv"
|
||||||
|
],
|
||||||
|
"features_scaling": true,
|
||||||
|
"features_size": null,
|
||||||
|
"ffn_hidden_size": 600,
|
||||||
|
"ffn_num_layers": 3,
|
||||||
|
"final_lr": 0.0001,
|
||||||
|
"folds_file": null,
|
||||||
|
"freeze_first_only": false,
|
||||||
|
"frzn_ffn_layers": 0,
|
||||||
|
"gpu": null,
|
||||||
|
"grad_clip": null,
|
||||||
|
"hidden_size": 600,
|
||||||
|
"hidden_size_solvent": 300,
|
||||||
|
"ignore_columns": null,
|
||||||
|
"ignore_nan_metrics": false,
|
||||||
|
"init_lr": 0.0001,
|
||||||
|
"is_atom_bond_targets": false,
|
||||||
|
"keeping_atom_map": false,
|
||||||
|
"log_frequency": 10,
|
||||||
|
"loss_function": "mse",
|
||||||
|
"max_data_size": null,
|
||||||
|
"max_lr": 0.001,
|
||||||
|
"metric": "rmse",
|
||||||
|
"metrics": [
|
||||||
|
"rmse"
|
||||||
|
],
|
||||||
|
"minimize_score": true,
|
||||||
|
"mpn_shared": false,
|
||||||
|
"multiclass_num_classes": 3,
|
||||||
|
"no_adding_bond_types": false,
|
||||||
|
"no_atom_descriptor_scaling": false,
|
||||||
|
"no_bond_descriptor_scaling": false,
|
||||||
|
"no_cache_mol": false,
|
||||||
|
"no_cuda": false,
|
||||||
|
"no_features_scaling": false,
|
||||||
|
"no_shared_atom_bond_ffn": false,
|
||||||
|
"num_folds": 1,
|
||||||
|
"num_lrs": 1,
|
||||||
|
"num_tasks": 1,
|
||||||
|
"num_workers": 8,
|
||||||
|
"number_of_molecules": 1,
|
||||||
|
"overwrite_default_atom_features": false,
|
||||||
|
"overwrite_default_bond_features": false,
|
||||||
|
"phase_features_path": null,
|
||||||
|
"pytorch_seed": 0,
|
||||||
|
"quantile_loss_alpha": 0.1,
|
||||||
|
"quantiles": [],
|
||||||
|
"quiet": false,
|
||||||
|
"reaction": false,
|
||||||
|
"reaction_mode": "reac_diff",
|
||||||
|
"reaction_solvent": false,
|
||||||
|
"reproducibility": {
|
||||||
|
"command_line": "python main_script.py train all_amine_split_for_paper",
|
||||||
|
"git_has_uncommitted_changes": true,
|
||||||
|
"git_root": "/media/andersonxps/wd_4tb/evan/LNP_ML",
|
||||||
|
"git_url": "https://github.com/jswitten/LNP_ML/tree/167822980dc26ba65c5c14539c4ce12b81b0b8f3",
|
||||||
|
"time": "Tue Jul 30 10:21:40 2024"
|
||||||
|
},
|
||||||
|
"resume_experiment": false,
|
||||||
|
"save_dir": "../data/crossval_splits/all_amine_split_for_paper/cv_1",
|
||||||
|
"save_preds": false,
|
||||||
|
"save_smiles_splits": false,
|
||||||
|
"seed": 42,
|
||||||
|
"separate_test_atom_descriptors_path": null,
|
||||||
|
"separate_test_bond_descriptors_path": null,
|
||||||
|
"separate_test_constraints_path": null,
|
||||||
|
"separate_test_features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_1/test_extra_x.csv"
|
||||||
|
],
|
||||||
|
"separate_test_path": "../data/crossval_splits/all_amine_split_for_paper/cv_1/test.csv",
|
||||||
|
"separate_test_phase_features_path": null,
|
||||||
|
"separate_val_atom_descriptors_path": null,
|
||||||
|
"separate_val_bond_descriptors_path": null,
|
||||||
|
"separate_val_constraints_path": null,
|
||||||
|
"separate_val_features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_1/valid_extra_x.csv"
|
||||||
|
],
|
||||||
|
"separate_val_path": "../data/crossval_splits/all_amine_split_for_paper/cv_1/valid.csv",
|
||||||
|
"separate_val_phase_features_path": null,
|
||||||
|
"shared_atom_bond_ffn": true,
|
||||||
|
"show_individual_scores": false,
|
||||||
|
"smiles_columns": [
|
||||||
|
"smiles"
|
||||||
|
],
|
||||||
|
"spectra_activation": "exp",
|
||||||
|
"spectra_phase_mask_path": null,
|
||||||
|
"spectra_target_floor": 1e-08,
|
||||||
|
"split_key_molecule": 0,
|
||||||
|
"split_sizes": [
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"split_type": "random",
|
||||||
|
"target_columns": null,
|
||||||
|
"target_weights": null,
|
||||||
|
"task_names": [
|
||||||
|
"quantified_delivery"
|
||||||
|
],
|
||||||
|
"test": false,
|
||||||
|
"test_fold_index": null,
|
||||||
|
"train_data_size": null,
|
||||||
|
"undirected": false,
|
||||||
|
"use_input_features": true,
|
||||||
|
"val_fold_index": null,
|
||||||
|
"warmup_epochs": 2.0,
|
||||||
|
"weights_ffn_num_layers": 2
|
||||||
|
}
|
||||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"rmse": [
|
||||||
|
1.01673724295223
|
||||||
|
]
|
||||||
|
}
|
||||||
165
models/pretrained/all_amine_split_for_LiON/cv_2/args.json
Normal file
165
models/pretrained/all_amine_split_for_LiON/cv_2/args.json
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
{
|
||||||
|
"activation": "ReLU",
|
||||||
|
"adding_bond_types": true,
|
||||||
|
"adding_h": false,
|
||||||
|
"aggregation": "mean",
|
||||||
|
"aggregation_norm": 100,
|
||||||
|
"atom_constraints": [],
|
||||||
|
"atom_descriptor_scaling": true,
|
||||||
|
"atom_descriptors": null,
|
||||||
|
"atom_descriptors_path": null,
|
||||||
|
"atom_descriptors_size": 0,
|
||||||
|
"atom_features_size": 0,
|
||||||
|
"atom_messages": false,
|
||||||
|
"atom_targets": [],
|
||||||
|
"batch_size": 50,
|
||||||
|
"bias": false,
|
||||||
|
"bias_solvent": false,
|
||||||
|
"bond_constraints": [],
|
||||||
|
"bond_descriptor_scaling": true,
|
||||||
|
"bond_descriptors": null,
|
||||||
|
"bond_descriptors_path": null,
|
||||||
|
"bond_descriptors_size": 0,
|
||||||
|
"bond_features_size": 0,
|
||||||
|
"bond_targets": [],
|
||||||
|
"cache_cutoff": 10000,
|
||||||
|
"checkpoint_dir": null,
|
||||||
|
"checkpoint_frzn": null,
|
||||||
|
"checkpoint_path": null,
|
||||||
|
"checkpoint_paths": null,
|
||||||
|
"class_balance": false,
|
||||||
|
"config_path": "../data/args_files/optimized_configs.json",
|
||||||
|
"constraints_path": null,
|
||||||
|
"crossval_index_dir": null,
|
||||||
|
"crossval_index_file": null,
|
||||||
|
"crossval_index_sets": null,
|
||||||
|
"cuda": true,
|
||||||
|
"data_path": "../data/crossval_splits/all_amine_split_for_paper/cv_2/train.csv",
|
||||||
|
"data_weights_path": "../data/crossval_splits/all_amine_split_for_paper/cv_2/train_weights.csv",
|
||||||
|
"dataset_type": "regression",
|
||||||
|
"depth": 4,
|
||||||
|
"depth_solvent": 3,
|
||||||
|
"device": {
|
||||||
|
"_string": "cuda",
|
||||||
|
"_type": "python_object (type = device)",
|
||||||
|
"_value": "gASVHwAAAAAAAACMBXRvcmNolIwGZGV2aWNllJOUjARjdWRhlIWUUpQu"
|
||||||
|
},
|
||||||
|
"dropout": 0.1,
|
||||||
|
"empty_cache": false,
|
||||||
|
"ensemble_size": 1,
|
||||||
|
"epochs": 50,
|
||||||
|
"evidential_regularization": 0,
|
||||||
|
"explicit_h": false,
|
||||||
|
"extra_metrics": [],
|
||||||
|
"features_generator": null,
|
||||||
|
"features_only": false,
|
||||||
|
"features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_2/train_extra_x.csv"
|
||||||
|
],
|
||||||
|
"features_scaling": true,
|
||||||
|
"features_size": null,
|
||||||
|
"ffn_hidden_size": 600,
|
||||||
|
"ffn_num_layers": 3,
|
||||||
|
"final_lr": 0.0001,
|
||||||
|
"folds_file": null,
|
||||||
|
"freeze_first_only": false,
|
||||||
|
"frzn_ffn_layers": 0,
|
||||||
|
"gpu": null,
|
||||||
|
"grad_clip": null,
|
||||||
|
"hidden_size": 600,
|
||||||
|
"hidden_size_solvent": 300,
|
||||||
|
"ignore_columns": null,
|
||||||
|
"ignore_nan_metrics": false,
|
||||||
|
"init_lr": 0.0001,
|
||||||
|
"is_atom_bond_targets": false,
|
||||||
|
"keeping_atom_map": false,
|
||||||
|
"log_frequency": 10,
|
||||||
|
"loss_function": "mse",
|
||||||
|
"max_data_size": null,
|
||||||
|
"max_lr": 0.001,
|
||||||
|
"metric": "rmse",
|
||||||
|
"metrics": [
|
||||||
|
"rmse"
|
||||||
|
],
|
||||||
|
"minimize_score": true,
|
||||||
|
"mpn_shared": false,
|
||||||
|
"multiclass_num_classes": 3,
|
||||||
|
"no_adding_bond_types": false,
|
||||||
|
"no_atom_descriptor_scaling": false,
|
||||||
|
"no_bond_descriptor_scaling": false,
|
||||||
|
"no_cache_mol": false,
|
||||||
|
"no_cuda": false,
|
||||||
|
"no_features_scaling": false,
|
||||||
|
"no_shared_atom_bond_ffn": false,
|
||||||
|
"num_folds": 1,
|
||||||
|
"num_lrs": 1,
|
||||||
|
"num_tasks": 1,
|
||||||
|
"num_workers": 8,
|
||||||
|
"number_of_molecules": 1,
|
||||||
|
"overwrite_default_atom_features": false,
|
||||||
|
"overwrite_default_bond_features": false,
|
||||||
|
"phase_features_path": null,
|
||||||
|
"pytorch_seed": 0,
|
||||||
|
"quantile_loss_alpha": 0.1,
|
||||||
|
"quantiles": [],
|
||||||
|
"quiet": false,
|
||||||
|
"reaction": false,
|
||||||
|
"reaction_mode": "reac_diff",
|
||||||
|
"reaction_solvent": false,
|
||||||
|
"reproducibility": {
|
||||||
|
"command_line": "python main_script.py train all_amine_split_for_paper",
|
||||||
|
"git_has_uncommitted_changes": true,
|
||||||
|
"git_root": "/media/andersonxps/wd_4tb/evan/LNP_ML",
|
||||||
|
"git_url": "https://github.com/jswitten/LNP_ML/tree/167822980dc26ba65c5c14539c4ce12b81b0b8f3",
|
||||||
|
"time": "Tue Jul 30 10:28:04 2024"
|
||||||
|
},
|
||||||
|
"resume_experiment": false,
|
||||||
|
"save_dir": "../data/crossval_splits/all_amine_split_for_paper/cv_2",
|
||||||
|
"save_preds": false,
|
||||||
|
"save_smiles_splits": false,
|
||||||
|
"seed": 42,
|
||||||
|
"separate_test_atom_descriptors_path": null,
|
||||||
|
"separate_test_bond_descriptors_path": null,
|
||||||
|
"separate_test_constraints_path": null,
|
||||||
|
"separate_test_features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_2/test_extra_x.csv"
|
||||||
|
],
|
||||||
|
"separate_test_path": "../data/crossval_splits/all_amine_split_for_paper/cv_2/test.csv",
|
||||||
|
"separate_test_phase_features_path": null,
|
||||||
|
"separate_val_atom_descriptors_path": null,
|
||||||
|
"separate_val_bond_descriptors_path": null,
|
||||||
|
"separate_val_constraints_path": null,
|
||||||
|
"separate_val_features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_2/valid_extra_x.csv"
|
||||||
|
],
|
||||||
|
"separate_val_path": "../data/crossval_splits/all_amine_split_for_paper/cv_2/valid.csv",
|
||||||
|
"separate_val_phase_features_path": null,
|
||||||
|
"shared_atom_bond_ffn": true,
|
||||||
|
"show_individual_scores": false,
|
||||||
|
"smiles_columns": [
|
||||||
|
"smiles"
|
||||||
|
],
|
||||||
|
"spectra_activation": "exp",
|
||||||
|
"spectra_phase_mask_path": null,
|
||||||
|
"spectra_target_floor": 1e-08,
|
||||||
|
"split_key_molecule": 0,
|
||||||
|
"split_sizes": [
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"split_type": "random",
|
||||||
|
"target_columns": null,
|
||||||
|
"target_weights": null,
|
||||||
|
"task_names": [
|
||||||
|
"quantified_delivery"
|
||||||
|
],
|
||||||
|
"test": false,
|
||||||
|
"test_fold_index": null,
|
||||||
|
"train_data_size": null,
|
||||||
|
"undirected": false,
|
||||||
|
"use_input_features": true,
|
||||||
|
"val_fold_index": null,
|
||||||
|
"warmup_epochs": 2.0,
|
||||||
|
"weights_ffn_num_layers": 2
|
||||||
|
}
|
||||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"rmse": [
|
||||||
|
0.8788072588544181
|
||||||
|
]
|
||||||
|
}
|
||||||
165
models/pretrained/all_amine_split_for_LiON/cv_3/args.json
Normal file
165
models/pretrained/all_amine_split_for_LiON/cv_3/args.json
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
{
|
||||||
|
"activation": "ReLU",
|
||||||
|
"adding_bond_types": true,
|
||||||
|
"adding_h": false,
|
||||||
|
"aggregation": "mean",
|
||||||
|
"aggregation_norm": 100,
|
||||||
|
"atom_constraints": [],
|
||||||
|
"atom_descriptor_scaling": true,
|
||||||
|
"atom_descriptors": null,
|
||||||
|
"atom_descriptors_path": null,
|
||||||
|
"atom_descriptors_size": 0,
|
||||||
|
"atom_features_size": 0,
|
||||||
|
"atom_messages": false,
|
||||||
|
"atom_targets": [],
|
||||||
|
"batch_size": 50,
|
||||||
|
"bias": false,
|
||||||
|
"bias_solvent": false,
|
||||||
|
"bond_constraints": [],
|
||||||
|
"bond_descriptor_scaling": true,
|
||||||
|
"bond_descriptors": null,
|
||||||
|
"bond_descriptors_path": null,
|
||||||
|
"bond_descriptors_size": 0,
|
||||||
|
"bond_features_size": 0,
|
||||||
|
"bond_targets": [],
|
||||||
|
"cache_cutoff": 10000,
|
||||||
|
"checkpoint_dir": null,
|
||||||
|
"checkpoint_frzn": null,
|
||||||
|
"checkpoint_path": null,
|
||||||
|
"checkpoint_paths": null,
|
||||||
|
"class_balance": false,
|
||||||
|
"config_path": "../data/args_files/optimized_configs.json",
|
||||||
|
"constraints_path": null,
|
||||||
|
"crossval_index_dir": null,
|
||||||
|
"crossval_index_file": null,
|
||||||
|
"crossval_index_sets": null,
|
||||||
|
"cuda": true,
|
||||||
|
"data_path": "../data/crossval_splits/all_amine_split_for_paper/cv_3/train.csv",
|
||||||
|
"data_weights_path": "../data/crossval_splits/all_amine_split_for_paper/cv_3/train_weights.csv",
|
||||||
|
"dataset_type": "regression",
|
||||||
|
"depth": 4,
|
||||||
|
"depth_solvent": 3,
|
||||||
|
"device": {
|
||||||
|
"_string": "cuda",
|
||||||
|
"_type": "python_object (type = device)",
|
||||||
|
"_value": "gASVHwAAAAAAAACMBXRvcmNolIwGZGV2aWNllJOUjARjdWRhlIWUUpQu"
|
||||||
|
},
|
||||||
|
"dropout": 0.1,
|
||||||
|
"empty_cache": false,
|
||||||
|
"ensemble_size": 1,
|
||||||
|
"epochs": 50,
|
||||||
|
"evidential_regularization": 0,
|
||||||
|
"explicit_h": false,
|
||||||
|
"extra_metrics": [],
|
||||||
|
"features_generator": null,
|
||||||
|
"features_only": false,
|
||||||
|
"features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_3/train_extra_x.csv"
|
||||||
|
],
|
||||||
|
"features_scaling": true,
|
||||||
|
"features_size": null,
|
||||||
|
"ffn_hidden_size": 600,
|
||||||
|
"ffn_num_layers": 3,
|
||||||
|
"final_lr": 0.0001,
|
||||||
|
"folds_file": null,
|
||||||
|
"freeze_first_only": false,
|
||||||
|
"frzn_ffn_layers": 0,
|
||||||
|
"gpu": null,
|
||||||
|
"grad_clip": null,
|
||||||
|
"hidden_size": 600,
|
||||||
|
"hidden_size_solvent": 300,
|
||||||
|
"ignore_columns": null,
|
||||||
|
"ignore_nan_metrics": false,
|
||||||
|
"init_lr": 0.0001,
|
||||||
|
"is_atom_bond_targets": false,
|
||||||
|
"keeping_atom_map": false,
|
||||||
|
"log_frequency": 10,
|
||||||
|
"loss_function": "mse",
|
||||||
|
"max_data_size": null,
|
||||||
|
"max_lr": 0.001,
|
||||||
|
"metric": "rmse",
|
||||||
|
"metrics": [
|
||||||
|
"rmse"
|
||||||
|
],
|
||||||
|
"minimize_score": true,
|
||||||
|
"mpn_shared": false,
|
||||||
|
"multiclass_num_classes": 3,
|
||||||
|
"no_adding_bond_types": false,
|
||||||
|
"no_atom_descriptor_scaling": false,
|
||||||
|
"no_bond_descriptor_scaling": false,
|
||||||
|
"no_cache_mol": false,
|
||||||
|
"no_cuda": false,
|
||||||
|
"no_features_scaling": false,
|
||||||
|
"no_shared_atom_bond_ffn": false,
|
||||||
|
"num_folds": 1,
|
||||||
|
"num_lrs": 1,
|
||||||
|
"num_tasks": 1,
|
||||||
|
"num_workers": 8,
|
||||||
|
"number_of_molecules": 1,
|
||||||
|
"overwrite_default_atom_features": false,
|
||||||
|
"overwrite_default_bond_features": false,
|
||||||
|
"phase_features_path": null,
|
||||||
|
"pytorch_seed": 0,
|
||||||
|
"quantile_loss_alpha": 0.1,
|
||||||
|
"quantiles": [],
|
||||||
|
"quiet": false,
|
||||||
|
"reaction": false,
|
||||||
|
"reaction_mode": "reac_diff",
|
||||||
|
"reaction_solvent": false,
|
||||||
|
"reproducibility": {
|
||||||
|
"command_line": "python main_script.py train all_amine_split_for_paper",
|
||||||
|
"git_has_uncommitted_changes": true,
|
||||||
|
"git_root": "/media/andersonxps/wd_4tb/evan/LNP_ML",
|
||||||
|
"git_url": "https://github.com/jswitten/LNP_ML/tree/167822980dc26ba65c5c14539c4ce12b81b0b8f3",
|
||||||
|
"time": "Tue Jul 30 10:34:31 2024"
|
||||||
|
},
|
||||||
|
"resume_experiment": false,
|
||||||
|
"save_dir": "../data/crossval_splits/all_amine_split_for_paper/cv_3",
|
||||||
|
"save_preds": false,
|
||||||
|
"save_smiles_splits": false,
|
||||||
|
"seed": 42,
|
||||||
|
"separate_test_atom_descriptors_path": null,
|
||||||
|
"separate_test_bond_descriptors_path": null,
|
||||||
|
"separate_test_constraints_path": null,
|
||||||
|
"separate_test_features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_3/test_extra_x.csv"
|
||||||
|
],
|
||||||
|
"separate_test_path": "../data/crossval_splits/all_amine_split_for_paper/cv_3/test.csv",
|
||||||
|
"separate_test_phase_features_path": null,
|
||||||
|
"separate_val_atom_descriptors_path": null,
|
||||||
|
"separate_val_bond_descriptors_path": null,
|
||||||
|
"separate_val_constraints_path": null,
|
||||||
|
"separate_val_features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_3/valid_extra_x.csv"
|
||||||
|
],
|
||||||
|
"separate_val_path": "../data/crossval_splits/all_amine_split_for_paper/cv_3/valid.csv",
|
||||||
|
"separate_val_phase_features_path": null,
|
||||||
|
"shared_atom_bond_ffn": true,
|
||||||
|
"show_individual_scores": false,
|
||||||
|
"smiles_columns": [
|
||||||
|
"smiles"
|
||||||
|
],
|
||||||
|
"spectra_activation": "exp",
|
||||||
|
"spectra_phase_mask_path": null,
|
||||||
|
"spectra_target_floor": 1e-08,
|
||||||
|
"split_key_molecule": 0,
|
||||||
|
"split_sizes": [
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"split_type": "random",
|
||||||
|
"target_columns": null,
|
||||||
|
"target_weights": null,
|
||||||
|
"task_names": [
|
||||||
|
"quantified_delivery"
|
||||||
|
],
|
||||||
|
"test": false,
|
||||||
|
"test_fold_index": null,
|
||||||
|
"train_data_size": null,
|
||||||
|
"undirected": false,
|
||||||
|
"use_input_features": true,
|
||||||
|
"val_fold_index": null,
|
||||||
|
"warmup_epochs": 2.0,
|
||||||
|
"weights_ffn_num_layers": 2
|
||||||
|
}
|
||||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"rmse": [
|
||||||
|
0.9245934905333985
|
||||||
|
]
|
||||||
|
}
|
||||||
165
models/pretrained/all_amine_split_for_LiON/cv_4/args.json
Normal file
165
models/pretrained/all_amine_split_for_LiON/cv_4/args.json
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
{
|
||||||
|
"activation": "ReLU",
|
||||||
|
"adding_bond_types": true,
|
||||||
|
"adding_h": false,
|
||||||
|
"aggregation": "mean",
|
||||||
|
"aggregation_norm": 100,
|
||||||
|
"atom_constraints": [],
|
||||||
|
"atom_descriptor_scaling": true,
|
||||||
|
"atom_descriptors": null,
|
||||||
|
"atom_descriptors_path": null,
|
||||||
|
"atom_descriptors_size": 0,
|
||||||
|
"atom_features_size": 0,
|
||||||
|
"atom_messages": false,
|
||||||
|
"atom_targets": [],
|
||||||
|
"batch_size": 50,
|
||||||
|
"bias": false,
|
||||||
|
"bias_solvent": false,
|
||||||
|
"bond_constraints": [],
|
||||||
|
"bond_descriptor_scaling": true,
|
||||||
|
"bond_descriptors": null,
|
||||||
|
"bond_descriptors_path": null,
|
||||||
|
"bond_descriptors_size": 0,
|
||||||
|
"bond_features_size": 0,
|
||||||
|
"bond_targets": [],
|
||||||
|
"cache_cutoff": 10000,
|
||||||
|
"checkpoint_dir": null,
|
||||||
|
"checkpoint_frzn": null,
|
||||||
|
"checkpoint_path": null,
|
||||||
|
"checkpoint_paths": null,
|
||||||
|
"class_balance": false,
|
||||||
|
"config_path": "../data/args_files/optimized_configs.json",
|
||||||
|
"constraints_path": null,
|
||||||
|
"crossval_index_dir": null,
|
||||||
|
"crossval_index_file": null,
|
||||||
|
"crossval_index_sets": null,
|
||||||
|
"cuda": true,
|
||||||
|
"data_path": "../data/crossval_splits/all_amine_split_for_paper/cv_4/train.csv",
|
||||||
|
"data_weights_path": "../data/crossval_splits/all_amine_split_for_paper/cv_4/train_weights.csv",
|
||||||
|
"dataset_type": "regression",
|
||||||
|
"depth": 4,
|
||||||
|
"depth_solvent": 3,
|
||||||
|
"device": {
|
||||||
|
"_string": "cuda",
|
||||||
|
"_type": "python_object (type = device)",
|
||||||
|
"_value": "gASVHwAAAAAAAACMBXRvcmNolIwGZGV2aWNllJOUjARjdWRhlIWUUpQu"
|
||||||
|
},
|
||||||
|
"dropout": 0.1,
|
||||||
|
"empty_cache": false,
|
||||||
|
"ensemble_size": 1,
|
||||||
|
"epochs": 50,
|
||||||
|
"evidential_regularization": 0,
|
||||||
|
"explicit_h": false,
|
||||||
|
"extra_metrics": [],
|
||||||
|
"features_generator": null,
|
||||||
|
"features_only": false,
|
||||||
|
"features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_4/train_extra_x.csv"
|
||||||
|
],
|
||||||
|
"features_scaling": true,
|
||||||
|
"features_size": null,
|
||||||
|
"ffn_hidden_size": 600,
|
||||||
|
"ffn_num_layers": 3,
|
||||||
|
"final_lr": 0.0001,
|
||||||
|
"folds_file": null,
|
||||||
|
"freeze_first_only": false,
|
||||||
|
"frzn_ffn_layers": 0,
|
||||||
|
"gpu": null,
|
||||||
|
"grad_clip": null,
|
||||||
|
"hidden_size": 600,
|
||||||
|
"hidden_size_solvent": 300,
|
||||||
|
"ignore_columns": null,
|
||||||
|
"ignore_nan_metrics": false,
|
||||||
|
"init_lr": 0.0001,
|
||||||
|
"is_atom_bond_targets": false,
|
||||||
|
"keeping_atom_map": false,
|
||||||
|
"log_frequency": 10,
|
||||||
|
"loss_function": "mse",
|
||||||
|
"max_data_size": null,
|
||||||
|
"max_lr": 0.001,
|
||||||
|
"metric": "rmse",
|
||||||
|
"metrics": [
|
||||||
|
"rmse"
|
||||||
|
],
|
||||||
|
"minimize_score": true,
|
||||||
|
"mpn_shared": false,
|
||||||
|
"multiclass_num_classes": 3,
|
||||||
|
"no_adding_bond_types": false,
|
||||||
|
"no_atom_descriptor_scaling": false,
|
||||||
|
"no_bond_descriptor_scaling": false,
|
||||||
|
"no_cache_mol": false,
|
||||||
|
"no_cuda": false,
|
||||||
|
"no_features_scaling": false,
|
||||||
|
"no_shared_atom_bond_ffn": false,
|
||||||
|
"num_folds": 1,
|
||||||
|
"num_lrs": 1,
|
||||||
|
"num_tasks": 1,
|
||||||
|
"num_workers": 8,
|
||||||
|
"number_of_molecules": 1,
|
||||||
|
"overwrite_default_atom_features": false,
|
||||||
|
"overwrite_default_bond_features": false,
|
||||||
|
"phase_features_path": null,
|
||||||
|
"pytorch_seed": 0,
|
||||||
|
"quantile_loss_alpha": 0.1,
|
||||||
|
"quantiles": [],
|
||||||
|
"quiet": false,
|
||||||
|
"reaction": false,
|
||||||
|
"reaction_mode": "reac_diff",
|
||||||
|
"reaction_solvent": false,
|
||||||
|
"reproducibility": {
|
||||||
|
"command_line": "python main_script.py train all_amine_split_for_paper",
|
||||||
|
"git_has_uncommitted_changes": true,
|
||||||
|
"git_root": "/media/andersonxps/wd_4tb/evan/LNP_ML",
|
||||||
|
"git_url": "https://github.com/jswitten/LNP_ML/tree/167822980dc26ba65c5c14539c4ce12b81b0b8f3",
|
||||||
|
"time": "Tue Jul 30 10:40:44 2024"
|
||||||
|
},
|
||||||
|
"resume_experiment": false,
|
||||||
|
"save_dir": "../data/crossval_splits/all_amine_split_for_paper/cv_4",
|
||||||
|
"save_preds": false,
|
||||||
|
"save_smiles_splits": false,
|
||||||
|
"seed": 42,
|
||||||
|
"separate_test_atom_descriptors_path": null,
|
||||||
|
"separate_test_bond_descriptors_path": null,
|
||||||
|
"separate_test_constraints_path": null,
|
||||||
|
"separate_test_features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_4/test_extra_x.csv"
|
||||||
|
],
|
||||||
|
"separate_test_path": "../data/crossval_splits/all_amine_split_for_paper/cv_4/test.csv",
|
||||||
|
"separate_test_phase_features_path": null,
|
||||||
|
"separate_val_atom_descriptors_path": null,
|
||||||
|
"separate_val_bond_descriptors_path": null,
|
||||||
|
"separate_val_constraints_path": null,
|
||||||
|
"separate_val_features_path": [
|
||||||
|
"../data/crossval_splits/all_amine_split_for_paper/cv_4/valid_extra_x.csv"
|
||||||
|
],
|
||||||
|
"separate_val_path": "../data/crossval_splits/all_amine_split_for_paper/cv_4/valid.csv",
|
||||||
|
"separate_val_phase_features_path": null,
|
||||||
|
"shared_atom_bond_ffn": true,
|
||||||
|
"show_individual_scores": false,
|
||||||
|
"smiles_columns": [
|
||||||
|
"smiles"
|
||||||
|
],
|
||||||
|
"spectra_activation": "exp",
|
||||||
|
"spectra_phase_mask_path": null,
|
||||||
|
"spectra_target_floor": 1e-08,
|
||||||
|
"split_key_molecule": 0,
|
||||||
|
"split_sizes": [
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"split_type": "random",
|
||||||
|
"target_columns": null,
|
||||||
|
"target_weights": null,
|
||||||
|
"task_names": [
|
||||||
|
"quantified_delivery"
|
||||||
|
],
|
||||||
|
"test": false,
|
||||||
|
"test_fold_index": null,
|
||||||
|
"train_data_size": null,
|
||||||
|
"undirected": false,
|
||||||
|
"use_input_features": true,
|
||||||
|
"val_fold_index": null,
|
||||||
|
"warmup_epochs": 2.0,
|
||||||
|
"weights_ffn_num_layers": 2
|
||||||
|
}
|
||||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"rmse": [
|
||||||
|
0.8268900471469541
|
||||||
|
]
|
||||||
|
}
|
||||||
0
notebooks/.gitkeep
Normal file
0
notebooks/.gitkeep
Normal file
26
pixi.toml
Normal file
26
pixi.toml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
[workspace]
|
||||||
|
authors = ["RYDE-WORK <rydewu@primedigitaltech.com>"]
|
||||||
|
channels = ["conda-forge"]
|
||||||
|
name = "LNP_ML"
|
||||||
|
platforms = ["linux-64", "osx-arm64"]
|
||||||
|
version = "0.1.0"
|
||||||
|
|
||||||
|
[tasks]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
loguru = "*"
|
||||||
|
ruff = "*"
|
||||||
|
tqdm = "*"
|
||||||
|
typer = "*"
|
||||||
|
pip = "*"
|
||||||
|
python = "3.8.*"
|
||||||
|
|
||||||
|
[pypi-dependencies]
|
||||||
|
lnp_ml = { path = ".", editable = true }
|
||||||
|
chemprop = "==1.7.0"
|
||||||
|
setuptools = "*"
|
||||||
|
pandas = ">=2.0.3, <3"
|
||||||
|
openpyxl = ">=3.1.5, <4"
|
||||||
|
python-dotenv = ">=1.0.1, <2"
|
||||||
|
pyarrow = ">=17.0.0, <18"
|
||||||
|
fastparquet = ">=2024.2.0, <2025"
|
||||||
32
pyproject.toml
Normal file
32
pyproject.toml
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["flit_core >=3.2,<4"]
|
||||||
|
build-backend = "flit_core.buildapi"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "lnp_ml"
|
||||||
|
version = "0.0.1"
|
||||||
|
description = "A short description of the project."
|
||||||
|
authors = [
|
||||||
|
{ name = "Wu Dinghong" },
|
||||||
|
]
|
||||||
|
license = { file = "LICENSE" }
|
||||||
|
readme = "README.md"
|
||||||
|
classifiers = [
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"License :: OSI Approved :: MIT License"
|
||||||
|
]
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 99
|
||||||
|
src = ["lnp_ml"]
|
||||||
|
include = ["pyproject.toml", "lnp_ml/**/*.py"]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
extend-select = ["I"] # Add import sorting
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
known-first-party = ["lnp_ml"]
|
||||||
|
force-sort-within-sections = true
|
||||||
|
|
||||||
0
references/.gitkeep
Normal file
0
references/.gitkeep
Normal file
0
reports/.gitkeep
Normal file
0
reports/.gitkeep
Normal file
0
reports/figures/.gitkeep
Normal file
0
reports/figures/.gitkeep
Normal file
85
scripts/data_cleaning.py
Normal file
85
scripts/data_cleaning.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
"""数据清洗脚本:修正原始数据中的问题"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import typer
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lnp_ml.config import RAW_DATA_DIR, INTERIM_DATA_DIR
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
input_path: Path = RAW_DATA_DIR / "internal_deleted_uncorrected.xlsx",
|
||||||
|
output_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
清洗原始数据,修正已知问题。
|
||||||
|
|
||||||
|
修正内容:
|
||||||
|
1. 修正肌肉注射组 Biodistribution_muscle=0.7745 的数据
|
||||||
|
2. 修复阳性对照组 (Amine="Crtl") 的数据
|
||||||
|
3. 按给药途径分组进行 z-score 标准化
|
||||||
|
4. 对 size 列取 log
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading data from {input_path}")
|
||||||
|
df = pd.read_excel(input_path, header=2)
|
||||||
|
logger.info(f"Loaded {len(df)} samples")
|
||||||
|
|
||||||
|
# 修正肌肉注射组 0.7745 的数据
|
||||||
|
logger.info("Correcting Biodistribution_muscle=0.7745 rows...")
|
||||||
|
rows_to_correct = df[df["Biodistribution_muscle"] == 0.7745]
|
||||||
|
for index, row in rows_to_correct.iterrows():
|
||||||
|
total_biodistribution = pd.to_numeric(row[[
|
||||||
|
"Biodistribution_lymph_nodes",
|
||||||
|
"Biodistribution_heart",
|
||||||
|
"Biodistribution_liver",
|
||||||
|
"Biodistribution_spleen",
|
||||||
|
"Biodistribution_lung",
|
||||||
|
"Biodistribution_kidney",
|
||||||
|
"Biodistribution_muscle"
|
||||||
|
]]).sum()
|
||||||
|
df.at[index, "Biodistribution_lymph_nodes"] = pd.to_numeric(row["Biodistribution_lymph_nodes"]) / total_biodistribution
|
||||||
|
df.at[index, "Biodistribution_heart"] = pd.to_numeric(row["Biodistribution_heart"]) / total_biodistribution
|
||||||
|
df.at[index, "Biodistribution_liver"] = pd.to_numeric(row["Biodistribution_liver"]) / total_biodistribution
|
||||||
|
df.at[index, "Biodistribution_spleen"] = pd.to_numeric(row["Biodistribution_spleen"]) / total_biodistribution
|
||||||
|
df.at[index, "Biodistribution_lung"] = pd.to_numeric(row["Biodistribution_lung"]) / total_biodistribution
|
||||||
|
df.at[index, "Biodistribution_kidney"] = pd.to_numeric(row["Biodistribution_kidney"]) / total_biodistribution
|
||||||
|
df.at[index, "Biodistribution_muscle"] = pd.to_numeric(row["Biodistribution_muscle"]) / total_biodistribution
|
||||||
|
df.at[index, "quantified_total_luminescence"] = pd.to_numeric(row["quantified_total_luminescence"]) / (1 - 0.7745)
|
||||||
|
df.at[index, "unnormalized_delivery"] = df.at[index, "quantified_total_luminescence"]
|
||||||
|
logger.info(f" Corrected {len(rows_to_correct)} rows")
|
||||||
|
|
||||||
|
# 修复阳性对照组的数据
|
||||||
|
logger.info("Fixing control group (Amine='Crtl')...")
|
||||||
|
rows_to_override = df["Amine"] == "Crtl"
|
||||||
|
df.loc[rows_to_override, "quantified_total_luminescence"] = 1
|
||||||
|
df.loc[rows_to_override, "unnormalized_delivery"] = 1
|
||||||
|
logger.info(f" Fixed {rows_to_override.sum()} rows")
|
||||||
|
|
||||||
|
# 分别对肌肉注射组和静脉注射组重新进行 z-score 标准化
|
||||||
|
logger.info("Z-score normalizing delivery by Route_of_administration...")
|
||||||
|
df["unnormalized_delivery"] = pd.to_numeric(df["unnormalized_delivery"], errors="coerce")
|
||||||
|
df["quantified_delivery"] = (
|
||||||
|
df.groupby("Route_of_administration")["unnormalized_delivery"]
|
||||||
|
.transform(lambda x: (x - x.mean()) / x.std())
|
||||||
|
)
|
||||||
|
|
||||||
|
# 对 size 列取 log
|
||||||
|
logger.info("Log-transforming size column...")
|
||||||
|
df["size"] = pd.to_numeric(df["size"], errors="coerce")
|
||||||
|
df["size"] = np.log(df["size"].replace(0, np.nan)) # 避免 log(0)
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
df.to_csv(output_path, index=False)
|
||||||
|
logger.success(f"Saved cleaned data to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
158
scripts/process_data.py
Normal file
158
scripts/process_data.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
"""数据处理脚本:将原始数据转换为模型可用的格式"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import typer
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lnp_ml.config import INTERIM_DATA_DIR, PROCESSED_DATA_DIR
|
||||||
|
from lnp_ml.dataset import (
|
||||||
|
process_dataframe,
|
||||||
|
SMILES_COL,
|
||||||
|
COMP_COLS,
|
||||||
|
HELP_COLS,
|
||||||
|
TARGET_REGRESSION,
|
||||||
|
TARGET_CLASSIFICATION_PDI,
|
||||||
|
TARGET_CLASSIFICATION_EE,
|
||||||
|
TARGET_TOXIC,
|
||||||
|
TARGET_BIODIST,
|
||||||
|
get_phys_cols,
|
||||||
|
get_exp_cols,
|
||||||
|
)
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv",
|
||||||
|
output_dir: Path = PROCESSED_DATA_DIR,
|
||||||
|
train_ratio: float = 0.56,
|
||||||
|
val_ratio: float = 0.14,
|
||||||
|
seed: int = 42,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
处理原始数据并划分训练/验证/测试集。
|
||||||
|
|
||||||
|
输出文件:
|
||||||
|
- train.parquet: 训练集
|
||||||
|
- val.parquet: 验证集
|
||||||
|
- test.parquet: 测试集
|
||||||
|
- feature_columns.txt: 特征列名配置
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading data from {input_path}")
|
||||||
|
df = pd.read_csv(input_path)
|
||||||
|
logger.info(f"Loaded {len(df)} samples")
|
||||||
|
|
||||||
|
# 处理数据
|
||||||
|
logger.info("Processing dataframe...")
|
||||||
|
df = process_dataframe(df)
|
||||||
|
|
||||||
|
# 定义要保留的列
|
||||||
|
phys_cols = get_phys_cols()
|
||||||
|
exp_cols = get_exp_cols()
|
||||||
|
|
||||||
|
keep_cols = (
|
||||||
|
[SMILES_COL]
|
||||||
|
+ COMP_COLS
|
||||||
|
+ phys_cols
|
||||||
|
+ HELP_COLS
|
||||||
|
+ exp_cols
|
||||||
|
+ TARGET_REGRESSION
|
||||||
|
+ TARGET_CLASSIFICATION_PDI
|
||||||
|
+ TARGET_CLASSIFICATION_EE
|
||||||
|
+ [TARGET_TOXIC]
|
||||||
|
+ TARGET_BIODIST
|
||||||
|
)
|
||||||
|
|
||||||
|
# 只保留存在的列
|
||||||
|
keep_cols = [c for c in keep_cols if c in df.columns]
|
||||||
|
df = df[keep_cols]
|
||||||
|
|
||||||
|
# 随机打乱并划分
|
||||||
|
logger.info("Splitting dataset...")
|
||||||
|
df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
|
||||||
|
|
||||||
|
n = len(df)
|
||||||
|
n_train = int(n * train_ratio)
|
||||||
|
n_val = int(n * val_ratio)
|
||||||
|
|
||||||
|
train_df = df.iloc[:n_train]
|
||||||
|
val_df = df.iloc[n_train:n_train + n_val]
|
||||||
|
test_df = df.iloc[n_train + n_val:]
|
||||||
|
|
||||||
|
logger.info(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
train_path = output_dir / "train.parquet"
|
||||||
|
val_path = output_dir / "val.parquet"
|
||||||
|
test_path = output_dir / "test.parquet"
|
||||||
|
|
||||||
|
train_df.to_parquet(train_path, index=False)
|
||||||
|
val_df.to_parquet(val_path, index=False)
|
||||||
|
test_df.to_parquet(test_path, index=False)
|
||||||
|
|
||||||
|
logger.success(f"Saved train to {train_path}")
|
||||||
|
logger.success(f"Saved val to {val_path}")
|
||||||
|
logger.success(f"Saved test to {test_path}")
|
||||||
|
|
||||||
|
# 保存列名配置
|
||||||
|
config_path = output_dir / "feature_columns.txt"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
f.write("# Feature columns configuration\n\n")
|
||||||
|
f.write(f"# SMILES\n{SMILES_COL}\n\n")
|
||||||
|
f.write(f"# comp token [{len(COMP_COLS)}]\n")
|
||||||
|
f.write("\n".join(COMP_COLS) + "\n\n")
|
||||||
|
f.write(f"# phys token [{len(phys_cols)}]\n")
|
||||||
|
f.write("\n".join(phys_cols) + "\n\n")
|
||||||
|
f.write(f"# help token [{len(HELP_COLS)}]\n")
|
||||||
|
f.write("\n".join(HELP_COLS) + "\n\n")
|
||||||
|
f.write(f"# exp token [{len(exp_cols)}]\n")
|
||||||
|
f.write("\n".join(exp_cols) + "\n\n")
|
||||||
|
f.write("# Targets\n")
|
||||||
|
f.write("## Regression\n")
|
||||||
|
f.write("\n".join(TARGET_REGRESSION) + "\n")
|
||||||
|
f.write("## PDI classification\n")
|
||||||
|
f.write("\n".join(TARGET_CLASSIFICATION_PDI) + "\n")
|
||||||
|
f.write("## EE classification\n")
|
||||||
|
f.write("\n".join(TARGET_CLASSIFICATION_EE) + "\n")
|
||||||
|
f.write("## Toxic\n")
|
||||||
|
f.write(f"{TARGET_TOXIC}\n")
|
||||||
|
f.write("## Biodistribution\n")
|
||||||
|
f.write("\n".join(TARGET_BIODIST) + "\n")
|
||||||
|
|
||||||
|
logger.success(f"Saved feature config to {config_path}")
|
||||||
|
|
||||||
|
# 打印统计信息
|
||||||
|
logger.info("\n=== Dataset Statistics ===")
|
||||||
|
logger.info(f"Total samples: {n}")
|
||||||
|
logger.info(f"SMILES unique: {df[SMILES_COL].nunique()}")
|
||||||
|
|
||||||
|
# 缺失值统计
|
||||||
|
logger.info("\nMissing values in targets:")
|
||||||
|
for col in TARGET_REGRESSION + [TARGET_TOXIC]:
|
||||||
|
if col in df.columns:
|
||||||
|
missing = df[col].isna().sum()
|
||||||
|
logger.info(f" {col}: {missing} ({100*missing/n:.1f}%)")
|
||||||
|
|
||||||
|
# PDI 分布
|
||||||
|
if all(c in df.columns for c in TARGET_CLASSIFICATION_PDI):
|
||||||
|
pdi_sum = df[TARGET_CLASSIFICATION_PDI].sum()
|
||||||
|
logger.info(f"\nPDI distribution:")
|
||||||
|
for col, count in pdi_sum.items():
|
||||||
|
logger.info(f" {col}: {int(count)}")
|
||||||
|
|
||||||
|
# EE 分布
|
||||||
|
if all(c in df.columns for c in TARGET_CLASSIFICATION_EE):
|
||||||
|
ee_sum = df[TARGET_CLASSIFICATION_EE].sum()
|
||||||
|
logger.info(f"\nEE distribution:")
|
||||||
|
for col, count in ee_sum.items():
|
||||||
|
logger.info(f" {col}: {int(count)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user