-
Notifications
You must be signed in to change notification settings - Fork 9
Support PEtab SciML problems and enable linting #482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: lint_mapping_table
Are you sure you want to change the base?
Changes from all commits
bfcb673
8daf3b6
88cbfe2
fe86e39
28a407e
66035d9
24a7c2c
f02c21f
e1c3d3d
f153916
ef86844
1eb9dec
1165218
f2c4193
96858f2
dd3f87b
8c1a8a9
8c5cd2b
238829b
d4d8a94
653ccc7
15c4486
3d8901b
4ef0a10
58c549a
f1d4540
2f76d3e
81292a7
840ae7e
9af9a54
e342f26
254a596
d95f09a
239e1ff
3c1a09f
118aed7
e133f21
81148e6
7759b50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| Annotated, | ||
| Any, | ||
| Generic, | ||
| Literal, | ||
| Self, | ||
| TypeVar, | ||
| get_args, | ||
|
|
@@ -308,6 +309,22 @@ def __iadd__(self, other: T) -> BaseTable[T]: | |
| return self | ||
|
|
||
|
|
||
| # SciML extension classes — imported after BaseTable is defined to avoid | ||
| # circular imports (sciml.py does not import from core.py). | ||
| from .extensions.sciml import ( # noqa: E402 | ||
| HybridizationTable, | ||
| SciMLConfig, | ||
| SciMLExt, | ||
| ) | ||
|
|
||
|
|
||
| class ProblemExtensions: | ||
| """Runtime extension state attached to a :class:`Problem`.""" | ||
|
|
||
| def __init__(self, sciml: SciMLExt = None): | ||
| self.sciml: SciMLExt = sciml or SciMLExt() | ||
|
|
||
|
|
||
| class Observable(BaseModel): | ||
| """Observable definition.""" | ||
|
|
||
|
|
@@ -318,9 +335,9 @@ class Observable(BaseModel): | |
| #: Observable name. | ||
| name: str | None = Field(alias=C.OBSERVABLE_NAME, default=None) | ||
| #: Observable formula. | ||
| formula: sp.Basic | None = Field(alias=C.OBSERVABLE_FORMULA, default=None) | ||
| formula: sp.Basic = Field(alias=C.OBSERVABLE_FORMULA) | ||
| #: Noise formula. | ||
| noise_formula: sp.Basic | None = Field(alias=C.NOISE_FORMULA, default=None) | ||
| noise_formula: sp.Basic = Field(alias=C.NOISE_FORMULA) | ||
| #: Noise distribution. | ||
| noise_distribution: NoiseDistribution = Field( | ||
| alias=C.NOISE_DISTRIBUTION, default=NoiseDistribution.NORMAL | ||
|
|
@@ -926,7 +943,8 @@ class Parameter(BaseModel): | |
| ) | ||
| #: Nominal value. | ||
| nominal_value: Annotated[ | ||
| float | None, BeforeValidator(_convert_nan_to_none) | ||
| # PEtab SciML supports arrays via "array" nominal values | ||
| float | Literal["array"] | None, BeforeValidator(_convert_nan_to_none) | ||
| ] = Field(alias=C.NOMINAL_VALUE, default=None) | ||
| #: Is the parameter to be estimated? | ||
| estimate: bool = Field(alias=C.ESTIMATE, default=True) | ||
|
|
@@ -1133,22 +1151,33 @@ def __init__( | |
| measurement_tables: list[MeasurementTable] = None, | ||
| parameter_tables: list[ParameterTable] = None, | ||
| mapping_tables: list[MappingTable] = None, | ||
| extensions: ProblemExtensions = None, | ||
| config: ProblemConfig = None, | ||
| ): | ||
| from ..v2.lint import default_validation_tasks | ||
| from ..v2.lint import default_validation_tasks, sciml_validation_tasks | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also this, import |
||
|
|
||
| self.config = config | ||
| self.models: list[Model] = models or [] | ||
| self.validation_tasks: list[ValidationTask] = ( | ||
| default_validation_tasks.copy() | ||
| ) | ||
| if ( | ||
| config | ||
| and config.extensions | ||
| and config.extensions.get(C.EXT_ID_SCIML) | ||
| ): | ||
| self.validation_tasks: list[ValidationTask] = ( | ||
| sciml_validation_tasks.copy() | ||
| ) | ||
| else: | ||
| self.validation_tasks: list[ValidationTask] = ( | ||
| default_validation_tasks.copy() | ||
| ) | ||
|
|
||
| self.observable_tables = observable_tables or [ObservableTable()] | ||
| self.condition_tables = condition_tables or [ConditionTable()] | ||
| self.experiment_tables = experiment_tables or [ExperimentTable()] | ||
| self.measurement_tables = measurement_tables or [MeasurementTable()] | ||
| self.mapping_tables = mapping_tables or [MappingTable()] | ||
| self.parameter_tables = parameter_tables or [ParameterTable()] | ||
| self.extensions = extensions or ProblemExtensions() | ||
|
|
||
| def __repr__(self): | ||
| return f"<{self.__class__.__name__} id={self.id!r}>" | ||
|
|
@@ -1321,6 +1350,45 @@ def from_yaml( | |
| else None | ||
| ) | ||
|
|
||
| extensions = ProblemExtensions() | ||
| if config.extensions and config.extensions.get(C.EXT_ID_SCIML): | ||
| from petab_sciml import ArrayDataStandard, NNModel, NNModelStandard | ||
|
|
||
| # Neural network classes are constructed via pytorch for now to get | ||
| # the proper inputs | ||
| neural_networks = [ | ||
| NNModel.from_pytorch_module( | ||
| NNModelStandard.load_data( | ||
| _generate_path( | ||
| file_path=nn_config.location, | ||
| base_path=base_path, | ||
| ) | ||
| ).to_pytorch_module(), | ||
| nn_model_id=nn_id, | ||
| ) | ||
| for nn_id, nn_config in ( | ||
| config.extensions[C.EXT_ID_SCIML].neural_networks or {} | ||
| ).items() | ||
| ] | ||
|
|
||
| hybridization_tables = [ | ||
| HybridizationTable.from_tsv(f, base_path) | ||
| for f in config.extensions[C.EXT_ID_SCIML].hybridization_files | ||
| ] | ||
|
|
||
| array_data_files = [ | ||
| ArrayDataStandard.load_data(_generate_path(f, base_path)) | ||
| for f in config.extensions[C.EXT_ID_SCIML].array_files | ||
| ] | ||
|
Comment on lines
+1355
to
+1382
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please move to a factory method like sciml = SciMLExt.from_config(config) if config.extensions and config.extensions.get(C.EXT_ID_SCIML) else None
extensions = ProblemExtensions(sciml=sciml) |
||
|
|
||
| extensions = ProblemExtensions( | ||
| sciml=SciMLExt( | ||
| neural_networks=neural_networks, | ||
| hybridization_tables=hybridization_tables, | ||
| array_data_files=array_data_files, | ||
| ) | ||
| ) | ||
|
|
||
| return Problem( | ||
| config=config, | ||
| models=models, | ||
|
|
@@ -1330,6 +1398,7 @@ def from_yaml( | |
| measurement_tables=measurement_tables, | ||
| parameter_tables=parameter_tables, | ||
| mapping_tables=mapping_tables, | ||
| extensions=extensions, | ||
| ) | ||
|
|
||
| @staticmethod | ||
|
|
@@ -1940,14 +2009,21 @@ def validate( | |
|
|
||
| validation_results = ValidationResultList() | ||
|
|
||
| if self.config and self.config.extensions: | ||
| extensions = ",".join(self.config.extensions.keys()) | ||
| supported_extensions = {C.EXT_ID_SCIML} | ||
| if ( | ||
| self.config | ||
| and self.config.extensions | ||
| and (self.config.extensions.keys() - supported_extensions) | ||
| ): | ||
| extensions_without_support = ",".join( | ||
| self.config.extensions.keys() - supported_extensions | ||
| ) | ||
| validation_results.append( | ||
| ValidationIssue( | ||
| ValidationIssueSeverity.WARNING, | ||
| "Validation of PEtab extensions is not yet implemented, " | ||
| "but the given problem uses the following extensions: " | ||
| f"{extensions}", | ||
| "The given problem uses the following extensions for " | ||
| "which validation is not yet implemented: " | ||
| f"{extensions_without_support}", | ||
| ) | ||
| ) | ||
|
|
||
|
|
@@ -2505,6 +2581,23 @@ class ProblemConfig(BaseModel): | |
| validate_assignment=True, | ||
| ) | ||
|
|
||
| @field_validator("extensions", mode="before") | ||
| @classmethod | ||
| def _parse_extensions(cls, v): | ||
| """Parse extensions dict and convert known extensions to their specific | ||
| config classes.""" | ||
| if isinstance(v, dict): | ||
| parsed_extensions = {} | ||
| for ext_name, ext_config in v.items(): | ||
| if ext_name == C.EXT_ID_SCIML: | ||
| # Convert sciml extension to SciMLConfig | ||
| parsed_extensions[ext_name] = SciMLConfig(**ext_config) | ||
| else: | ||
| # Keep other extensions as ExtensionConfig | ||
| parsed_extensions[ext_name] = ExtensionConfig(**ext_config) | ||
| return parsed_extensions | ||
| return v | ||
|
|
||
| # convert parameter_file to list | ||
| @field_validator( | ||
| "parameter_files", | ||
|
|
@@ -2542,12 +2635,22 @@ def to_yaml(self, filename: str | Path): | |
|
|
||
| for model_id in data.get("model_files", {}): | ||
| data["model_files"][model_id][C.MODEL_LOCATION] = str( | ||
| data["model_files"][model_id]["location"] | ||
| data["model_files"][model_id][C.MODEL_LOCATION] | ||
| ) | ||
| if data["id"] is None: | ||
| # The schema requires a valid id or no id field at all. | ||
| del data["id"] | ||
|
|
||
| for ext_id, d_ext in data[C.EXTENSIONS].items(): | ||
| if ext_id == C.EXT_ID_SCIML: | ||
| # convert Paths to strings | ||
| for key in ("array_files", "hybridization_files"): | ||
| d_ext[key] = list(map(str, d_ext[key])) | ||
| for nn in d_ext["neural_networks"]: | ||
| d_ext["neural_networks"][nn][C.MODEL_LOCATION] = str( | ||
| d_ext["neural_networks"][nn][C.MODEL_LOCATION] | ||
| ) | ||
|
Comment on lines
+2646
to
+2652
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move to |
||
|
|
||
| write_yaml(data, filename) | ||
|
|
||
| @property | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.