diff --git a/docs/docs/tutorials/analysis.ipynb b/docs/docs/tutorials/analysis.ipynb
index 41e6c7ed1..b00d430fb 100644
--- a/docs/docs/tutorials/analysis.ipynb
+++ b/docs/docs/tutorials/analysis.ipynb
@@ -338,9 +338,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
"language": "python",
- "name": "python3"
+ "name": "python3",
+ "display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
@@ -351,10 +351,9 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.14.4"
+ "pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/docs/tutorials/analysis1d.ipynb b/docs/docs/tutorials/analysis1d.ipynb
index 784f2be50..8d921318f 100644
--- a/docs/docs/tutorials/analysis1d.ipynb
+++ b/docs/docs/tutorials/analysis1d.ipynb
@@ -81,9 +81,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
"language": "python",
- "name": "python3"
+ "name": "python3",
+ "display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
@@ -94,10 +94,9 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.14.5"
+ "pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/docs/tutorials/component_collection.ipynb b/docs/docs/tutorials/component_collection.ipynb
index 54f763871..4138d5cb7 100644
--- a/docs/docs/tutorials/component_collection.ipynb
+++ b/docs/docs/tutorials/component_collection.ipynb
@@ -67,9 +67,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
"language": "python",
- "name": "python3"
+ "name": "python3",
+ "display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
@@ -80,10 +80,9 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.14.4"
+ "pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/docs/tutorials/components.ipynb b/docs/docs/tutorials/components.ipynb
index 70070f1e0..3b072d315 100644
--- a/docs/docs/tutorials/components.ipynb
+++ b/docs/docs/tutorials/components.ipynb
@@ -39,7 +39,9 @@
"gaussian = sm.Gaussian(display_name='Gaussian', width=0.5, area=1)\n",
"dho = sm.DampedHarmonicOscillator(display_name='DHO', center=1.0, width=0.3, area=2.0)\n",
"lorentzian = sm.Lorentzian(display_name='Lorentzian', center=-1.0, width=0.2, area=1.0)\n",
- "polynomial = sm.Polynomial(display_name='Polynomial', coefficients=[0.1, 0, 0.5]) # y=0.1+0.5*x^2\n",
+ "polynomial = sm.Polynomial(\n",
+ " display_name='Polynomial', coefficients=[-0.2, 0, 0.5]\n",
+ ") # y=-0.2+0.5*x^2\n",
"exponential = sm.Exponential(display_name='Exponential', amplitude=1.0, rate=-0.5)\n",
"\n",
"x = np.linspace(-2, 2, 100)\n",
@@ -59,6 +61,18 @@
"plt.show()"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6db00a3e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Suppress the warning from the Polynomial:\n",
+ "polynomial.suppress_warnings = True\n",
+ "y = polynomial.evaluate(x)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -172,7 +186,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "default",
"language": "python",
"name": "python3"
},
@@ -191,4 +205,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/docs/tutorials/convolution.ipynb b/docs/docs/tutorials/convolution.ipynb
index d37c38c14..90129696e 100644
--- a/docs/docs/tutorials/convolution.ipynb
+++ b/docs/docs/tutorials/convolution.ipynb
@@ -210,9 +210,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
"language": "python",
- "name": "python3"
+ "name": "python3",
+ "display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
@@ -223,10 +223,9 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.14.4"
+ "pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/docs/tutorials/data/fake_advanced_data.hdf5 b/docs/docs/tutorials/data/fake_advanced_data.hdf5
index d83903fb0..2a21cb285 100644
Binary files a/docs/docs/tutorials/data/fake_advanced_data.hdf5 and b/docs/docs/tutorials/data/fake_advanced_data.hdf5 differ
diff --git a/docs/docs/tutorials/data/fake_simple_data.hdf5 b/docs/docs/tutorials/data/fake_simple_data.hdf5
index 14afb8306..bef9761a8 100644
Binary files a/docs/docs/tutorials/data/fake_simple_data.hdf5 and b/docs/docs/tutorials/data/fake_simple_data.hdf5 differ
diff --git a/docs/docs/tutorials/delta_lorentz.ipynb b/docs/docs/tutorials/delta_lorentz.ipynb
index 351358124..9d41573d6 100644
--- a/docs/docs/tutorials/delta_lorentz.ipynb
+++ b/docs/docs/tutorials/delta_lorentz.ipynb
@@ -96,9 +96,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
"language": "python",
- "name": "python3"
+ "name": "python3",
+ "display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
@@ -109,10 +109,9 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.14.4"
+ "pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/docs/tutorials/detailed_balance.ipynb b/docs/docs/tutorials/detailed_balance.ipynb
index 12d926a6d..69a22f71a 100644
--- a/docs/docs/tutorials/detailed_balance.ipynb
+++ b/docs/docs/tutorials/detailed_balance.ipynb
@@ -86,9 +86,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "easydynamics_newbase",
"language": "python",
- "name": "python3"
+ "name": "python3",
+ "display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
@@ -99,8 +99,7 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.12"
+ "pygments_lexer": "ipython3"
}
},
"nbformat": 4,
diff --git a/docs/docs/tutorials/diffusion_model.ipynb b/docs/docs/tutorials/diffusion_model.ipynb
index ddf9e9a0e..7468efc59 100644
--- a/docs/docs/tutorials/diffusion_model.ipynb
+++ b/docs/docs/tutorials/diffusion_model.ipynb
@@ -85,9 +85,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
"language": "python",
- "name": "python3"
+ "name": "python3",
+ "display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
@@ -98,10 +98,9 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.14.4"
+ "pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/docs/tutorials/experiment.ipynb b/docs/docs/tutorials/experiment.ipynb
index 0e6b7c5b4..124d83ff1 100644
--- a/docs/docs/tutorials/experiment.ipynb
+++ b/docs/docs/tutorials/experiment.ipynb
@@ -70,9 +70,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python (Pixi)",
"language": "python",
- "name": "pixi-kernel-python3"
+ "name": "python3",
+ "display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
@@ -83,10 +83,9 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.12"
+ "pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/docs/tutorials/instrument_model.ipynb b/docs/docs/tutorials/instrument_model.ipynb
index 27ccbb93a..988f56ca5 100644
--- a/docs/docs/tutorials/instrument_model.ipynb
+++ b/docs/docs/tutorials/instrument_model.ipynb
@@ -75,9 +75,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "easydynamics_newbase",
"language": "python",
- "name": "python3"
+ "name": "python3",
+ "display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
@@ -88,10 +88,9 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.12"
+ "pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/docs/tutorials/sample_model.ipynb b/docs/docs/tutorials/sample_model.ipynb
index fa34e8da1..be0d967dc 100644
--- a/docs/docs/tutorials/sample_model.ipynb
+++ b/docs/docs/tutorials/sample_model.ipynb
@@ -61,7 +61,7 @@
" diffusion_models=diffusion_model,\n",
" components=component_collection,\n",
" Q=Q,\n",
- " unit='meV',\n",
+ " x_unit='meV',\n",
" display_name='MySampleModel',\n",
" temperature=10,\n",
")\n",
@@ -127,9 +127,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
"language": "python",
- "name": "python3"
+ "name": "python3",
+ "display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
@@ -140,10 +140,9 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.14.5"
+ "pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/docs/tutorials/tutorial0_basics.ipynb b/docs/docs/tutorials/tutorial0_basics.ipynb
index 548c6c714..176aaf170 100644
--- a/docs/docs/tutorials/tutorial0_basics.ipynb
+++ b/docs/docs/tutorials/tutorial0_basics.ipynb
@@ -386,7 +386,19 @@
"id": "842c1f01",
"metadata": {},
"source": [
- "The final step in this tutorial is to fit the are of the `Gaussian` to a straight line. For this, we use the `ParameterAnalysis` class. We create a `Polynomial` with two coefficients for the fit function. We create a `FitBinding`, telling the class we want to fit the parameter named `Gaussian area` with the fit function that we define."
+ "The final step in this tutorial is to fit the area of the `Gaussian` as a function of Q using a straight line. For this, we use the `ParameterAnalysis` class.\n",
+ "\n",
+ "We create a `Polynomial` with two coefficients as the fit function. We then create a `FitBinding` to connect the parameter named `Gaussian area` to the fit function, and pass both to a `ParameterAnalysis` object:\n",
+ "\n",
+ "\n",
+ " Note
\n",
+ "\n",
+ "Two units must be set on the fit function:\n",
+ "\n",
+ "- `x_unit='1/angstrom'` — because `ParameterAnalysis` always uses Q as its x-axis.\n",
+ "- `y_unit='meV'` — because we are fitting `Gaussian area`, which has unit `meV`.\n",
+ "
\n",
+ " "
]
},
{
@@ -397,10 +409,13 @@
"outputs": [],
"source": [
"fit_func = sm.Polynomial(\n",
- " coefficients=[3.7, -0.5], name='Straight line', display_name='Straight line'\n",
+ " coefficients=[3.7, -0.5],\n",
+ " x_unit='1/angstrom',\n",
+ " y_unit='meV',\n",
+ " name='Straight line',\n",
")\n",
"\n",
- "binding = edyn.FitBinding(parameter_name='Gaussian area', model=fit_func)\n",
+ "binding = edyn.FitBinding(model=fit_func, targets='Gaussian area')\n",
"\n",
"parameter_analysis = edyn.ParameterAnalysis(\n",
" parameters=analysis,\n",
@@ -450,7 +465,7 @@
"id": "dc33728c",
"metadata": {},
"source": [
- "To see the parameters we can use the `get_all_parameters()` method. We can also see only the parameters that can be fitted:"
+ "To see the parameters we can use the `get_all_parameters()` method. We can also see only the parameters that can be fitted using `get_fittable_parameters`:"
]
},
{
@@ -476,7 +491,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "default",
"language": "python",
"name": "python3"
},
diff --git a/docs/docs/tutorials/tutorial0_more_advanced.ipynb b/docs/docs/tutorials/tutorial0_more_advanced.ipynb
index f203dbdcb..ece987107 100644
--- a/docs/docs/tutorials/tutorial0_more_advanced.ipynb
+++ b/docs/docs/tutorials/tutorial0_more_advanced.ipynb
@@ -144,6 +144,13 @@
"
\n",
" SampleModel, BackgroundModel and ResolutionModel (introduced later) all take components in several different ways: you can add a single component like in the previous tutorial, append components using the `append_component` method, or give a ComponentCollection like we do here. \n",
"
\n",
+ "\n",
+ "\n",
+ "\n",
+ " 💡 Tip
\n",
+ " \n",
+ " Polynomials will warn if they produce negative values, since negative backgrounds are unphysical. To suppress the warning, set suppress_warnings to True, either when constructing the Polynomial (polynomial=sm.Polynomial(coefficients=[0.0], suppress_warnings=True)), or afterwards (polynomial.suppress_warnings = True).\n",
+ "
\n",
" \n"
]
},
@@ -300,7 +307,19 @@
"id": "0eadbd91",
"metadata": {},
"source": [
- "With apologies for the lack of creativity, these all appear like straight lines. We can fit them individually or all together using `ParameterAnalysis`"
+ "With apologies for the lack of creativity, these all appear like straight lines. We can fit them individually or all together using `ParameterAnalysis`.\n",
+ "\n",
+ "For each parameter we want to fit, we create a fit function (here a `Polynomial`) and a `FitBinding` connecting the parameter name to the fit function. We then pass all bindings to a single `ParameterAnalysis` object:\n",
+ "\n",
+ "\n",
+ " Note
\n",
+ "\n",
+ "Two units must be set on each fit function:\n",
+ "\n",
+ "- `x_unit='1/angstrom'` — because `ParameterAnalysis` always uses Q as its x-axis.\n",
+ "- `y_unit='meV'` — because all three parameters (`Gaussian area`, `DHO area`, `DHO center`) have unit `meV`.\n",
+ "
\n",
+ " "
]
},
{
@@ -310,17 +329,21 @@
"metadata": {},
"outputs": [],
"source": [
- "gauss_fit_func = sm.Polynomial(coefficients=[3.7, -0.5], unit='1/angstrom', name='Gauss area fit')\n",
- "dho_area_fit_func = sm.Polynomial(coefficients=[2.0, 0.12], unit='1/angstrom', name='DHO area fit')\n",
+ "gauss_fit_func = sm.Polynomial(\n",
+ " coefficients=[3.7, -0.5], x_unit='1/angstrom', y_unit='meV', name='Gauss area fit'\n",
+ ")\n",
+ "dho_area_fit_func = sm.Polynomial(\n",
+ " coefficients=[2.0, 0.12], x_unit='1/angstrom', y_unit='meV', name='DHO area fit'\n",
+ ")\n",
"dho_center_fit_func = sm.Polynomial(\n",
- " coefficients=[1.1, 0.2], unit='1/angstrom', name='DHO center fit'\n",
+ " coefficients=[1.1, 0.2], x_unit='1/angstrom', y_unit='meV', name='DHO center fit'\n",
")\n",
"\n",
- "binding1 = edyn.FitBinding(parameter_name='Gaussian area', model=gauss_fit_func)\n",
+ "binding1 = edyn.FitBinding(model=gauss_fit_func, targets='Gaussian area')\n",
"\n",
- "binding2 = edyn.FitBinding(parameter_name='DHO area', model=dho_area_fit_func)\n",
+ "binding2 = edyn.FitBinding(model=dho_area_fit_func, targets='DHO area')\n",
"\n",
- "binding3 = edyn.FitBinding(parameter_name='DHO center', model=dho_center_fit_func)\n",
+ "binding3 = edyn.FitBinding(model=dho_center_fit_func, targets='DHO center')\n",
"\n",
"parameter_analysis = edyn.ParameterAnalysis(\n",
" parameters=analysis,\n",
@@ -335,7 +358,7 @@
"id": "32bc1efc",
"metadata": {},
"source": [
- "The start guesses look reasonable, so we fit:"
+ "The start guesses look reasonable, so we fit and plot the result:"
]
},
{
@@ -352,7 +375,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "default",
"language": "python",
"name": "python3"
},
diff --git a/docs/docs/tutorials/tutorial1_brownian.ipynb b/docs/docs/tutorials/tutorial1_brownian.ipynb
index 269bb2599..0965bfaa7 100644
--- a/docs/docs/tutorials/tutorial1_brownian.ipynb
+++ b/docs/docs/tutorials/tutorial1_brownian.ipynb
@@ -451,7 +451,7 @@
"$$\n",
"where $\\Gamma(Q) = D Q^2$ and $D$ is the diffusion coefficient. $S$ is an overall scale.\n",
"\n",
- "To fit the Brownian translational diffusion model to the data, we use the `ParameterAnalysis`. This time, we wish to fit the `Lorentzian`, and we wish to fit both its area (scale, $S$) and width to the diffusion model. We do this by creating a `FitBinding`, saying we want to fit both the area an width of the component called `Lorentzian`"
+ "To fit the Brownian translational diffusion model to the data, we use the `ParameterAnalysis`. This time, we wish to fit the `Lorentzian`, and we wish to fit both its area (scale, $S$) and width to the diffusion model. We do this by creating a `FitBinding`. Because the diffusion model's `lorentzian_name` matches the fitted component's name, the default targets automatically fit both the area and the width against the `Lorentzian area` and `Lorentzian width` parameters"
]
},
{
@@ -462,14 +462,13 @@
"outputs": [],
"source": [
"brownian_diffusion_model = sm.BrownianTranslationalDiffusion(\n",
- " name='Brownian Translational Diffusion', diffusion_coefficient=2.4e-9, scale=0.5\n",
+ " name='Brownian Translational Diffusion',\n",
+ " lorentzian_name='Lorentzian',\n",
+ " diffusion_coefficient=2.4e-9,\n",
+ " scale=0.5,\n",
")\n",
"\n",
- "binding = edyn.FitBinding(\n",
- " parameter_name='Lorentzian',\n",
- " model=brownian_diffusion_model,\n",
- " modes=['area', 'width'],\n",
- ")\n",
+ "binding = edyn.FitBinding(model=brownian_diffusion_model)\n",
"\n",
"parameter_analysis = edyn.ParameterAnalysis(\n",
" parameters=diffusion_analysis,\n",
@@ -691,7 +690,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "default",
"language": "python",
"name": "python3"
},
diff --git a/docs/docs/tutorials/tutorial2_nanoparticles.ipynb b/docs/docs/tutorials/tutorial2_nanoparticles.ipynb
index 22de24900..da2188bcf 100644
--- a/docs/docs/tutorials/tutorial2_nanoparticles.ipynb
+++ b/docs/docs/tutorials/tutorial2_nanoparticles.ipynb
@@ -593,7 +593,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "default",
"language": "python",
"name": "python3"
},
diff --git a/src/easydynamics/analysis/analysis.py b/src/easydynamics/analysis/analysis.py
index 7291ea917..25a600a14 100644
--- a/src/easydynamics/analysis/analysis.py
+++ b/src/easydynamics/analysis/analysis.py
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2026 EasyScience contributors
# SPDX-License-Identifier: BSD-3-Clause
+from copy import copy
from typing import Any
import numpy as np
@@ -456,10 +457,7 @@ def data_and_model_to_datagroup(
self._verify_bool(include_components, 'include_components')
self._verify_bool(include_residuals, 'include_residuals')
- energy = self._verify_energy(energy)
-
- if energy is None:
- energy = self.energy
+ energy = self._verify_energy(energy) if energy is not None else self.energy
data_and_model = {
'Data': self.experiment.binned_data,
@@ -532,6 +530,7 @@ def parameters_to_dataset(self) -> sc.Dataset:
units[name] = p.unit
elif units[name] != p.unit:
try:
+ p = copy(p)
p.convert_unit(units[name])
except Exception as e:
raise UnitError(
@@ -587,7 +586,7 @@ def plot_parameters(
ds = self.parameters_to_dataset()
- if not names:
+ if names is None:
names = list(ds.keys())
if isinstance(names, str):
@@ -686,9 +685,8 @@ def _on_convolution_settings_changed(self) -> None:
def _ensure_analysis_list_current(self) -> None:
"""Rebuild the analysis list if any dependency has changed since it was last built."""
- if self._analysis_list_is_dirty:
- if self.Q is not None:
- self._create_analysis_list()
+ if self._analysis_list_is_dirty and self.Q is not None:
+ self._create_analysis_list()
self._analysis_list_is_dirty = False
def _create_analysis_list(self) -> None:
@@ -698,6 +696,8 @@ def _create_analysis_list(self) -> None:
"""
self._analysis_list = []
for Q_index in range(len(self.Q)):
+ # The ConvolutionSettings object is shared so user changes reach every Q index;
+ # plan validity is tracked per convolver, not on the settings object.
analysis = Analysis1d(
display_name=f'{self.display_name}_Q{Q_index}',
experiment=self.experiment,
@@ -756,16 +756,17 @@ def _fit_all_Q_simultaneously(self) -> FitResults:
ys = []
ws = []
+ # TODO: consider using scipp built-in masking instead of numpy boolean masks, # noqa: FIX002 TD002 TD003
+ # once the EasyScience fitter accepts scipp Variables directly.
for analysis1d in self.analysis_list:
- x, y, weight, _ = self.experiment._extract_x_y_weights_only_finite( # noqa: SLF001
- analysis1d.Q_index
- )
+ x, y, weight, _ = self.experiment.extract_x_y_weights_only_finite(analysis1d.Q_index)
xs.append(x)
ys.append(y)
ws.append(weight)
- # Make sure the convolver is up to date for this Q index
- analysis1d.refresh_convolver(energy=x)
+ analysis1d.refresh_convolver(
+ energy=self.experiment.get_masked_energy(Q_index=analysis1d.Q_index)
+ )
mf = MultiFitter(
fit_objects=self.analysis_list,
diff --git a/src/easydynamics/analysis/analysis1d.py b/src/easydynamics/analysis/analysis1d.py
index 341103e16..b45a3ebc7 100644
--- a/src/easydynamics/analysis/analysis1d.py
+++ b/src/easydynamics/analysis/analysis1d.py
@@ -170,7 +170,7 @@ def calculate(self, energy: sc.Variable | None = None) -> np.ndarray:
"""
Calculate the model prediction for the chosen Q index.
- Makes sure the convolver is up to date before calculating.
+ Creates a new convolver before calculating without touching the stored convolver.
Parameters
----------
@@ -184,14 +184,12 @@ def calculate(self, energy: sc.Variable | None = None) -> np.ndarray:
The calculated model prediction.
"""
energy = self._verify_energy(energy)
- self._convolver = self._create_convolver(energy=energy)
- # Mark dirty so the next fit() call rebuilds the convolver with the standard
- # (unmasked) energy grid rather than reusing this plot-path grid.
- self._convolver_is_dirty = True
+ convolver = self._create_convolver(energy=energy)
+ return self._calculate(energy=energy, convolver=convolver)
- return self._calculate(energy=energy)
-
- def _calculate(self, energy: sc.Variable | None = None) -> np.ndarray:
+ def _calculate(
+ self, energy: sc.Variable | None = None, convolver: Convolution | None = None
+ ) -> np.ndarray:
"""
Calculate the model prediction for the chosen Q index.
@@ -202,18 +200,21 @@ def _calculate(self, energy: sc.Variable | None = None) -> np.ndarray:
energy : sc.Variable | None, default=None
Optional energy grid to use for calculation. If None, the energy grid from the
experiment is used.
+ convolver : Convolution | None, default=None
+ Optional convolver to use. If None, uses self._convolver.
Returns
-------
np.ndarray
The calculated model prediction.
"""
-
+ if convolver is None:
+ convolver = self._convolver
Q_index = self._require_Q_index()
sample = self._evaluate_with_convolution(
self.sample_model.get_component_collection(Q_index),
energy,
- convolver=self._convolver,
+ convolver=convolver,
)
background = self._evaluate_direct(
self.instrument_model.background_model.get_component_collection(Q_index),
@@ -254,7 +255,7 @@ def fit(self) -> FitResults:
fit_function=self.as_fit_function(),
)
- x, y, weights, _ = self.experiment._extract_x_y_weights_only_finite( # noqa: SLF001
+ x, y, weights, _ = self.experiment.extract_x_y_weights_only_finite(
Q_index=self._require_Q_index()
)
fit_result = fitter.fit(x=x, y=y, weights=weights)
@@ -438,7 +439,7 @@ def data_and_model_to_datagroup(
energy = self._masked_energy
data_and_model = {
- 'Data': self.experiment.binned_data['Q', self.Q_index],
+ 'Data': self.experiment.get_masked_binned_data(Q_index=self.Q_index),
'Model': self._create_model_array(energy=energy),
}
@@ -516,6 +517,10 @@ def _on_Q_index_changed(self) -> None:
This method is called whenever the Q index is changed. It updates the masked energy from
the experiment for the new Q index and marks the convolver as dirty.
"""
+ if self._Q_index is None:
+ self._masked_energy = None
+ self._convolver_is_dirty = True
+ return
masked_energy = self.experiment.get_masked_energy(Q_index=self._Q_index)
self._masked_energy = masked_energy
self._convolver_is_dirty = True
@@ -523,6 +528,9 @@ def _on_Q_index_changed(self) -> None:
def _on_experiment_changed(self) -> None:
"""Mark the convolver as dirty when the experiment changes."""
super()._on_experiment_changed()
+ # Refresh masked energy if Q_index is already set (i.e. post-init experiment swap).
+ if getattr(self, '_Q_index', None) is not None and self.experiment is not None:
+ self._masked_energy = self.experiment.get_masked_energy(Q_index=self._Q_index)
self._convolver_is_dirty = True
def _on_sample_model_changed(self) -> None:
@@ -561,28 +569,15 @@ def _calculate_energy_with_offset(
energy_offset : Parameter
The energy offset to apply.
- Raises
- ------
- sc.UnitError
- If the energy and energy offset have incompatible units.
-
Returns
-------
sc.Variable
The energy grid with the offset applied.
"""
- if energy.unit != energy_offset.unit:
- try:
- energy_offset.convert_unit(str(energy.unit))
- except Exception as e:
- raise sc.UnitError(
- f'Energy and energy offset must have compatible units. '
- f'Got {energy.unit} and {energy_offset.unit}.'
- ) from e
-
- energy_with_offset = energy.copy(deep=True)
- energy_with_offset.values -= energy_offset.value
+ offset_value = sc.to_unit(energy_offset.full_value, energy.unit).value
+ energy_with_offset = energy.copy()
+ energy_with_offset.values = energy.values - offset_value
return energy_with_offset
#############
@@ -640,18 +635,15 @@ def _evaluate_with_convolution(
energy=energy_with_offset,
temperature=self.temperature,
divide_by_temperature=self.detailed_balance_settings.normalize_detailed_balance,
- energy_unit=self.unit,
+ energy_unit=self.x_unit,
)
return result
- return Convolution(
- energy=energy,
+ return self._build_convolution(
sample_components=components,
resolution_components=resolution,
+ energy=energy,
energy_offset=energy_offset,
- convolution_settings=self.convolution_settings,
- temperature=self.temperature,
- detailed_balance_settings=self.detailed_balance_settings,
).convolution()
def _evaluate_direct(
@@ -719,11 +711,25 @@ def _create_convolver(
if resolution_components.is_empty:
return None
+ return self._build_convolution(
+ sample_components=sample_components,
+ resolution_components=resolution_components,
+ energy=energy,
+ energy_offset=self.instrument_model.get_energy_offset(Q_index),
+ )
+
+ def _build_convolution(
+ self,
+ sample_components: ComponentCollection | ModelComponent,
+ resolution_components: ComponentCollection,
+ energy: sc.Variable,
+ energy_offset: Parameter,
+ ) -> Convolution:
return Convolution(
energy=energy,
sample_components=sample_components,
resolution_components=resolution_components,
- energy_offset=self.instrument_model.get_energy_offset(Q_index),
+ energy_offset=energy_offset,
convolution_settings=self.convolution_settings,
temperature=self.temperature,
detailed_balance_settings=self.detailed_balance_settings,
@@ -769,7 +775,7 @@ def _create_residuals_array(self) -> sc.DataArray:
if self.Q_index is None:
raise ValueError('Q_index must be set to calculate residuals.')
- data = self.experiment.binned_data['Q', self.Q_index]
+ data = self.experiment.get_masked_binned_data(Q_index=self.Q_index)
model = self._create_model_array()
return data.copy(deep=True) - model
diff --git a/src/easydynamics/analysis/analysis_base.py b/src/easydynamics/analysis/analysis_base.py
index 598daea67..80be2833f 100644
--- a/src/easydynamics/analysis/analysis_base.py
+++ b/src/easydynamics/analysis/analysis_base.py
@@ -13,6 +13,7 @@
from easydynamics.sample_model import SampleModel
from easydynamics.settings.convolution_settings import ConvolutionSettings
from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings
+from easydynamics.utils.utils import verify_Q_index
class AnalysisBase(EasyDynamicsModelBase):
@@ -503,13 +504,6 @@ def _verify_Q_index(self, Q_index: int | None) -> int | None:
Q_index : int | None
The Q index to verify.
- Raises
- ------
- TypeError
- If Q_index is not an integer or None.
- IndexError
- If the Q index is not valid.
-
Returns
-------
int | None
@@ -517,12 +511,7 @@ def _verify_Q_index(self, Q_index: int | None) -> int | None:
"""
if Q_index is None:
return None
-
- if not isinstance(Q_index, int):
- raise TypeError('Q_index must be an integer or None.')
-
- if Q_index < 0 or (self.Q is not None and Q_index >= len(self.Q)):
- raise IndexError('Q_index must be a valid index for the Q values.')
+ verify_Q_index(Q_index, self.Q)
return Q_index
def _verify_energy(self, energy: sc.Variable | None) -> sc.Variable | None:
diff --git a/src/easydynamics/analysis/fit_binding.py b/src/easydynamics/analysis/fit_binding.py
index da74e383f..e43ced4ab 100644
--- a/src/easydynamics/analysis/fit_binding.py
+++ b/src/easydynamics/analysis/fit_binding.py
@@ -3,145 +3,110 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
+from dataclasses import replace
from easydynamics.base_classes.easydynamics_base import EasyDynamicsBase
from easydynamics.sample_model.component_collection import ComponentCollection
from easydynamics.sample_model.components.model_component import ModelComponent
from easydynamics.sample_model.diffusion_model.diffusion_model_base import DiffusionModelBase
-
-if TYPE_CHECKING:
- from collections.abc import Callable
+from easydynamics.utils.fit_target import FitTarget
class FitBinding(EasyDynamicsBase):
"""
- Contract between dataset, model, and fit function for ParameterAnalysis. This class
- encapsulates the necessary information to bind a dataset key to a model and convert it into a
- fit function callable.
+ Contract between dataset, model, and fit functions for ParameterAnalysis. A binding maps the
+ model's fittable predictions (its FitTargets) onto keys of the parameters Dataset they should
+ be fitted against.
Examples
--------
- **Basic usage with a ModelComponent**
+ **Fitting a component model to one parameter**
- Bind a fit parameter name to any ``ModelComponent`` or ``ComponentCollection``:
+ Component models (e.g. a Polynomial) have a single prediction — their ``evaluate`` — so
+ ``targets`` is simply the dataset key to fit against:
```python
import easydynamics as edyn
import easydynamics.sample_model as sm
fit_func = sm.Polynomial(coefficients=[3.7, -0.5], display_name='Straight line')
- binding = edyn.FitBinding(parameter_name='Gaussian area', model=fit_func)
+ binding = edyn.FitBinding(model=fit_func, targets='Gaussian area')
```
- **Usage with a DiffusionModelBase and specific modes**
+ **Fitting a diffusion model with default dataset keys**
- For diffusion models, use ``modes`` to select which parameter arrays are fitted:
+ Diffusion models declare their predictions (``'area'``, ``'width'``, and for DeltaLorentz also
+ ``'delta_area'``). With ``targets=None`` all predictions are fitted against default dataset
+ keys derived from the model's component names:
```python
brownian = sm.BrownianTranslationalDiffusion(
- display_name='Brownian Translational Diffusion',
diffusion_coefficient=2.4e-9,
scale=0.5,
+ lorentzian_name='Lorentzian',
)
+ binding = edyn.FitBinding(model=brownian) # fits 'Lorentzian area' and 'Lorentzian width'
+ ```
+
+ **Selecting predictions or mapping them to custom dataset keys**
+
+ Pass a list of prediction names, or a dict mapping prediction names to dataset keys:
+ ```python
+ binding = edyn.FitBinding(model=brownian, targets=['width'])
+
+ delta_lorentz = sm.DeltaLorentz(A_0=0.5, lorentzian_width=0.1)
binding = edyn.FitBinding(
- parameter_name='Lorentzian',
- model=brownian,
- modes=['area', 'width'],
+ model=delta_lorentz,
+ targets={
+ 'width': 'Lorentzian width',
+ 'area': 'Lorentzian area',
+ 'delta_area': 'Elastic area',
+ },
)
```
"""
def __init__(
self,
- parameter_name: str,
model: ModelComponent | ComponentCollection | DiffusionModelBase,
- modes: str | list[str] | None = None,
+ targets: str | list[str] | dict[str, str] | None = None,
display_name: str | None = None,
unique_name: str | None = None,
) -> None:
"""
Initialize a FitBinding.
+ Validation raises ``TypeError`` if model or targets have an invalid type, and
+ ``ValueError`` if targets names a prediction the model does not declare.
+
Parameters
----------
- parameter_name : str
- The name of the parameter to fit. This should correspond to a key in the dataset.
model : ModelComponent | ComponentCollection | DiffusionModelBase
The model to fit. This can be a single ModelComponent, a ComponentCollection, or a
DiffusionModelBase.
- modes : str | list[str] | None, default=None
- The modes to fit for diffusion models. This can be a single string, a list of strings,
- or None (which defaults to ["area", "width"]). Only applicable if the model is a
- DiffusionModelBase. Default is None.
+ targets : str | list[str] | dict[str, str] | None, default=None
+ Which predictions of the model to fit, and against which dataset keys. For component
+ models this must be a string: the dataset key to fit the model's ``evaluate`` against.
+ For diffusion models: None fits all predictions against their default dataset keys; a
+ string or list of strings selects predictions by name (default keys); a dict maps
+ prediction names to custom dataset keys.
display_name : str | None, default=None
An optional display name for the FitBinding. If None, the unique_name will be used.
Default is None.
unique_name : str | None, default=None
An optional unique name for the FitBinding. If None, a unique name will be generated.
Default is None.
-
- Raises
- ------
- TypeError
- If parameter_name is not a string, if model is not a ModelComponent,
- ComponentCollection or DiffusionModelBase, or if modes is not a string, list of
- strings, or None.
"""
super().__init__(display_name=display_name, unique_name=unique_name)
- if not isinstance(parameter_name, str):
- raise TypeError('parameter_name must be a string')
-
- if not isinstance(model, (ModelComponent, ComponentCollection, DiffusionModelBase)):
- raise TypeError(
- 'model must be a ModelComponent, ComponentCollection, or DiffusionModelBase'
- )
-
- if modes is not None and not isinstance(modes, (str, list)):
- raise TypeError('modes must be a string, list of strings, or None')
-
- if isinstance(modes, list) and not all(isinstance(mode, str) for mode in modes):
- raise TypeError('All modes in the list must be strings')
-
- self._parameter_name = parameter_name
+ self._validate_model(model)
+ self._validate_targets(model, targets)
self._model = model
- self._modes = modes
+ self._targets = targets
# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------
- @property
- def parameter_name(self) -> str:
- """
- The name of the parameter to fit. This should correspond to a key in the dataset.
-
- Returns
- -------
- str
- The name of the parameter to fit.
- """
- return self._parameter_name
-
- @parameter_name.setter
- def parameter_name(self, value: str) -> None:
- """
- Set the name of the parameter to fit.
-
- Parameters
- ----------
- value : str
- The new name of the parameter to fit.
-
- Raises
- ------
- TypeError
- If the value is not a string.
- """
- if not isinstance(value, str):
- raise TypeError('parameter_name must be a string')
- self._parameter_name = value
-
@property
def model(self) -> ModelComponent | ComponentCollection | DiffusionModelBase:
"""
@@ -160,162 +125,168 @@ def model(self, value: ModelComponent | ComponentCollection | DiffusionModelBase
"""
Set the model to fit.
+ Validation raises ``TypeError`` if the value has an invalid type, and ``ValueError`` if the
+ current targets name a prediction the new model does not declare.
+
Parameters
----------
value : ModelComponent | ComponentCollection | DiffusionModelBase
The new model to fit.
-
- Raises
- ------
- TypeError
- If the value is not a ModelComponent, ComponentCollection, or DiffusionModelBase.
"""
- if not isinstance(value, (ModelComponent, ComponentCollection, DiffusionModelBase)):
- raise TypeError(
- 'model must be a ModelComponent, ComponentCollection, or DiffusionModelBase.'
- )
+ self._validate_model(value)
+ self._validate_targets(value, self._targets)
self._model = value
@property
- def modes(self) -> str | list[str] | None:
+ def targets(self) -> str | list[str] | dict[str, str] | None:
"""
- The modes to fit for diffusion models. This can be a single string, a list of strings, or
- None (which defaults to ["area", "width"]).
+ Which predictions of the model to fit, and against which dataset keys.
Returns
-------
- str | list[str] | None
- The modes to fit for diffusion models.
+ str | list[str] | dict[str, str] | None
+ The targets specification (see ``__init__``).
"""
- return self._modes
+ return self._targets
- @modes.setter
- def modes(self, value: str | list[str] | None) -> None:
+ @targets.setter
+ def targets(self, value: str | list[str] | dict[str, str] | None) -> None:
"""
- Set the modes to fit for diffusion models.
+ Set which predictions of the model to fit, and against which dataset keys.
+
+ Validation raises ``TypeError`` if the value has an invalid type for the current model, and
+ ``ValueError`` if it names a prediction the model does not declare.
Parameters
----------
- value : str | list[str] | None
- The new modes to fit for diffusion models.
-
- Raises
- ------
- TypeError
- If the value is not a string, list of strings, or None.
+ value : str | list[str] | dict[str, str] | None
+ The new targets specification (see ``__init__``).
"""
- if value is not None and not isinstance(value, (str, list)):
- raise TypeError('modes must be a string, list of strings, or None')
-
- if isinstance(value, str):
- value = [value]
- if isinstance(value, list) and not all(isinstance(mode, str) for mode in value):
- raise TypeError('All modes in the list must be strings')
- self._modes = value
+ self._validate_targets(self._model, value)
+ self._targets = value
# ------------------------------------------------------------------
# Other methods
# ------------------------------------------------------------------
- def build_callables(self) -> list[Callable]:
+ def get_targets(self) -> list[FitTarget]:
"""
- Build the callables for fitting based on the model and modes.
+ Get the FitTargets this binding fits, with dataset keys resolved.
- Returns
- -------
- list[Callable]
- A list of callables for fitting.
- """
- modes = self._get_modes()
-
- if isinstance(self.model, DiffusionModelBase):
- return [self._build_diffusion_callable(mode) for mode in modes]
-
- return [lambda x, **_: self.model.evaluate(x)]
-
- def get_model_names(self) -> list[str]:
- """
- Get the names of the models based on the current modes.
+ Targets are built from the model at call time, so their units and default dataset keys
+ reflect the model's current state.
Returns
-------
- list[str]
- A list of model names.
+ list[FitTarget]
+ The resolved fit targets.
"""
- modes = self._get_modes()
-
if isinstance(self.model, DiffusionModelBase):
- return [f'{self.model.display_name} {mode}' for mode in modes]
-
- return [self.model.display_name]
-
- def get_parameter_names(self) -> list[str]:
- """
- Get the names of the parameters based on the current modes.
-
- Returns
- -------
- list[str]
- A list of parameter names.
- """
- modes = self._get_modes()
-
- if isinstance(self.model, DiffusionModelBase):
- # This needs to be generalised.
- # TODO: Generalise this for different diffusion models and modes. # noqa TD002 TD003
- if 'delta' in modes:
- return [f'{self.parameter_name} area' for mode in modes]
-
- return [f'{self.parameter_name} {mode}' for mode in modes]
-
- return [self.parameter_name]
+ available = {target.name: target for target in self.model.get_fit_targets()}
+ if self._targets is None:
+ return list(available.values())
+ if isinstance(self._targets, str):
+ return [available[self._targets]]
+ if isinstance(self._targets, list):
+ return [available[name] for name in self._targets]
+ return [
+ replace(available[name], dataset_key=dataset_key)
+ for name, dataset_key in self._targets.items()
+ ]
+
+ return [
+ FitTarget(
+ name='value',
+ dataset_key=self._targets,
+ function=lambda x, model=self.model, **_: model.evaluate(x),
+ label=self.model.display_name,
+ x_unit=self.model.x_unit,
+ y_unit=self.model.y_unit,
+ )
+ ]
# ------------------------------------------------------------------
# Private methods
# ------------------------------------------------------------------
- def _build_diffusion_callable(self, mode: str) -> Callable:
+ @staticmethod
+ def _validate_model(
+ model: ModelComponent | ComponentCollection | DiffusionModelBase,
+ ) -> None:
"""
- Build a callable for a specific diffusion mode.
+ Validate the model type.
Parameters
----------
- mode : str
- The diffusion mode ("area" or "width").
-
- Returns
- -------
- Callable
- A callable for the specified diffusion mode.
+ model : ModelComponent | ComponentCollection | DiffusionModelBase
+ The model to validate.
Raises
------
- ValueError
- If the mode is unknown.
+ TypeError
+ If model is not a ModelComponent, ComponentCollection, or DiffusionModelBase.
"""
- model = self.model
-
- if mode == 'area':
- return lambda x, **_: model.calculate_QISF(x) * model.scale.value
-
- if mode == 'width':
- return lambda x, **_: model.calculate_width(x)
-
- if mode == 'delta':
- return lambda x, **_: model.calculate_EISF(x) * model.scale.value
-
- raise ValueError(f'Unknown diffusion mode: {mode}')
+ if not isinstance(model, (ModelComponent, ComponentCollection, DiffusionModelBase)):
+ raise TypeError(
+ 'model must be a ModelComponent, ComponentCollection, or DiffusionModelBase'
+ )
- def _get_modes(self) -> list[str]:
+ @staticmethod
+ def _validate_targets(
+ model: ModelComponent | ComponentCollection | DiffusionModelBase,
+ targets: str | list[str] | dict[str, str] | None,
+ ) -> None:
"""
- Get the modes to fit for diffusion models, defaulting to ["area", "width"] if not set.
+ Validate a targets specification against a model.
- Returns
- -------
- list[str]
- The modes to fit for diffusion models.
+ Parameters
+ ----------
+ model : ModelComponent | ComponentCollection | DiffusionModelBase
+ The model the targets apply to.
+ targets : str | list[str] | dict[str, str] | None
+ The targets specification to validate.
+
+ Raises
+ ------
+ TypeError
+ If targets has an invalid type for the given model.
+ ValueError
+ If targets names a prediction the model does not declare.
"""
- return ['area', 'width'] if self.modes is None else self.modes
+ if isinstance(model, DiffusionModelBase):
+ if targets is None:
+ return
+ if isinstance(targets, str):
+ requested = [targets]
+ elif isinstance(targets, list):
+ requested = targets
+ elif isinstance(targets, dict):
+ requested = list(targets.keys())
+ if not all(isinstance(key, str) for key in targets.values()):
+ raise TypeError('targets dict values must be dataset keys (strings)')
+ else:
+ raise TypeError(
+ 'targets must be None, a prediction name, a list of prediction names, '
+ 'or a dict mapping prediction names to dataset keys'
+ )
+ if not all(isinstance(name, str) for name in requested):
+ raise TypeError('prediction names in targets must be strings')
+
+ available = [target.name for target in model.get_fit_targets()]
+ unknown = sorted(set(requested) - set(available))
+ if unknown:
+ raise ValueError(
+ f'Unknown prediction(s) {", ".join(unknown)} for '
+ f'{model.__class__.__name__}. Available predictions: '
+ f'{", ".join(available)}.'
+ )
+ return
+
+ if not isinstance(targets, str):
+ raise TypeError(
+ 'For component models, targets must be the dataset key (a string) to fit '
+ "the model's evaluate against."
+ )
# ------------------------------------------------------------------
# dunder methods
@@ -331,10 +302,8 @@ def __repr__(self) -> str:
A string representation of the FitBinding.
"""
return (
- f'{self.__class__.__name__}('
- f'parameter_name={self.parameter_name!r},\n'
- f' model={self.model.display_name!r},\n'
- f' modes={self.modes},\n'
- f' display_name={self.display_name!r},\n'
- f' unique_name={self.unique_name!r})'
+ f'{self.__class__.__name__}(\n'
+ f' model={self.model.display_name},\n'
+ f' targets={self.targets},\n'
+ f')'
)
diff --git a/src/easydynamics/analysis/parameter_analysis.py b/src/easydynamics/analysis/parameter_analysis.py
index 142f0b1a4..7e022e008 100644
--- a/src/easydynamics/analysis/parameter_analysis.py
+++ b/src/easydynamics/analysis/parameter_analysis.py
@@ -15,6 +15,7 @@
from easydynamics.analysis.analysis import Analysis
from easydynamics.analysis.fit_binding import FitBinding
from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase
+from easydynamics.utils.fit_target import FitTarget
from easydynamics.utils.utils import _in_notebook
@@ -25,15 +26,15 @@ class ParameterAnalysis(EasyDynamicsModelBase):
Can be used to fit parameters to ModelComponents, ComponentCollections, or DiffusionModelBase
objects, and to plot the parameters and fit results. The parameters to be analyzed can be
provided as a sc.Dataset or directly as an Analysis object. Multiple parameters can be fitted
- simultaneously, and the fit functions can be customized for each parameter. For diffusion
- models, the area and width can be fitted separately (or not at all) by specifying fit settings.
+ simultaneously, and each binding maps its model's predictions onto the dataset keys they are
+ fitted against (for diffusion models e.g. 'area', 'width', or 'delta_area').
Examples
--------
**Fitting Lorentzian widths to a diffusion model**
- After a full Analysis fit, pass the Analysis directly and bind each parameter to a fit function
- using a ``FitBinding``:
+ After a full Analysis fit, pass the Analysis directly and bind the model's predictions to
+ dataset keys using a ``FitBinding``:
```python
import easydynamics as edyn
import easydynamics.sample_model as sm
@@ -41,9 +42,8 @@ class ParameterAnalysis(EasyDynamicsModelBase):
# analysis is an edyn.Analysis object with previously fitted parameters
diffusion_model = sm.BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, scale=0.5)
binding = edyn.FitBinding(
- parameter_name='Lorentzian width',
model=diffusion_model,
- modes=['width'],
+ targets={'width': 'Lorentzian width'},
)
param_analysis = edyn.ParameterAnalysis(
@@ -58,8 +58,8 @@ class ParameterAnalysis(EasyDynamicsModelBase):
```python
area_binding = edyn.FitBinding(
- parameter_name='Lorentzian area',
model=sm.Polynomial(coefficients=[0.5, 0.0]),
+ targets='Lorentzian area',
)
param_analysis = edyn.ParameterAnalysis(
parameters=analysis,
@@ -183,23 +183,24 @@ def fit(self) -> FitResults:
funcs, models = [], []
for binding in self.bindings:
- param_names = binding.get_parameter_names()
- callables = binding.build_callables()
-
- for pname, func in zip(param_names, callables, strict=True):
- if pname not in self.parameters:
+ for target in binding.get_targets():
+ if target.dataset_key not in self.parameters:
raise ValueError(
- f"Parameter '{pname}' from binding '{binding.unique_name}' "
- f'not found in parameters Dataset.'
+ f"Parameter '{target.dataset_key}' from binding "
+ f"'{binding.unique_name}' not found in parameters Dataset."
)
- x, y, weight = self._get_xyweight_from_dataset(pname)
+ x, y, weight = self._get_xyweight_from_dataset(target.dataset_key)
+ x_factor, y_factor = self._get_unit_conversions(target)
+ x = x * x_factor
+ y = y * y_factor
+ weight = weight / y_factor
xs.append(x)
ys.append(y)
ws.append(weight)
- funcs.append(func)
+ funcs.append(target.function)
models.append(binding.model)
mf = MultiFitter(
@@ -258,7 +259,7 @@ def plot(
names = list(self.parameters.keys())
else:
for b in self.bindings:
- names.extend(b.get_parameter_names())
+ names.extend(target.dataset_key for target in b.get_targets())
names = self._normalize_names(names)
@@ -282,14 +283,13 @@ def plot(
data_arrays = {}
model_arrays = {}
- # map parameter names to model names
+ # map dataset keys to model labels
param_to_model = {}
if self.bindings is not None:
for b in self.bindings:
- param_names = b.get_parameter_names()
- model_names = b.get_model_names()
-
- param_to_model.update(dict(zip(param_names, model_names, strict=True)))
+ param_to_model.update({
+ target.dataset_key: target.label for target in b.get_targets()
+ })
for pname in names:
data_arrays[pname] = self.parameters[pname]
@@ -354,22 +354,20 @@ def calculate_model_dataset(self, bindings: list[FitBinding]) -> sc.Dataset:
arrays = {}
for b in bindings:
- param_names = b.get_parameter_names()
- model_names = b.get_model_names()
- callables = b.build_callables()
-
- for pname, mname, func in zip(param_names, model_names, callables, strict=True):
- if pname not in self.parameters:
+ for target in b.get_targets():
+ if target.dataset_key not in self.parameters:
raise ValueError(
- f"Parameter '{pname}' from binding '{b.unique_name}' "
+ f"Parameter '{target.dataset_key}' from binding '{b.unique_name}' "
f'not found in parameters Dataset.'
)
- da = self.parameters[pname]
+ da = self.parameters[target.dataset_key]
x = da.coords['Q']
+ x_factor, y_factor = self._get_unit_conversions(target)
- y_model = func(x.values)
+ y_model = target.function(x.values * x_factor)
+ y_model = y_model / y_factor
- arrays[mname] = sc.DataArray(
+ arrays[target.label] = sc.DataArray(
data=sc.array(dims=['Q'], values=y_model, unit=da.unit),
coords={'Q': x},
)
@@ -521,6 +519,57 @@ def _normalize_names(self, names: str | list[str] | None) -> list[str] | None:
# Private methods
#############
+ def _get_unit_conversions(self, target: FitTarget) -> tuple[float, float]:
+ """
+ Return (x_factor, y_factor) to convert dataset values into the target's declared units.
+
+ x_factor converts Q coordinate values from their stored unit to target.x_unit (e.g.
+ 1/angstrom for diffusion-model predictions). y_factor converts parameter values from their
+ stored unit to target.y_unit. A factor is 1.0 when the target declares the corresponding
+ unit as None (its function takes raw values).
+
+ Parameters
+ ----------
+ target : FitTarget
+ The fit target whose unit contract defines the target units.
+
+ Returns
+ -------
+ tuple[float, float]
+ ``(x_factor, y_factor)`` scale factors to apply before/after model evaluation.
+
+ Raises
+ ------
+ sc.UnitError
+ If x or y units are physically incompatible (e.g. meV vs 1/angstrom).
+ """
+ da = self.parameters[target.dataset_key]
+ x_factor = 1.0
+ y_factor = 1.0
+
+ if target.x_unit is not None:
+ q_unit = str(da.coords['Q'].unit)
+ try:
+ x_factor = sc.to_unit(sc.scalar(1.0, unit=q_unit), str(target.x_unit)).value
+ except Exception as e:
+ raise sc.UnitError(
+ f"Q coordinate unit '{q_unit}' is incompatible with "
+ f"the x_unit '{target.x_unit}' of fit target '{target.label}' "
+ f"for parameter '{target.dataset_key}'."
+ ) from e
+
+ if target.y_unit is not None:
+ param_unit = str(da.unit)
+ try:
+ y_factor = sc.to_unit(sc.scalar(1.0, unit=param_unit), str(target.y_unit)).value
+ except Exception as e:
+ raise sc.UnitError(
+ f"Parameter '{target.dataset_key}' unit '{param_unit}' is incompatible "
+ f"with the y_unit '{target.y_unit}' of fit target '{target.label}'."
+ ) from e
+
+ return x_factor, y_factor
+
def _get_xyweight_from_dataset(
self, parameter_name: str
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
@@ -549,21 +598,29 @@ def _get_xyweight_from_dataset(
raise ValueError(f"Parameter name '{parameter_name}' not found in parameters Dataset.")
variances = self._parameters[parameter_name].variances
+ values = self._parameters[parameter_name].values
+ q_values = self._parameters[parameter_name].coords['Q'].values
+
if variances is None:
- weight = np.ones_like(self._parameters[parameter_name].values)
- elif np.any(~np.isfinite(variances)) or np.any(variances <= 0):
+ return q_values, values, np.ones_like(values)
+
+ # NaN variances arise when a parameter is absent for a given Q (parameters_to_dataset
+ # fills np.nan for missing parameters). Filter those rows silently; other non-finite or
+ # non-positive variances are errors.
+ nan_mask = np.isnan(variances)
+ if np.any(~nan_mask & (~np.isfinite(variances) | (variances <= 0))):
raise ValueError(
f"Non-finite variances found for parameter '{parameter_name}', "
f'cannot compute weights.'
)
- else:
- weight = 1 / np.sqrt(variances)
+ valid_mask = ~nan_mask
+ if not np.any(valid_mask):
+ raise ValueError(
+ f"No finite positive variances found for parameter '{parameter_name}', "
+ f'cannot compute weights.'
+ )
- return (
- self._parameters[parameter_name].coords['Q'].values,
- self._parameters[parameter_name].values,
- weight,
- )
+ return q_values[valid_mask], values[valid_mask], 1 / np.sqrt(variances[valid_mask])
#############
# Dunder methods
@@ -587,9 +644,8 @@ def __repr__(self) -> str:
binding_info = [
{
- 'parameter': b.parameter_name,
'model': b.model.display_name,
- 'modes': b.modes,
+ 'targets': b.targets,
}
for b in self._bindings
]
diff --git a/src/easydynamics/base_classes/easydynamics_modelbase.py b/src/easydynamics/base_classes/easydynamics_modelbase.py
index 6d002f71e..3e48fe82c 100644
--- a/src/easydynamics/base_classes/easydynamics_modelbase.py
+++ b/src/easydynamics/base_classes/easydynamics_modelbase.py
@@ -14,7 +14,8 @@ class EasyDynamicsModelBase(NameMixin, ModelBase):
def __init__(
self,
*args: object,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'MyEasyDynamicsModel',
display_name: str | None = None,
unique_name: str | None = None,
@@ -27,8 +28,10 @@ def __init__(
----------
*args : object
Positional arguments to pass to the parent class.
- unit : str | sc.Unit, default='meV'
- Unit of the model.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the y-axis.
name : str, default='MyEasyDynamicsModel'
Name of the model.
display_name : str | None, default=None
@@ -58,37 +61,43 @@ def __init__(
**kwargs,
)
- self._unit = _validate_unit(unit)
+ self._x_unit = _validate_unit(x_unit)
+ self._y_unit = _validate_unit(y_unit)
@property
- def unit(self) -> str | sc.Unit | None:
+ def x_unit(self) -> str | sc.Unit | None:
"""
- Get the unit of the model.
+ Get the unit of the x-axis.
Returns
-------
str | sc.Unit | None
- The unit of the model.
+ The unit of the x-axis.
"""
+ return self._x_unit
- return self._unit
+ @x_unit.setter
+ def x_unit(self, _: str) -> None:
+ raise AttributeError(
+ f'x_unit is read-only. Use convert_x_unit to change the unit '
+ f'or create a new {self.__class__.__name__} with the desired unit.'
+ )
- @unit.setter
- def unit(self, _unit_str: str) -> None:
+ @property
+ def y_unit(self) -> str | sc.Unit | None:
"""
- Unit is read-only and cannot be set directly.
-
- Parameters
- ----------
- _unit_str : str
- The new unit to set (ignored).
+ Get the unit of the model output.
- Raises
- ------
- AttributeError
- Always raised to indicate that the unit is read-only.
+ Returns
+ -------
+ str | sc.Unit | None
+ The unit of the y-axis.
"""
+ return self._y_unit
+
+ @y_unit.setter
+ def y_unit(self, _: str) -> None:
raise AttributeError(
- f'Unit is read-only. Use convert_unit to change the unit between allowed types '
+ f'y_unit is read-only. Use convert_y_unit to change the unit '
f'or create a new {self.__class__.__name__} with the desired unit.'
)
diff --git a/src/easydynamics/convolution/analytical_convolution.py b/src/easydynamics/convolution/analytical_convolution.py
index 535bea989..efcdf6f52 100644
--- a/src/easydynamics/convolution/analytical_convolution.py
+++ b/src/easydynamics/convolution/analytical_convolution.py
@@ -491,6 +491,7 @@ def __repr__(self) -> str:
f'{self.__class__.__name__}('
f'display_name={self.display_name!r}, '
f'unique_name={self.unique_name!r}, '
- f'unit={self._unit}, '
+ f'x_unit={self.x_unit}, '
+ f'y_unit={self.y_unit}, '
f'energy_len={len(self.energy)})'
)
diff --git a/src/easydynamics/convolution/convolution.py b/src/easydynamics/convolution/convolution.py
index d4e8aafee..89c7f2ca6 100644
--- a/src/easydynamics/convolution/convolution.py
+++ b/src/easydynamics/convolution/convolution.py
@@ -80,7 +80,7 @@ class Convolution(NumericalConvolutionBase):
# needs to be rebuilt.
# Note: the public 'energy' property setter always writes to '_energy', so '_energy' alone
# is sufficient — listing 'energy' separately would cause a double invalidation.
- _invalidate_plan_on_change: ClassVar[dict[str, object]] = {
+ _invalidate_plan_on_change: ClassVar[set[str]] = {
'_energy',
'_energy_grid',
'_sample_components',
@@ -101,7 +101,8 @@ def __init__(
temperature: Parameter | Numeric | None = None,
temperature_unit: str | sc.Unit = 'K',
detailed_balance_settings: DetailedBalanceSettings | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
display_name: str | None = 'MyConvolution',
unique_name: str | None = None,
) -> None:
@@ -113,7 +114,7 @@ def __init__(
energy : np.ndarray | sc.Variable
1D array of energy values where the convolution is evaluated.
sample_components : ComponentCollection | ModelComponent
- The sample components to be convolved.
+ The sample components to be convolved.
resolution_components : ComponentCollection | ModelComponent
The resolution components to convolve with.
energy_offset : Numeric | Parameter, default=0.0
@@ -126,8 +127,10 @@ def __init__(
The unit of the temperature parameter.
detailed_balance_settings : DetailedBalanceSettings | None, default=None
The settings for detailed balance. If None, default settings will be used.
- unit : str | sc.Unit, default='meV'
- The unit of the energy.
+ x_unit : str | sc.Unit, default='meV'
+ The unit of the energy axis.
+ y_unit : str | sc.Unit, default='dimensionless'
+ The unit of the model output (intensity).
display_name : str | None, default='MyConvolution'
Display name of the model.
unique_name : str | None, default=None
@@ -144,7 +147,8 @@ def __init__(
temperature=temperature,
temperature_unit=temperature_unit,
detailed_balance_settings=detailed_balance_settings,
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
display_name=display_name,
unique_name=unique_name,
)
@@ -168,7 +172,7 @@ def convolution(
np.ndarray
The convolved values evaluated at energy.
"""
- if not self.convolution_settings.convolution_plan_is_valid:
+ if not self._convolution_plan_is_current():
self._build_convolution_plan()
total = np.zeros_like(self.energy.values, dtype=float)
@@ -287,7 +291,7 @@ def _build_convolution_plan(self) -> None:
# Update convolvers
self._set_convolvers()
- self.convolution_settings.convolution_plan_is_valid = True
+ self._mark_convolution_plan_current()
def _set_convolvers(self) -> None:
"""
@@ -303,6 +307,8 @@ def _set_convolvers(self) -> None:
energy_offset=self.energy_offset,
sample_components=self._analytical_sample_components,
resolution_components=self._resolution_components,
+ x_unit=self.x_unit,
+ y_unit=self.y_unit,
)
else:
self._analytical_convolver = None
@@ -317,11 +323,27 @@ def _set_convolvers(self) -> None:
temperature=self.temperature,
temperature_unit=self._temperature_unit,
detailed_balance_settings=self.detailed_balance_settings,
- unit=self.unit,
+ x_unit=self.x_unit,
+ y_unit=self.y_unit,
)
else:
self._numerical_convolver = None
+ def convert_y_unit(self, unit: str) -> None:
+ """
+ Convert the y-axis unit and propagate it to the analytical and numerical sub-convolvers.
+
+ Parameters
+ ----------
+ unit : str
+ The new y-axis unit.
+ """
+ super().convert_y_unit(unit)
+ if getattr(self, '_analytical_convolver', None) is not None:
+ self._analytical_convolver._y_unit = self.y_unit # noqa: SLF001
+ if getattr(self, '_numerical_convolver', None) is not None:
+ self._numerical_convolver._y_unit = self.y_unit # noqa: SLF001
+
# Update some setters so the internal sample models are updated
def __setattr__(self, name: str, value: any) -> None:
"""
@@ -343,6 +365,7 @@ def __setattr__(self, name: str, value: any) -> None:
# Only rebuild the convolution plan if reactions are enabled, to
# avoid issues during __init__
if getattr(self, '_reactions_enabled', False) and name in self._invalidate_plan_on_change:
+ self._plan_is_valid = False
self.convolution_settings.convolution_plan_is_valid = False
def __repr__(self) -> str:
@@ -350,7 +373,8 @@ def __repr__(self) -> str:
f'{self.__class__.__name__}('
f'display_name={self.display_name!r}, '
f'unique_name={self.unique_name!r}, '
- f'unit={self.unit}, '
+ f'x_unit={self.x_unit}, '
+ f'y_unit={self.y_unit}, '
f'energy_len={len(self.energy)}, '
f'temperature={self.temperature})'
)
diff --git a/src/easydynamics/convolution/convolution_base.py b/src/easydynamics/convolution/convolution_base.py
index c8516cb0b..2301b70ba 100644
--- a/src/easydynamics/convolution/convolution_base.py
+++ b/src/easydynamics/convolution/convolution_base.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: 2026 EasyScience contributors
# SPDX-License-Identifier: BSD-3-Clause
+import contextlib
+
import numpy as np
import scipp as sc
from easyscience.variable import Parameter
@@ -9,6 +11,7 @@
from easydynamics.sample_model.component_collection import ComponentCollection
from easydynamics.sample_model.components.model_component import ModelComponent
from easydynamics.utils.utils import Numeric
+from easydynamics.utils.utils import energy_to_scipp
class ConvolutionBase(EasyDynamicsModelBase):
@@ -23,7 +26,8 @@ def __init__(
energy: np.ndarray | sc.Variable,
sample_components: ComponentCollection | ModelComponent | None = None,
resolution_components: ComponentCollection | ModelComponent | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
energy_offset: Numeric | Parameter = 0.0,
display_name: str | None = 'MyConvolution',
unique_name: str | None = None,
@@ -39,8 +43,10 @@ def __init__(
The sample model to be convolved.
resolution_components : ComponentCollection | ModelComponent | None, default=None
The resolution model to convolve with.
- unit : str | sc.Unit, default='meV'
- The unit of the energy.
+ x_unit : str | sc.Unit, default='meV'
+ The unit of the energy axis.
+ y_unit : str | sc.Unit, default='dimensionless'
+ The unit of the model output (intensity).
energy_offset : Numeric | Parameter, default=0.0
The energy offset applied to the convolution.
display_name : str | None, default='MyConvolution'
@@ -58,7 +64,8 @@ def __init__(
"""
super().__init__(
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
display_name=display_name,
unique_name=unique_name,
)
@@ -70,10 +77,12 @@ def __init__(
raise TypeError(f'Energy must be a numpy ndarray or a scipp Variable. Got {energy}')
if isinstance(energy, np.ndarray):
- energy = sc.array(dims=['energy'], values=energy, unit=unit)
+ energy = energy_to_scipp(energy, x_unit)
if isinstance(energy_offset, Numeric):
- energy_offset = Parameter(name='energy_offset', value=float(energy_offset), unit=unit)
+ energy_offset = Parameter(
+ name='energy_offset', value=float(energy_offset), unit=x_unit
+ )
if not isinstance(energy_offset, Parameter):
raise TypeError('Energy_offset must be a number or a Parameter.')
@@ -147,8 +156,9 @@ def energy_with_offset(self) -> sc.Variable:
sc.Variable
The energy values with the offset applied.
"""
+ offset_value = sc.to_unit(self.energy_offset.full_value, self._energy.unit).value
energy_with_offset = self.energy.copy()
- energy_with_offset.values = self.energy.values - self.energy_offset.value
+ energy_with_offset.values = self.energy.values - offset_value
return energy_with_offset
@property
@@ -187,15 +197,15 @@ def energy(self, energy: np.ndarray | sc.Variable) -> None:
raise TypeError('Energy must be a Number, a numpy ndarray or a scipp Variable.')
if isinstance(energy, np.ndarray):
- self._energy = sc.array(dims=['energy'], values=energy, unit=self._energy.unit)
+ self._energy = energy_to_scipp(energy, self._energy.unit)
if isinstance(energy, sc.Variable):
self._energy = energy
- self._unit = energy.unit
+ self._x_unit = energy.unit
- def convert_unit(self, unit: str | sc.Unit) -> None:
+ def convert_x_unit(self, unit: str | sc.Unit) -> None:
"""
- Convert the energy and energy_offset to the specified unit.
+ Convert the energy axis, energy_offset, and all components to the specified unit.
Parameters
----------
@@ -213,20 +223,62 @@ def convert_unit(self, unit: str | sc.Unit) -> None:
raise TypeError('Energy unit must be a string or scipp unit.')
old_energy = self.energy.copy()
+ old_x_unit = str(self.x_unit)
+ old_offset_unit = str(self.energy_offset.unit)
+
try:
self.energy = sc.to_unit(self.energy, unit)
- except Exception as e:
+ self._energy_offset.convert_unit(unit)
+ if self.sample_components is not None:
+ self.sample_components.convert_x_unit(unit)
+ if self.resolution_components is not None:
+ self.resolution_components.convert_x_unit(unit)
+ except Exception:
self.energy = old_energy
- raise e
+ # Roll back energy_offset if it was already converted to the new unit.
+ if str(self._energy_offset.unit) != old_offset_unit:
+ self._energy_offset.convert_unit(old_offset_unit)
+ # Roll back component collections that may have been partially converted.
+ if self.sample_components is not None:
+ with contextlib.suppress(Exception):
+ self.sample_components.convert_x_unit(old_x_unit)
+ if self.resolution_components is not None:
+ with contextlib.suppress(Exception):
+ self.resolution_components.convert_x_unit(old_x_unit)
+ raise
+
+ self._x_unit = unit
+
+ def convert_y_unit(self, unit: str | sc.Unit) -> None:
+ """
+ Convert the y-axis unit of the sample components.
- old_energy_offset = self.energy_offset
- try:
- self.energy_offset.convert_unit(unit)
- except Exception as e:
- self.energy_offset = old_energy_offset
- raise e
+ Only propagates to sample components (resolution is normalised and unit-independent).
+
+ Parameters
+ ----------
+ unit : str | sc.Unit
+ The new y-axis unit.
- self._unit = unit
+ Raises
+ ------
+ TypeError
+ If unit is not a string or scipp unit.
+ Exception
+ If any component raises during unit conversion. On failure, attempts to roll back.
+ """
+ if not isinstance(unit, (str, sc.Unit)):
+ raise TypeError('y_unit must be a string or scipp unit.')
+ old_y_unit = self.y_unit
+ try:
+ if self.sample_components is not None:
+ self.sample_components.convert_y_unit(unit)
+ self._y_unit = str(unit) if isinstance(unit, sc.Unit) else unit
+ except Exception:
+ if self.sample_components is not None:
+ with contextlib.suppress(Exception):
+ self.sample_components.convert_y_unit(old_y_unit)
+ raise
@property
def sample_components(self) -> ComponentCollection | ModelComponent:
diff --git a/src/easydynamics/convolution/numerical_convolution.py b/src/easydynamics/convolution/numerical_convolution.py
index e0ddde05d..fbe36b42a 100644
--- a/src/easydynamics/convolution/numerical_convolution.py
+++ b/src/easydynamics/convolution/numerical_convolution.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
+import scipp as sc
from scipy.signal import fftconvolve
from easydynamics.convolution.numerical_convolution_base import NumericalConvolutionBase
@@ -31,9 +32,9 @@ def convolution(
"""
# Make sure the convolver is updated with the latest convolution
# settings before convolution.
- if not self.convolution_settings.convolution_plan_is_valid:
+ if not self._convolution_plan_is_current():
self._energy_grid = self._create_energy_grid()
- self.convolution_settings.convolution_plan_is_valid = True
+ self._mark_convolution_plan_current()
# Give warnings if peaks are very wide or very narrow
if not self.convolution_settings.suppress_warnings:
@@ -46,18 +47,24 @@ def convolution(
model_name='resolution model',
)
+ # Unit-convert the energy offset to match the energy grid unit.
+ # sc.to_unit returns a new scalar — self.energy_offset is never mutated.
+ offset_value = sc.to_unit(self.energy_offset.full_value, self.energy.unit).value
+
# Evaluate sample model. If called via the Convolution class,
# delta functions are already filtered out.
sample_vals = self.sample_components.evaluate(
self._energy_grid.energy_dense
- self._energy_grid.energy_even_length_offset
- - self.energy_offset.value
+ - offset_value
)
# Detailed balance correction
if self.temperature is not None and self.detailed_balance_settings.use_detailed_balance:
detailed_balance_factor_correction = detailed_balance_factor(
- energy=self._energy_grid.energy_dense - self.energy_offset.value,
+ energy=self._energy_grid.energy_dense
+ - self._energy_grid.energy_even_length_offset
+ - offset_value,
temperature=self.temperature,
energy_unit=self.energy.unit,
divide_by_temperature=self.detailed_balance_settings.normalize_detailed_balance,
@@ -90,7 +97,8 @@ def __repr__(self) -> str:
f'{self.__class__.__name__}('
f'display_name={self.display_name!r}, '
f'unique_name={self.unique_name!r}, '
- f'unit={self.unit}, '
+ f'x_unit={self.x_unit}, '
+ f'y_unit={self.y_unit}, '
f'energy_len={len(self.energy)}, '
f'temperature={self.temperature})'
)
diff --git a/src/easydynamics/convolution/numerical_convolution_base.py b/src/easydynamics/convolution/numerical_convolution_base.py
index 188a69dbc..e837a4497 100644
--- a/src/easydynamics/convolution/numerical_convolution_base.py
+++ b/src/easydynamics/convolution/numerical_convolution_base.py
@@ -43,7 +43,8 @@ def __init__(
temperature: Parameter | Numeric | None = None,
temperature_unit: str | sc.Unit = 'K',
detailed_balance_settings: DetailedBalanceSettings | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
display_name: str | None = 'MyConvolution',
unique_name: str | None = None,
) -> None:
@@ -68,8 +69,10 @@ def __init__(
The unit of the temperature parameter.
detailed_balance_settings : DetailedBalanceSettings | None, default=None
The settings for detailed balance. If None, default settings will be used.
- unit : str | sc.Unit, default='meV'
- The unit of the energy.
+ x_unit : str | sc.Unit, default='meV'
+ The unit of the energy axis.
+ y_unit : str | sc.Unit, default='dimensionless'
+ The unit of the model output (intensity).
display_name : str | None, default='MyConvolution'
Display name of the model.
unique_name : str | None, default=None
@@ -85,7 +88,8 @@ def __init__(
energy=energy,
sample_components=sample_components,
resolution_components=resolution_components,
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
energy_offset=energy_offset,
display_name=display_name,
unique_name=unique_name,
@@ -116,7 +120,36 @@ def __init__(
# When upsample_factor>1, we evaluate on this grid and
# interpolate back to the original values at the end
self._energy_grid = self._create_energy_grid()
- self._convolution_settings.convolution_plan_is_valid = True
+ self._mark_convolution_plan_current()
+
+ def _convolution_plan_is_current(self) -> bool:
+ """
+ Check whether this convolver's plan is up to date.
+
+ Plan validity is tracked per convolver so several convolvers can share one
+ ConvolutionSettings object: a local flag covers convolver-local invalidations (e.g. a new
+ energy grid), while the settings' plan version detects settings changes this convolver has
+ not consumed yet — even if a sibling convolver already rebuilt and set the shared flag.
+ Explicitly setting ``convolution_plan_is_valid = True`` on the settings remains an escape
+ hatch that suppresses rebuilds for all convolvers.
+
+ Returns
+ -------
+ bool
+ True if the plan does not need to be rebuilt.
+ """
+ if not getattr(self, '_plan_is_valid', False):
+ return False
+ return self.convolution_settings._plan_valid_for( # noqa: SLF001
+ self._plan_settings_version_seen
+ )
+
+ def _mark_convolution_plan_current(self) -> None:
+ """Record that this convolver's plan matches its current state and settings."""
+ self._plan_is_valid = True
+ self._plan_settings_version_seen = (
+ self.convolution_settings._mark_plan_built() # noqa: SLF001
+ )
@property
def convolution_settings(self) -> ConvolutionSettings:
@@ -149,6 +182,7 @@ def convolution_settings(self, settings: ConvolutionSettings) -> None:
if not isinstance(settings, ConvolutionSettings):
raise TypeError('settings must be a ConvolutionSettings instance.')
self._convolution_settings = settings
+ self._plan_is_valid = False
self._convolution_settings.convolution_plan_is_valid = False
@ConvolutionBase.energy.setter
@@ -162,6 +196,8 @@ def energy(self, energy: np.ndarray) -> None:
The new energy array.
"""
ConvolutionBase.energy.fset(self, energy)
+ self._energy_grid = self._create_energy_grid()
+ self._plan_is_valid = False
self.convolution_settings.convolution_plan_is_valid = False
@property
@@ -437,12 +473,12 @@ def __repr__(self) -> str:
"""
return (
f'{self.__class__.__name__}('
- f' energy=array of shape {self.energy.values.shape},\n'
- f' sample_components={self.sample_components!r},\n'
- f' resolution_components={self.resolution_components!r},\n'
- f' unit={self.unit}, '
- f' upsample_factor={self.upsample_factor}, '
- f' extension_factor={self.extension_factor}, '
- f' temperature={self.temperature}, '
- f' detailed_balance={self.detailed_balance_settings!r})'
+ f'energy=array of shape {self.energy.values.shape},\n '
+ f'sample_components={self.sample_components!r}, \n'
+ f'resolution_components={self.resolution_components!r},\n '
+ f'x_unit={self.x_unit}, y_unit={self.y_unit}, '
+ f'upsample_factor={self.upsample_factor}, '
+ f'extension_factor={self.extension_factor}, '
+ f'temperature={self.temperature}, '
+ f'detailed_balance={self.detailed_balance_settings!r})'
)
diff --git a/src/easydynamics/experiment/experiment.py b/src/easydynamics/experiment/experiment.py
index a2fb04fe4..d791ee5db 100644
--- a/src/easydynamics/experiment/experiment.py
+++ b/src/easydynamics/experiment/experiment.py
@@ -13,6 +13,7 @@
from easydynamics.base_classes.easydynamics_base import EasyDynamicsBase
from easydynamics.utils.utils import _in_notebook
+from easydynamics.utils.utils import verify_Q_index
class Experiment(EasyDynamicsBase):
@@ -239,11 +240,6 @@ def get_masked_energy(self, Q_index: int) -> sc.Variable | None:
Q_index : int
The Q index to get the masked energy values for.
- Raises
- ------
- IndexError
- If Q_index is not a valid index for the Q values.
-
Returns
-------
sc.Variable | None
@@ -252,19 +248,58 @@ def get_masked_energy(self, Q_index: int) -> sc.Variable | None:
if self.binned_data is None:
return None
- if (
- not isinstance(Q_index, int)
- or Q_index < 0
- or (self.Q is not None and Q_index >= len(self.Q))
- ):
- raise IndexError('Q_index must be a valid index for the Q values.')
+ verify_Q_index(Q_index, self.Q)
energy = self.binned_data.coords['energy']
- _, _, _, mask = self._extract_x_y_weights_only_finite(Q_index=Q_index)
-
- mask_var = sc.array(dims=['energy'], values=mask)
+ mask_var = self.get_finite_energy_mask(Q_index=Q_index)
return energy[mask_var]
+ def get_finite_energy_mask(self, Q_index: int) -> sc.Variable | None:
+ """
+ Get a boolean scipp Variable selecting energy points with finite intensity at the given Q.
+
+ Parameters
+ ----------
+ Q_index : int
+ The Q index to get the mask for.
+
+ Returns
+ -------
+ sc.Variable | None
+ Boolean scipp Variable of length n_energy with dim ``'energy'``, or None if no data is
+ loaded.
+ """
+ if self.binned_data is None:
+ return None
+
+ verify_Q_index(Q_index, self.Q)
+
+ _, _, _, mask = self.extract_x_y_weights_only_finite(Q_index=Q_index)
+ return sc.array(dims=['energy'], values=mask)
+
+ def get_masked_binned_data(self, Q_index: int) -> sc.DataArray | None:
+ """
+ Get the binned data for a single Q slice with non-finite points masked out.
+
+ Parameters
+ ----------
+ Q_index : int
+ The Q index to extract.
+
+ Returns
+ -------
+ sc.DataArray | None
+ The binned data for the given Q index with NaN/Inf points removed, or None if no data
+ is loaded.
+ """
+ if self.binned_data is None:
+ return None
+
+ verify_Q_index(Q_index, self.Q)
+
+ mask_var = self.get_finite_energy_mask(Q_index=Q_index)
+ return self.binned_data['Q', Q_index][mask_var]
+
###########
# Handle data
###########
@@ -546,7 +581,7 @@ def _extract_x_y_var(self, Q_index: int) -> tuple[np.ndarray, np.ndarray, np.nda
var = data.variances
return x, y, var
- def _extract_x_y_weights_only_finite(
+ def extract_x_y_weights_only_finite(
self, Q_index: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
diff --git a/src/easydynamics/sample_model/background_model.py b/src/easydynamics/sample_model/background_model.py
index 34d313932..031d9493c 100644
--- a/src/easydynamics/sample_model/background_model.py
+++ b/src/easydynamics/sample_model/background_model.py
@@ -47,7 +47,8 @@ def __init__(
self,
display_name: str | None = 'MyBackgroundModel',
unique_name: str | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
components: ModelComponent | ComponentCollection | None = None,
Q: Q_type | None = None,
) -> None:
@@ -60,8 +61,10 @@ def __init__(
Display name of the model.
unique_name : str | None, default=None
Unique name of the model. If None, a unique name will be generated.
- unit : str | sc.Unit, default='meV'
- Unit of the model.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis (energy, Q, etc.).
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the model output (intensity).
components : ModelComponent | ComponentCollection | None, default=None
Template components of the model. If None, no components are added. These components
are copied into ComponentCollections for each Q value.
@@ -71,7 +74,8 @@ def __init__(
super().__init__(
display_name=display_name,
unique_name=unique_name,
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
components=components,
Q=Q,
)
@@ -80,7 +84,7 @@ def __repr__(self) -> str:
return (
f'{self.__class__.__name__}('
f'unique_name={self.unique_name!r}, '
- f'unit={self.unit}, '
+ f'x_unit={self.x_unit}, y_unit={self.y_unit}, '
f'Q_len={None if self._Q is None else len(self._Q)}, '
f'components={self.components})'
)
diff --git a/src/easydynamics/sample_model/component_collection.py b/src/easydynamics/sample_model/component_collection.py
index 54a6b0010..840b18709 100644
--- a/src/easydynamics/sample_model/component_collection.py
+++ b/src/easydynamics/sample_model/component_collection.py
@@ -9,14 +9,14 @@
import numpy as np
import scipp as sc
-from easyscience.variable import DescriptorBase
-from easyscience.variable import Parameter
from easydynamics.base_classes.easydynamics_list import EasyDynamicsList
from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase
from easydynamics.sample_model.components.model_component import ModelComponent
if TYPE_CHECKING:
+ from easyscience.variable import DescriptorBase
+
from easydynamics.utils.utils import Numeric
@@ -54,7 +54,8 @@ class ComponentCollection(EasyDynamicsList, EasyDynamicsModelBase):
def __init__(
self,
components: ModelComponent | list[ModelComponent] | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'ComponentCollection',
display_name: str | None = None,
unique_name: str | None = None,
@@ -66,8 +67,10 @@ def __init__(
----------
components : ModelComponent | list[ModelComponent] | None, default=None
Initial model components to add to the ComponentCollection.
- unit : str | sc.Unit, default='meV'
- Unit of the collection.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis (energy, Q, etc.).
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the model output (intensity).
name : str, default='ComponentCollection'
Name of the collection.
display_name : str | None, default=None
@@ -86,12 +89,14 @@ def __init__(
components = [components]
elif not isinstance(components, list):
raise TypeError(
- f'components must be a ModelComponent or a list of ModelComponent, got {type(components).__name__} instead.' # noqa: E501
+ f'components must be a ModelComponent or a list of ModelComponent, '
+ f'got {type(components).__name__} instead.'
)
for comp in components:
if not isinstance(comp, ModelComponent):
raise TypeError(
- f'All items in components must be instances of ModelComponent, got {type(comp).__name__} instead.' # noqa: E501
+ f'All items in components must be instances of ModelComponent, '
+ f'got {type(comp).__name__} instead.'
)
EasyDynamicsList.__init__(
@@ -102,7 +107,8 @@ def __init__(
EasyDynamicsModelBase.__init__(
self,
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
name=name,
display_name=display_name,
unique_name=unique_name,
@@ -146,39 +152,74 @@ def is_empty(self, _value: bool) -> None:
'whether the collection has components.'
)
- def convert_unit(self, unit: str | sc.Unit) -> None:
+ # ------------------------------------------------------------------
+ # Unit conversion
+ # ------------------------------------------------------------------
+
+ def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None:
+ """
+ Convert the x-axis unit of the ComponentCollection and all its components.
+
+ Parameters
+ ----------
+ new_x_unit : str | sc.Unit
+ The target x-axis unit to convert to.
+ """
+ self._convert_axis_unit(new_x_unit, axis='x')
+
+ def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None:
"""
- Convert the unit of the ComponentCollection and all its components.
+ Convert the y-axis unit of the ComponentCollection and all its components.
+
+ Parameters
+ ----------
+ new_y_unit : str | sc.Unit
+ The target y-axis unit to convert to.
+ """
+ self._convert_axis_unit(new_y_unit, axis='y')
+
+ def _convert_axis_unit(self, unit: str | sc.Unit, axis: str) -> None:
+ """
+ Convert one axis unit on all components in the collection.
+
+ Converts every component via its ``convert__unit`` method and updates the
+ collection's own unit attribute. On failure, attempts a best-effort rollback of all
+ components to the old unit before re-raising.
Parameters
----------
unit : str | sc.Unit
- The target unit to convert to.
+ The new unit to convert to.
+ axis : str
+ Which axis to convert: ``'x'`` or ``'y'``.
Raises
------
TypeError
- If unit is not a string or sc.Unit.
+ If the provided unit is not a string or sc.Unit.
Exception
If any component cannot be converted to the specified unit.
"""
-
if not isinstance(unit, (str, sc.Unit)):
- raise TypeError(f'Unit must be a string or sc.Unit, got {type(unit).__name__}')
-
- old_unit = self._unit
+ raise TypeError(f'{axis}_unit must be a string or sc.Unit, got {type(unit).__name__}')
+ method = f'convert_{axis}_unit'
+ old_unit = self.x_unit if axis == 'x' else self.y_unit
try:
for component in self:
- component.convert_unit(unit)
- self._unit = unit
+ getattr(component, method)(unit)
+ unit_str = str(unit) if isinstance(unit, sc.Unit) else unit
+ if axis == 'x':
+ self._x_unit = unit_str
+ else:
+ self._y_unit = unit_str
except Exception as e:
- # Attempt to rollback on failure
- try:
- for component in self:
- component.convert_unit(old_unit)
- except Exception: # noqa: S110
- pass # Best effort rollback
+ if old_unit is not None:
+ try:
+ for component in self:
+ getattr(component, method)(old_unit)
+ except Exception: # noqa: S110
+ pass
raise e
# ------------------------------------------------------------------
@@ -211,7 +252,6 @@ def list_component_names(self) -> list[str]:
list[str]
List of names of the components in the collection.
"""
-
return [component.name for component in self]
def normalize_area(self) -> None:
@@ -230,28 +270,38 @@ def normalize_area(self) -> None:
raise ValueError('No components in the model to normalize.')
area_params = []
- total_area = Parameter(name='total_area', value=0.0, unit=self._unit)
for component in self:
if hasattr(component, 'area'):
area_params.append(component.area)
- total_area += component.area
else:
warnings.warn(
f"Component '{component.name}' does not have an 'area' attribute "
- f'and will be skipped in normalization.',
+ 'and will be skipped in normalization.',
UserWarning,
stacklevel=2,
)
- if total_area.value == 0:
+ if not area_params:
+ raise ValueError('No components with an area attribute; cannot normalize.')
+
+ # Sum the areas in a common unit so components with different (but compatible)
+ # units normalize correctly. Dividing each value by the total expressed in the
+ # reference unit makes the areas sum to 1 in that unit.
+ reference_unit = str(area_params[0].unit)
+ total_area_value = sum(
+ sc.to_unit(sc.scalar(p.value, unit=str(p.unit)), reference_unit).value
+ for p in area_params
+ )
+
+ if total_area_value == 0:
raise ValueError('Total area is zero; cannot normalize.')
- if not np.isfinite(total_area.value):
+ if not np.isfinite(total_area_value):
raise ValueError('Total area is not finite; cannot normalize.')
for param in area_params:
- param.value /= total_area.value
+ param.value /= total_area_value
# ------------------------------------------------------------------
# Other methods
@@ -259,17 +309,20 @@ def normalize_area(self) -> None:
def get_all_variables(self) -> list[DescriptorBase]:
"""
- Get all parameters from the model component.
+ Get all parameters from all model components.
Returns
-------
list[DescriptorBase]
- List of parameters in the component.
+ List of parameters in the collection.
"""
-
return [var for component in self for var in component.get_all_variables()]
- def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray:
+ def evaluate(
+ self,
+ x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray,
+ output: str = 'numpy',
+ ) -> np.ndarray | sc.Variable:
"""
Evaluate the sum of all components.
@@ -277,22 +330,35 @@ def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray)
----------
x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
Energy axis.
+ output : str, default='numpy'
+ 'numpy' returns np.ndarray; 'scipp' returns sc.Variable with y_unit.
Returns
-------
- np.ndarray
+ np.ndarray | sc.Variable
Evaluated model values.
"""
-
if not self:
- return np.zeros_like(x)
- return sum(component.evaluate(x) for component in self)
+ if isinstance(x, (sc.Variable, sc.DataArray)):
+ values = np.zeros_like(x.values, dtype=float)
+ dim = x.dims[0] if x.dims else 'x'
+ else:
+ values = np.zeros_like(x, dtype=float)
+ dim = 'x'
+ if output == 'scipp':
+ return sc.array(dims=[dim], values=values, unit=self.y_unit)
+ return values
+ # This is needed to handle both scipp and numpy output - a normal call to sum does not work
+ gen = (component.evaluate(x, output=output) for component in self)
+ first = next(gen)
+ return sum(gen, first)
def evaluate_component(
self,
x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray,
name: str,
- ) -> np.ndarray:
+ output: str = 'numpy',
+ ) -> np.ndarray | sc.Variable:
"""
Evaluate a single component by name.
@@ -302,6 +368,8 @@ def evaluate_component(
Energy axis.
name : str
Component name.
+ output : str, default='numpy'
+ 'numpy' returns np.ndarray; 'scipp' returns sc.Variable with y_unit.
Raises
------
@@ -314,22 +382,17 @@ def evaluate_component(
Returns
-------
- np.ndarray
+ np.ndarray | sc.Variable
Evaluated values for the specified component.
"""
if not self:
raise ValueError('No components in the model to evaluate.')
-
if not isinstance(name, str):
raise TypeError(f'Component name must be a string, got {type(name)} instead.')
-
matches = [comp for comp in self if comp.name == name]
if not matches:
raise KeyError(f"No component named '{name}' exists.")
-
- component = matches[0]
-
- return component.evaluate(x)
+ return matches[0].evaluate(x, output=output)
def fix_all_parameters(self) -> None:
"""Fix all free parameters in the model."""
@@ -376,11 +439,10 @@ def __repr__(self) -> str:
String representation of the ComponentCollection.
"""
comp_names = ', '.join(c.name for c in self) or 'No components'
-
return (
- f'{self.__class__.__name__}('
- f'name={self.name!r}, unit={self.unit},\n'
- f' components=[{comp_names}])'
+ f"{self.__class__.__name__}(name='{self.name}', "
+ f"x_unit='{self.x_unit}', y_unit='{self.y_unit}',\n"
+ f'Components: {comp_names})'
)
def to_dict(self) -> dict:
@@ -395,7 +457,8 @@ def to_dict(self) -> dict:
return {
'@module': self.__class__.__module__,
'@class': self.__class__.__name__,
- 'unit': str(self.unit),
+ 'x_unit': str(self.x_unit),
+ 'y_unit': str(self.y_unit),
'name': self.name,
'display_name': self.display_name,
'components': [c.to_dict() for c in self._data],
@@ -433,13 +496,14 @@ def deserialise_component(d: dict) -> ModelComponent:
cls = getattr(module, d['@class'])
return cls.from_dict(d)
- components = [deserialise_component(c) for c in obj_dict.get('components', [])]
+ components = [deserialise_component(c) for c in obj_dict['components']]
return cls(
components=components,
- unit=obj_dict.get('unit', 'meV'),
- name=obj_dict.get('name', 'ComponentCollection'),
- display_name=obj_dict.get('display_name'),
+ x_unit=obj_dict['x_unit'],
+ y_unit=obj_dict['y_unit'],
+ name=obj_dict['name'],
+ display_name=obj_dict['display_name'],
)
def __copy__(self) -> ComponentCollection:
@@ -451,5 +515,4 @@ def __copy__(self) -> ComponentCollection:
ComponentCollection
A deep copy of the ComponentCollection.
"""
-
return self.from_dict(self.to_dict())
diff --git a/src/easydynamics/sample_model/components/damped_harmonic_oscillator.py b/src/easydynamics/sample_model/components/damped_harmonic_oscillator.py
index 9d4d1ed75..90a63562f 100644
--- a/src/easydynamics/sample_model/components/damped_harmonic_oscillator.py
+++ b/src/easydynamics/sample_model/components/damped_harmonic_oscillator.py
@@ -20,8 +20,11 @@ class DampedHarmonicOscillator(CreateParametersMixin, ModelComponent):
r"""
Model of a Damped Harmonic Oscillator (DHO).
- The intensity is given by $$ I(x) = \frac{2 A x_0^2 \gamma}{\pi \left( (x^2 - x_0^2)^2 + (2
- \gamma x)^2 \right)}, $$ where $A$ is the area, $x_0$ is the center, and $\gamma$ is the width.
+ $$ I(x) = \frac{2 A x_0^2 \gamma}{\pi \left( (x^2 - x_0^2)^2 + (2\gamma x)^2 \right)} $$
+
+ where $A$ is the area (``area``), $x_0$ is the center (``center``), and $\gamma$ is the half
+ width at half max (``width``). area has unit = x_unit * y_unit; center and width have unit =
+ x_unit.
Examples
--------
@@ -55,56 +58,57 @@ def __init__(
area: Numeric = 1.0,
center: Numeric = 1.0,
width: Numeric = 1.0,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'DampedHarmonicOscillator',
display_name: str | None = None,
unique_name: str | None = None,
) -> None:
"""
- Initialize the Damped Harmonic Oscillator.
+ Initialize the Damped Harmonic Oscillator component.
Parameters
----------
area : Numeric, default=1.0
- Area under the curve.
+ Integrated area under the DHO profile. Unit is ``x_unit * y_unit``.
center : Numeric, default=1.0
- Resonance frequency, approximately the peak position.
+ Resonance frequency (x_0) in x_unit; approximately the peak position. Must be strictly
+ positive; a minimum of ``DHO_MINIMUM_CENTER`` (1e-10) is enforced.
width : Numeric, default=1.0
- Damping constant, approximately the half width at half max (HWHM) of the peaks. By
- default, 1.0.
- unit : str | sc.Unit, default='meV'
- Unit of the parameters.
+ Damping coefficient (gamma) in x_unit. Must be strictly positive. Approximately equal
+ to the HWHM of each peak.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis. center and width are stored in this unit. area_unit = x_unit *
+ y_unit.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the y-axis (output).
name : str, default='DampedHarmonicOscillator'
- Name of the component for indexing.
+ Name of the component.
display_name : str | None, default=None
- Display name of the component.
+ Display name shown when plotting. Falls back to *name* if None.
unique_name : str | None, default=None
- Unique name of the component. If None, a unique_name is automatically generated. By
- default, None.
+ Globally unique identifier. Auto-generated if None.
"""
-
super().__init__(
name=name,
display_name=display_name,
unique_name=unique_name,
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
)
- # These methods live in ValidationMixin
- area = self._create_area_parameter(area=area, name=name, unit=self._unit)
- center = self._create_center_parameter(
+ # These methods live in CreateParametersMixin
+ self._area = self._create_area_parameter(
+ area=area, name=name, x_unit=self.x_unit, y_unit=self.y_unit
+ )
+ self._center = self._create_center_parameter(
center=center,
name=name,
fix_if_none=False,
- unit=self._unit,
+ x_unit=self.x_unit,
enforce_minimum_center=True,
)
-
- width = self._create_width_parameter(width=width, name=name, unit=self._unit)
-
- self._area = area
- self._center = center
- self._width = width
+ self._width = self._create_width_parameter(width=width, name=name, x_unit=self.x_unit)
@property
def area(self) -> Parameter:
@@ -114,24 +118,22 @@ def area(self) -> Parameter:
Returns
-------
Parameter
- The area parameter.
+ The area Parameter with unit ``x_unit * y_unit``.
"""
return self._area
@area.setter
def area(self, value: Numeric) -> None:
"""
- Set the value of the area parameter.
-
Parameters
----------
value : Numeric
- The new value for the area parameter.
+ New area value (in current area unit = x_unit * y_unit).
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
"""
if not isinstance(value, Numeric):
raise TypeError('area must be a number')
@@ -140,35 +142,32 @@ def area(self, value: Numeric) -> None:
@property
def center(self) -> Parameter:
"""
- Get the center parameter.
+ Get the center parameter (resonance frequency).
Returns
-------
Parameter
- The center parameter.
+ The resonance frequency (x_0) Parameter with unit ``x_unit``.
"""
return self._center
@center.setter
def center(self, value: Numeric) -> None:
"""
- Set the value of the center parameter.
-
Parameters
----------
value : Numeric
- The new value for the center parameter.
+ New resonance frequency in x_unit. Must be strictly positive.
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
ValueError
- If the value is not positive.
+ If *value* is not positive.
"""
if not isinstance(value, Numeric):
raise TypeError('center must be a number')
-
if float(value) <= 0:
raise ValueError('center must be positive')
self._center.value = value
@@ -176,66 +175,94 @@ def center(self, value: Numeric) -> None:
@property
def width(self) -> Parameter:
"""
- Get the width parameter.
+ Get the width parameter (damping coefficient).
Returns
-------
Parameter
- The width parameter.
+ The damping coefficient (gamma) Parameter with unit ``x_unit``.
"""
return self._width
@width.setter
def width(self, value: Numeric) -> None:
"""
- Set the value of the width parameter.
-
Parameters
----------
value : Numeric
- The new value for the width parameter.
+ New damping coefficient in x_unit. Must be strictly positive.
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
ValueError
- If the value is not positive.
+ If *value* is not positive.
"""
if not isinstance(value, Numeric):
raise TypeError('width must be a number')
-
if float(value) <= 0:
raise ValueError('width must be positive')
-
self._width.value = value
- def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray:
+ def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray:
r"""
- Evaluate the Damped Harmonic Oscillator at the given x values.
+ Evaluate the DHO at x_vals.
+
+ $$ I(x) = \frac{2 A x_0^2 \gamma}{\pi \left( (x^2 - x_0^2)^2 + (2\gamma x)^2 \right)} $$
- If x is a scipp Variable, the unit of the DHO will be converted to match x. The intensity
- is given by $$ I(x) = \frac{2 A x_0^2 \gamma}{\pi \left( (x^2 - x_0^2)^2 + (2 \gamma x)^2
- \right)}, $$ where $A$ is the area, $x_0$ is the center, and $\gamma$ is the width.
+ where *A* is ``area``, *x*₀ is ``center`` (resonance frequency), and *gamma* is ``width``
+ (damping coefficient). Here *I* is the scattered intensity. Parameters in the model's own
+ units are temporarily converted to eval_unit for the computation.
Parameters
----------
- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- The x values at which to evaluate the DHO.
+ x_vals : np.ndarray
+ Raw x values expressed in eval_unit.
+ eval_unit : str | None
+ The unit of x_vals.
Returns
-------
np.ndarray
- The intensity of the DHO at the given x values.
+ Evaluated DHO values at x_vals.
"""
+ center = self._resolve_param_value(self._center, eval_unit)
+ width = self._resolve_param_value(self._width, eval_unit)
+ area = self._resolve_param_value(self._area, self._eval_area_unit(eval_unit))
- x = self._prepare_x_for_evaluate(x)
+ normalization = 2 * center**2 * width / np.pi
+ denominator = (x_vals**2 - center**2) ** 2 + (2 * width * x_vals) ** 2
+ # denominator cannot reach zero: center > 0 enforced by DHO_MINIMUM_CENTER
+ return area * normalization / denominator
- normalization = 2 * self.center.value**2 * self.width.value / np.pi
- # No division by zero here, width>0 enforced in setter
- denominator = (x**2 - self.center.value**2) ** 2 + (2 * self.width.value * x) ** 2
+ def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None:
+ """
+ Convert x-axis parameters (center, width) and area to new_x_unit.
+
+ Parameters
+ ----------
+ new_x_unit : str | sc.Unit
+ Target x-axis unit. Must be dimensionally compatible with the current x_unit.
+ """
+ self._convert_x_unit_area_based(
+ new_x_unit=new_x_unit,
+ x_params=[self._center, self._width],
+ area_param=self._area,
+ )
- return self.area.value * normalization / (denominator)
+ def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None:
+ """
+ Convert the y-axis unit by rescaling the area parameter.
+
+ The area is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``.
+
+ Parameters
+ ----------
+ new_y_unit : str | sc.Unit
+ Target y-axis unit.
+ """
+ self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._area)
def __repr__(self) -> str:
"""
@@ -247,10 +274,9 @@ def __repr__(self) -> str:
A string representation of the Damped Harmonic Oscillator.
"""
return (
- f'{self.__class__.__name__}('
- f'name={self.name!r}, display_name={self.display_name!r}, '
- f'unit={self._unit},\n'
- f' area={self.area},\n'
- f' center={self.center},\n'
- f' width={self.width})'
+ f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, '
+ f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n '
+ f' area = {self.area},\n '
+ f' center = {self.center},\n '
+ f' width = {self.width})'
)
diff --git a/src/easydynamics/sample_model/components/delta_function.py b/src/easydynamics/sample_model/components/delta_function.py
index bca8c4239..9bea615ab 100644
--- a/src/easydynamics/sample_model/components/delta_function.py
+++ b/src/easydynamics/sample_model/components/delta_function.py
@@ -11,8 +11,7 @@
from easydynamics.sample_model.components.model_component import ModelComponent
from easydynamics.utils.utils import Numeric
-EPSILON = 1e-8 # small number to avoid floating point issues
-
+EPSILON = 1e-8 # tolerance for bin-edge comparisons
if TYPE_CHECKING:
import scipp as sc
@@ -23,9 +22,12 @@ class DeltaFunction(CreateParametersMixin, ModelComponent):
"""
Delta function.
- Evaluates to zero everywhere, except in convolutions, where it acts as an identity. This is
- handled by the Convolution method. If the center is not provided, it will be centered at 0 and
- fixed, which is typically what you want in QENS.
+ When called directly, returns zero everywhere except at the bin nearest to ``center``, where it
+ returns ``area / bin_width``. In convolutions it acts as an identity element (handled by the
+ ``Convolution`` class). area has unit = x_unit * y_unit; center has unit = x_unit.
+
+ If the center is not provided, it will be centered at 0 and fixed, which is typically what you
+ want in QENS.
Examples
--------
@@ -57,7 +59,8 @@ def __init__(
self,
center: Numeric | None = None,
area: Numeric = 1.0,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'DeltaFunction',
display_name: str | None = None,
unique_name: str | None = None,
@@ -68,35 +71,35 @@ def __init__(
Parameters
----------
center : Numeric | None, default=None
- Center of the delta function. If None, it will be centered at 0 and fixed.
+ Position of the delta function in x_unit. If None, defaults to 0 and the center
+ parameter is fixed.
area : Numeric, default=1.0
- Total area under the curve.
- unit : str | sc.Unit, default='meV'
- Unit of the parameters.
+ Integrated area (weight) of the delta function. Unit is ``x_unit * y_unit``.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis. center is stored in this unit. area_unit = x_unit * y_unit.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the y-axis (output).
name : str, default='DeltaFunction'
- Name of the component for indexing.
+ Name of the component.
display_name : str | None, default=None
- Display name of the component.
+ Display name of the component, shown when plotting. Falls back to *name* if None.
unique_name : str | None, default=None
- Unique name of the component. If None, a unique_name is automatically generated. By
- default, None.
+ Globally unique identifier. Auto-generated if None.
"""
- # Validate inputs and create Parameters if not given
super().__init__(
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
name=name,
display_name=display_name,
unique_name=unique_name,
)
- # These methods live in ValidationMixin
- area = self._create_area_parameter(area=area, name=name, unit=self._unit)
- center = self._create_center_parameter(
- center=center, name=name, fix_if_none=True, unit=self._unit
+ self._area = self._create_area_parameter(
+ area=area, name=name, x_unit=self.x_unit, y_unit=self.y_unit
+ )
+ self._center = self._create_center_parameter(
+ center=center, name=name, fix_if_none=True, x_unit=self.x_unit
)
-
- self._area = area
- self._center = center
@property
def area(self) -> Parameter:
@@ -106,27 +109,23 @@ def area(self) -> Parameter:
Returns
-------
Parameter
- The area parameter.
+ The area Parameter with unit ``x_unit * y_unit``.
"""
-
return self._area
@area.setter
def area(self, value: Numeric) -> None:
"""
- Set the value of the area parameter.
-
Parameters
----------
value : Numeric
- The new value for the area parameter.
+ New area value (in current area unit = x_unit * y_unit).
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
"""
-
if not isinstance(value, Numeric):
raise TypeError('area must be a number')
self._area.value = value
@@ -139,27 +138,24 @@ def center(self) -> Parameter:
Returns
-------
Parameter
- The center parameter.
+ The center Parameter with unit ``x_unit``.
"""
-
return self._center
@center.setter
def center(self, value: Numeric | None) -> None:
"""
- Set the center parameter value.
-
Parameters
----------
value : Numeric | None
- The new value for the center parameter. If None, defaults to 0 and is fixed.
+ New center value in x_unit. If None, the center is set to 0 and the parameter is
+ fixed.
Raises
------
TypeError
- If the value is not a number or None.
+ If *value* is not None and not a numeric type.
"""
-
if value is None:
value = 0.0
self._center.fixed = True
@@ -167,53 +163,87 @@ def center(self, value: Numeric | None) -> None:
raise TypeError('center must be a number')
self._center.value = value
- def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray:
+ def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray:
"""
- Evaluate the Delta function at the given x values.
+ Evaluate the Delta function at x_vals.
- The Delta function evaluates to zero everywhere, except at the center. Its numerical
- integral is equal to the area. It acts as an identity in convolutions.
+ Parameters in the model's own units are temporarily converted to eval_unit for the
+ computation.
Parameters
----------
- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- The x values at which to evaluate the Delta function.
+ x_vals : np.ndarray
+ Raw x values expressed in eval_unit. Assumed sorted for the bin-width computation.
+ eval_unit : str | None
+ The unit of x_vals.
Returns
-------
np.ndarray
- The evaluated Delta function at the given x values.
+ Zero everywhere, with a single non-zero bin nearest the center when center falls within
+ the x range.
+
+ Notes
+ -----
+ When ``center`` falls within the x range, the bin nearest to ``center`` receives ``area /
+ bin_width`` rather than zero. In convolutions, the DeltaFunction acts as an identity
+ element (handled by the Convolution class).
"""
+ center = self._resolve_param_value(self._center, eval_unit)
+ area = self._resolve_param_value(self._area, self._eval_area_unit(eval_unit))
- # x assumed sorted, 1D numpy array
- x = self._prepare_x_for_evaluate(x)
- model = np.zeros_like(x, dtype=float)
- center = self.center.value
- area = self.area.value
+ model = np.zeros_like(x_vals, dtype=float)
- if x.min() - EPSILON <= center <= x.max() + EPSILON:
+ if x_vals.min() - EPSILON <= center <= x_vals.max() + EPSILON:
# nearest index
- i = np.argmin(np.abs(x - center))
+ i = np.argmin(np.abs(x_vals - center))
# left half-width
- if i == 0: # noqa: SIM108
- left = x[1] - x[0] if x.size > 1 else 0.5
+ if i == 0:
+ left = x_vals[1] - x_vals[0] if x_vals.size > 1 else 0.5
else:
- left = x[i] - x[i - 1]
+ left = x_vals[i] - x_vals[i - 1]
# right half-width
- if i == x.size - 1: # noqa: SIM108
- right = x[-1] - x[-2] if x.size > 1 else 0.5
+ if i == x_vals.size - 1:
+ right = x_vals[-1] - x_vals[-2] if x_vals.size > 1 else 0.5
else:
- right = x[i + 1] - x[i]
+ right = x_vals[i + 1] - x_vals[i]
# effective bin width: half left + half right
bin_width = 0.5 * (left + right)
-
model[i] = area / bin_width
return model
+ def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None:
+ """
+ Convert x-axis parameters (center) and area to new_x_unit.
+
+ Parameters
+ ----------
+ new_x_unit : str | sc.Unit
+ Target x-axis unit. Must be dimensionally compatible with the current x_unit.
+ """
+ self._convert_x_unit_area_based(
+ new_x_unit=new_x_unit,
+ x_params=[self._center],
+ area_param=self._area,
+ )
+
+ def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None:
+ """
+ Convert the y-axis unit by rescaling the area parameter.
+
+ The area is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``.
+
+ Parameters
+ ----------
+ new_y_unit : str | sc.Unit
+ Target y-axis unit.
+ """
+ self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._area)
+
def __repr__(self) -> str:
"""
Return a string representation of the Delta function.
@@ -223,11 +253,9 @@ def __repr__(self) -> str:
str
A string representation of the Delta function.
"""
-
return (
- f'{self.__class__.__name__}('
- f'name={self.name!r}, display_name={self.display_name!r}, '
- f'unit={self.unit},\n'
- f' area={self.area},\n'
- f' center={self.center})'
+ f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, '
+ f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n'
+ f' area = {self.area},\n'
+ f' center = {self.center})'
)
diff --git a/src/easydynamics/sample_model/components/exponential.py b/src/easydynamics/sample_model/components/exponential.py
index 8b7e3241a..25b1048e9 100644
--- a/src/easydynamics/sample_model/components/exponential.py
+++ b/src/easydynamics/sample_model/components/exponential.py
@@ -16,11 +16,10 @@ class Exponential(CreateParametersMixin, ModelComponent):
r"""
Model of an exponential function.
- The intensity is given by
+ $$ I(x) = A e^{B (x-x_0)} $$
- $$ I(x) = A e^{B (x-x_0)}, $$
-
- where $A$ is the amplitude, $x_0$ is the center, and $B$ describes the rate of decay or growth.
+ where $A$ is the amplitude, $x_0$ is the center, and $B$ is the rate. amplitude has unit =
+ x_unit * y_unit; center has unit = x_unit; rate has unit = 1/x_unit.
Examples
--------
@@ -53,7 +52,8 @@ def __init__(
amplitude: Numeric = 1.0,
center: Numeric | None = None,
rate: Numeric = 1.0,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'Exponential',
display_name: str | None = None,
unique_name: str | None = None,
@@ -64,61 +64,59 @@ def __init__(
Parameters
----------
amplitude : Numeric, default=1.0
- Amplitude of the Exponential.
+ Pre-exponential factor A. Unit is ``x_unit * y_unit``.
center : Numeric | None, default=None
- Center of the Exponential. If None, the center is fixed at 0.
+ Reference point x_0 in x_unit. If None, defaults to 0 and the center parameter is
+ fixed.
rate : Numeric, default=1.0
- Decay or growth constant of the Exponential.
- unit : str | sc.Unit, default='meV'
- Unit of the parameters.
+ Exponential rate B in units of ``1/x_unit``.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis. center is stored in this unit; rate is stored in ``1/x_unit``.
+ amplitude_unit = x_unit * y_unit.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the y-axis (output).
name : str, default='Exponential'
- Name of the component for indexing.
+ Name of the component.
display_name : str | None, default=None
- Display name of the component.
+ Display name shown when plotting. Falls back to *name* if None.
unique_name : str | None, default=None
- Unique name of the component. If None, a unique_name is automatically generated. By
- default, None.
+ Globally unique identifier. Auto-generated if None.
Raises
------
TypeError
- If amplitude, center, or rate are not numbers or Parameters.
+ If *amplitude* or *rate* is not numeric.
ValueError
- If amplitude, center or rate are not finite numbers.
+ If *amplitude* or *rate* is not finite.
"""
- # Validate inputs and create Parameters if not given
super().__init__(
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
name=name,
display_name=display_name,
unique_name=unique_name,
)
- if not isinstance(amplitude, (Parameter, Numeric)):
- raise TypeError('amplitude must be a number or a Parameter.')
-
- if isinstance(amplitude, Numeric):
- if not np.isfinite(amplitude):
- raise ValueError('amplitude must be a finite number or a Parameter')
+ x_unit_str = str(x_unit) if isinstance(x_unit, sc.Unit) else x_unit
+ amplitude_unit = str(sc.Unit(x_unit_str) * sc.Unit(self.y_unit))
- amplitude = Parameter(name=name + ' amplitude', value=float(amplitude), unit=unit)
-
- center = self._create_center_parameter(
- center=center, name=name, fix_if_none=True, unit=self._unit
+ if not isinstance(amplitude, Numeric):
+ raise TypeError('amplitude must be a number.')
+ if not np.isfinite(amplitude):
+ raise ValueError('amplitude must be finite.')
+ self._amplitude = Parameter(
+ name=name + ' amplitude', value=float(amplitude), unit=amplitude_unit
)
- if not isinstance(rate, (Parameter, Numeric)):
- raise TypeError('rate must be a number or a Parameter.')
-
- if isinstance(rate, Numeric):
- if not np.isfinite(rate):
- raise ValueError('rate must be a finite number or a Parameter')
-
- rate = Parameter(name=name + ' rate', value=float(rate), unit='1/' + str(unit))
+ self._center = self._create_center_parameter(
+ center=center, name=name, fix_if_none=True, x_unit=self.x_unit
+ )
- self._amplitude = amplitude
- self._center = center
- self._rate = rate
+ if not isinstance(rate, Numeric):
+ raise TypeError('rate must be a number.')
+ if not np.isfinite(rate):
+ raise ValueError('rate must be finite.')
+ self._rate = Parameter(name=name + ' rate', value=float(rate), unit='1/' + x_unit_str)
@property
def amplitude(self) -> Parameter:
@@ -128,27 +126,23 @@ def amplitude(self) -> Parameter:
Returns
-------
Parameter
- The amplitude parameter.
+ The amplitude Parameter with unit ``x_unit * y_unit``.
"""
-
return self._amplitude
@amplitude.setter
def amplitude(self, value: Numeric) -> None:
"""
- Set the value of the amplitude parameter.
-
Parameters
----------
value : Numeric
- The new value for the amplitude parameter.
+ New amplitude value (in current amplitude unit = x_unit * y_unit).
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
"""
-
if not isinstance(value, Numeric):
raise TypeError('amplitude must be a number')
self._amplitude.value = value
@@ -161,31 +155,27 @@ def center(self) -> Parameter:
Returns
-------
Parameter
- The center parameter.
+ The center (x_0) Parameter with unit ``x_unit``.
"""
-
return self._center
@center.setter
def center(self, value: Numeric | None) -> None:
"""
- Set the center parameter value.
-
Parameters
----------
value : Numeric | None
- The new value for the center parameter.
+ New center value in x_unit. If None, the center is set to 0 and the parameter is
+ fixed.
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not None and not a numeric type.
"""
-
if value is None:
value = 0.0
self._center.fixed = True
-
if not isinstance(value, Numeric):
raise TypeError('center must be a number')
self._center.value = value
@@ -198,94 +188,84 @@ def rate(self) -> Parameter:
Returns
-------
Parameter
- The rate parameter.
+ The exponential rate (B) Parameter with unit ``1/x_unit``.
"""
return self._rate
@rate.setter
def rate(self, value: Numeric) -> None:
"""
- Set the rate parameter value.
-
Parameters
----------
value : Numeric
- The new value for the rate parameter.
+ New exponential rate in ``1/x_unit``.
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
"""
if not isinstance(value, Numeric):
raise TypeError('rate must be a number')
-
self._rate.value = value
- def evaluate(
- self,
- x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray,
- ) -> np.ndarray:
+ def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray:
r"""
- Evaluate the Exponential at the given x values.
+ Evaluate the Exponential at x_vals.
- If x is a scipp Variable, the unit of the Exponential will be converted to match x. The
- intensity is given by $$ I(x) = A \exp\left( r (x - x_0) \right) $$
-
- where $A$ is the amplitude, $x_0$ is the center, and $r$ is the rate.
+ Parameters in the model's own units are temporarily converted to eval_unit for the
+ computation.
Parameters
----------
- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- The x values at which to evaluate the Exponential.
+ x_vals : np.ndarray
+ Raw x values expressed in eval_unit.
+ eval_unit : str | None
+ The unit of x_vals.
Returns
-------
np.ndarray
- The intensity of the Exponential at the given x values.
+ Evaluated exponential values at x_vals.
"""
+ eval_rate_unit = None if eval_unit is None else '1/' + str(eval_unit)
- x = self._prepare_x_for_evaluate(x)
- exponent = self.rate.value * (x - self.center.value)
+ center = self._resolve_param_value(self._center, eval_unit)
+ rate = self._resolve_param_value(self._rate, eval_rate_unit)
+ amplitude = self._resolve_param_value(self._amplitude, self._eval_area_unit(eval_unit))
- return self.amplitude.value * np.exp(exponent)
+ exponent = rate * (x_vals - center)
+ return amplitude * np.exp(exponent)
- def convert_unit(self, unit: str | sc.Unit) -> None:
+ def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None:
"""
- Convert the unit of the Parameters in the component.
+ Convert center and amplitude to new_x_unit, rate to 1/new_x_unit.
Parameters
----------
- unit : str | sc.Unit
- The new unit to convert to.
+ new_x_unit : str | sc.Unit
+ Target x-axis unit. Must be dimensionally compatible with the current x_unit. The
+ rate unit is set to ``1/new_x_unit``.
+ """
+ self._convert_x_unit_area_based(
+ new_x_unit=new_x_unit,
+ x_params=[self._center],
+ area_param=self._amplitude,
+ inverse_params=[self._rate],
+ )
- Raises
- ------
- TypeError
- If unit is not a string or sc.Unit.
- Exception
- If conversion fails for any parameter.
+ def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None:
"""
+ Convert the y-axis unit by rescaling the amplitude parameter.
+
+ The amplitude is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``.
- if not isinstance(unit, (str, sc.Unit)):
- raise TypeError('unit must be a string or sc.Unit')
-
- old_unit = self._unit
- pars = [self.amplitude, self.center]
- try:
- for p in pars:
- p.convert_unit(unit)
- self.rate.convert_unit('1/' + str(unit))
- self._unit = unit
- except Exception as e:
- # Attempt to rollback on failure
- try:
- for p in pars:
- p.convert_unit(old_unit)
- self.rate.convert_unit('1/' + str(old_unit))
- except Exception: # noqa: S110
- pass # Best effort rollback
- raise e
+ Parameters
+ ----------
+ new_y_unit : str | sc.Unit
+ Target y-axis unit.
+ """
+ self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._amplitude)
def __repr__(self) -> str:
"""
@@ -296,12 +276,10 @@ def __repr__(self) -> str:
str
A string representation of the Exponential.
"""
-
return (
- f'{self.__class__.__name__}('
- f'name={self.name!r}, display_name={self.display_name!r}, '
- f'unit={self._unit},\n'
- f' amplitude={self.amplitude},\n'
- f' center={self.center},\n'
- f' rate={self.rate})'
+ f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, '
+ f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n '
+ f' amplitude = {self.amplitude},\n '
+ f' center = {self.center},\n '
+ f' rate = {self.rate})'
)
diff --git a/src/easydynamics/sample_model/components/expression_component.py b/src/easydynamics/sample_model/components/expression_component.py
index feda61c6f..3c85e57e9 100644
--- a/src/easydynamics/sample_model/components/expression_component.py
+++ b/src/easydynamics/sample_model/components/expression_component.py
@@ -3,19 +3,24 @@
from __future__ import annotations
+import warnings
+from copy import copy
from typing import TYPE_CHECKING
from typing import ClassVar
+import scipp as sc
import sympy as sp
+from easyscience.variable import DescriptorNumber
from easyscience.variable import Parameter
from scipy.special import erf
-from easydynamics.sample_model.components.model_component import ModelComponent
-from easydynamics.utils.utils import Numeric
-
if TYPE_CHECKING:
import numpy as np
- import scipp as sc
+
+from easydynamics.sample_model.components.model_component import ModelComponent
+from easydynamics.utils.utils import Numeric
+from easydynamics.utils.utils import hbar
+from easydynamics.utils.utils import kb
class ExpressionComponent(ModelComponent):
@@ -40,7 +45,7 @@ class ExpressionComponent(ModelComponent):
expr = sm.ExpressionComponent(
'A * exp(-(x - x0)**2 / (2*sigma**2))',
parameters={'A': 10, 'x0': 0, 'sigma': 1},
- unit='meV',
+ x_unit='meV',
display_name='Gaussian Peak',
)
x = np.linspace(-3, 3, 100)
@@ -54,18 +59,45 @@ class ExpressionComponent(ModelComponent):
expr.A = 5
expr.sigma = 0.5
```
+
+ **Giving parameters units**
+
+ Parameters are dimensionless by default. Units can be given per parameter at construction, or
+ relabelled later with ``set_unit`` (the numeric value is kept as-is). When units are in use,
+ the unit of the evaluated expression is derived from the parameter units and x_unit (see
+ ``output_unit``), and a warning is issued if it does not match y_unit:
+ ```python
+ expr = sm.ExpressionComponent(
+ 'A * exp(-(x - x0)**2 / (2*sigma**2))',
+ parameters={'A': 10, 'x0': 0, 'sigma': 1},
+ parameter_units={'A': '1/meV', 'x0': 'meV', 'sigma': 'meV'},
+ y_unit='1/meV',
+ )
+ expr.set_unit('A', '1/meV')
+ ```
+
+ **Physical constants**
+
+ The symbols ``hbar`` (in meV*ps) and ``kb`` (in meV/K) are provided automatically as read-only
+ constants (DescriptorNumbers) when they appear in the expression:
+ ```python
+ boltzmann = sm.ExpressionComponent(
+ 'exp(-x / (kb * T))',
+ parameters={'T': 300.0},
+ parameter_units={'T': 'K'},
+ )
+ ```
+ Use e.g. ``boltzmann.kb.convert_unit('eV/K')`` to work in another unit (this rescales the
+ value, unlike ``set_unit``).
"""
- # -------------------------
- # Allowed symbolic functions
- # -------------------------
_ALLOWED_FUNCS: ClassVar[dict[str, object]] = {
- # Exponentials & logs
+ # exponential / logarithmic
'exp': sp.exp,
'log': sp.log,
'ln': sp.log,
'sqrt': sp.sqrt,
- # Trigonometric
+ # trigonometric
'sin': sp.sin,
'cos': sp.cos,
'tan': sp.tan,
@@ -73,25 +105,23 @@ class ExpressionComponent(ModelComponent):
'cot': sp.cot,
'sec': sp.sec,
'csc': sp.csc,
+ # inverse trigonometric
'asin': sp.asin,
'acos': sp.acos,
'atan': sp.atan,
- # Hyperbolic
+ # hyperbolic
'sinh': sp.sinh,
'cosh': sp.cosh,
'tanh': sp.tanh,
- # Misc
+ # rounding / sign
'abs': sp.Abs,
'sign': sp.sign,
'floor': sp.floor,
'ceil': sp.ceiling,
- # Special functions
+ # special functions
'erf': sp.erf,
}
- # -------------------------
- # Allowed constants
- # -------------------------
_ALLOWED_CONSTANTS: ClassVar[dict[str, object]] = {
'pi': sp.pi,
'E': sp.E,
@@ -99,11 +129,30 @@ class ExpressionComponent(ModelComponent):
_RESERVED_NAMES: ClassVar[dict[str, object]] = {'x'}
+ # Physical constants provided automatically as read-only DescriptorNumbers when their
+ # symbol appears in the expression: name -> (source constant from utils, default unit).
+ _PHYSICAL_CONSTANTS: ClassVar[dict[str, tuple[DescriptorNumber, str]]] = {
+ 'hbar': (hbar, 'meV*ps'),
+ 'kb': (kb, 'meV/K'),
+ }
+
+ # Sympy functions that preserve the unit of their argument; all other allowed functions
+ # require a dimensionless argument and return a dimensionless result.
+ _UNIT_PRESERVING_FUNCS: ClassVar[tuple] = (sp.Abs, sp.floor, sp.ceiling)
+ # Functions that accept any unit but return a dimensionless result.
+ _DIMENSIONLESS_RESULT_FUNCS: ClassVar[tuple] = (sp.sign,)
+
+ # parameter_units is applied at construction only; the resulting units are captured in
+ # the serialized Parameter dicts, so the argument is skipped during serialization.
+ _REDIRECT: ClassVar[dict[str, object]] = {'parameter_units': None}
+
def __init__(
self,
expression: str,
parameters: dict[str, Numeric] | None = None,
- unit: str | sc.Unit = 'meV',
+ parameter_units: dict[str, str | sc.Unit] | None = None,
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'Expression',
display_name: str | None = None,
unique_name: str | None = None,
@@ -114,26 +163,43 @@ def __init__(
Parameters
----------
expression : str
- The symbolic expression as a string. Must contain 'x' as the independent variable.
+ The symbolic expression as a string. Must contain 'x' as the independent variable. The
+ symbols ``hbar`` and ``kb`` are provided automatically as read-only physical constants
+ (in meV*ps and meV/K respectively) unless overridden via *parameters*.
parameters : dict[str, Numeric] | None, default=None
- Dictionary of parameter names and their initial values.
- unit : str | sc.Unit, default='meV'
- Unit of the output.
+ Dictionary of parameter names and their initial values. Parameters that are not given a
+ unit are dimensionless.
+ parameter_units : dict[str, str | sc.Unit] | None, default=None
+ Optional units per parameter name. Each entry sets the unit of the named parameter
+ without rescaling its value (see :meth:`set_unit`), and takes precedence over the unit
+ of a Parameter instance given in *parameters*. When units are in use, a warning is
+ issued if the expression's output unit does not match y_unit.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the y-axis (output).
name : str, default='Expression'
- Name of the component for indexing.
+ Name of the component.
display_name : str | None, default=None
- Display name for the component.
+ Display name shown when plotting. Falls back to *name* if None.
unique_name : str | None, default=None
Unique name for the component.
Raises
------
ValueError
- If the expression is invalid or does not contain 'x'.
+ If the expression is invalid or does not contain 'x', or if parameter_units names a
+ parameter that is not in the expression.
TypeError
- If any parameter value is not numeric.
+ If any parameter value is not numeric, or if parameter_units is not a dictionary.
"""
- super().__init__(unit=unit, name=name, display_name=display_name, unique_name=unique_name)
+ super().__init__(
+ x_unit=x_unit,
+ y_unit=y_unit,
+ name=name,
+ display_name=display_name,
+ unique_name=unique_name,
+ )
if 'np.' in expression:
raise ValueError(
@@ -158,15 +224,14 @@ def __init__(
if 'x' not in symbol_names:
raise ValueError("Expression must contain 'x' as independent variable")
-
# Reject unknown functions early so invalid expressions fail at init,
# not later during numerical evaluation.
allowed_function_names = set(self._ALLOWED_FUNCS) | {
func.__name__ for func in self._ALLOWED_FUNCS.values()
}
-
# Walk all function-call nodes in the parsed expression (e.g. sin(x), foo(x)).
# Keep only function names that are not in our allowlist.
+
unknown_function_names: set[str] = set()
function_atoms = self._expr.atoms(sp.Function)
for function_atom in function_atoms:
@@ -175,12 +240,10 @@ def __init__(
unknown_function_names.add(function_name)
unknown_functions = sorted(unknown_function_names)
-
if unknown_functions:
raise ValueError(
f'Unsupported function(s) in expression: {", ".join(unknown_functions)}'
)
-
# Create parameters
if parameters is not None and not isinstance(parameters, dict):
raise TypeError(
@@ -196,34 +259,65 @@ def __init__(
)
parameters = parameters or {}
self._parameters: dict[str, Parameter] = {}
+ self._constants: dict[str, DescriptorNumber] = {}
self._symbol_names = symbol_names
for name in self._symbol_names:
if name in self._RESERVED_NAMES:
continue
+ # Physical constants are provided automatically, unless the user explicitly
+ # supplies a parameter with the same name.
+ if name in self._PHYSICAL_CONSTANTS and name not in parameters:
+ self._constants[name] = self._create_physical_constant(name)
+ continue
+
value = parameters.get(name, 1.0)
if isinstance(value, Parameter):
self._parameters[name] = value
-
elif isinstance(value, dict) and value.get('@class') == 'Parameter':
self._parameters[name] = Parameter.from_dict(value)
else:
self._parameters[name] = Parameter(
name=name,
value=value,
- unit=self._unit,
+ unit='dimensionless',
+ )
+
+ if parameter_units is not None:
+ if not isinstance(parameter_units, dict):
+ raise TypeError(
+ f'parameter_units must be None or a dictionary, '
+ f'got {type(parameter_units).__name__}'
+ )
+ constant_names = sorted(set(parameter_units) & set(self._constants))
+ if constant_names:
+ raise ValueError(
+ f'parameter_units given for physical constant(s): '
+ f'{", ".join(constant_names)}. Use e.g. '
+ f'expr.{constant_names[0]}.convert_unit(...) to change the unit of a '
+ f'constant (this rescales its value).'
)
+ unknown_names = sorted(set(parameter_units) - set(self._parameters))
+ if unknown_names:
+ raise ValueError(
+ f'parameter_units given for unknown parameter(s): {", ".join(unknown_names)}'
+ )
+ for parameter_name, parameter_unit in parameter_units.items():
+ # Relabel without the per-call output-unit warning; the consistency check
+ # runs once below, after all units are applied.
+ self._relabel_parameter_unit(parameter_name, parameter_unit)
# Create numerical function
ordered_symbols = [sp.Symbol(name) for name in self._symbol_names]
-
self._func = sp.lambdify(
ordered_symbols,
self._expr,
modules=[{'erf': erf}, 'numpy'],
)
+ self._warn_if_output_unit_mismatch()
+
# -------------------------
# Properties
# -------------------------
@@ -255,31 +349,48 @@ def expression(self, _new_expr: str) -> None:
AttributeError
Always raised to prevent changing the expression.
"""
+
raise AttributeError('Expression cannot be changed after initialization')
- def evaluate(
- self,
- x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray,
- ) -> np.ndarray:
+ def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray:
"""
- Evaluate the expression for given x values.
+ Evaluate the expression for the given x values.
+
+ Unit conversion of parameters is not supported for ExpressionComponent. If x_vals is
+ expressed in a different unit than x_unit, a warning is issued and the values are used
+ as-is.
Parameters
----------
- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- Input values for the independent variable.
+ x_vals : np.ndarray
+ Raw x values expressed in eval_unit.
+ eval_unit : str | None
+ The unit of x_vals.
Returns
-------
np.ndarray
Evaluated results.
"""
- x = self._prepare_x_for_evaluate(x)
+ if (
+ eval_unit is not None
+ and self.x_unit is not None
+ and sc.Unit(eval_unit) != sc.Unit(self.x_unit)
+ ):
+ warnings.warn(
+ f'Input x has unit {eval_unit} but {self.__class__.__name__} has '
+ f'x_unit {self.x_unit}. ExpressionComponent cannot auto-convert parameters. '
+ 'x values are used as-is.',
+ UserWarning,
+ stacklevel=3,
+ )
args = []
for name in self._symbol_names:
if name == 'x':
- args.append(x)
+ args.append(x_vals)
+ elif name in self._constants:
+ args.append(self._constants[name].value)
else:
args.append(self._parameters[name].value)
@@ -296,11 +407,302 @@ def get_all_variables(self) -> list[Parameter]:
"""
return list(self._parameters.values())
- def convert_unit(self, _new_unit: str | sc.Unit) -> None:
+ def set_unit(self, name: str, unit: str | sc.Unit) -> None:
+ """
+ Set the unit of a parameter without rescaling its value.
+
+ This relabels the unit: the numeric value, bounds, and variance are kept as-is. Use
+ ``Parameter.convert_unit`` instead to rescale a value into a compatible unit. Issues a
+ warning if the resulting output unit of the expression no longer matches y_unit. Raises the
+ same exceptions as :meth:`_relabel_parameter_unit` on invalid input.
+
+ Parameters
+ ----------
+ name : str
+ Name of the parameter whose unit to set.
+ unit : str | sc.Unit
+ The new unit.
"""
- Convert the unit of the expression.
+ self._relabel_parameter_unit(name, unit)
+ self._warn_if_output_unit_mismatch()
- Unit conversion is not implemented for ExpressionComponent.
+ def _relabel_parameter_unit(self, name: str, unit: str | sc.Unit) -> None:
+ """
+ Validate and apply a unit relabel to a parameter (see set_unit).
+
+ Parameters
+ ----------
+ name : str
+ Name of the parameter whose unit to set.
+ unit : str | sc.Unit
+ The new unit.
+
+ Raises
+ ------
+ TypeError
+ If name is not a string or unit is not a string or sc.Unit.
+ KeyError
+ If no parameter with the given name exists.
+ ValueError
+ If unit is not a valid scipp unit.
+ AttributeError
+ If the parameter is a physical constant or a dependent parameter.
+ """
+ if not isinstance(name, str):
+ raise TypeError(f'name must be a string, got {type(name).__name__}')
+ if '_constants' in self.__dict__ and name in self._constants:
+ raise AttributeError(
+ f"'{name}' is a physical constant. Use expr.{name}.convert_unit(...) to change "
+ f'its unit (this rescales its value).'
+ )
+ if '_parameters' not in self.__dict__ or name not in self._parameters:
+ raise KeyError(f"No parameter named '{name}' in this {self.__class__.__name__}.")
+ if not isinstance(unit, (str, sc.Unit)):
+ raise TypeError(f'unit must be a string or sc.Unit, got {type(unit).__name__}')
+ try:
+ new_unit = sc.Unit(str(unit))
+ except sc.UnitError as e:
+ raise ValueError(f"'{unit}' is not a valid scipp unit.") from e
+
+ param = self._parameters[name]
+ if not param.independent:
+ raise AttributeError(
+ f"Cannot set the unit of dependent parameter '{name}'; its unit is controlled "
+ f'by the dependency expression.'
+ )
+
+ # Parameter.unit is read-only and convert_unit rescales the value, so a pure relabel
+ # has to swap the underlying scipp scalars (value and bounds) directly.
+ param._scalar = sc.scalar( # noqa: SLF001
+ param.value,
+ unit=new_unit,
+ variance=param._scalar.variance, # noqa: SLF001
+ )
+ param._min = sc.scalar(param._min.value, unit=new_unit) # noqa: SLF001
+ param._max = sc.scalar(param._max.value, unit=new_unit) # noqa: SLF001
+
+ @classmethod
+ def _create_physical_constant(cls, name: str) -> DescriptorNumber:
+ """
+ Create a DescriptorNumber for a supported physical constant.
+
+ The constant is a per-instance copy of the shared constant in ``easydynamics.utils``,
+ converted to its default unit, so converting its unit on one component does not affect
+ others.
+
+ Parameters
+ ----------
+ name : str
+ The constant's symbol name (a key of ``_PHYSICAL_CONSTANTS``).
+
+ Returns
+ -------
+ DescriptorNumber
+ The constant's value in its default unit.
+ """
+ source, default_unit = cls._PHYSICAL_CONSTANTS[name]
+ constant = copy(source)
+ constant.convert_unit(default_unit)
+ return constant
+
+ @property
+ def constants(self) -> dict[str, DescriptorNumber]:
+ """
+ Get the physical constants used by the expression.
+
+ Returns
+ -------
+ dict[str, DescriptorNumber]
+ The automatically provided constants (e.g. ``hbar``, ``kb``) keyed by symbol name.
+ """
+ return dict(self._constants)
+
+ @property
+ def output_unit(self) -> str:
+ """
+ Get the unit of the evaluated expression, derived from x_unit and the parameter units.
+
+ The unit is propagated through the expression tree: addition requires compatible units,
+ multiplication and powers combine units, and functions like ``exp`` or ``sin`` require a
+ dimensionless argument. Propagation raises ``sc.UnitError`` if the expression is not
+ unit-consistent (e.g. adding meV to a dimensionless quantity, or taking ``exp`` of a
+ quantity with a unit).
+
+ Returns
+ -------
+ str
+ The unit of the evaluated expression.
+ """
+ return str(self._canonical_unit(self._propagate_unit(self._expr)))
+
+ def _canonical_unit(self, unit: sc.Unit) -> sc.Unit:
+ """
+ Replace a unit by an equal, cleaner-printing candidate where possible.
+
+ Unit algebra accumulates tiny floating-point scale factors (e.g. ``meV * 1/meV`` prints as
+ ``0.999999999999999889`` rather than ``dimensionless``); scipp's unit equality still
+ matches, so map such units onto the common candidates for readable output.
+
+ Parameters
+ ----------
+ unit : sc.Unit
+ The unit to canonicalize.
+
+ Returns
+ -------
+ sc.Unit
+ An equal unit with a cleaner string representation, or the input unit unchanged.
+ """
+ candidate_strings = ['dimensionless', self.y_unit, self.x_unit]
+ for candidate_string in candidate_strings:
+ if candidate_string is None:
+ continue
+ candidate = sc.Unit(candidate_string)
+ if unit == candidate:
+ return candidate
+ return unit
+
+ def _propagate_unit(self, node: sp.Basic) -> sc.Unit:
+ """
+ Recursively derive the unit of a sympy expression node.
+
+ Parameters
+ ----------
+ node : sp.Basic
+ The sympy expression node to derive the unit of.
+
+ Raises
+ ------
+ sc.UnitError
+ If the node mixes incompatible units, applies a function that requires a dimensionless
+ argument to a quantity with a unit, or raises a quantity with a unit to a symbolic or
+ unsupported power.
+
+ Returns
+ -------
+ sc.Unit
+ The unit of the node.
+ """
+ dimensionless = sc.Unit('dimensionless')
+
+ if isinstance(node, sp.Symbol):
+ name = str(node)
+ if name == 'x':
+ return sc.Unit(self.x_unit) if self.x_unit is not None else dimensionless
+ if name in self._constants:
+ return sc.Unit(str(self._constants[name].unit))
+ return sc.Unit(str(self._parameters[name].unit))
+
+ if node.is_number:
+ return dimensionless
+
+ if isinstance(node, sp.Add):
+ units = [self._propagate_unit(argument) for argument in node.args]
+ for unit in units[1:]:
+ try:
+ sc.to_unit(sc.scalar(1.0, unit=unit), units[0])
+ except sc.UnitError as e:
+ raise sc.UnitError(
+ f'The expression adds quantities with incompatible units '
+ f'{units[0]} and {unit}.'
+ ) from e
+ return units[0]
+
+ if isinstance(node, sp.Mul):
+ result = dimensionless
+ for argument in node.args:
+ result = result * self._propagate_unit(argument)
+ return result
+
+ if isinstance(node, sp.Pow):
+ base_unit = self._propagate_unit(node.base)
+ exponent_unit = self._propagate_unit(node.exp)
+ if exponent_unit != dimensionless:
+ raise sc.UnitError(f'An exponent must be dimensionless, got {exponent_unit}.')
+ if base_unit == dimensionless:
+ return dimensionless
+ if not node.exp.is_number:
+ raise sc.UnitError(
+ f'Cannot determine units: quantity with unit {base_unit} is raised to a '
+ f'symbolic power.'
+ )
+ exponent = float(node.exp)
+ if exponent == int(exponent):
+ return base_unit ** int(exponent)
+ if 2 * exponent == int(2 * exponent):
+ # Half-integer power: the square root of an integer power.
+ return sc.sqrt(sc.scalar(1.0, unit=base_unit ** int(2 * exponent))).unit
+ raise sc.UnitError(
+ f'Cannot determine units: quantity with unit {base_unit} is raised to the '
+ f'power {exponent}.'
+ )
+
+ if isinstance(node, self._UNIT_PRESERVING_FUNCS):
+ return self._propagate_unit(node.args[0])
+
+ if isinstance(node, self._DIMENSIONLESS_RESULT_FUNCS):
+ self._propagate_unit(node.args[0])
+ return dimensionless
+
+ if isinstance(node, sp.Function):
+ for argument in node.args:
+ argument_unit = self._propagate_unit(argument)
+ try:
+ sc.to_unit(sc.scalar(1.0, unit=argument_unit), dimensionless)
+ except sc.UnitError as e:
+ raise sc.UnitError(
+ f'{node.func.__name__} requires a dimensionless argument, '
+ f'got {argument_unit}.'
+ ) from e
+ return dimensionless
+
+ raise sc.UnitError(
+ f'Cannot determine units for expression node {node} of type {type(node).__name__}.'
+ )
+
+ def _warn_if_output_unit_mismatch(self) -> None:
+ """
+ Warn if the expression's output unit does not match y_unit.
+
+ The check only runs when units are in use, i.e. when the expression uses physical constants
+ or any parameter has a unit other than dimensionless. Unit-agnostic expressions (all
+ parameters dimensionless) stay silent.
+ """
+ units_in_use = bool(self._constants) or any(
+ str(parameter.unit) != 'dimensionless' for parameter in self._parameters.values()
+ )
+ if not units_in_use:
+ return
+
+ try:
+ output_unit = sc.Unit(self.output_unit)
+ except sc.UnitError as e:
+ warnings.warn(
+ f'The expression is not unit-consistent: {e}',
+ UserWarning,
+ stacklevel=3,
+ )
+ return
+
+ y_unit = sc.Unit(self.y_unit) if self.y_unit is not None else sc.Unit('dimensionless')
+ if output_unit != y_unit:
+ warnings.warn(
+ f'The expression evaluates to unit {output_unit}, which does not match '
+ f'y_unit {y_unit}. The evaluated values are labelled with y_unit; adjust the '
+ f'parameter units or y_unit to make them consistent.',
+ UserWarning,
+ stacklevel=3,
+ )
+
+ def convert_x_unit(self, _new_unit: str | sc.Unit) -> None:
+ """
+ Convert the x-axis unit of the expression.
+
+ Unit conversion is not implemented for ExpressionComponent. Should it ever be needed, the
+ viable path is dimensional analysis on the parameter units: for each parameter, determine
+ the power n of the x-dimension in its unit and rescale its value by the x-unit conversion
+ factor to the power n (the generalization of Polynomial's power-law rescaling). This only
+ works when x_unit has a single unambiguous dimension.
Parameters
----------
@@ -312,21 +714,39 @@ def convert_unit(self, _new_unit: str | sc.Unit) -> None:
NotImplementedError
Always raised to indicate unit conversion is not supported.
"""
+ raise NotImplementedError('Unit conversion is not implemented for ExpressionComponent')
+
+ def convert_y_unit(self, _new_unit: str | sc.Unit) -> None:
+ """
+ Convert the y-axis unit of the expression.
+
+ Unit conversion is not implemented for ExpressionComponent. See convert_x_unit for the
+ approach that would make it possible.
+
+ Parameters
+ ----------
+ _new_unit : str | sc.Unit
+ The new unit to convert to (ignored).
+ Raises
+ ------
+ NotImplementedError
+ Always raised to indicate unit conversion is not supported.
+ """
raise NotImplementedError('Unit conversion is not implemented for ExpressionComponent')
# -------------------------
# dunder methods
# -------------------------
- def __getattr__(self, name: str) -> Parameter:
+ def __getattr__(self, name: str) -> Parameter | DescriptorNumber:
"""
- Allow access to parameters as attributes.
+ Allow access to parameters and physical constants as attributes.
Parameters
----------
name : str
- Name of the parameter to access.
+ Name of the parameter or constant to access.
Raises
------
@@ -335,11 +755,13 @@ def __getattr__(self, name: str) -> Parameter:
Returns
-------
- Parameter
- The parameter with the given name.
+ Parameter | DescriptorNumber
+ The parameter or constant with the given name.
"""
if '_parameters' in self.__dict__ and name in self._parameters:
return self._parameters[name]
+ if '_constants' in self.__dict__ and name in self._constants:
+ return self._constants[name]
raise AttributeError(f"{self.__class__.__name__} has no attribute '{name}'")
def __setattr__(self, name: str, value: Numeric) -> None:
@@ -355,30 +777,31 @@ def __setattr__(self, name: str, value: Numeric) -> None:
Raises
------
+ AttributeError
+ If the name refers to a physical constant.
TypeError
If the value is not numeric.
"""
+ if '_constants' in self.__dict__ and name in self._constants:
+ raise AttributeError(f"'{name}' is a physical constant and cannot be set.")
if '_parameters' in self.__dict__ and name in self._parameters:
param = self._parameters[name]
-
if not isinstance(value, Numeric):
raise TypeError(f'{name} must be numeric')
-
param.value = value
else:
- # For other attributes, use default behavior
super().__setattr__(name, value)
def __dir__(self) -> list[str]:
"""
- Include parameter names in dir() output for better IDE support.
+ Include parameter and constant names in dir() output for better IDE support.
Returns
-------
list[str]
- List of attribute names, including parameters.
+ List of attribute names, including parameters and constants.
"""
- return super().__dir__() + list(self._parameters.keys())
+ return super().__dir__() + list(self._parameters.keys()) + list(self._constants.keys())
def __repr__(self) -> str:
"""
@@ -390,10 +813,15 @@ def __repr__(self) -> str:
String representation of the ExpressionComponent.
"""
param_str = ', '.join(f'{k}={v.value}' for k, v in self._parameters.items())
+ constants_str = ''
+ if self._constants:
+ constants_str = ',\n constants={ ' + ', '.join(
+ f'{k}={v.value} {v.unit}' for k, v in self._constants.items()
+ )
+ constants_str += ' }'
return (
- f'{self.__class__.__name__}('
- f'name={self.name!r}, display_name={self.display_name!r}, '
- f'unit={self._unit},\n'
- f' expr={self._expression_str!r},\n'
- f' parameters={{{param_str}}})'
+ f'{self.__class__.__name__}(name={self.name}, display_name={self.display_name}, '
+ f'x_unit={self.x_unit}, y_unit={self.y_unit},\n'
+ f" expr='{self._expression_str}',\n"
+ f' parameters={{ {param_str} }}{constants_str} )'
)
diff --git a/src/easydynamics/sample_model/components/gaussian.py b/src/easydynamics/sample_model/components/gaussian.py
index 6a1805e13..364e89aba 100644
--- a/src/easydynamics/sample_model/components/gaussian.py
+++ b/src/easydynamics/sample_model/components/gaussian.py
@@ -20,12 +20,11 @@ class Gaussian(CreateParametersMixin, ModelComponent):
r"""
Model of a Gaussian function.
- The intensity is given by
-
$$ I(x) = \frac{A}{\sigma \sqrt{2\pi}} \exp\left( -\frac{1}{2} \left(\frac{x -
x_0}{\sigma}\right)^2 \right) $$
- where $A$ is the area, $x_0$ is the center, and $\sigma$ is the width.
+ where $A$ is the area, $x_0$ is the center, and $\sigma$ is the width. area has unit = x_unit *
+ y_unit; center and width have unit = x_unit.
If the center is not provided, it will be centered at 0 and fixed, which is typically what you
want in QENS.
@@ -62,7 +61,8 @@ def __init__(
area: Numeric = 1.0,
center: Numeric | None = None,
width: Numeric = 1.0,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'Gaussian',
display_name: str | None = None,
unique_name: str | None = None,
@@ -73,39 +73,38 @@ def __init__(
Parameters
----------
area : Numeric, default=1.0
- Area of the Gaussian.
+ Integrated area under the Gaussian. Unit is ``x_unit * y_unit``.
center : Numeric | None, default=None
- Center of the Gaussian. If None, defaults to 0 and is fixed.
+ Peak position in x_unit. If None, defaults to 0 and the center parameter is fixed.
width : Numeric, default=1.0
- Standard deviation.
- unit : str | sc.Unit, default='meV'
- Unit of the parameters.
+ Standard deviation (sigma) in x_unit. Must be strictly positive.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis. center and width are stored in this unit. area_unit = x_unit *
+ y_unit.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the y-axis (output).
name : str, default='Gaussian'
- Name of the component for indexing.
- display_name : str | None, default=None
Name of the component.
+ display_name : str | None, default=None
+ Display name shown when plotting. Falls back to *name* if None.
unique_name : str | None, default=None
- Unique name of the component. if None, a unique_name is automatically generated. By
- default, None.
+ Globally unique identifier. Auto-generated if None.
"""
- # Validate inputs and create Parameters if not given
super().__init__(
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
name=name,
display_name=display_name,
unique_name=unique_name,
)
- # These methods live in ValidationMixin
- area = self._create_area_parameter(area=area, name=name, unit=self._unit)
- center = self._create_center_parameter(
- center=center, name=name, fix_if_none=True, unit=self._unit
+ self._area = self._create_area_parameter(
+ area=area, name=name, x_unit=self.x_unit, y_unit=self.y_unit
)
- width = self._create_width_parameter(width=width, name=name, unit=self._unit)
-
- self._area = area
- self._center = center
- self._width = width
+ self._center = self._create_center_parameter(
+ center=center, name=name, fix_if_none=True, x_unit=self.x_unit
+ )
+ self._width = self._create_width_parameter(width=width, name=name, x_unit=self.x_unit)
@property
def area(self) -> Parameter:
@@ -115,27 +114,23 @@ def area(self) -> Parameter:
Returns
-------
Parameter
- The area parameter.
+ The area Parameter with unit ``x_unit * y_unit``.
"""
-
return self._area
@area.setter
def area(self, value: Numeric) -> None:
"""
- Set the value of the area parameter.
-
Parameters
----------
value : Numeric
- The new value for the area parameter.
+ New area value (in current area unit = x_unit * y_unit).
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
"""
-
if not isinstance(value, Numeric):
raise TypeError('area must be a number')
self._area.value = value
@@ -148,27 +143,24 @@ def center(self) -> Parameter:
Returns
-------
Parameter
- The center parameter.
+ The center Parameter with unit ``x_unit``.
"""
-
return self._center
@center.setter
def center(self, value: Numeric | None) -> None:
"""
- Set the center parameter value.
-
Parameters
----------
value : Numeric | None
- The new value for the center parameter. If None, defaults to 0 and is fixed.
+ New center value in x_unit. If None, the center is set to 0 and the parameter is
+ fixed.
Raises
------
TypeError
- If the value is not a number or None.
+ If *value* is not None and not a numeric type.
"""
-
if value is None:
value = 0.0
self._center.fixed = True
@@ -179,48 +171,43 @@ def center(self, value: Numeric | None) -> None:
@property
def width(self) -> Parameter:
"""
- Get the width parameter (standard deviation).
+ Get the width parameter (sigma).
Returns
-------
Parameter
- The width parameter.
+ The width (sigma) Parameter with unit ``x_unit``.
"""
return self._width
@width.setter
def width(self, value: Numeric) -> None:
"""
- Set the width parameter value.
-
Parameters
----------
value : Numeric
- The new value for the width parameter.
+ New width value in x_unit. Must be strictly positive.
Raises
------
TypeError
- If the value is not a number or None.
+ If *value* is not a numeric type.
ValueError
- If the value is not positive.
+ If *value* is not positive.
"""
if not isinstance(value, Numeric):
raise TypeError('width must be a number')
-
if float(value) <= 0:
raise ValueError('width must be positive')
-
self._width.value = value
- def evaluate(
- self,
- x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray,
- ) -> np.ndarray:
+ def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray:
r"""
- Evaluate the Gaussian at the given x values.
+ Evaluate the Gaussian at x_vals.
+
+ Parameters in the model's own units are temporarily converted to eval_unit for the
+ computation.
- If x is a scipp Variable, the unit of the Gaussian will be converted to match x. The
intensity is given by $$ I(x) = \frac{A}{\sigma \sqrt{2\pi}} \exp\left( -\frac{1}{2}
\left(\frac{x - x_0}{\sigma}\right)^2 \right) $$
@@ -228,21 +215,51 @@ def evaluate(
Parameters
----------
- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- The x values at which to evaluate the Gaussian.
+ x_vals : np.ndarray
+ Raw x values expressed in eval_unit.
+ eval_unit : str | None
+ The unit of x_vals.
Returns
-------
np.ndarray
- The intensity of the Gaussian at the given x values.
+ Evaluated Gaussian values at x_vals.
+ """
+ center = self._resolve_param_value(self._center, eval_unit)
+ width = self._resolve_param_value(self._width, eval_unit)
+ area = self._resolve_param_value(self._area, self._eval_area_unit(eval_unit))
+
+ normalization = 1 / (np.sqrt(2 * np.pi) * width)
+ exponent = -0.5 * ((x_vals - center) / width) ** 2
+ return area * normalization * np.exp(exponent)
+
+ def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None:
+ """
+ Convert x-axis parameters (center, width) and area to new_x_unit.
+
+ Parameters
+ ----------
+ new_x_unit : str | sc.Unit
+ Target x-axis unit. Must be dimensionally compatible with the current x_unit.
"""
+ self._convert_x_unit_area_based(
+ new_x_unit=new_x_unit,
+ x_params=[self._center, self._width],
+ area_param=self._area,
+ )
- x = self._prepare_x_for_evaluate(x)
+ def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None:
+ """
+ Convert the y-axis (output) unit by rescaling the area parameter.
- normalization = 1 / (np.sqrt(2 * np.pi) * self.width.value)
- exponent = -0.5 * ((x - self.center.value) / self.width.value) ** 2
+ The area is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``.
- return self.area.value * normalization * np.exp(exponent)
+ Parameters
+ ----------
+ new_y_unit : str | sc.Unit
+ Target y-axis unit.
+ """
+ self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._area)
def __repr__(self) -> str:
"""
@@ -253,12 +270,10 @@ def __repr__(self) -> str:
str
A string representation of the Gaussian.
"""
-
return (
- f'{self.__class__.__name__}('
- f'name={self.name!r}, display_name={self.display_name!r}, '
- f'unit={self._unit},\n'
- f' area={self.area},\n'
- f' center={self.center},\n'
- f' width={self.width})'
+ f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, '
+ f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n'
+ f' area = {self.area},\n'
+ f' center = {self.center},\n'
+ f' width = {self.width})'
)
diff --git a/src/easydynamics/sample_model/components/lorentzian.py b/src/easydynamics/sample_model/components/lorentzian.py
index f2da78c84..fe36340fc 100644
--- a/src/easydynamics/sample_model/components/lorentzian.py
+++ b/src/easydynamics/sample_model/components/lorentzian.py
@@ -20,9 +20,10 @@ class Lorentzian(CreateParametersMixin, ModelComponent):
r"""
Model of a Lorentzian function.
- The intensity is given by $$ I(x) = \frac{A}{\pi} \frac{\Gamma}{(x - x_0)^2 + \Gamma^2}, $$
- where $A$ is the area, $x_0$ is the center, and $\Gamma$ is the half width at half maximum
- (HWHM).
+ $$ I(x) = \frac{A}{\pi} \frac{\Gamma}{(x - x_0)^2 + \Gamma^2} $$
+
+ where $A$ is the area, $x_0$ is the center, and $\Gamma$ is the hald width at half max (HWHM).
+ area has unit = x_unit * y_unit; center and width have unit = x_unit.
If the center is not provided, it will be centered at 0 and fixed, which is typically what you
want in QENS.
@@ -58,7 +59,8 @@ def __init__(
area: Numeric = 1.0,
center: Numeric | None = None,
width: Numeric = 1.0,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'Lorentzian',
display_name: str | None = None,
unique_name: str | None = None,
@@ -69,39 +71,38 @@ def __init__(
Parameters
----------
area : Numeric, default=1.0
- Area of the Lorentzian.
+ Integrated area under the Lorentzian. Unit is ``x_unit * y_unit``.
center : Numeric | None, default=None
- Center of the Lorentzian. If None, defaults to 0 and is fixed.
+ Peak position in x_unit. If None, defaults to 0 and the center parameter is fixed.
width : Numeric, default=1.0
- Half width at half maximum (HWHM).
- unit : str | sc.Unit, default='meV'
- Unit of the parameters.
+ Half-width at half-maximum (HWHM, gamma) in x_unit. Must be strictly positive.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis. center and width are stored in this unit. area_unit = x_unit *
+ y_unit.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the y-axis (output).
name : str, default='Lorentzian'
- Name of the component for indexing.
+ Name of the component.
display_name : str | None, default=None
- Display name for the component.
+ Display name shown when plotting. Falls back to *name* if None.
unique_name : str | None, default=None
- Unique name of the component. If None, a unique_name is automatically generated. By
- default, None.
+ Globally unique identifier. Auto-generated if None.
"""
-
super().__init__(
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
name=name,
display_name=display_name,
unique_name=unique_name,
)
- # These methods live in ValidationMixin
- area = self._create_area_parameter(area=area, name=name, unit=self._unit)
- center = self._create_center_parameter(
- center=center, name=name, fix_if_none=True, unit=self._unit
+ self._area = self._create_area_parameter(
+ area=area, name=name, x_unit=self.x_unit, y_unit=self.y_unit
)
- width = self._create_width_parameter(width=width, name=name, unit=self._unit)
-
- self._area = area
- self._center = center
- self._width = width
+ self._center = self._create_center_parameter(
+ center=center, name=name, fix_if_none=True, x_unit=self.x_unit
+ )
+ self._width = self._create_width_parameter(width=width, name=name, x_unit=self.x_unit)
@property
def area(self) -> Parameter:
@@ -111,24 +112,22 @@ def area(self) -> Parameter:
Returns
-------
Parameter
- The area parameter.
+ The area Parameter with unit ``x_unit * y_unit``.
"""
return self._area
@area.setter
def area(self, value: Numeric) -> None:
"""
- Set the value of the area parameter.
-
Parameters
----------
value : Numeric
- The new value for the area parameter.
+ New area value (in current area unit = x_unit * y_unit).
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
"""
if not isinstance(value, Numeric):
raise TypeError('area must be a number')
@@ -142,26 +141,24 @@ def center(self) -> Parameter:
Returns
-------
Parameter
- The center parameter.
+ The center Parameter with unit ``x_unit``.
"""
return self._center
@center.setter
def center(self, value: Numeric | None) -> None:
"""
- Set the value of the center parameter.
-
Parameters
----------
value : Numeric | None
- The new value for the center parameter. If None, defaults to 0 and is fixed.
+ New center value in x_unit. If None, the center is set to 0 and the parameter is
+ fixed.
Raises
------
TypeError
- If the value is not a number or None.
+ If *value* is not None and not a numeric type.
"""
-
if value is None:
value = 0.0
self._center.fixed = True
@@ -177,63 +174,85 @@ def width(self) -> Parameter:
Returns
-------
Parameter
- The width parameter.
+ The HWHM (gamma) Parameter with unit ``x_unit``.
"""
return self._width
@width.setter
def width(self, value: Numeric) -> None:
"""
- Set the width parameter value (HWHM).
-
Parameters
----------
value : Numeric
- The new value for the width parameter.
+ New HWHM value in x_unit. Must be strictly positive.
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
ValueError
- If the value is not positive.
+ If *value* is not positive.
"""
if not isinstance(value, Numeric):
raise TypeError('width must be a number')
-
if float(value) <= 0:
raise ValueError('width must be positive')
self._width.value = value
- def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray:
+ def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray:
r"""
- Evaluate the Lorentzian at the given x values.
+ Evaluate the Lorentzian at x_vals.
- If x is a scipp Variable, the unit of the Lorentzian will be converted to match x. The
- intensity is given by
-
- $$ I(x) = \frac{A}{\pi} \frac{\Gamma}{(x - x_0)^2 + \Gamma^2}, $$
-
- where $A$ is the area, $x_0$ is the center, and $\Gamma$ is the half width at half maximum
- (HWHM).
+ Parameters in the model's own units are temporarily converted to eval_unit for the
+ computation.
Parameters
----------
- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- The x values at which to evaluate the Lorentzian.
+ x_vals : np.ndarray
+ Raw x values expressed in eval_unit.
+ eval_unit : str | None
+ The unit of x_vals.
Returns
-------
np.ndarray
- The intensity of the Lorentzian at the given x values.
+ Evaluated Lorentzian values at x_vals.
"""
+ center = self._resolve_param_value(self._center, eval_unit)
+ width = self._resolve_param_value(self._width, eval_unit)
+ area = self._resolve_param_value(self._area, self._eval_area_unit(eval_unit))
- x = self._prepare_x_for_evaluate(x)
+ normalization = width / np.pi
+ denominator = (x_vals - center) ** 2 + width**2
+ return area * normalization / denominator
- normalization = self.width.value / np.pi
- denominator = (x - self.center.value) ** 2 + self.width.value**2
+ def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None:
+ """
+ Convert x-axis parameters (center, width) and area to new_x_unit.
- return self.area.value * normalization / denominator
+ Parameters
+ ----------
+ new_x_unit : str | sc.Unit
+ Target x-axis unit. Must be dimensionally compatible with the current x_unit.
+ """
+ self._convert_x_unit_area_based(
+ new_x_unit=new_x_unit,
+ x_params=[self._center, self._width],
+ area_param=self._area,
+ )
+
+ def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None:
+ """
+ Convert the y-axis (output) unit by rescaling the area parameter.
+
+ The area is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``.
+
+ Parameters
+ ----------
+ new_y_unit : str | sc.Unit
+ Target y-axis unit.
+ """
+ self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._area)
def __repr__(self) -> str:
"""
@@ -245,10 +264,9 @@ def __repr__(self) -> str:
A string representation of the Lorentzian.
"""
return (
- f'{self.__class__.__name__}('
- f'name={self.name!r}, display_name={self.display_name!r}, '
- f'unit={self._unit},\n'
- f' area={self.area},\n'
- f' center={self.center},\n'
- f' width={self.width})'
+ f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, '
+ f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n'
+ f' area = {self.area},\n'
+ f' center = {self.center},\n'
+ f' width = {self.width})'
)
diff --git a/src/easydynamics/sample_model/components/mixins.py b/src/easydynamics/sample_model/components/mixins.py
index 7552d2e80..82be533f3 100644
--- a/src/easydynamics/sample_model/components/mixins.py
+++ b/src/easydynamics/sample_model/components/mixins.py
@@ -18,52 +18,54 @@ class CreateParametersMixin:
"""
Provides parameter creation and validation methods for model components.
- This mixin provides methods to create and validate common physics parameters (area, center,
- width) with appropriate bounds and type checking.
+ area_unit = x_unit * y_unit, so when y_unit='dimensionless', area_unit = x_unit.
"""
def _create_area_parameter(
self,
area: Numeric,
name: str,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
minimum_area: float = MINIMUM_AREA,
) -> Parameter:
"""
- Validate and convert a number to a Parameter describing the area of a function.
-
- If the area is negative, a warning is raised. If the area is non-negative, its minimum is
- set to 0 to avoid it accidentally becoming negative during fitting.
+ Create a Parameter for the area with unit = x_unit * y_unit.
Parameters
----------
area : Numeric
- The area value.
+ Initial area value.
name : str
- The name of the model component.
- unit : str | sc.Unit, default='meV'
- The unit of the area Parameter.
+ Base name used to label the Parameter (``name + ' area'``).
+ x_unit : str | sc.Unit, default='meV'
+ X-axis unit. The resulting area unit is ``x_unit * y_unit``.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Y-axis unit. The resulting area unit is ``x_unit * y_unit``.
minimum_area : float, default=MINIMUM_AREA
- The minimum allowed area.
+ Lower bound applied to the Parameter when the area is non-negative. When *area* is
+ negative no lower bound is set and a :class:`UserWarning` is issued.
+
+ Returns
+ -------
+ Parameter
+ Configured area Parameter with ``unit = x_unit * y_unit``.
Raises
------
TypeError
- If area is not a number.
+ If *area* is not a numeric type.
ValueError
- If area is not a finite number.
-
- Returns
- -------
- Parameter
- The validated area Parameter.
+ If *area* is not finite.
"""
if not isinstance(area, Numeric):
raise TypeError('area must be a number.')
if not np.isfinite(area):
raise ValueError('area must be a finite number.')
- area_param = Parameter(name=name + ' area', value=float(area), unit=unit)
+
+ area_unit = str(sc.Unit(x_unit) * sc.Unit(y_unit))
+ area_param = Parameter(name=name + ' area', value=float(area), unit=area_unit)
if area_param.value < 0:
warnings.warn(
@@ -81,36 +83,38 @@ def _create_center_parameter(
center: Numeric | None,
name: str,
fix_if_none: bool,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
enforce_minimum_center: bool = False,
) -> Parameter:
"""
- Validate and convert a number to a Parameter describing the center of a function.
+ Create a Parameter for the center with unit = x_unit.
Parameters
----------
center : Numeric | None
- The center value.
+ Initial center value. If None, the center is set to 0.0 and ``fixed`` is controlled by
+ *fix_if_none*.
name : str
- The name of the model component.
+ Base name used to label the Parameter (``name + ' center'``).
fix_if_none : bool
- Whether to fix the center Parameter if center is None.
- unit : str | sc.Unit, default='meV'
- The unit of the center Parameter.
+ Whether to fix the Parameter when *center* is None.
+ x_unit : str | sc.Unit, default='meV'
+ X-axis unit, applied to the center Parameter.
enforce_minimum_center : bool, default=False
- Whether to enforce a minimum center value to avoid zero center in DHO.
+ If True, the Parameter's lower bound is raised to ``DHO_MINIMUM_CENTER`` (1e-10) to
+ prevent a zero center.
+
+ Returns
+ -------
+ Parameter
+ Configured center Parameter with ``unit = x_unit``.
Raises
------
TypeError
- If center is not None or a number.
+ If *center* is not None and not a numeric type.
ValueError
- If center is a number but not finite.
-
- Returns
- -------
- Parameter
- The validated center Parameter.
+ If *center* is not None and not finite.
"""
if center is not None and not isinstance(center, Numeric):
raise TypeError('center must be None or a number.')
@@ -119,16 +123,18 @@ def _create_center_parameter(
center_param = Parameter(
name=name + ' center',
value=0.0,
- unit=unit,
+ unit=x_unit,
fixed=fix_if_none,
)
else:
if not np.isfinite(center):
raise ValueError('center must be None or a finite number.')
+ center_param = Parameter(name=name + ' center', value=float(center), unit=x_unit)
- center_param = Parameter(name=name + ' center', value=float(center), unit=unit)
if enforce_minimum_center and center_param.min < DHO_MINIMUM_CENTER:
center_param.min = DHO_MINIMUM_CENTER
+ if center_param.value < DHO_MINIMUM_CENTER:
+ center_param.value = DHO_MINIMUM_CENTER
return center_param
def _create_width_parameter(
@@ -136,36 +142,37 @@ def _create_width_parameter(
width: Numeric,
name: str,
param_name: str = 'width',
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
minimum_width: float = MINIMUM_WIDTH,
) -> Parameter:
"""
- Validate and convert a number to a Parameter describing the width of a function.
+ Create a Parameter for the width with unit = x_unit.
Parameters
----------
width : Numeric
- The width value.
+ Initial width value. Must be strictly positive (>= *minimum_width*).
name : str
- The name of the model component.
+ Base name used to label the Parameter (``name + ' ' + param_name``).
param_name : str, default='width'
- The name of the width parameter.
- unit : str | sc.Unit, default='meV'
- The unit of the width Parameter.
+ Logical name of the parameter used in the label and error messages (e.g.
+ ``'gaussian_width'``, ``'lorentzian_width'``).
+ x_unit : str | sc.Unit, default='meV'
+ X-axis unit, applied to the width Parameter.
minimum_width : float, default=MINIMUM_WIDTH
- The minimum allowed width.
+ Absolute lower bound for the width to prevent division-by-zero.
+
+ Returns
+ -------
+ Parameter
+ Configured width Parameter with ``unit = x_unit`` and ``min = minimum_width``.
Raises
------
TypeError
- If width is not a number.
+ If *width* is not a numeric type.
ValueError
- If width is non-positive.
-
- Returns
- -------
- Parameter
- The validated width Parameter.
+ If *width* is not finite or is smaller than *minimum_width*.
"""
if not isinstance(width, Numeric):
raise TypeError(f'{param_name} must be a number.')
@@ -180,6 +187,6 @@ def _create_width_parameter(
return Parameter(
name=name + ' ' + param_name,
value=float(width),
- unit=unit,
+ unit=x_unit,
min=minimum_width,
)
diff --git a/src/easydynamics/sample_model/components/model_component.py b/src/easydynamics/sample_model/components/model_component.py
index 878f501fa..1d2c98d0f 100644
--- a/src/easydynamics/sample_model/components/model_component.py
+++ b/src/easydynamics/sample_model/components/model_component.py
@@ -3,8 +3,8 @@
from __future__ import annotations
-import warnings
from abc import abstractmethod
+from typing import TYPE_CHECKING
import numpy as np
import scipp as sc
@@ -12,6 +12,10 @@
from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase
from easydynamics.utils.utils import Numeric
+from easydynamics.utils.utils import convert_parameter_unit
+
+if TYPE_CHECKING:
+ from easyscience.variable import Parameter
class ModelComponent(EasyDynamicsModelBase):
@@ -19,107 +23,145 @@ class ModelComponent(EasyDynamicsModelBase):
def __init__(
self,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'ModelComponent',
display_name: str | None = None,
unique_name: str | None = None,
) -> None:
"""
- Initialize the ModelComponent.
-
Parameters
----------
- unit : str | sc.Unit, default='meV'
- The unit of the model component.
+ x_unit : str | sc.Unit, default='meV'
+ Unit for the x-axis (independent variable) of this component.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit for the y-axis (dependent variable / output) of this component.
name : str, default='ModelComponent'
- The name of the model component for indexing.
+ Internal name used for parameter labelling and logging.
display_name : str | None, default=None
- A human-readable name for the component.
+ Human-readable name shown in plots and reports. Falls back to *name* if None.
unique_name : str | None, default=None
- A unique identifier for the component.
+ Globally unique identifier. Auto-generated if None.
"""
super().__init__(
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
name=name,
display_name=display_name,
unique_name=unique_name,
)
@property
- def unit(self) -> str:
+ def x_unit(self) -> str | None:
"""
- Get the unit.
-
+ Get the unit of the x axis.
Returns
-------
- str
- The unit of the model component.
+ str | None
+ The current x-axis unit as a string, or None if no unit is set.
"""
- return str(self._unit)
+ return str(self._x_unit) if self._x_unit is not None else None
- @unit.setter
- def unit(self, _unit_str: str) -> None:
+ @x_unit.setter
+ def x_unit(self, _: str) -> None:
"""
- Unit is read-only.
+ Unit is read-only; raises AttributeError always.
- Use convert_unit to change the unit between allowed types or create a new ModelComponent
- with the desired unit.
+ Use :meth:`convert_x_unit` to change the unit, or create a new instance with the desired
+ unit.
- Parameters
- ----------
- _unit_str : str
- The new unit to set.
+ Raises
+ ------
+ AttributeError
+ Always raised when this setter is called.
+ """
+ raise AttributeError(
+ f'x_unit is read-only. Use convert_x_unit to change the unit '
+ f'or create a new {self.__class__.__name__} with the desired unit.'
+ )
+
+ @property
+ def y_unit(self) -> str | None:
+ """
+ Get the unit of the y axis
+
+ Returns
+ -------
+ str | None
+ The current y-axis unit as a string, or None if no unit is set.
+ """
+ return str(self._y_unit) if self._y_unit is not None else None
+
+ @y_unit.setter
+ def y_unit(self, _: str) -> None:
+ """
+ Unit is read-only; raises AttributeError always.
+
+ Use :meth:`convert_y_unit` to change the unit, or create a new instance with the desired
+ unit.
Raises
------
AttributeError
- Always raised since unit is read-only.
+ Always raised when this setter is called.
"""
raise AttributeError(
- f'Unit is read-only. Use convert_unit to change the unit between allowed types '
+ f'y_unit is read-only. Use convert_y_unit to change the unit '
f'or create a new {self.__class__.__name__} with the desired unit.'
)
def fix_all_parameters(self) -> None:
- """Fix all parameters in the model component."""
+ """
+ Fix all parameters in the model component.
- pars = self.get_fittable_parameters()
- for p in pars:
+ Sets ``fixed=True`` on every fittable parameter returned by
+ :meth:`get_fittable_parameters`.
+ """
+ for p in self.get_fittable_parameters():
p.fixed = True
def free_all_parameters(self) -> None:
- """Free all parameters in the model component."""
+ """
+ Free all parameters in the model component.
+
+ Sets ``fixed=False`` on every fittable parameter returned by
+ :meth:`get_fittable_parameters`.
+ """
for p in self.get_fittable_parameters():
p.fixed = False
def _prepare_x_for_evaluate(
self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- ) -> np.ndarray:
+ ) -> tuple[np.ndarray, str | None, str]:
"""
- Prepare the input x for evaluation by handling units and converting to a numpy array.
+ Validate x and extract its values, detected unit, and dimension name.
+
+ x is never converted. When x carries a unit, the caller is responsible for resolving
+ parameter values to that unit via _resolve_param_value.
Parameters
----------
x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- The input data to prepare.
+ Input x values to validate and extract.
+
+ Returns
+ -------
+ tuple[np.ndarray, str | None, str]
+ x_values : np.ndarray of raw float values (no unit conversion) detected_unit : str unit
+ of x if scipp input, else None dim : scipp dimension name if scipp input, else 'x'
Raises
------
- ValueError
- If x contains NaN or infinite values, or if a sc.DataArray has more than one
- coordinate.
UnitError
- If x has incompatible units that cannot be converted to the component's unit.
-
- Returns
- -------
- np.ndarray
- The prepared input data as a numpy array.
+ If x has a unit incompatible with the model's x_unit.
+ ValueError
+ If x contains NaN or infinite values, or if a DataArray has more than one coordinate.
"""
+ detected_unit: str | None = None
+ dim: str = 'x'
+ dim_from_dataarray: bool = False
- # Handle units
if isinstance(x, sc.DataArray):
- # Check that there's exactly one coordinate
coords = dict(x.coords)
ncoords = len(coords)
if ncoords != 1:
@@ -128,31 +170,25 @@ def _prepare_x_for_evaluate(
f'scipp.DataArray must have exactly one coordinate to be used as input `x`. '
f'Found {ncoords} coordinates: {coord_names}.'
)
- # get the coordinate, it's a sc.Variable
- _, coord_obj = next(iter(coords.items()))
+ dim, coord_obj = next(iter(coords.items()))
x = coord_obj
+ dim_from_dataarray = True
+
if isinstance(x, sc.Variable):
- # Need to check if the units are consistent,
- # and convert if not.
+ detected_unit = str(x.unit)
+ if not dim_from_dataarray:
+ dim = x.dims[0] if x.dims else 'x'
x_in = x.value if x.sizes == {} else x.values
- if self._unit is not None and x.unit != self._unit:
- self_unit_for_warning = self._unit
+
+ # Validate that x's unit is compatible with model's x_unit
+ if self.x_unit is not None and detected_unit != self.x_unit:
try:
- self.convert_unit(x.unit.name)
+ sc.to_unit(sc.scalar(1.0, unit=detected_unit), self.x_unit)
except Exception as e:
raise UnitError(
- f'Input x has unit {x.unit}, but {self.__class__.__name__} component \
- has unit {self._unit}. \
- Failed to convert {self.__class__.__name__} to {x.unit}.'
+ f'Input x has unit {detected_unit}, which is incompatible with '
+ f'{self.__class__.__name__} x_unit {self.x_unit}.'
) from e
-
- warnings.warn(
- f'Input x has unit {x.unit}, but {self.__class__.__name__} component \
- has unit {self_unit_for_warning}. \
- Converting {self.__class__.__name__} to {x.unit}.',
- UserWarning,
- stacklevel=3,
- )
else:
x_in = x
@@ -167,60 +203,262 @@ def _prepare_x_for_evaluate(
if any(np.isinf(x_in)):
raise ValueError('Input x contains infinite values.')
- return np.sort(x_in)
+ return x_in, detected_unit, dim
+
+ def _resolve_param_value(self, param: Parameter, target_unit: str | None) -> float:
+ """
+ Return param's value converted to target_unit without mutating param.
+
+ If target_unit is None or already matches param's unit, returns param.value directly. Uses
+ a temporary scipp scalar for the conversion.
+
+ Parameters
+ ----------
+ param : Parameter
+ The parameter whose value should be resolved.
+ target_unit : str | None
+ The unit to which the parameter value should be converted. When None (or equal to the
+ parameter's own unit) the raw value is returned without any conversion.
+
+ Returns
+ -------
+ float
+ The parameter value expressed in *target_unit*.
+ """
+ if target_unit is None or str(param.unit) == str(target_unit):
+ return param.value
+ return sc.to_unit(sc.scalar(param.value, unit=str(param.unit)), target_unit).value
- def convert_unit(self, unit: str | sc.Unit) -> None:
+ def _convert_x_unit_area_based(
+ self,
+ new_x_unit: str | sc.Unit,
+ x_params: list,
+ area_param: Parameter,
+ inverse_params: list | None = None,
+ ) -> None:
"""
- Convert the unit of the Parameters in the component.
+ Shared convert_x_unit logic for components with an area parameter (area = x_unit * y_unit).
+
+ Validates the input type, converts all x-axis parameters, the area parameter, and any
+ reciprocal (``1/x_unit``) parameters to the new unit, and updates ``_x_unit``. Rolls back
+ all conversions if any step fails.
Parameters
----------
- unit : str | sc.Unit
- The new unit to convert to.
+ new_x_unit : str | sc.Unit
+ Target x-axis unit.
+ x_params : list
+ Parameters whose unit equals *x_unit* (e.g. center, width).
+ area_param : Parameter
+ The parameter whose unit equals ``x_unit * y_unit``.
+ inverse_params : list | None, default=None
+ Parameters whose unit equals ``1/x_unit`` (e.g. an exponential rate).
Raises
------
TypeError
- If the provided unit is not a str or sc.Unit.
+ If *new_x_unit* is not a ``str`` or ``sc.Unit``.
Exception
- If the provided unit is invalid or incompatible with the component's parameters.
+ If the conversion fails; all parameters are rolled back to their original units.
+ """
+ if not isinstance(new_x_unit, (str, sc.Unit)):
+ raise TypeError(f'x_unit must be a string or sc.Unit, got {type(new_x_unit).__name__}')
+ inverse_params = inverse_params or []
+ old_x_unit = self.x_unit
+ new_x_str = str(new_x_unit) if isinstance(new_x_unit, sc.Unit) else new_x_unit
+ new_area_unit = str(sc.Unit(new_x_str) * sc.Unit(self.y_unit))
+ try:
+ for p in x_params:
+ convert_parameter_unit(p, new_x_unit)
+ convert_parameter_unit(area_param, new_area_unit)
+ for p in inverse_params:
+ convert_parameter_unit(p, '1/' + new_x_str)
+ self._x_unit = new_x_str
+ except Exception as e:
+ try:
+ old_area_unit = str(sc.Unit(old_x_unit) * sc.Unit(self.y_unit))
+ for p in x_params:
+ convert_parameter_unit(p, old_x_unit)
+ convert_parameter_unit(area_param, old_area_unit)
+ for p in inverse_params:
+ convert_parameter_unit(p, '1/' + str(old_x_unit))
+ except Exception: # noqa: S110
+ pass
+ raise e
+
+ def _convert_y_unit_area_based(
+ self,
+ new_y_unit: str | sc.Unit,
+ area_param: Parameter,
+ ) -> None:
"""
- if not isinstance(unit, (str, sc.Unit)):
- raise TypeError(f'Unit must be a string or sc.Unit, got {type(unit).__name__}')
+ Shared convert_y_unit logic for components with an area parameter (area = x_unit * y_unit).
- old_unit = self._unit
+ Validates the input type, rescales the area parameter from ``x_unit * old_y_unit`` to
+ ``x_unit * new_y_unit``, and updates ``_y_unit``. Rolls back on failure.
+
+ Parameters
+ ----------
+ new_y_unit : str | sc.Unit
+ Target y-axis unit.
+ area_param : Parameter
+ The parameter whose unit equals ``x_unit * y_unit``.
+
+ Raises
+ ------
+ TypeError
+ If *new_y_unit* is not a ``str`` or ``sc.Unit``.
+ Exception
+ If the conversion fails; the area parameter is rolled back to its original unit.
+ """
+ if not isinstance(new_y_unit, (str, sc.Unit)):
+ raise TypeError(f'y_unit must be a string or sc.Unit, got {type(new_y_unit).__name__}')
+ old_y_unit = self.y_unit
+ new_area_unit = str(sc.Unit(self.x_unit) * sc.Unit(new_y_unit))
+ try:
+ convert_parameter_unit(area_param, new_area_unit)
+ self._y_unit = str(new_y_unit) if isinstance(new_y_unit, sc.Unit) else new_y_unit
+ except Exception as e:
+ try:
+ old_area_unit = str(sc.Unit(self.x_unit) * sc.Unit(old_y_unit))
+ convert_parameter_unit(area_param, old_area_unit)
+ except Exception: # noqa: S110
+ pass
+ raise e
+
+ def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None:
+ """
+ Convert the x-axis unit of the component.
+
+ The base implementation converts all parameters. Subclasses with mixed-unit parameters
+ (e.g. area ≠ x_unit) should override this method.
+
+ Parameters
+ ----------
+ new_x_unit : str | sc.Unit
+ Target x-axis unit. Must be dimensionally compatible with the current x_unit.
+
+ Raises
+ ------
+ TypeError
+ If *new_x_unit* is not a ``str`` or ``sc.Unit``.
+ Exception
+ If the conversion between the current unit and *new_x_unit* fails. On failure the
+ component is rolled back to its original unit.
+ """
+ if not isinstance(new_x_unit, (str, sc.Unit)):
+ raise TypeError(f'x_unit must be a string or sc.Unit, got {type(new_x_unit).__name__}')
+
+ old_unit = self.x_unit
pars = self.get_all_parameters()
try:
for p in pars:
- p.convert_unit(unit)
- self._unit = unit
+ convert_parameter_unit(p, new_x_unit)
+ self._x_unit = str(new_x_unit) if isinstance(new_x_unit, sc.Unit) else new_x_unit
except Exception as e:
- # Attempt to rollback on failure
try:
for p in pars:
- if hasattr(p, 'convert_unit'):
- p.convert_unit(old_unit)
+ convert_parameter_unit(p, old_unit)
except Exception: # noqa: S110
- pass # Best effort rollback
+ pass
raise e
- @abstractmethod
- def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray:
+ def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None:
+ """
+ Convert the y-axis (output) unit. Subclasses with an area parameter should override this.
+
+ Parameters
+ ----------
+ new_y_unit : str | sc.Unit
+ Target y-axis unit.
+
+ Raises
+ ------
+ NotImplementedError
+ Always raised in this base implementation. Subclasses that carry an area parameter
+ (area_unit = x_unit * y_unit) must override this method to rescale the area
+ appropriately.
+ """
+ raise NotImplementedError(f'{self.__class__.__name__} does not support convert_y_unit.')
+
+ def evaluate(
+ self,
+ x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray,
+ output: str = 'numpy',
+ ) -> np.ndarray | sc.Variable:
"""
- Abstract method to evaluate the model component at input x.
+ Evaluate the model component at input x.
- Must be implemented by subclasses.
+ When x carries a unit (scipp input), parameter values are temporarily converted to that
+ unit for the computation without mutating the parameters.
Parameters
----------
x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- The x values at which to evaluate the component.
+ Input x values.
+ output : str, default='numpy'
+ 'numpy' returns np.ndarray; 'scipp' returns sc.Variable with y_unit.
+
+ Raises
+ ------
+ ValueError
+ If output is not 'numpy' or 'scipp'.
+
+ Returns
+ -------
+ np.ndarray | sc.Variable
+ Evaluated model values at x.
+ """
+ if output not in ('numpy', 'scipp'):
+ raise ValueError(f"output must be 'numpy' or 'scipp', got {output!r}")
+ x_vals, detected_unit, dim = self._prepare_x_for_evaluate(x)
+ eval_unit = detected_unit or self.x_unit
+ values = self._evaluate_values(x_vals, eval_unit)
+ if output == 'scipp':
+ return sc.array(dims=[dim], values=values, unit=self.y_unit)
+ return values
+
+ @abstractmethod
+ def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray:
+ """
+ Compute the component's values for raw x values expressed in eval_unit.
+
+ Implementations must resolve their parameters to *eval_unit* (via
+ :meth:`_resolve_param_value`) and return plain numpy values; the public :meth:`evaluate`
+ handles input validation and output wrapping.
+
+ Parameters
+ ----------
+ x_vals : np.ndarray
+ Raw x values, expressed in *eval_unit*.
+ eval_unit : str | None
+ The unit of *x_vals* (the detected input unit, falling back to the component's x_unit),
+ or None if no unit information is available.
Returns
-------
np.ndarray
- Evaluated function values.
+ Evaluated model values at x_vals.
+ """
+
+ def _eval_area_unit(self, eval_unit: str | None) -> str | None:
+ """
+ Get the area unit (``eval_unit * y_unit``) matching an evaluation unit.
+
+ Parameters
+ ----------
+ eval_unit : str | None
+ The unit x values are expressed in during evaluation.
+
+ Returns
+ -------
+ str | None
+ The corresponding area unit, or None when either unit is unknown (in which case
+ parameter values are used unconverted).
"""
+ if eval_unit is None or self.y_unit is None:
+ return None
+ return str(sc.Unit(eval_unit) * sc.Unit(self.y_unit))
def __repr__(self) -> str:
"""
@@ -231,5 +469,7 @@ def __repr__(self) -> str:
str
A string representation of the ModelComponent.
"""
-
- return f'{self.__class__.__name__}(unique_name={self.unique_name!r}, unit={self._unit})'
+ return (
+ f'{self.__class__.__name__}(unique_name={self.unique_name}, '
+ f'x_unit={self.x_unit}, y_unit={self.y_unit})'
+ )
diff --git a/src/easydynamics/sample_model/components/polynomial.py b/src/easydynamics/sample_model/components/polynomial.py
index a36700dff..683af8d9f 100644
--- a/src/easydynamics/sample_model/components/polynomial.py
+++ b/src/easydynamics/sample_model/components/polynomial.py
@@ -23,8 +23,10 @@ class Polynomial(ModelComponent):
r"""
Polynomial function component.
- The intensity is given by $$ I(x) = c_0 + c_1 x + c_2 x^2 + ... + c_N x^N, $$ where $C_i$ are
- the coefficients.
+ $$ I(x) = c_0 + c_1 x + c_2 x^2 + ... + c_N x^N $$
+
+ Coefficients are stored as dimensionless Parameters. When x_unit changes, the coefficient
+ values are rescaled so the evaluated result stays the same. The output unit is y_unit.
Examples
--------
@@ -53,39 +55,46 @@ class Polynomial(ModelComponent):
def __init__(
self,
coefficients: Sequence[Numeric | Parameter] = (0.0,),
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'Polynomial',
display_name: str | None = None,
unique_name: str | None = None,
+ suppress_warnings: bool = False,
) -> None:
"""
- Initialize the Polynomial component.
-
Parameters
----------
coefficients : Sequence[Numeric | Parameter], default=(0.0,)
- Coefficients c0, c1, ..., cN.
- unit : str | sc.Unit, default='meV'
- Unit of the Polynomial component.
+ Ordered list of polynomial coefficients ``[c0, c1, ..., cN]`` where the polynomial is
+ ``c0 + c1*x + c2*x^2 + ... + cN*x^N``. Each element may be a plain numeric value
+ (wrapped into a dimensionless :class:`Parameter`) or an existing :class:`Parameter`
+ instance. Must contain at least one element.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis. When the x_unit is changed via :meth:`convert_x_unit`, coefficient
+ values are rescaled by power-law factors so the evaluated output remains unchanged.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the y-axis (output).
name : str, default='Polynomial'
- Name of the component for indexing.
+ Name of the component.
display_name : str | None, default=None
- Display name of the Polynomial component.
+ Display name shown when plotting. Falls back to *name* if None.
unique_name : str | None, default=None
- Unique name of the component. If None, a unique_name is automatically generated. By
- default, None.
+ Globally unique identifier. Auto-generated if None.
+ suppress_warnings : bool, default=False
+ Whether to suppress warnings
Raises
------
TypeError
- If coefficients is not a sequence of numbers or Parameters or if any item in
- coefficients is not a number or Parameter.
+ If *coefficients* is not a list, tuple, or ndarray, or if any element is neither
+ numeric nor a :class:`Parameter`.
ValueError
- If coefficients is an empty sequence.
+ If *coefficients* is empty.
"""
-
super().__init__(
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
name=name,
display_name=display_name,
unique_name=unique_name,
@@ -93,17 +102,15 @@ def __init__(
if not isinstance(coefficients, (list, tuple, np.ndarray)):
raise TypeError(
- 'coefficients must be a sequence (list/tuple/ndarray) \
- of numbers or Parameter objects.'
+ 'coefficients must be a sequence (list/tuple/ndarray) '
+ 'of numbers or Parameter objects.'
)
if len(coefficients) == 0:
raise ValueError('At least one coefficient must be provided.')
- # Internal storage of Parameter objects
self._coefficients: list[Parameter] = []
- # Coefficients are treated as dimensionless Parameters
for i, coef in enumerate(coefficients):
if isinstance(coef, Parameter):
param = coef
@@ -113,19 +120,20 @@ def __init__(
raise TypeError('Each coefficient must be either a numeric value or a Parameter.')
self._coefficients.append(param)
- # Helper scipp scalar to track unit conversions
- # (value initialized to 1 with provided unit)
- self._unit_conversion_helper = sc.scalar(value=1.0, unit=unit)
+ # Tracks the current x_unit scale for convert_x_unit power-law rescaling
+ self._x_unit_helper = sc.scalar(value=1.0, unit=x_unit)
+
+ self.suppress_warnings = suppress_warnings
@property
def coefficients(self) -> list[Parameter]:
"""
Get the coefficients of the polynomial as a list of Parameters.
-
Returns
-------
list[Parameter]
- The coefficients of the polynomial.
+ A shallow copy of the internal coefficient list ``[c0, c1, ..., cN]``. Modifying the
+ returned list does not affect the model; use the setter to replace values.
"""
return list(self._coefficients)
@@ -133,25 +141,24 @@ def coefficients(self) -> list[Parameter]:
def coefficients(self, coeffs: Sequence[Numeric | Parameter]) -> None:
"""
Set the coefficients of the polynomial.
-
- Length must match current number of coefficients.
-
Parameters
----------
coeffs : Sequence[Numeric | Parameter]
- New coefficients as a sequence of numbers or Parameters.
+ New coefficient values. Must be a list, tuple, or ndarray and must have the same
+ length as the current number of coefficients. Numeric values update the existing
+ Parameter's ``.value``; a Parameter instance replaces the stored Parameter entirely.
Raises
------
TypeError
- If coeffs is not a sequence of numbers or Parameters or if any item in coeffs is not a
- number or Parameter.
+ If *coeffs* is not a list, tuple, or ndarray, or if any element is neither numeric nor
+ a Parameter.
ValueError
- If the length of coeffs does not match the existing number of coefficients.
+ If the length of *coeffs* does not match the current number of coefficients.
"""
if not isinstance(coeffs, (list, tuple, np.ndarray)):
raise TypeError(
- 'coefficients must be a sequence (list/tuple/ndarray) of numbers or Parameter .'
+ 'coefficients must be a sequence (list/tuple/ndarray) of numbers or Parameter.'
)
if len(coeffs) != len(self._coefficients):
raise ValueError(
@@ -159,7 +166,6 @@ def coefficients(self, coeffs: Sequence[Numeric | Parameter]) -> None:
)
for i, coef in enumerate(coeffs):
if isinstance(coef, Parameter):
- # replace parameter
self._coefficients[i] = coef
elif isinstance(coef, Numeric):
self._coefficients[i].value = float(coef)
@@ -169,56 +175,20 @@ def coefficients(self, coeffs: Sequence[Numeric | Parameter]) -> None:
def coefficient_values(self) -> list[float]:
"""
Get the coefficients of the polynomial as a list.
-
Returns
-------
list[float]
- The coefficient values of the polynomial.
+ Current numeric values of all coefficients ``[c0.value, c1.value, ..., cN.value]``.
"""
return [param.value for param in self._coefficients]
- def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray:
- r"""
- Evaluate the Polynomial at the given x values.
-
- The intensity is given by $$ I(x) = c_0 + c_1 x + c_2 x^2 + ...
- + c_N x^N, $$ where $C_i$ are the coefficients.
-
- Parameters
- ----------
- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- The x values at which to evaluate the Polynomial.
-
- Returns
- -------
- np.ndarray
- The evaluated Polynomial at the given x values.
- """
-
- x = self._prepare_x_for_evaluate(x)
-
- result = np.zeros_like(x, dtype=float)
- for i, param in enumerate(self._coefficients):
- result += param.value * np.power(x, i)
-
- if any(result < 0):
- warnings.warn(
- f'The Polynomial with unique_name {self.unique_name} has negative values, '
- 'which may not be physically meaningful.',
- UserWarning,
- stacklevel=2,
- )
- return result
-
@property
def degree(self) -> int:
"""
- Get the degree of the polynomial.
-
Returns
-------
int
- The degree of the polynomial.
+ Polynomial degree, equal to ``len(coefficients) - 1``.
"""
return len(self._coefficients) - 1
@@ -230,73 +200,167 @@ def degree(self, _value: int) -> None:
Parameters
----------
_value : int
- The new degree of the polynomial.
+ Ignored; this setter always raises :exc:`AttributeError`.
Raises
------
AttributeError
- Always raised since degree cannot be set directly.
+ Always raised when this setter is called.
"""
raise AttributeError(
'The degree of the polynomial is determined by the number of coefficients '
'and cannot be set directly.'
)
- def get_all_variables(self) -> list[DescriptorBase]:
+ @property
+ def suppress_warnings(self) -> bool:
+ """
+ Get whether or not to suppress warnings.
+ """
+ return self._suppress_warnings
+
+ @suppress_warnings.setter
+ def suppress_warnings(self, value: bool) -> None:
"""
- Get all variables from the model component.
+ Choose whether or not to suppress warnings.
+
+ Parameters
+ ----------
+ value : bool
+ Whether or not to suppress warnings
+
+ Raises
+ ------
+ TypeError
+ If suppress_warnings is not True or False
+ """
+ if not isinstance(value, bool):
+ raise TypeError('Suppress_warnings must be True or False')
+ self._suppress_warnings = value
+
+ def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray:
+ r"""
+ Evaluate the Polynomial at x_vals.
+
+ When x_vals is expressed in a different unit than the stored x_unit, coefficient values are
+ temporarily rescaled (same power-law logic as convert_x_unit) without mutation.
+
+ Parameters
+ ----------
+ x_vals : np.ndarray
+ Raw x values expressed in eval_unit.
+ eval_unit : str | None
+ The unit of x_vals.
+ Returns
+ -------
+ np.ndarray
+ Evaluated polynomial values.
+ """
+ if (
+ eval_unit is not None
+ and self.x_unit is not None
+ and sc.Unit(eval_unit) != sc.Unit(self.x_unit)
+ ):
+ # Temporary coefficient rescaling — no mutation
+ helper = sc.scalar(1.0, unit=self.x_unit)
+ helper_in_x = sc.to_unit(helper, eval_unit)
+ scale = helper.value / helper_in_x.value
+ coeff_vals = [p.value * scale**i for i, p in enumerate(self._coefficients)]
+ else:
+ coeff_vals = [p.value for p in self._coefficients]
+
+ result = np.zeros_like(x_vals, dtype=float)
+ for i, cv in enumerate(coeff_vals):
+ result += cv * np.power(x_vals, i)
+
+ if not self._suppress_warnings and any(result < 0):
+ warnings.warn(
+ f'The Polynomial with unique_name {self.unique_name} has negative values, '
+ 'which may not be physically meaningful.',
+ UserWarning,
+ stacklevel=3,
+ )
+
+ return result
+
+ def get_all_variables(self) -> list[DescriptorBase]:
+ """
Returns
-------
list[DescriptorBase]
- List of variables in the component.
+ The coefficient Parameters that constitute the fittable variables of this polynomial
+ component.
"""
return list(self._coefficients)
- def convert_unit(self, unit: str | sc.Unit) -> None:
+ def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None:
"""
- Convert the unit of the polynomial.
+ Convert the x-axis unit by rescaling coefficients with power-law factors.
+
+ Each coefficient ``c_i`` is rescaled by ``(old_scale / new_scale) ** i`` so the evaluated
+ polynomial output is unchanged after the conversion.
Parameters
----------
- unit : str | sc.Unit
- The target unit to convert to.
+ new_x_unit : str | sc.Unit
+ Target x-axis unit. Must be dimensionally compatible with the current x_unit.
Raises
------
UnitError
- If the provided unit is not a string or sc.Unit.
+ If *new_x_unit* is not a valid unit string or ``sc.Unit``, or if the conversion between
+ the current unit and *new_x_unit* fails.
"""
+ if not isinstance(new_x_unit, (str, sc.Unit)):
+ raise UnitError('new_x_unit must be a string or a scipp unit.')
- if not isinstance(unit, (str, sc.Unit)):
- raise UnitError('unit must be a string or a scipp unit.')
-
- # Find out how much the unit changes
- # by converting a helper variable
- conversion_value_before = self._unit_conversion_helper.value
- self._unit_conversion_helper = sc.to_unit(self._unit_conversion_helper, unit=unit)
- conversion_value_after = self._unit_conversion_helper.value
+ conversion_value_before = self._x_unit_helper.value
+ self._x_unit_helper = sc.to_unit(self._x_unit_helper, unit=new_x_unit)
+ conversion_value_after = self._x_unit_helper.value
for i, param in enumerate(self._coefficients):
- param.value *= (
- conversion_value_before / conversion_value_after
- ) ** i # set the values directly to the appropriate power
+ param.value *= (conversion_value_before / conversion_value_after) ** i
- self._unit = unit
+ self._x_unit = str(new_x_unit) if isinstance(new_x_unit, sc.Unit) else new_x_unit
- def __repr__(self) -> str:
+ def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None:
"""
- Return a string representation of the Polynomial.
+ Rescale all coefficients so the evaluated output remains the same physical value.
- Returns
- -------
- str
- A string representation of the Polynomial.
+ All coefficients are multiplied by the conversion factor from ``old_y_unit`` to
+ ``new_y_unit`` so that ``I(x) [new_y_unit]`` represents the same physical quantity as
+ ``I(x) [old_y_unit]``.
+
+ Parameters
+ ----------
+ new_y_unit : str | sc.Unit
+ Target y-axis unit. Must be dimensionally compatible with the current y_unit.
+
+ Raises
+ ------
+ UnitError
+ If *new_y_unit* is not a valid unit string or ``sc.Unit``, or if the conversion between
+ the current y_unit and *new_y_unit* fails.
"""
+ if not isinstance(new_y_unit, (str, sc.Unit)):
+ raise UnitError('new_y_unit must be a string or a scipp unit.')
+ old_y_unit = self.y_unit or 'dimensionless'
+ new_y_str = str(new_y_unit) if isinstance(new_y_unit, sc.Unit) else new_y_unit
+
+ # Compute conversion factor: 1 old_y_unit expressed in new_y_unit
+ y_helper = sc.scalar(1.0, unit=old_y_unit)
+ y_helper_new = sc.to_unit(y_helper, new_y_str)
+ scale = y_helper_new.value / y_helper.value
+
+ for param in self._coefficients:
+ param.value *= scale
+ self._y_unit = new_y_str
+
+ def __repr__(self) -> str:
coeffs_str = ', '.join(f'{param.name}={param.value}' for param in self._coefficients)
return (
- f'{self.__class__.__name__}('
- f'name={self.name!r}, display_name={self.display_name!r}, '
- f'unit={self._unit},\n'
- f' coefficients=[{coeffs_str}])'
+ f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, '
+ f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n'
+ f' coefficients = [{coeffs_str}])'
)
diff --git a/src/easydynamics/sample_model/components/voigt.py b/src/easydynamics/sample_model/components/voigt.py
index 5ca29105c..ee8c29046 100644
--- a/src/easydynamics/sample_model/components/voigt.py
+++ b/src/easydynamics/sample_model/components/voigt.py
@@ -19,13 +19,14 @@
class Voigt(CreateParametersMixin, ModelComponent):
r"""
- Voigt profile, a convolution of Gaussian and Lorentzian.
+ Voigt profile — convolution of Gaussian and Lorentzian.
+
+ Uses ``scipy.special.voigt_profile`` to evaluate the profile. area has unit = x_unit * y_unit;
+ center, gaussian_width, and lorentzian_width have unit = x_unit.
If the center is not provided, it will be centered at 0 and fixed, which is typically what you
want in QENS.
- Use scipy.special.voigt_profile to evaluate the Voigt profile.
-
Examples
--------
**Creating a Voigt profile with a fixed center (typical QENS use)**
@@ -60,7 +61,8 @@ def __init__(
center: Numeric | Parameter | None = None,
gaussian_width: Numeric | Parameter = 1.0,
lorentzian_width: Numeric | Parameter = 1.0,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'Voigt',
display_name: str | None = None,
unique_name: str | None = None,
@@ -71,54 +73,52 @@ def __init__(
Parameters
----------
area : Numeric | Parameter, default=1.0
- Total area under the curve.
+ Integrated area under the Voigt profile. Unit is ``x_unit * y_unit``.
center : Numeric | Parameter | None, default=None
- Center of the Voigt profile.
+ Peak position in x_unit. If None, defaults to 0 and the center parameter is fixed.
gaussian_width : Numeric | Parameter, default=1.0
- Standard deviation of the Gaussian part.
+ Gaussian component standard deviation (sigma) in x_unit. Must be strictly positive.
lorentzian_width : Numeric | Parameter, default=1.0
- Half width at half max (HWHM) of the Lorentzian part.
- unit : str | sc.Unit, default='meV'
- Unit of the parameters.
+ Lorentzian component HWHM (gamma) in x_unit. Must be strictly positive.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis. center, gaussian_width, and lorentzian_width are stored in this
+ unit. area_unit = x_unit * y_unit.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the y-axis (output).
name : str, default='Voigt'
- Name of the component for indexing.
+ Name of the component.
display_name : str | None, default=None
- Display name of the component.
+ Display name shown when plotting. Falls back to *name* if None.
unique_name : str | None, default=None
- Unique name of the component. If None, a unique_name is automatically generated. By
- default, None.
+ Globally unique identifier. Auto-generated if None.
"""
-
super().__init__(
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
name=name,
display_name=display_name,
unique_name=unique_name,
)
- # These methods live in ValidationMixin
- area = self._create_area_parameter(area=area, name=name, unit=self._unit)
- center = self._create_center_parameter(
- center=center, name=name, fix_if_none=True, unit=self._unit
+ self._area = self._create_area_parameter(
+ area=area, name=name, x_unit=self.x_unit, y_unit=self.y_unit
)
- gaussian_width = self._create_width_parameter(
+ self._center = self._create_center_parameter(
+ center=center, name=name, fix_if_none=True, x_unit=self.x_unit
+ )
+ self._gaussian_width = self._create_width_parameter(
width=gaussian_width,
name=name,
param_name='gaussian_width',
- unit=self._unit,
+ x_unit=self.x_unit,
)
- lorentzian_width = self._create_width_parameter(
+ self._lorentzian_width = self._create_width_parameter(
width=lorentzian_width,
name=name,
param_name='lorentzian_width',
- unit=self._unit,
+ x_unit=self.x_unit,
)
- self._area = area
- self._center = center
- self._gaussian_width = gaussian_width
- self._lorentzian_width = lorentzian_width
-
@property
def area(self) -> Parameter:
"""
@@ -127,24 +127,22 @@ def area(self) -> Parameter:
Returns
-------
Parameter
- The area parameter.
+ The area Parameter with unit ``x_unit * y_unit``.
"""
return self._area
@area.setter
def area(self, value: Numeric) -> None:
"""
- Set the value of the area parameter.
-
Parameters
----------
value : Numeric
- The new value for the area parameter.
+ New area value (in current area unit = x_unit * y_unit).
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
"""
if not isinstance(value, Numeric):
raise TypeError('area must be a number')
@@ -158,24 +156,23 @@ def center(self) -> Parameter:
Returns
-------
Parameter
- The center parameter.
+ The center Parameter with unit ``x_unit``.
"""
return self._center
@center.setter
def center(self, value: Numeric | None) -> None:
"""
- Set the value of the center parameter.
-
Parameters
----------
value : Numeric | None
- The new value for the center parameter. If None, defaults to 0 and is fixed.
+ New center value in x_unit. If None, the center is set to 0 and the parameter is
+ fixed.
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not None and not a numeric type.
"""
if value is None:
value = 0.0
@@ -187,31 +184,30 @@ def center(self, value: Numeric | None) -> None:
@property
def gaussian_width(self) -> Parameter:
"""
- Get the Gaussian width parameter.
+ Get the Gaussian width parameter (sigma).
Returns
-------
Parameter
- The Gaussian width parameter.
+ The Gaussian component width (sigma) Parameter with unit ``x_unit``.
"""
return self._gaussian_width
@gaussian_width.setter
def gaussian_width(self, value: Numeric) -> None:
"""
- Set the width parameter value.
-
+ Set the gaussian width parameter value.
Parameters
----------
value : Numeric
- The new value for the width parameter.
+ New Gaussian width (sigma) in x_unit. Must be strictly positive.
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
ValueError
- If the value is not positive.
+ If *value* is not positive.
"""
if not isinstance(value, Numeric):
raise TypeError('gaussian_width must be a number')
@@ -227,7 +223,7 @@ def lorentzian_width(self) -> Parameter:
Returns
-------
Parameter
- The Lorentzian width parameter.
+ The Lorentzian component HWHM (gamma) Parameter with unit ``x_unit``.
"""
return self._lorentzian_width
@@ -235,18 +231,17 @@ def lorentzian_width(self) -> Parameter:
def lorentzian_width(self, value: Numeric) -> None:
"""
Set the value of the Lorentzian width parameter.
-
Parameters
----------
value : Numeric
- The new value for the Lorentzian width parameter.
+ New Lorentzian HWHM (gamma) in x_unit. Must be strictly positive.
Raises
------
TypeError
- If the value is not a number.
+ If *value* is not a numeric type.
ValueError
- If the value is not positive.
+ If *value* is not positive.
"""
if not isinstance(value, Numeric):
raise TypeError('lorentzian_width must be a number')
@@ -254,33 +249,60 @@ def lorentzian_width(self, value: Numeric) -> None:
raise ValueError('lorentzian_width must be positive')
self._lorentzian_width.value = value
- def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray:
- r"""
- Evaluate the Voigt at the given x values.
+ def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray:
+ """
+ Evaluate the Voigt at x_vals.
- If x is a scipp Variable, the unit of the Voigt will be converted to match x. The Voigt
- evaluates to the convolution of a Gaussian with sigma gaussian_width and a Lorentzian with
- half width at half max lorentzian_width, centered at center, with area equal to area.
+ Parameters in the model's own units are temporarily converted to eval_unit for the
+ computation.
Parameters
----------
- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- The x values at which to evaluate the Voigt.
+ x_vals : np.ndarray
+ Raw x values expressed in eval_unit.
+ eval_unit : str | None
+ The unit of x_vals.
Returns
-------
np.ndarray
- The intensity of the Voigt at the given x values.
+ Evaluated Voigt profile values at x_vals.
"""
+ center = self._resolve_param_value(self._center, eval_unit)
+ gw = self._resolve_param_value(self._gaussian_width, eval_unit)
+ lw = self._resolve_param_value(self._lorentzian_width, eval_unit)
+ area = self._resolve_param_value(self._area, self._eval_area_unit(eval_unit))
+
+ return area * voigt_profile(x_vals - center, gw, lw)
- x = self._prepare_x_for_evaluate(x)
+ def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None:
+ """
+ Convert x-axis parameters (center, widths) and area to new_x_unit.
- return self.area.value * voigt_profile(
- x - self.center.value,
- self.gaussian_width.value,
- self.lorentzian_width.value,
+ Parameters
+ ----------
+ new_x_unit : str | sc.Unit
+ Target x-axis unit. Must be dimensionally compatible with the current x_unit.
+ """
+ self._convert_x_unit_area_based(
+ new_x_unit=new_x_unit,
+ x_params=[self._center, self._gaussian_width, self._lorentzian_width],
+ area_param=self._area,
)
+ def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None:
+ """
+ Convert the y-axis unit by rescaling the area parameter.
+
+ The area is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``.
+
+ Parameters
+ ----------
+ new_y_unit : str | sc.Unit
+ Target y-axis unit.
+ """
+ self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._area)
+
def __repr__(self) -> str:
"""
Return a string representation of the Voigt.
@@ -290,12 +312,11 @@ def __repr__(self) -> str:
str
A string representation of the Voigt.
"""
-
return (
- f'{self.__class__.__name__}('
- f'name={self.name!r}, display_name={self.display_name!r}, unit={self._unit},\n'
- f' area={self.area},\n'
- f' center={self.center},\n'
- f' gaussian_width={self.gaussian_width},\n'
- f' lorentzian_width={self.lorentzian_width})'
+ f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, '
+ f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n'
+ f' area = {self.area},\n'
+ f' center = {self.center},\n'
+ f' gaussian_width = {self.gaussian_width},\n'
+ f' lorentzian_width = {self.lorentzian_width})'
)
diff --git a/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py b/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py
index 7255f88a0..b15d00120 100644
--- a/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py
+++ b/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py
@@ -52,7 +52,8 @@ def __init__(
scale: Numeric = 1.0,
diffusion_coefficient: Numeric = 1.0,
Q: Q_type | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'BrownianTranslationalDiffusion',
display_name: str | None = 'BrownianTranslationalDiffusion',
lorentzian_name: str | None = None,
@@ -70,8 +71,10 @@ def __init__(
Diffusion coefficient D in m^2/s.
Q : Q_type | None, default=None
Q values for the model. If None, Q is not set.
- unit : str | sc.Unit, default='meV'
- Unit of the diffusion model. Must be convertible to meV.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis (energy/frequency). Must be convertible to meV.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the model output (intensity). Determines scale.unit = x_unit * y_unit.
name : str, default='BrownianTranslationalDiffusion'
Name of the diffusion model.
display_name : str | None, default='BrownianTranslationalDiffusion'
@@ -96,7 +99,8 @@ def __init__(
"""
super().__init__(
Q=Q,
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
scale=scale,
name=name,
display_name=display_name,
@@ -186,7 +190,7 @@ def calculate_width(self, Q: Q_type | None = None) -> np.ndarray:
Q = self._ensure_Q(Q)
unit_conversion_factor = self._hbar * self.diffusion_coefficient / (self._angstrom**2)
- unit_conversion_factor.convert_unit(self.unit)
+ unit_conversion_factor.convert_unit(self.x_unit)
return Q**2 * unit_conversion_factor.value
def calculate_EISF(self, Q: Q_type | None = None) -> np.ndarray:
@@ -240,11 +244,11 @@ def create_component_collections(
Lorentzian has a width given by $D*Q^2$ and an area given by the scale parameter
multiplied by the QISF (which is 1 for this model).
"""
- Q = self.Q
- if Q is None:
+ if self.Q is None:
self._component_collections = []
return self._component_collections
+ Q = self.Q.values
component_collection_list = [None] * len(Q)
# In more complex models, this is used to scale the area of the
# Lorentzians and the delta function.
@@ -257,31 +261,37 @@ def create_component_collections(
component_collection_list[i] = ComponentCollection(
name=f'{self.name}_Q{Q_value:.2f}',
display_name=f'{self.display_name}_Q{Q_value:.2f}',
- unit=self.unit,
+ x_unit=self.x_unit,
+ y_unit=self.y_unit,
)
lorentzian_component = Lorentzian(
name=self.lorentzian_name,
display_name=self.lorentzian_display_name,
- unit=self.unit,
+ x_unit=self.x_unit,
+ y_unit=self.y_unit,
)
# Make the width dependent on Q
dependency_expression = self._write_width_dependency_expression(Q[i])
dependency_map = self._write_width_dependency_map_expression()
- lorentzian_component.width.make_dependent_on(
- dependency_expression=dependency_expression,
- dependency_map=dependency_map,
- desired_unit=self.unit,
- )
-
- # Make the area dependent on Q
- area_dependency_map = self._write_area_dependency_map_expression()
- lorentzian_component.area.make_dependent_on(
- dependency_expression=self._write_area_dependency_expression(QISF[i]),
- dependency_map=area_dependency_map,
- )
+ # easyscience propagates inf bounds through arithmetic, producing inf/inf=nan
+ # as a transient intermediate. Python's min/max ignore nan so the final bounds
+ # are correct; suppress the spurious numpy RuntimeWarning.
+ with np.errstate(invalid='ignore'):
+ lorentzian_component.width.make_dependent_on(
+ dependency_expression=dependency_expression,
+ dependency_map=dependency_map,
+ desired_unit=self.x_unit,
+ )
+
+ # Make the area dependent on Q
+ area_dependency_map = self._write_area_dependency_map_expression()
+ lorentzian_component.area.make_dependent_on(
+ dependency_expression=self._write_area_dependency_expression(QISF[i]),
+ dependency_map=area_dependency_map,
+ )
component_collection_list[i].append_component(lorentzian_component)
@@ -339,43 +349,6 @@ def _write_width_dependency_map_expression(self) -> dict[str, DescriptorNumber]:
'angstrom': self._angstrom,
}
- def _write_area_dependency_expression(self, QISF: float) -> str:
- """
- Write the dependency expression for the area to make dependent Parameters.
-
- Parameters
- ----------
- QISF : float
- Quasielastic Incoherent Scattering Function.
-
- Raises
- ------
- TypeError
- If QISF is not a float.
-
- Returns
- -------
- str
- Dependency expression for the area.
- """
- if not isinstance(QISF, (float)):
- raise TypeError('QISF must be a float.')
-
- return f'{QISF} * scale'
-
- def _write_area_dependency_map_expression(self) -> dict[str, DescriptorNumber]:
- """
- Write the dependency map expression to make dependent Parameters.
-
- Returns
- -------
- dict[str, DescriptorNumber]
- Dependency map for the area.
- """
- return {
- 'scale': self.scale,
- }
-
# ------------------------------------------------------------------
# dunder methods
# ------------------------------------------------------------------
@@ -392,6 +365,7 @@ def __repr__(self) -> str:
return (
f'{self.__class__.__name__}('
f'name={self.name!r}, display_name={self.display_name!r},\n'
+ f'x_unit={self.x_unit}, y_unit={self.y_unit}, \n'
f' diffusion_coefficient={self.diffusion_coefficient},\n'
f' scale={self.scale})'
)
diff --git a/src/easydynamics/sample_model/diffusion_model/delta_lorentz.py b/src/easydynamics/sample_model/diffusion_model/delta_lorentz.py
index 1e709dbc4..12af54314 100644
--- a/src/easydynamics/sample_model/diffusion_model/delta_lorentz.py
+++ b/src/easydynamics/sample_model/diffusion_model/delta_lorentz.py
@@ -11,8 +11,11 @@
from easydynamics.sample_model.components import DeltaFunction
from easydynamics.sample_model.components import Lorentzian
from easydynamics.sample_model.diffusion_model.diffusion_model_base import DiffusionModelBase
+from easydynamics.utils.fit_target import FitTarget
from easydynamics.utils.utils import Numeric
from easydynamics.utils.utils import Q_type
+from easydynamics.utils.utils import angstrom
+from easydynamics.utils.utils import verify_Q_index
MINIMUM_WIDTH = 1e-10 # To avoid division by zero
@@ -63,7 +66,8 @@ def __init__(
lorentzian_width: Numeric = 1.0,
allow_Q_variation: dict | None = None,
Q: Q_type | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'DeltaLorentz',
display_name: str | None = None,
lorentzian_name: str = 'Lorentzian',
@@ -92,8 +96,10 @@ def __init__(
allowed.
Q : Q_type | None, default=None
Q values for the model. If None, Q is not set.
- unit : str | sc.Unit, default='meV'
- Unit of the diffusion model. Must be convertible to meV.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis (energy/frequency). Must be convertible to meV.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the model output (intensity). Determines scale.unit = x_unit * y_unit.
name : str, default='DeltaLorentz'
Name of the diffusion model.
display_name : str | None, default=None
@@ -120,7 +126,8 @@ def __init__(
"""
super().__init__(
scale=scale,
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
Q=Q,
lorentzian_name=lorentzian_name,
lorentzian_display_name=lorentzian_display_name,
@@ -134,6 +141,18 @@ def __init__(
# --------------------------------------------------------------
self._mean_u_squared = self._create_mean_u_squared_parameter(mean_u_squared)
+ # Dimensionless /angstrom^2 for use inside exp() in the area dependency
+ # expressions and the EISF/QISF calculations. Deriving it through the dependency
+ # graph keeps those quantities correct if mean_u_squared is converted to another
+ # unit (e.g. nm**2), which a raw `.value` would silently get wrong.
+ with np.errstate(invalid='ignore'):
+ self._mean_u_squared_over_angstrom_squared = Parameter.from_dependency(
+ name='mean_u_squared_over_angstrom_squared',
+ dependency_expression='mean_u_squared / angstrom**2',
+ dependency_map={'mean_u_squared': self._mean_u_squared, 'angstrom': angstrom},
+ )
+ self._mean_u_squared_over_angstrom_squared.set_desired_unit('dimensionless')
+
self._A_0, self._A_1 = self._create_A0_A1_parameters(A_0)
self._lorentzian_width = self._create_lorentzian_width_parameter(lorentzian_width)
@@ -398,8 +417,18 @@ def calculate_width(self, Q: Q_type = None) -> np.ndarray:
-------
np.ndarray
HWHM values in the unit of the model (e.g., meV).
+
+ Raises
+ ------
+ ValueError
+ If Q-variation is enabled but Q has not been set on the model yet.
"""
if self._allow_Q_variation['lorentzian_width'] is True:
+ if not self._lorentzian_width_list:
+ raise ValueError(
+ 'Lorentzian width Q-variation list is empty. '
+ 'Set Q before calling calculate_width.'
+ )
widths = [lorentzian_width.value for lorentzian_width in self._lorentzian_width_list]
return np.array(widths)
@@ -424,12 +453,13 @@ def calculate_EISF(self, Q: Q_type = None) -> np.ndarray:
EISF values (dimensionless).
"""
Q = self._ensure_Q(Q)
+ mean_u_squared = self._mean_u_squared_over_angstrom_squared.value
if self._allow_Q_variation['A_0'] is True:
A_0_values = [A_0_.value for A_0_ in self._A_0_list]
- return np.exp(-self.mean_u_squared.value * Q**2 / 3) * np.array(A_0_values)
+ return np.exp(-mean_u_squared * Q**2 / 3) * np.array(A_0_values)
A_0_values = [self.A_0.value] * len(Q)
- return np.exp(-self.mean_u_squared.value * Q**2 / 3) * np.array(A_0_values)
+ return np.exp(-mean_u_squared * Q**2 / 3) * np.array(A_0_values)
def calculate_QISF(self, Q: Q_type = None) -> np.ndarray:
"""
@@ -446,12 +476,13 @@ def calculate_QISF(self, Q: Q_type = None) -> np.ndarray:
QISF values (dimensionless).
"""
Q = self._ensure_Q(Q)
+ mean_u_squared = self._mean_u_squared_over_angstrom_squared.value
if self._allow_Q_variation['A_0'] is True:
A_1_values = [A_1_.value for A_1_ in self._A_1_list]
- return np.exp(-self.mean_u_squared.value * Q**2 / 3) * np.array(A_1_values)
+ return np.exp(-mean_u_squared * Q**2 / 3) * np.array(A_1_values)
A_1_values = [self.A_1.value] * len(Q)
- return np.exp(-self.mean_u_squared.value * Q**2 / 3) * np.array(A_1_values)
+ return np.exp(-mean_u_squared * Q**2 / 3) * np.array(A_1_values)
def create_component_collections(
self,
@@ -465,10 +496,11 @@ def create_component_collections(
List of ComponentCollections with Lorentzian and delta function components for each Q
value.
"""
- Q = self.Q
- if Q is None:
+ if self.Q is None:
return []
+ Q = self.Q.values
+
if self._allow_Q_variation['A_0'] is True:
A_0_list, A_1_list = self._create_A0_A1_parameter_lists(self.A_0)
self._A_0_list = A_0_list
@@ -484,19 +516,21 @@ def create_component_collections(
for i, Q_value in enumerate(Q):
component_collection_list[i] = ComponentCollection(
display_name=f'{self.display_name}_Q{Q_value:.2f}',
- unit=self.unit,
+ x_unit=self.x_unit,
+ y_unit=self.y_unit,
)
# ------------------------------#
# Create Lorentzian
# ------------------------------#
- lorentzian_component = Lorentzian(
+ lorz_component = Lorentzian(
name=self.lorentzian_name,
display_name=self.lorentzian_display_name,
- unit=self.unit,
+ x_unit=self.x_unit,
+ y_unit=self.y_unit,
)
if self._allow_Q_variation['lorentzian_width'] is True:
- lorentzian_component._width = self._lorentzian_width_list[i] # noqa: SLF001
+ lorz_component._width = self._lorentzian_width_list[i] # noqa: SLF001
# If the width is allowed to vary with Q it is independent.
# If the width is not allowed to vary with Q it must be made
@@ -504,10 +538,10 @@ def create_component_collections(
if self._allow_Q_variation['lorentzian_width'] is False:
dependency_map = self._write_width_dependency_map_expression()
- lorentzian_component.width.make_dependent_on(
+ lorz_component.width.make_dependent_on(
dependency_expression=self._write_lorz_width_dependency_expression(Q_value),
dependency_map=dependency_map,
- desired_unit=self.unit,
+ desired_unit=self.x_unit,
)
# The area is always a dependent parameter in this model, as
# it depends on the scale, mean_u_squared and A_1 parameters
@@ -520,12 +554,12 @@ def create_component_collections(
else:
dependency_map = self._write_lorz_area_dependency_map_expression(None)
- lorentzian_component.area.make_dependent_on(
+ lorz_component.area.make_dependent_on(
dependency_expression=self._write_lorz_area_dependency_expression(Q_value),
dependency_map=dependency_map,
)
- component_collection_list[i].append_component(lorentzian_component)
+ component_collection_list[i].append_component(lorz_component)
# ------------------------------#
# Create delta function
@@ -534,7 +568,8 @@ def create_component_collections(
delta_component = DeltaFunction(
name=self.delta_name,
display_name=self.delta_display_name,
- unit=self.unit,
+ x_unit=self.x_unit,
+ y_unit=self.y_unit,
)
if self._allow_Q_variation['A_0'] is True:
@@ -551,6 +586,32 @@ def create_component_collections(
return component_collection_list
+ def get_fit_targets(self) -> list[FitTarget]:
+ """
+ Get the fittable predictions of the DeltaLorentz model as FitTargets.
+
+ Extends the base ``'area'`` and ``'width'`` predictions with ``'delta_area'`` (``scale *
+ EISF(Q)``, the delta function's weight), whose default dataset key is derived from the
+ delta component's name.
+
+ Returns
+ -------
+ list[FitTarget]
+ The fittable predictions of this model.
+ """
+ targets = super().get_fit_targets()
+ targets.append(
+ FitTarget(
+ name='delta_area',
+ dataset_key=f'{self.delta_name} area',
+ function=lambda Q, model=self, **_: model.calculate_EISF(Q) * model.scale.value,
+ label=f'{self.display_name} delta_area',
+ x_unit='1/angstrom',
+ y_unit=str(self.scale.unit),
+ )
+ )
+ return targets
+
def get_global_variables(self) -> list[Parameter]:
"""
Get all global variables from the diffusion model.
@@ -587,22 +648,9 @@ def get_independent_variables(self, Q_index: int | None = None) -> list[Paramete
-------
list[Parameter]
List of independent variables in the model.
-
- Raises
- ------
- ValueError
- If Q_index is not None and is not a valid index for the Q values in the model.
"""
- if Q_index is not None and (
- not isinstance(Q_index, int)
- or Q_index < 0
- or Q_index >= len(self._component_collections)
- ):
- raise ValueError(
- f'Q_index must be an integer between 0 and '
- f'{len(self._component_collections) - 1}, or None.'
- )
+ verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True)
variables = []
if self._allow_Q_variation['A_0'] is True:
@@ -629,22 +677,9 @@ def get_all_variables(self, Q_index: int | None = None) -> list[DescriptorNumber
-------
list[DescriptorNumber]
List of all variables in the model.
-
- Raises
- ------
- ValueError
- If Q_index is not None and is not a valid index for the Q values in the model.
"""
- if Q_index is not None and (
- not isinstance(Q_index, int)
- or Q_index < 0
- or Q_index >= len(self._component_collections)
- ):
- raise ValueError(
- f'Q_index must be an integer between 0 and '
- f'{len(self._component_collections) - 1}, or None.'
- )
+ verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True)
variables = self.get_global_variables()
variables.extend(self.get_independent_variables(Q_index=Q_index))
@@ -814,7 +849,7 @@ def _create_lorentzian_width_parameter(self, lorentzian_width: Numeric) -> Param
value=float(lorentzian_width),
fixed=False,
min=MINIMUM_WIDTH,
- unit=self.unit,
+ unit=self.x_unit,
)
def _create_A0_A1_parameter_lists(
@@ -837,7 +872,7 @@ def _create_A0_A1_parameter_lists(
"""
A_0_list = []
A_1_list = []
- for _ in self.Q:
+ for _ in range(len(self.Q)):
a0 = Parameter(
name='A_0',
value=float(A_0.value),
@@ -881,9 +916,9 @@ def _create_lorentzian_width_parameter_list(
value=float(lorentzian_width.value),
fixed=False,
min=MINIMUM_WIDTH,
- unit=self.unit,
+ unit=self.x_unit,
)
- for _ in self.Q
+ for _ in range(len(self.Q))
]
# ------------------------------------------------------------------
@@ -914,6 +949,20 @@ def _on_Q_change(self) -> None:
self._lorentzian_width_list = []
self._component_collections = self.create_component_collections()
+ def _convert_extra_x_unit_parameters(self, unit_str: str) -> None:
+ """
+ Convert the Lorentzian width template to the new x-axis unit.
+
+ The per-Q width list (when Q-variation is enabled) holds the very Parameter objects used by
+ the components, so those are converted in place with the collections.
+
+ Parameters
+ ----------
+ unit_str : str
+ The new x-axis unit as a string.
+ """
+ self._lorentzian_width.convert_unit(unit_str)
+
def _write_lorz_width_dependency_expression(self, Q: float) -> str:
"""
Write the dependency expression for the width as a function of Q to make dependent
@@ -976,7 +1025,9 @@ def _write_lorz_area_dependency_expression(self, Q: float) -> str:
if not isinstance(Q, (float)):
raise TypeError('Q must be a float.')
- return f'scale * exp(-mean_u_squared.value * {Q}**2 / 3) * A_1'
+ # mean_u_squared_ratio is /angstrom^2, kept dimensionless through the dependency
+ # graph so the expression stays correct if mean_u_squared is converted to another unit.
+ return f'scale * exp(-mean_u_squared_ratio.value * {Q}**2 / 3) * A_1'
def _write_lorz_area_dependency_map_expression(
self, Q_index: int | None
@@ -998,13 +1049,13 @@ def _write_lorz_area_dependency_map_expression(
if Q_index is None:
return {
'scale': self.scale,
- 'mean_u_squared': self.mean_u_squared,
+ 'mean_u_squared_ratio': self._mean_u_squared_over_angstrom_squared,
'A_1': self.A_1,
}
return {
'scale': self.scale,
- 'mean_u_squared': self.mean_u_squared,
+ 'mean_u_squared_ratio': self._mean_u_squared_over_angstrom_squared,
'A_1': self._A_1_list[Q_index],
}
@@ -1030,7 +1081,9 @@ def _write_delta_area_dependency_expression(self, Q: float) -> str:
if not isinstance(Q, (float)):
raise TypeError('Q must be a float.')
- return f'scale * exp(-mean_u_squared.value * {Q}**2 / 3) * A_0'
+ # mean_u_squared_ratio is /angstrom^2, kept dimensionless through the dependency
+ # graph so the expression stays correct if mean_u_squared is converted to another unit.
+ return f'scale * exp(-mean_u_squared_ratio.value * {Q}**2 / 3) * A_0'
def _write_delta_area_dependency_map_expression(
self,
@@ -1053,12 +1106,12 @@ def _write_delta_area_dependency_map_expression(
if Q_index is None:
return {
'scale': self.scale,
- 'mean_u_squared': self.mean_u_squared,
+ 'mean_u_squared_ratio': self._mean_u_squared_over_angstrom_squared,
'A_0': self.A_0,
}
return {
'scale': self.scale,
- 'mean_u_squared': self.mean_u_squared,
+ 'mean_u_squared_ratio': self._mean_u_squared_over_angstrom_squared,
'A_0': self._A_0_list[Q_index],
}
@@ -1076,10 +1129,10 @@ def __repr__(self) -> str:
String representation of the DeltaLorentz model.
"""
return (
- f'{self.__class__.__name__}('
- f'display_name={self.display_name!r}, unit={self.unit},\n'
- f' mean_u_squared={self.mean_u_squared},\n'
- f' A_0={self.A_0}, A_1={self.A_1},\n'
- f' lorentzian_width={self.lorentzian_width},\n'
+ f'DeltaLorentz(display_name={self.display_name}, '
+ f'x_unit={self.x_unit}, y_unit={self.y_unit}, \n'
+ f' mean_u_squared={self.mean_u_squared}, \n'
+ f' A_0={self.A_0}, A_1={self.A_1}, \n'
+ f' lorentzian_width={self.lorentzian_width}, \n'
f' scale={self.scale})'
)
diff --git a/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py b/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py
index 7d752b826..da6533f07 100644
--- a/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py
+++ b/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: 2026 EasyScience contributors
# SPDX-License-Identifier: BSD-3-Clause
+import contextlib
+
import numpy as np
import scipp as sc
from easyscience.variable import DescriptorNumber
@@ -9,9 +11,11 @@
from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase
from easydynamics.sample_model.component_collection import ComponentCollection
+from easydynamics.utils.fit_target import FitTarget
from easydynamics.utils.utils import Numeric
from easydynamics.utils.utils import Q_type
from easydynamics.utils.utils import _validate_and_convert_Q
+from easydynamics.utils.utils import verify_Q_index
class DiffusionModelBase(EasyDynamicsModelBase):
@@ -21,7 +25,8 @@ def __init__(
self,
scale: Numeric = 1.0,
Q: Q_type | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'DiffusionModel',
display_name: str | None = 'DiffusionModel',
lorentzian_name: str | None = None,
@@ -31,14 +36,20 @@ def __init__(
"""
Initialize a new DiffusionModel.
+ Unit validation raises ``UnitError`` if x_unit is not a string or scipp Unit, or if it
+ cannot be converted to meV.
+
Parameters
----------
scale : Numeric, default=1.0
- Scale factor for the diffusion model. Must be a non-negative number.
+ Scale factor for the diffusion model. Must be a non-negative number. Its unit equals
+ area_unit = x_unit * y_unit because scale * QISF/EISF (dimensionless) = component area.
Q : Q_type | None, default=None
Q values for the model. If None, Q is not set.
- unit : str | sc.Unit, default='meV'
- Unit of the diffusion model. Must be convertible to meV.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis (energy/frequency). Must be convertible to meV.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the model output (intensity). Together with x_unit determines area_unit.
name : str, default='DiffusionModel'
Name of the diffusion model.
display_name : str | None, default='DiffusionModel'
@@ -57,21 +68,13 @@ def __init__(
------
TypeError
If scale is not a number.
- UnitError
- If unit is not a string or scipp Unit, or if it cannot be converted to meV.
ValueError
If scale is negative.
"""
self._Q = _validate_and_convert_Q(Q)
- try:
- test = DescriptorNumber(name='test', value=1, unit=unit)
- test.convert_unit('meV')
- except Exception as e:
- raise UnitError(
- f'Invalid unit: {unit}. Unit must be a string or scipp Unit and convertible to meV.' # noqa: E501
- ) from e
+ self._assert_convertible_to_mev(x_unit)
if not isinstance(scale, Numeric):
raise TypeError('scale must be a number.')
@@ -79,10 +82,18 @@ def __init__(
if float(scale) < 0:
raise ValueError('scale must be non-negative.')
- scale = Parameter(name='scale', value=float(scale), fixed=False, min=0.0, unit=unit)
- self._scale = scale
+ area_unit = str(sc.Unit(str(x_unit)) * sc.Unit(str(y_unit)))
+ self._scale = Parameter(
+ name='scale', value=float(scale), fixed=False, min=0.0, unit=area_unit
+ )
- super().__init__(unit=unit, name=name, display_name=display_name, unique_name=unique_name)
+ super().__init__(
+ x_unit=x_unit,
+ y_unit=y_unit,
+ name=name,
+ display_name=display_name,
+ unique_name=unique_name,
+ )
if lorentzian_name is None:
lorentzian_name = name
@@ -102,7 +113,10 @@ def __init__(
if self.Q is None:
self._component_collections = []
else:
- self._component_collections = [ComponentCollection()] * len(self.Q)
+ self._component_collections = [
+ ComponentCollection(x_unit=self.x_unit, y_unit=self.y_unit)
+ for _ in range(len(self.Q))
+ ]
# ------------------------------------------------------------------
# Properties
@@ -145,14 +159,14 @@ def scale(self, scale: Numeric) -> None:
self._scale.value = float(scale)
@property
- def Q(self) -> np.ndarray | None:
+ def Q(self) -> sc.Variable | None:
"""
Get the Q values of the SampleModel.
Returns
-------
- np.ndarray | None
- The Q values of the SampleModel, or None if not set.
+ sc.Variable | None
+ The Q values of the SampleModel in 1/angstrom, or None if not set.
"""
return self._Q
@@ -184,7 +198,7 @@ def Q(self, value: Q_type | None) -> None:
self._on_Q_change()
return
- if len(old_Q) != len(new_Q) or not np.allclose(old_Q, new_Q):
+ if len(old_Q) != len(new_Q) or not sc.allclose(old_Q, new_Q):
raise ValueError(
'New Q values are not similar to the old ones. '
'To change Q values, first run clear_Q().'
@@ -274,6 +288,177 @@ def clear_Q(self, confirm: bool = False) -> None:
self._Q = None
self._on_Q_change()
+ # ------------------------------------------------------------------
+ # Unit conversion
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def _assert_convertible_to_mev(unit: str | sc.Unit) -> None:
+ """
+ Assert that the given unit is an energy unit (convertible to meV).
+
+ Frequency axes such as 1/ps are not supported for now.
+
+ Parameters
+ ----------
+ unit : str | sc.Unit
+ The unit to validate.
+
+ Raises
+ ------
+ UnitError
+ If unit is not a string or scipp Unit, or if it cannot be converted to meV.
+ """
+ try:
+ test = DescriptorNumber(name='test', value=1, unit=str(unit))
+ test.convert_unit('meV')
+ except Exception as e:
+ raise UnitError(
+ f'Invalid unit: {unit}. Unit must be a string or scipp Unit and convertible to meV.' # noqa: E501
+ ) from e
+
+ def convert_x_unit(self, unit: str | sc.Unit) -> None:
+ """
+ Convert the x-axis unit of the diffusion model.
+
+ Converts the scale parameter (unit ``x_unit * y_unit``), any subclass-specific x-unit
+ parameters, and the existing component collections in place — parameter values and object
+ identity are preserved, and nothing is scheduled for regeneration. Only energy units are
+ supported (the unit must be convertible to meV).
+
+ Unit validation raises ``UnitError`` when the unit is not convertible to meV.
+
+ Parameters
+ ----------
+ unit : str | sc.Unit
+ The new x-axis unit.
+
+ Raises
+ ------
+ TypeError
+ If unit is not a string or sc.Unit.
+ Exception
+ If any conversion fails; the already-converted state is rolled back best-effort before
+ re-raising.
+ """
+ if not isinstance(unit, (str, sc.Unit)):
+ raise TypeError(f'x_unit must be a string or sc.Unit, got {type(unit).__name__}')
+ unit_str = str(unit)
+ self._assert_convertible_to_mev(unit_str)
+
+ old_x_unit = str(self.x_unit)
+ new_scale_unit = str(sc.Unit(unit_str) * sc.Unit(str(self.y_unit)))
+ self._scale.convert_unit(new_scale_unit)
+ try:
+ self._convert_extra_x_unit_parameters(unit_str)
+ for collection in self._component_collections:
+ collection.convert_x_unit(unit_str)
+ except Exception:
+ old_scale_unit = str(sc.Unit(old_x_unit) * sc.Unit(str(self.y_unit)))
+ with contextlib.suppress(Exception):
+ self._scale.convert_unit(old_scale_unit)
+ with contextlib.suppress(Exception):
+ self._convert_extra_x_unit_parameters(old_x_unit)
+ for collection in self._component_collections:
+ with contextlib.suppress(Exception):
+ collection.convert_x_unit(old_x_unit)
+ raise
+
+ self._x_unit = unit_str
+
+ def convert_y_unit(self, unit: str | sc.Unit) -> None:
+ """
+ Convert the y-axis unit of the diffusion model.
+
+ Converts the scale parameter from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit`` and
+ the existing component collections in place — parameter values and object identity are
+ preserved, and nothing is scheduled for regeneration. The new y-unit must be dimensionally
+ compatible with the current one; the scale conversion raises ``UnitError`` otherwise.
+
+ Parameters
+ ----------
+ unit : str | sc.Unit
+ The new y-axis unit.
+
+ Raises
+ ------
+ TypeError
+ If unit is not a string or sc.Unit.
+ Exception
+ If any conversion fails; the already-converted state is rolled back best-effort before
+ re-raising.
+ """
+ if not isinstance(unit, (str, sc.Unit)):
+ raise TypeError(f'y_unit must be a string or sc.Unit, got {type(unit).__name__}')
+ unit_str = str(unit)
+
+ old_y_unit = str(self.y_unit)
+ new_scale_unit = str(sc.Unit(str(self.x_unit)) * sc.Unit(unit_str))
+ self._scale.convert_unit(new_scale_unit)
+ try:
+ for collection in self._component_collections:
+ collection.convert_y_unit(unit_str)
+ except Exception:
+ old_scale_unit = str(sc.Unit(str(self.x_unit)) * sc.Unit(old_y_unit))
+ with contextlib.suppress(Exception):
+ self._scale.convert_unit(old_scale_unit)
+ for collection in self._component_collections:
+ with contextlib.suppress(Exception):
+ collection.convert_y_unit(old_y_unit)
+ raise
+
+ self._y_unit = unit_str
+
+ def _convert_extra_x_unit_parameters(self, unit_str: str) -> None:
+ """
+ Convert subclass-specific parameters that carry the x-axis unit.
+
+ The base implementation does nothing; subclasses with x-unit parameters (e.g. a Lorentzian
+ width template) override this method.
+
+ Parameters
+ ----------
+ unit_str : str
+ The new x-axis unit as a string.
+ """
+
+ # ------------------------------------------------------------------
+ # Fit targets
+ # ------------------------------------------------------------------
+
+ def get_fit_targets(self) -> list[FitTarget]:
+ """
+ Get the fittable predictions of the diffusion model as FitTargets.
+
+ The base implementation declares ``'area'`` (``scale * QISF(Q)``) and ``'width'`` (the HWHM
+ ``Gamma(Q)``), with default dataset keys derived from the Lorentzian component's name.
+ Subclasses with additional predictions (e.g. a delta-function weight) extend this list. The
+ targets are snapshots: units and default keys reflect the model state at call time.
+
+ Returns
+ -------
+ list[FitTarget]
+ The fittable predictions of this model.
+ """
+ return [
+ FitTarget(
+ name='area',
+ dataset_key=f'{self.lorentzian_name} area',
+ function=lambda Q, model=self, **_: model.calculate_QISF(Q) * model.scale.value,
+ label=f'{self.display_name} area',
+ x_unit='1/angstrom',
+ y_unit=str(self.scale.unit),
+ ),
+ FitTarget(
+ name='width',
+ dataset_key=f'{self.lorentzian_name} width',
+ function=lambda Q, model=self, **_: model.calculate_width(Q),
+ label=f'{self.display_name} width',
+ x_unit='1/angstrom',
+ y_unit=str(self.x_unit),
+ ),
+ ]
+
# ------------------------------------------------------------------
# Methods
# ------------------------------------------------------------------
@@ -305,21 +490,8 @@ def get_independent_variables(self, Q_index: int | None = None) -> list[Paramete
-------
list[Parameter]
List of independent variables in the model.
-
- Raises
- ------
- ValueError
- If Q_index is not None and is not a valid index for the Q values in the model.
"""
- if Q_index is not None and (
- not isinstance(Q_index, int)
- or Q_index < 0
- or Q_index >= len(self._component_collections)
- ):
- raise ValueError(
- f'Q_index must be an integer between 0 and '
- f'{max(len(self._component_collections) - 1, 0)}, or None.'
- )
+ verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True)
return []
@@ -337,22 +509,8 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]:
-------
list[Parameter]
A list of all Parameters from the diffusion model.
-
- Raises
- ------
- ValueError
- If Q_index is out of bounds for the number of ComponentCollections.
"""
-
- if Q_index is not None and (
- not isinstance(Q_index, int)
- or Q_index < 0
- or Q_index >= len(self._component_collections)
- ):
- raise ValueError(
- f'Q_index must be an integer between 0 and '
- f'{max(len(self._component_collections) - 1, 0)}, or None.'
- )
+ verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True)
variables = self.get_global_variables()
variables.extend(self.get_independent_variables(Q_index))
@@ -448,7 +606,9 @@ def create_component_collections(self) -> list[ComponentCollection]:
self._component_collections = []
return self._component_collections
- self._component_collections = [ComponentCollection()] * len(self.Q)
+ self._component_collections = [
+ ComponentCollection(x_unit=self.x_unit, y_unit=self.y_unit) for _ in range(len(self.Q))
+ ]
return self._component_collections
@@ -464,43 +624,64 @@ def get_component_collections(
The index of the desired ComponentCollection. If None, all ComponentCollections are
returned.
- Raises
- ------
- TypeError
- If Q_index is not an int.
- IndexError
- If Q_index is out of bounds for the number of ComponentCollections.
-
Returns
-------
ComponentCollection | list[ComponentCollection]
The ComponentCollection at the specified Q index. If Q_index is None, a list of all
ComponentCollections is returned.
"""
+ verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True)
if Q_index is None:
return self._component_collections
- if not isinstance(Q_index, int):
- raise TypeError(f'Q_index must be an int, got {type(Q_index).__name__}')
- if Q_index < 0 or Q_index >= len(self._component_collections):
- raise IndexError(
- f'Q_index {Q_index} is out of bounds for component collections '
- f'of length {len(self._component_collections)}'
- )
return self._component_collections[Q_index]
# ------------------------------------------------------------------
# private methods
# ------------------------------------------------------------------
+ def _write_area_dependency_expression(self, QISF: float) -> str:
+ """
+ Write the dependency expression for the Lorentzian area.
+
+ Parameters
+ ----------
+ QISF : float
+ Q-dependent incoherent scattering function value.
+
+ Raises
+ ------
+ TypeError
+ If QISF is not a float.
+
+ Returns
+ -------
+ str
+ Dependency expression for the area.
+ """
+ if not isinstance(QISF, float):
+ raise TypeError('QISF must be a float.')
+ return f'{QISF} * scale'
+
+ def _write_area_dependency_map_expression(self) -> dict[str, DescriptorNumber]:
+ """
+ Write the dependency map for the Lorentzian area.
+
+ Returns
+ -------
+ dict[str, DescriptorNumber]
+ Dependency map for the area.
+ """
+ return {'scale': self.scale}
+
def _on_Q_change(self) -> None:
"""Handle changes to the Q values."""
self.create_component_collections()
def _ensure_Q(self, Q: Q_type) -> np.ndarray:
"""
- Convert Q to a numpy array, ensuring it is not None. Uses the stored Q if no input is
- given.
+ Convert Q to a numpy array of values in 1/angstrom, ensuring it is not None. Uses the
+ stored Q if no input is given.
Parameters
----------
@@ -522,7 +703,7 @@ def _ensure_Q(self, Q: Q_type) -> np.ndarray:
if Q is None:
raise ValueError('Q must be provided either as an argument or set in the model.')
- return _validate_and_convert_Q(Q)
+ return _validate_and_convert_Q(Q).values
# ------------------------------------------------------------------
# dunder methods
@@ -538,8 +719,7 @@ def __repr__(self) -> str:
String representation of the DiffusionModel.
"""
return (
- f'{self.__class__.__name__}('
- f'name={self.name!r}, display_name={self.display_name!r}, '
- f'unit={self.unit},\n'
+ f'{self.__class__.__name__}(name={self.name}, display_name={self.display_name}, '
+ f'x_unit={self.x_unit}, y_unit={self.y_unit}, \n'
f' scale={self.scale})'
)
diff --git a/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py b/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py
index 0cd8a8284..67a693205 100644
--- a/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py
+++ b/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py
@@ -56,7 +56,8 @@ def __init__(
diffusion_coefficient: Numeric = 1.0,
relaxation_time: Numeric = 1.0,
Q: Q_type | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
name: str = 'JumpTranslationalDiffusion',
display_name: str | None = 'JumpTranslationalDiffusion',
lorentzian_name: str | None = None,
@@ -76,8 +77,10 @@ def __init__(
Relaxation time t in ps.
Q : Q_type | None, default=None
Q values for the model. If None, Q is not set.
- unit : str | sc.Unit, default='meV'
- Unit of the diffusion model. Must be convertible to meV.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis (energy/frequency). Must be convertible to meV.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the model output (intensity). Determines scale.unit = x_unit * y_unit.
name : str, default='JumpTranslationalDiffusion'
Name of the diffusion model.
display_name : str | None, default='JumpTranslationalDiffusion'
@@ -101,7 +104,8 @@ def __init__(
"""
super().__init__(
Q=Q,
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
scale=scale,
name=name,
display_name=display_name,
@@ -246,7 +250,7 @@ def calculate_width(self, Q: Q_type | None = None) -> np.ndarray:
unit_conversion_factor_numerator = (
self._hbar * self.diffusion_coefficient / (self._angstrom**2)
)
- unit_conversion_factor_numerator.convert_unit(self.unit)
+ unit_conversion_factor_numerator.convert_unit(self.x_unit)
numerator = unit_conversion_factor_numerator.value * Q**2
@@ -306,11 +310,11 @@ def create_component_collections(
list[ComponentCollection]
List of ComponentCollections with Jump Diffusion Lorentzian components.
"""
- Q = self.Q
- if Q is None:
+ if self.Q is None:
self._component_collections = []
return self._component_collections
+ Q = self.Q.values
component_collection_list = [None] * len(Q)
# In more complex models, this is used to scale the area of the
# Lorentzians and the delta function.
@@ -323,31 +327,37 @@ def create_component_collections(
component_collection_list[i] = ComponentCollection(
name=f'{self.name}_Q{Q_value:.2f}',
display_name=f'{self.display_name}_Q{Q_value:.2f}',
- unit=self.unit,
+ x_unit=self.x_unit,
+ y_unit=self.y_unit,
)
lorentzian_component = Lorentzian(
name=self.lorentzian_name,
display_name=self.lorentzian_display_name,
- unit=self.unit,
+ x_unit=self.x_unit,
+ y_unit=self.y_unit,
)
# Make the width dependent on Q
dependency_expression = self._write_width_dependency_expression(Q[i])
dependency_map = self._write_width_dependency_map_expression()
- lorentzian_component.width.make_dependent_on(
- dependency_expression=dependency_expression,
- dependency_map=dependency_map,
- desired_unit=self.unit,
- )
-
- # Make the area dependent on Q
- area_dependency_map = self._write_area_dependency_map_expression()
- lorentzian_component.area.make_dependent_on(
- dependency_expression=self._write_area_dependency_expression(QISF[i]),
- dependency_map=area_dependency_map,
- )
+ # easyscience propagates inf bounds through arithmetic, producing inf/inf=nan
+ # as a transient intermediate. Python's min/max ignore nan so the final bounds
+ # are correct; suppress the spurious numpy RuntimeWarning.
+ with np.errstate(invalid='ignore'):
+ lorentzian_component.width.make_dependent_on(
+ dependency_expression=dependency_expression,
+ dependency_map=dependency_map,
+ desired_unit=self.x_unit,
+ )
+
+ # Make the area dependent on Q
+ area_dependency_map = self._write_area_dependency_map_expression()
+ lorentzian_component.area.make_dependent_on(
+ dependency_expression=self._write_area_dependency_expression(QISF[i]),
+ dependency_map=area_dependency_map,
+ )
component_collection_list[i].append_component(lorentzian_component)
@@ -405,44 +415,6 @@ def _write_width_dependency_map_expression(self) -> dict[str, DescriptorNumber]:
'angstrom': self._angstrom,
}
- def _write_area_dependency_expression(self, QISF: float) -> str:
- """
- Write the dependency expression for the area to make dependent Parameters.
-
- Parameters
- ----------
- QISF : float
- Q-dependent intermediate scattering function.
-
- Raises
- ------
- TypeError
- If QISF is not a float.
-
- Returns
- -------
- str
- Dependency expression for the area.
- """
-
- if not isinstance(QISF, (float)):
- raise TypeError('QISF must be a float.')
-
- return f'{QISF} * scale'
-
- def _write_area_dependency_map_expression(self) -> dict[str, DescriptorNumber]:
- """
- Write the dependency map expression to make dependent Parameters.
-
- Returns
- -------
- dict[str, DescriptorNumber]
- Dependency map for the area.
- """
- return {
- 'scale': self.scale,
- }
-
################################
# dunder methods
################################
@@ -459,6 +431,7 @@ def __repr__(self) -> str:
return (
f'{self.__class__.__name__}('
f'name={self.name!r}, display_name={self.display_name!r},\n'
+ f'x_unit={self.x_unit}, y_unit={self.y_unit}, \n'
f' diffusion_coefficient={self.diffusion_coefficient},\n'
f' scale={self.scale})'
)
diff --git a/src/easydynamics/sample_model/instrument_model.py b/src/easydynamics/sample_model/instrument_model.py
index 24c3f1719..c6d731129 100644
--- a/src/easydynamics/sample_model/instrument_model.py
+++ b/src/easydynamics/sample_model/instrument_model.py
@@ -3,7 +3,6 @@
from copy import copy
-import numpy as np
import scipp as sc
from easyscience.base_classes.new_base import NewBase
from easyscience.variable import Parameter
@@ -15,6 +14,7 @@
from easydynamics.utils.utils import Q_type
from easydynamics.utils.utils import _validate_and_convert_Q
from easydynamics.utils.utils import _validate_unit
+from easydynamics.utils.utils import verify_Q_index
class InstrumentModel(NewBase):
@@ -60,7 +60,7 @@ def __init__(
resolution_model: ResolutionModel | SampleModel | None = None,
background_model: BackgroundModel | None = None,
energy_offset: Numeric | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
) -> None:
"""
Initialize an InstrumentModel.
@@ -83,7 +83,7 @@ def __init__(
energy_offset : Numeric | None, default=None
Template energy offset of the instrument. Will be copied to each Q value. If None, the
energy offset will be 0.
- unit : str | sc.Unit, default='meV'
+ x_unit : str | sc.Unit, default='meV'
The unit of the energy axis.
Raises
@@ -97,7 +97,7 @@ def __init__(
unique_name=unique_name,
)
- self._unit = _validate_unit(unit)
+ self._x_unit = _validate_unit(x_unit)
if resolution_model is None:
self._resolution_model = ResolutionModel()
@@ -130,7 +130,7 @@ def __init__(
self._energy_offset = Parameter(
name='energy_offset',
value=float(energy_offset),
- unit=self.unit,
+ unit=self.x_unit,
fixed=False,
)
self._energy_offsets: list = []
@@ -189,7 +189,6 @@ def background_model(self) -> BackgroundModel:
BackgroundModel
The background model of the instrument.
"""
-
return self._background_model
@background_model.setter
@@ -207,7 +206,6 @@ def background_model(self, value: BackgroundModel) -> None:
TypeError
If value is not a BackgroundModel.
"""
-
if not isinstance(value, BackgroundModel):
raise TypeError(
f'background_model must be a BackgroundModel, got {type(value).__name__}'
@@ -216,14 +214,14 @@ def background_model(self, value: BackgroundModel) -> None:
self._on_background_model_change()
@property
- def Q(self) -> np.ndarray | None:
+ def Q(self) -> sc.Variable | None:
"""
Get the Q values of the InstrumentModel.
Returns
-------
- np.ndarray | None
- The Q values of the InstrumentModel, or None if not set.
+ sc.Variable | None
+ The Q values of the InstrumentModel in 1/angstrom, or None if not set.
"""
return self._Q
@@ -256,68 +254,63 @@ def Q(self, value: Q_type | None) -> None:
self._on_Q_change()
return
- if len(old_Q) != len(new_Q) or not np.allclose(old_Q, new_Q):
+ if len(old_Q) != len(new_Q) or not sc.allclose(old_Q, new_Q):
raise ValueError(
'New Q values are not similar to the old ones. '
'To change Q values, first run clear_Q().'
)
@property
- def unit(self) -> str | sc.Unit:
+ def x_unit(self) -> str | sc.Unit | None:
"""
- Get the unit of the InstrumentModel.
+ Get the x-axis unit of the InstrumentModel.
Returns
-------
- str | sc.Unit
- The unit of the InstrumentModel.
+ str | sc.Unit | None
+ The x-axis unit of the InstrumentModel.
"""
- return self._unit
+ return self._x_unit
- @unit.setter
- def unit(self, _unit_str: str) -> None:
+ @x_unit.setter
+ def x_unit(self, _: str) -> None:
"""
- Set the unit of the InstrumentModel.
-
- The unit is read-only and cannot be set directly. Use convert_unit to change the unit
- between allowed types or create a new InstrumentModel with the desired unit.
+ x_unit is read-only and cannot be set directly.
- Parameters
- ----------
- _unit_str : str
- The new unit for the InstrumentModel (ignored).
+ Use convert_x_unit to change the unit between allowed types or create a new InstrumentModel
+ with the desired unit.
Raises
------
AttributeError
- Always, as the unit is read-only.
+ Always, as x_unit is read-only.
"""
raise AttributeError(
- f'Unit is read-only. Use convert_unit to change the unit between allowed types '
+ f'x_unit is read-only. Use convert_x_unit to change the unit between allowed types '
f'or create a new {self.__class__.__name__} with the desired unit.'
)
@property
def energy_offset(self) -> Parameter:
"""
- Get the energy offset template parameter of the instrument model.
+ Get the template energy offset of the instrument.
Returns
-------
Parameter
- The energy offset template parameter of the instrument model.
+ The energy offset Parameter. Each Q value gets its own copy via get_energy_offset().
"""
return self._energy_offset
@energy_offset.setter
def energy_offset(self, value: Numeric) -> None:
"""
- Set the offset parameter of the instrument model.
+ Set the template energy offset value, propagating to all Q-specific offsets.
Parameters
----------
value : Numeric
- The new value for the energy offset parameter. Will be copied to all Q values.
+ The new energy offset value in x_unit.
Raises
------
@@ -327,7 +320,6 @@ def energy_offset(self, value: Numeric) -> None:
if not isinstance(value, Numeric):
raise TypeError(f'energy_offset must be a number, got {type(value).__name__}')
self._energy_offset.value = value
-
self._on_energy_offset_change()
# --------------------------------------------------------------
@@ -358,32 +350,32 @@ def clear_Q(self, confirm: bool = False) -> None:
self.resolution_model.clear_Q(confirm=True)
self._on_Q_change()
- def convert_unit(self, unit_str: str | sc.Unit) -> None:
+ def convert_x_unit(self, x_unit: str | sc.Unit) -> None:
"""
Convert the unit of the InstrumentModel.
Parameters
----------
- unit_str : str | sc.Unit
+ x_unit : str | sc.Unit
The unit to convert to.
Raises
------
ValueError
- If unit_str is not a valid unit string or scipp Unit.
+ If x_unit is not a valid unit string or scipp Unit.
"""
- unit = _validate_unit(unit_str)
+ unit = _validate_unit(x_unit)
if unit is None:
- raise ValueError('unit_str must be a valid unit string or scipp Unit')
+ raise ValueError('x_unit must be a valid unit string or scipp Unit')
- self._background_model.convert_unit(unit)
- self._resolution_model.convert_unit(unit)
- self._energy_offset.convert_unit(unit)
+ self._background_model.convert_x_unit(unit)
+ self._resolution_model.convert_x_unit(unit)
self._ensure_energy_offsets_current()
+ self._energy_offset.convert_unit(unit)
for offset in self._energy_offsets:
offset.convert_unit(unit)
- self._unit = unit
+ self._x_unit = unit
def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]:
"""
@@ -395,13 +387,6 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]:
The index of the Q value to get variables for. If None, get variables for all Q values.
- Raises
- ------
- TypeError
- If Q_index is not an int or None.
- IndexError
- If Q_index is out of bounds for the Q values in the InstrumentModel.
-
Returns
-------
list[Parameter]
@@ -413,15 +398,10 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]:
return []
self._ensure_energy_offsets_current()
+ verify_Q_index(Q_index=Q_index, Q=self._Q, allow_none=True)
if Q_index is None:
variables = [self._energy_offsets[i] for i in range(len(self._Q))]
else:
- if not isinstance(Q_index, int):
- raise TypeError(f'Q_index must be an int or None, got {type(Q_index).__name__}')
- if Q_index < 0 or Q_index >= len(self._Q):
- raise IndexError(
- f'Q_index {Q_index} is out of bounds for Q of length {len(self._Q)}'
- )
variables = [self._energy_offsets[Q_index]]
variables.extend(self._background_model.get_all_variables(Q_index=Q_index))
@@ -458,10 +438,6 @@ def get_energy_offset(
------
ValueError
If no Q values are set in the InstrumentModel.
- IndexError
- If Q_index is out of bounds.
- TypeError
- If Q_index is not an int or None.
Returns
-------
@@ -473,15 +449,10 @@ def get_energy_offset(
raise ValueError('No Q values are set in the InstrumentModel.')
self._ensure_energy_offsets_current()
+ verify_Q_index(Q_index=Q_index, Q=self._Q, allow_none=True)
if Q_index is None:
return self._energy_offsets
- if not isinstance(Q_index, int):
- raise TypeError(f'Q_index must be an int or None, got {type(Q_index).__name__}')
-
- if Q_index < 0 or Q_index >= len(self._Q):
- raise IndexError(f'Q_index {Q_index} is out of bounds for Q of length {len(self._Q)}')
-
return self._energy_offsets[Q_index]
def fix_energy_offset(self, Q_index: int | None = None) -> None:
@@ -531,27 +502,14 @@ def _fix_or_free_energy_offset(self, Q_index: int | None = None, fixed: bool = T
energy offsets for all Q values.
fixed : bool, default=True
Whether to fix (True) or free (False) the energy offset.
-
- Raises
- ------
- TypeError
- If Q_index is not an int or None.
- IndexError
- If Q_index is out of bounds for the Q values in the InstrumentModel.
"""
-
self._ensure_energy_offsets_current()
+ verify_Q_index(Q_index=Q_index, Q=self._Q, allow_none=True)
if Q_index is None:
+ self._energy_offset.fixed = fixed
for offset in self._energy_offsets:
offset.fixed = fixed
else:
- if not isinstance(Q_index, int):
- raise TypeError(f'Q_index must be an int or None, got {type(Q_index).__name__}')
-
- if Q_index < 0 or Q_index >= len(self._Q):
- raise IndexError(
- f'Q_index {Q_index} is out of bounds for Q of length {len(self._Q)}'
- )
self._energy_offsets[Q_index].fixed = fixed
def _ensure_energy_offsets_current(self) -> None:
@@ -566,7 +524,7 @@ def _generate_energy_offsets(self) -> None:
self._energy_offsets = []
return
- self._energy_offsets = [copy(self._energy_offset) for _ in self._Q]
+ self._energy_offsets = [copy(self._energy_offset) for _ in range(len(self._Q))]
def _on_Q_change(self) -> None:
"""Handle changes to the Q values."""
@@ -601,11 +559,10 @@ def __repr__(self) -> str:
str
A string representation of the InstrumentModel.
"""
-
return (
f'{self.__class__.__name__}('
f'unique_name={self.unique_name!r}, '
- f'unit={self.unit}, '
+ f'x_unit={self.x_unit}, '
f'Q_len={None if self._Q is None else len(self._Q)}, '
f'resolution_model={self._resolution_model!r}, '
f'background_model={self._background_model!r})'
diff --git a/src/easydynamics/sample_model/model_base.py b/src/easydynamics/sample_model/model_base.py
index a2f1d7765..f2e13ae9d 100644
--- a/src/easydynamics/sample_model/model_base.py
+++ b/src/easydynamics/sample_model/model_base.py
@@ -13,6 +13,7 @@
from easydynamics.utils.utils import Numeric
from easydynamics.utils.utils import Q_type
from easydynamics.utils.utils import _validate_and_convert_Q
+from easydynamics.utils.utils import verify_Q_index
class ModelBase(EasyDynamicsModelBase):
@@ -26,7 +27,8 @@ def __init__(
self,
display_name: str = 'MyModelBase',
unique_name: str | None = None,
- unit: str | sc.Unit | None = 'meV',
+ x_unit: str | sc.Unit | None = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
components: ModelComponent | ComponentCollection | None = None,
Q: Q_type | None = None,
) -> None:
@@ -39,8 +41,10 @@ def __init__(
Display name of the model.
unique_name : str | None, default=None
Unique name of the model. If None, a unique name will be generated.
- unit : str | sc.Unit | None, default='meV'
- Unit of the model.
+ x_unit : str | sc.Unit | None, default='meV'
+ Unit of the x-axis (energy, Q, etc.).
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the model output (intensity).
components : ModelComponent | ComponentCollection | None, default=None
Template components of the model. If None, no components are added. These components
are copied into ComponentCollections for each Q value.
@@ -53,7 +57,8 @@ def __init__(
If components is not a ModelComponent or ComponentCollection.
"""
super().__init__(
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
display_name=display_name,
unique_name=unique_name,
)
@@ -74,17 +79,19 @@ def __init__(
self.append_component(components)
def evaluate(
- self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- ) -> list[np.ndarray]:
+ self,
+ x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray,
+ output: str = 'numpy',
+ ) -> list[np.ndarray] | list[sc.Variable]:
"""
Evaluate the sample model at all Q for the given x values.
Parameters
----------
x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- Energy axis values to evaluate the model at. If a scipp Variable or DataArray is
- provided, the unit of the model will be converted to match the unit of x for
- evaluation, and the result will be returned in the same unit as x.
+ Energy axis values to evaluate the model at.
+ output : str, default='numpy'
+ 'numpy' returns np.ndarray per Q; 'scipp' returns sc.Variable per Q.
Raises
------
@@ -93,15 +100,16 @@ def evaluate(
Returns
-------
- list[np.ndarray]
- A list of numpy arrays containing the evaluated model values for each Q. The length of
- the list will match the number of Q values in the model.
+ list[np.ndarray] | list[sc.Variable]
+ A list of arrays containing the evaluated model values for each Q. The length of the
+ list will match the number of Q values in the model.
"""
-
self._ensure_component_collections_current()
if not self._component_collections:
raise ValueError('No components in the model to evaluate.')
- return [collection.evaluate(x) for collection in self._component_collections]
+ return [
+ collection.evaluate(x, output=output) for collection in self._component_collections
+ ]
# ------------------------------------------------------------------
# Component management
@@ -186,14 +194,14 @@ def component_collections_is_dirty(self) -> bool:
return self._component_collections_is_dirty
@property
- def Q(self) -> np.ndarray | None:
+ def Q(self) -> sc.Variable | None:
"""
Get the Q values of the SampleModel.
Returns
-------
- np.ndarray | None
- The Q values of the SampleModel, or None if not set.
+ sc.Variable | None
+ The Q values of the SampleModel in 1/angstrom, or None if not set.
"""
return self._Q
@@ -225,7 +233,7 @@ def Q(self, value: Q_type | None) -> None:
self._on_Q_change()
return
- if len(old_Q) != len(new_Q) or not np.allclose(old_Q, new_Q):
+ if len(old_Q) != len(new_Q) or not sc.allclose(old_Q, new_Q):
raise ValueError(
'New Q values are not similar to the old ones. '
'To change Q values, first run clear_Q().'
@@ -257,14 +265,42 @@ def clear_Q(self, confirm: bool = False) -> None:
# Other methods
# ------------------------------------------------------------------
- def convert_unit(self, unit: str | sc.Unit) -> None:
+ def convert_x_unit(self, unit: str | sc.Unit) -> None:
+ """
+ Convert the x-axis unit of all components in the model.
+
+ Parameters
+ ----------
+ unit : str | sc.Unit
+ The new x-axis unit to convert to.
+ """
+ self._convert_axis_unit(unit, axis='x')
+
+ def convert_y_unit(self, unit: str | sc.Unit) -> None:
+ """
+ Convert the y-axis unit of all components in the model.
+
+ Parameters
+ ----------
+ unit : str | sc.Unit
+ The new y-axis unit to convert to.
+ """
+ self._convert_axis_unit(unit, axis='y')
+
+ def _convert_axis_unit(self, unit: str | sc.Unit, axis: str) -> None:
"""
- Convert the unit of the ComponentCollection and all its components.
+ Convert one axis unit on all template components and per-Q collections.
+
+ Converts every child via its ``convert__unit`` method and updates the model's own
+ unit attribute. On failure, attempts a best-effort rollback of all children to the old unit
+ before re-raising.
Parameters
----------
unit : str | sc.Unit
The new unit to convert to.
+ axis : str
+ Which axis to convert: ``'x'`` or ``'y'``.
Raises
------
@@ -273,24 +309,31 @@ def convert_unit(self, unit: str | sc.Unit) -> None:
Exception
If the provided unit is not compatible with the current unit.
"""
-
- old_unit = self._unit
-
if not isinstance(unit, (str, sc.Unit)):
raise TypeError(f'Unit must be a string or sc.Unit, got {type(unit).__name__}')
+
+ method = f'convert_{axis}_unit'
+ old_unit = self.x_unit if axis == 'x' else self.y_unit
try:
for component in self.components:
- component.convert_unit(unit)
- self._unit = unit
+ getattr(component, method)(unit)
+ for collection in self._component_collections:
+ getattr(collection, method)(unit)
+ unit_str = str(unit) if isinstance(unit, sc.Unit) else unit
+ if axis == 'x':
+ self._x_unit = unit_str
+ else:
+ self._y_unit = unit_str
except Exception as e:
- # Attempt to rollback on failure
- try:
- for component in self.components:
- component.convert_unit(old_unit)
- except Exception: # noqa: S110
- pass # Best effort rollback
+ if old_unit is not None:
+ try:
+ for component in self.components:
+ getattr(component, method)(old_unit)
+ for collection in self._component_collections:
+ getattr(collection, method)(old_unit)
+ except Exception: # noqa: S110
+ pass
raise e
- self._on_components_change()
def fix_all_parameters(self) -> None:
"""Fix all Parameters in all ComponentCollections."""
@@ -314,21 +357,14 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]:
If None, get variables for all ComponentCollections. If int, get variables for the
ComponentCollection at this index.
- Raises
- ------
- TypeError
- If Q_index is not an int or None.
- IndexError
- If Q_index is out of bounds for the number of ComponentCollections.
-
Returns
-------
list[Parameter]
A list of all Parameters and Descriptors from the ComponentCollections in the
ModelBase.
"""
-
self._ensure_component_collections_current()
+ verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True)
if Q_index is None:
all_vars = [
var
@@ -336,13 +372,6 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]:
for var in collection.get_all_variables()
]
else:
- if not isinstance(Q_index, int):
- raise TypeError(f'Q_index must be an int or None, got {type(Q_index).__name__}')
- if Q_index < 0 or Q_index >= len(self._component_collections):
- raise IndexError(
- f'Q_index {Q_index} is out of bounds for component collections '
- f'of length {len(self._component_collections)}'
- )
all_vars = self._component_collections[Q_index].get_all_variables()
return all_vars
@@ -355,26 +384,13 @@ def get_component_collection(self, Q_index: int) -> ComponentCollection:
Q_index : int
The index of the desired ComponentCollection.
- Raises
- ------
- TypeError
- If Q_index is not an int.
- IndexError
- If Q_index is out of bounds for the number of ComponentCollections.
-
Returns
-------
ComponentCollection
The ComponentCollection at the given Q index.
"""
self._ensure_component_collections_current()
- if not isinstance(Q_index, int):
- raise TypeError(f'Q_index must be an int, got {type(Q_index).__name__}')
- if Q_index < 0 or Q_index >= len(self._component_collections):
- raise IndexError(
- f'Q_index {Q_index} is out of bounds for component collections '
- f'of length {len(self._component_collections)}'
- )
+ verify_Q_index(Q_index=Q_index, Q=self.Q)
return self._component_collections[Q_index]
def normalize_area(self) -> None:
@@ -397,13 +413,12 @@ def _ensure_component_collections_current(self) -> None:
def _generate_component_collections(self) -> None:
"""Generate ComponentCollections for each Q value."""
-
if self.Q is None:
self._component_collections = []
return
self._component_collections = []
- for _ in self.Q:
+ for _ in range(len(self.Q)):
self._component_collections.append(copy(self._components))
def _on_Q_change(self) -> None:
@@ -430,7 +445,8 @@ def __repr__(self) -> str:
return (
f'{self.__class__.__name__}('
f'unique_name={self.unique_name!r}, '
- f'unit={self.unit}, '
- f'Q={self.Q}, '
+ f'x_unit={self.x_unit}, '
+ f'y_unit={self.y_unit}, '
+ f'Q={None if self.Q is None else self.Q.values}, '
f'components={self.components})'
)
diff --git a/src/easydynamics/sample_model/resolution_model.py b/src/easydynamics/sample_model/resolution_model.py
index ce5117ff7..a9fc9e1ee 100644
--- a/src/easydynamics/sample_model/resolution_model.py
+++ b/src/easydynamics/sample_model/resolution_model.py
@@ -51,12 +51,13 @@ def __init__(
self,
display_name: str = 'MyResolutionModel',
unique_name: str | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
components: ModelComponent | ComponentCollection | None = None,
Q: Q_type | None = None,
) -> None:
"""
- Initialize a ResolutionModel.
+ Initialize the ResolutionModel.
Parameters
----------
@@ -64,19 +65,20 @@ def __init__(
Display name of the model.
unique_name : str | None, default=None
Unique name of the model. If None, a unique name will be generated.
- unit : str | sc.Unit, default='meV'
- Unit of the model.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the y-axis (output).
components : ModelComponent | ComponentCollection | None, default=None
- Template components of the model. If None, no components are added. These components
- are copied into ComponentCollections for each Q value.
+ Template components. DeltaFunction, Polynomial, and Exponential are not allowed.
Q : Q_type | None, default=None
Q values for the model. If None, Q is not set.
"""
-
super().__init__(
display_name=display_name,
unique_name=unique_name,
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
components=components,
Q=Q,
)
@@ -85,8 +87,8 @@ def append_component(self, component: ModelComponent | ComponentCollection) -> N
"""
Append a component to the ResolutionModel.
- Does not allow DeltaFunction or Polynomial components, as these are not physical resolution
- components.
+ Does not allow DeltaFunction, Polynomial, or Exponential components, as these are not
+ physical resolution components.
Parameters
----------
@@ -96,7 +98,7 @@ def append_component(self, component: ModelComponent | ComponentCollection) -> N
Raises
------
TypeError
- If the component is a DeltaFunction or Polynomial.
+ If the component is a DeltaFunction, Polynomial, or Exponential.
"""
components = component if isinstance(component, ComponentCollection) else (component,)
@@ -151,22 +153,27 @@ def from_sample_model(
resolution_model = cls(
display_name=sample_model.display_name,
- unit=sample_model.unit,
+ x_unit=sample_model.x_unit,
+ y_unit=sample_model.y_unit,
components=sample_model.components,
Q=sample_model.Q,
)
if sample_model.Q is not None:
- resolution_model._ensure_component_collections_current()
- for index in range(len(sample_model.Q)):
- resolution_model._component_collections[index] = copy(
- sample_model.get_component_collection(Q_index=index)
- )
- if normalize_area:
- resolution_model.normalize_area()
-
- if fix_parameters:
- resolution_model.fix_all_parameters()
+ # Prepare the per-Q collections detached from the model so no EasyScience
+ # callback can schedule a rebuild halfway through, then install them and
+ # clear the dirty flag in one final step.
+ collections = [
+ copy(sample_model.get_component_collection(Q_index=index))
+ for index in range(len(sample_model.Q))
+ ]
+ for collection in collections:
+ if normalize_area:
+ collection.normalize_area()
+ if fix_parameters:
+ collection.fix_all_parameters()
+ resolution_model._component_collections = collections
+ resolution_model._component_collections_is_dirty = False
return resolution_model
@@ -174,7 +181,8 @@ def __repr__(self) -> str:
return (
f'{self.__class__.__name__}('
f'unique_name={self.unique_name!r}, '
- f'unit={self.unit}, '
+ f'x_unit={self.x_unit}, '
+ f'y_unit={self.y_unit}, '
f'Q_len={None if self._Q is None else len(self._Q)}, '
f'components={self.components})'
)
diff --git a/src/easydynamics/sample_model/sample_model.py b/src/easydynamics/sample_model/sample_model.py
index 3991a10f9..f0039161f 100644
--- a/src/easydynamics/sample_model/sample_model.py
+++ b/src/easydynamics/sample_model/sample_model.py
@@ -66,7 +66,8 @@ def __init__(
self,
display_name: str = 'MySampleModel',
unique_name: str | None = None,
- unit: str | sc.Unit = 'meV',
+ x_unit: str | sc.Unit = 'meV',
+ y_unit: str | sc.Unit = 'dimensionless',
components: ModelComponent | ComponentCollection | None = None,
Q: Q_type | None = None,
diffusion_models: DiffusionModelBase | list[DiffusionModelBase] | None = None,
@@ -83,33 +84,31 @@ def __init__(
Display name of the model.
unique_name : str | None, default=None
Unique name of the model. If None, a unique name will be generated.
- unit : str | sc.Unit, default='meV'
- Unit of the model. If None,.
+ x_unit : str | sc.Unit, default='meV'
+ Unit of the x-axis.
+ y_unit : str | sc.Unit, default='dimensionless'
+ Unit of the y-axis (output).
components : ModelComponent | ComponentCollection | None, default=None
- Template components of the model. If None, no components are added. These components
- are copied into ComponentCollections for each Q value.
+ Template components copied into each Q's ComponentCollection.
Q : Q_type | None, default=None
- Q values for the model. If None, Q is not set.
+ Q values. If None, Q is not set.
diffusion_models : DiffusionModelBase | list[DiffusionModelBase] | None, default=None
- Diffusion models to include in the SampleModel. If None, no diffusion models are added.
+ Diffusion models to include. Each must be a DiffusionModelBase.
temperature : float | None, default=None
- Temperature for detailed balancing. If None, no detailed balancing is applied. By
- default, None.
+ Sample temperature in temperature_unit. If provided, detailed balance is applied.
temperature_unit : str | sc.Unit, default='K'
- Unit of the temperature.
+ Unit for the temperature parameter.
detailed_balance_settings : DetailedBalanceSettings | None, default=None
- Settings for detailed balancing.
+ Detailed balance settings. If None, default settings are used.
Raises
------
TypeError
- If diffusion_models is not a DiffusionModelBase, a list of DiffusionModelBase, or None,
- or if temperature is not a number or None, or if detailed_balance_settings is not a
- DetailedBalanceSettings instance.
+ If diffusion_models contains non-DiffusionModelBase items, temperature is not numeric,
+ or detailed_balance_settings is not a DetailedBalanceSettings instance.
ValueError
If temperature is negative.
"""
-
if diffusion_models is None:
self._diffusion_models = []
elif isinstance(diffusion_models, DiffusionModelBase):
@@ -131,7 +130,8 @@ def __init__(
super().__init__(
display_name=display_name,
unique_name=unique_name,
- unit=unit,
+ x_unit=x_unit,
+ y_unit=y_unit,
components=components,
Q=Q,
)
@@ -178,7 +178,6 @@ def append_diffusion_model(self, diffusion_model: DiffusionModelBase) -> None:
TypeError
If the diffusion_model is not a DiffusionModelBase.
"""
-
if not isinstance(diffusion_model, DiffusionModelBase):
raise TypeError(
f'diffusion_model must be a DiffusionModelBase, got {type(diffusion_model).__name__}' # noqa: E501
@@ -201,15 +200,19 @@ def remove_diffusion_model(self, name: str) -> None:
ValueError
If no DiffusionModel with the given name is found.
"""
- for i, dm in enumerate(self.diffusion_models):
- if dm.name == name:
- del self.diffusion_models[i]
- self._component_collections_is_dirty = True
- return
- raise ValueError(
- f'No DiffusionModel with name {name} found. \n'
- f'The available names are: {[dm.name for dm in self.diffusion_models]}'
- )
+ matches = [i for i, dm in enumerate(self.diffusion_models) if dm.name == name]
+ if len(matches) == 0:
+ raise ValueError(
+ f'No DiffusionModel with name {name!r} found. '
+ f'Available names are: {[dm.name for dm in self.diffusion_models]}'
+ )
+ if len(matches) > 1:
+ raise ValueError(
+ f'Multiple DiffusionModels share the name {name!r}. '
+ f'Rename them to have unique names before removing.'
+ )
+ del self.diffusion_models[matches[0]]
+ self._component_collections_is_dirty = True
def clear_diffusion_models(self) -> None:
"""Clear all DiffusionModels from the SampleModel."""
@@ -250,7 +253,6 @@ def diffusion_models(
TypeError
If value is not a DiffusionModelBase, a list of DiffusionModelBase, or None.
"""
-
if value is None:
self._diffusion_models = []
self._on_diffusion_models_change()
@@ -292,7 +294,7 @@ def temperature(self, value: Numeric | None) -> None:
Parameters
----------
value : Numeric | None
- The temperature value to set. Can be a number or None to unset the temperature.
+ The temperature value. If None, temperature is cleared (no detailed balance).
Raises
------
@@ -325,12 +327,12 @@ def temperature(self, value: Numeric | None) -> None:
@property
def temperature_unit(self) -> str | sc.Unit:
"""
- Get the temperature unit of the SampleModel.
+ Get the temperature unit.
Returns
-------
str | sc.Unit
- The unit of the temperature Parameter.
+ The unit of the temperature parameter.
"""
return self._temperature_unit
@@ -351,8 +353,8 @@ def temperature_unit(self, _value: str | sc.Unit) -> None:
"""
raise AttributeError(
- f'Temperature_unit is read-only. Use convert_temperature_unit to change the unit between allowed types ' # noqa: E501
- f'or create a new {self.__class__.__name__} with the desired unit.'
+ f'Temperature_unit is read-only. Use convert_temperature_unit to change the unit '
+ f'between allowed types or create a new {self.__class__.__name__} with the desired unit.' # noqa: E501
)
def convert_temperature_unit(self, unit: str | sc.Unit) -> None:
@@ -382,10 +384,53 @@ def convert_temperature_unit(self, unit: str | sc.Unit) -> None:
self._temperature_unit = unit
except Exception:
# Attempt to rollback on failure
+
with suppress(Exception):
self.temperature.convert_unit(old_unit)
raise
+ def _convert_axis_unit(self, unit: str | sc.Unit, axis: str) -> None:
+ """
+ Convert one axis unit on the SampleModel, including any attached diffusion models.
+
+ Extends the ModelBase conversion (template components and per-Q collections) by also
+ converting each diffusion model, whose regenerated collections would otherwise come back in
+ the old unit on the next rebuild.
+
+ Parameters
+ ----------
+ unit : str | sc.Unit
+ The new unit to convert to.
+ axis : str
+ Which axis to convert: ``'x'`` or ``'y'``.
+
+ Raises
+ ------
+ Exception
+ If any conversion fails; the already-converted diffusion models and the ModelBase state
+ are rolled back before re-raising.
+ """
+ old_unit = self.x_unit if axis == 'x' else self.y_unit
+ super()._convert_axis_unit(unit, axis)
+
+ method = f'convert_{axis}_unit'
+ converted_models = []
+ try:
+ for diffusion_model in self.diffusion_models:
+ getattr(diffusion_model, method)(unit)
+ converted_models.append(diffusion_model)
+ except Exception:
+ for diffusion_model in converted_models:
+ with suppress(Exception):
+ getattr(diffusion_model, method)(old_unit)
+ if old_unit is not None:
+ with suppress(Exception):
+ super()._convert_axis_unit(old_unit, axis)
+ raise
+ # Everything is converted in place (the merged collections share the diffusion models'
+ # component objects), so nothing is marked dirty: conversion must not discard the
+ # per-Q state by triggering a rebuild.
+
@property
def normalize_detailed_balance(self) -> bool:
"""
@@ -420,12 +465,12 @@ def normalize_detailed_balance(self, value: bool) -> None:
@property
def use_detailed_balance(self) -> bool:
"""
- Get whether to apply detailed balance to the model.
+ Get whether detailed balance correction is applied.
Returns
-------
bool
- True if detailed balance is applied, False otherwise.
+ True if detailed balance is applied during evaluation, False otherwise
"""
return self.detailed_balance_settings.use_detailed_balance
@@ -451,12 +496,12 @@ def use_detailed_balance(self, value: bool) -> None:
@property
def detailed_balance_settings(self) -> DetailedBalanceSettings:
"""
- Get the DetailedBalanceSettings of the SampleModel.
+ Get the detailed balance settings.
Returns
-------
DetailedBalanceSettings
- The DetailedBalanceSettings of the SampleModel.
+ The detailed balance settings object.
"""
return self._detailed_balance_settings
@@ -484,31 +529,33 @@ def detailed_balance_settings(self, value: DetailedBalanceSettings) -> None:
# ------------------------------------------------------------------
def evaluate(
- self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- ) -> list[np.ndarray]:
+ self,
+ x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray,
+ output: str = 'numpy',
+ ) -> list[np.ndarray] | list[sc.Variable]:
"""
Evaluate the sample model at all Q for the given x values.
Parameters
----------
x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray
- The x values to evaluate the model at. Can be a number, list, numpy array, scipp
- Variable, or scipp DataArray.
+ The x values to evaluate the model at.
+ output : str, default='numpy'
+ 'numpy' returns list of np.ndarray; 'scipp' returns list of sc.Variable.
Returns
-------
- list[np.ndarray]
+ list[np.ndarray] | list[sc.Variable]
List of evaluated model values for each Q.
"""
-
- y = super().evaluate(x)
+ y = super().evaluate(x, output=output)
if self.temperature is not None and self.detailed_balance_settings.use_detailed_balance:
DBF = detailed_balance_factor(
energy=x,
temperature=self.temperature,
divide_by_temperature=self.detailed_balance_settings.normalize_detailed_balance,
- energy_unit=self.unit,
+ energy_unit=self.x_unit,
)
y = [yi * DBF for yi in y]
@@ -530,10 +577,8 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]:
Returns
-------
list[Parameter]
- List of all Parameters and Descriptors, including temperature if set and all variables
- from diffusion models.
+ All Parameters and Descriptors in the SampleModel.
"""
-
all_vars = super().get_all_variables(Q_index=Q_index)
if self.temperature is not None:
all_vars.append(self.temperature)
@@ -557,8 +602,7 @@ def _generate_component_collections(self) -> None:
if self.Q is None:
return
- # Generate components from diffusion models
- # and add to component collections
+ # Generate components from diffusion models and add to component collections
for diffusion_model in self.diffusion_models:
diffusion_collections = diffusion_model.get_component_collections()
for target, source in zip(
@@ -595,12 +639,11 @@ def __repr__(self) -> str:
str
A string representation of the SampleModel.
"""
-
return (
- f'{self.__class__.__name__}('
- f'unique_name={self.unique_name!r}, unit={self.unit},\n'
- f' Q={self.Q},\n'
- f' components={self.components}, diffusion_models={self.diffusion_models},\n'
- f' temperature={self.temperature},\n'
- f' detailed_balance_settings={self.detailed_balance_settings})'
+ f'{self.__class__.__name__}(unique_name={self.unique_name}, '
+ f'x_unit={self.x_unit}, y_unit={self.y_unit}, '
+ f'Q = {None if self.Q is None else self.Q.values}, \n '
+ f'components = {self.components}, diffusion_models = {self.diffusion_models}, '
+ f'temperature = {self.temperature}, '
+ f'detailed_balance_settings = {self.detailed_balance_settings})'
)
diff --git a/src/easydynamics/settings/convolution_settings.py b/src/easydynamics/settings/convolution_settings.py
index fadc13a1e..309efa9c6 100644
--- a/src/easydynamics/settings/convolution_settings.py
+++ b/src/easydynamics/settings/convolution_settings.py
@@ -90,6 +90,13 @@ def __init__(
self._suppress_warnings = suppress_warnings
self._convolution_plan_is_valid = False
+ # Plan-invalidation bookkeeping for convolvers sharing this settings object.
+ # _plan_version is bumped whenever the plan is invalidated (accuracy knob change or
+ # an explicit convolution_plan_is_valid = False). _blessed_plan_version records the
+ # version at the last explicit convolution_plan_is_valid = True, which suppresses
+ # rebuilds for all convolvers until the next invalidation.
+ self._plan_version = 0
+ self._blessed_plan_version = -1
@property
def upsample_factor(self) -> Numeric | None:
@@ -173,6 +180,11 @@ def extension_factor(self, factor: Numeric) -> None:
If factor is negative.
"""
+ if factor is None:
+ self._extension_factor = factor
+ self.convolution_plan_is_valid = False
+ return
+
if not isinstance(factor, Numeric):
raise TypeError('Extension factor must be a number.')
if factor < 0.0:
@@ -211,6 +223,48 @@ def convolution_plan_is_valid(self, is_valid: bool) -> None:
if not isinstance(is_valid, bool):
raise TypeError('convolution_plan_is_valid must be True or False.')
self._convolution_plan_is_valid = is_valid
+ if is_valid:
+ # An explicit True blesses the current version: all convolvers sharing these
+ # settings may skip rebuilding until the next invalidation.
+ self._blessed_plan_version = self._plan_version
+ else:
+ # An explicit False (or a knob setter delegating here) invalidates the plan for
+ # every convolver sharing these settings.
+ self._plan_version += 1
+
+ def _plan_valid_for(self, seen_version: int) -> bool:
+ """
+ Check whether a convolver that last rebuilt at seen_version can skip rebuilding.
+
+ Parameters
+ ----------
+ seen_version : int
+ The plan version the convolver recorded when it last rebuilt its plan.
+
+ Returns
+ -------
+ bool
+ True if the plan flag is set and no invalidation happened since the convolver's rebuild
+ (or the current version was explicitly blessed).
+ """
+ return self._convolution_plan_is_valid and (
+ seen_version == self._plan_version or self._blessed_plan_version == self._plan_version
+ )
+
+ def _mark_plan_built(self) -> int:
+ """
+ Record that one convolver rebuilt its plan.
+
+ Sets the plan flag without blessing the current version, so other convolvers sharing this
+ settings object still detect invalidations they have not consumed yet.
+
+ Returns
+ -------
+ int
+ The current plan version, to be stored by the convolver.
+ """
+ self._convolution_plan_is_valid = True
+ return self._plan_version
@property
def suppress_warnings(self) -> bool:
@@ -243,6 +297,22 @@ def suppress_warnings(self, suppress: bool) -> None:
raise TypeError('suppress_warnings must be True or False.')
self._suppress_warnings = suppress
+ def __copy__(self) -> 'ConvolutionSettings':
+ """
+ Return a shallow copy of the ConvolutionSettings.
+
+ Returns
+ -------
+ 'ConvolutionSettings'
+ A new ConvolutionSettings instance with the same parameter values.
+ """
+ return ConvolutionSettings(
+ upsample_factor=self.upsample_factor,
+ extension_factor=self.extension_factor,
+ suppress_warnings=self.suppress_warnings,
+ display_name=self.display_name,
+ )
+
def __repr__(self) -> str:
"""
Return a string representation of the ConvolutionSettings.
diff --git a/src/easydynamics/utils/fit_target.py b/src/easydynamics/utils/fit_target.py
new file mode 100644
index 000000000..55c7ea5c5
--- /dev/null
+++ b/src/easydynamics/utils/fit_target.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: 2026 EasyScience contributors
+# SPDX-License-Identifier: BSD-3-Clause
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+
+@dataclass(frozen=True)
+class FitTarget:
+ """
+ One fittable prediction of a model, bound to a key in a parameters Dataset.
+
+ Models declare their predictions by returning FitTargets (see
+ ``DiffusionModelBase.get_fit_targets``), and ``FitBinding`` maps them onto the dataset keys
+ they should be fitted against. Instances are immutable snapshots created on demand, so the
+ units always reflect the model state at the time the targets are built.
+
+ Attributes
+ ----------
+ name : str
+ The prediction's name (e.g. ``'width'``, ``'area'``, ``'delta_area'``, ``'value'``).
+ dataset_key : str
+ The key in the parameters Dataset holding the data this prediction is fitted against.
+ function : Callable
+ The fit function; called as ``function(x)`` with raw x values expressed in *x_unit* and
+ returning raw values expressed in *y_unit*.
+ label : str
+ Display label used for plots and results (e.g. ``'DeltaLorentz width'``).
+ x_unit : str | None
+ The unit *function* expects its input in, or None if no unit conversion applies.
+ y_unit : str | None
+ The unit of *function*'s output, or None if no unit conversion applies.
+ """
+
+ name: str
+ dataset_key: str
+ function: Callable
+ label: str
+ x_unit: str | None
+ y_unit: str | None
diff --git a/src/easydynamics/utils/utils.py b/src/easydynamics/utils/utils.py
index d9fc27c0f..02ec54e5f 100644
--- a/src/easydynamics/utils/utils.py
+++ b/src/easydynamics/utils/utils.py
@@ -4,8 +4,10 @@
import numpy as np
import scipp as sc
from easyscience.variable import DescriptorNumber
+from easyscience.variable import Parameter
from numpy.typing import ArrayLike
from scipp.constants import hbar as scipp_hbar
+from scipp.constants import k as scipp_k
Numeric = float | int
@@ -13,19 +15,101 @@
energy_type = np.ndarray | Numeric | list | ArrayLike | sc.Variable
hbar = DescriptorNumber.from_scipp('hbar', scipp_hbar)
+kb = DescriptorNumber.from_scipp('kb', scipp_k)
angstrom = DescriptorNumber('angstrom', 1e-10, unit='m')
+def verify_Q_index(Q_index: int, Q: sc.Variable | None, allow_none: bool = False) -> None:
+ """
+ Verify that Q_index is a valid integer index into Q.
+
+ Parameters
+ ----------
+ Q_index : int
+ Index to validate.
+ Q : sc.Variable | None
+ The Q values (may be None if no data is loaded).
+ allow_none : bool, default=False
+ Whether or not to allow Q_index to be None
+
+ Raises
+ ------
+ TypeError
+ If Q_index is not an int (or not an int or None when allow_none=True).
+ IndexError
+ If Q_index is out of range.
+ ValueError
+ If Q is None and Q_index is not None.
+ """
+ if allow_none and Q_index is None:
+ return
+
+ if Q_index is None or not isinstance(Q_index, int):
+ if allow_none:
+ raise TypeError(f'Q_index must be an int or None, got {type(Q_index).__name__}')
+ raise TypeError(f'Q_index must be an int, got {type(Q_index).__name__}')
+
+ if Q is None:
+ raise ValueError('Q is None, cannot validate Q_index.')
+
+ if not (0 <= Q_index < len(Q)):
+ raise IndexError(f'Q_index {Q_index} is out of bounds for Q of length {len(Q)}')
+
+
+def convert_parameter_unit(parameter: Parameter, unit: str | sc.Unit) -> None:
+ """
+ Convert a parameter to a new unit, keeping dependent parameters consistent.
+
+ Independent parameters are converted with ``convert_unit``. Dependent parameters are converted
+ with ``set_desired_unit``, so the new unit survives later dependency-graph recomputations (a
+ plain ``convert_unit`` would be reverted to the old desired unit the next time the dependency
+ expression is re-evaluated).
+
+ Parameters
+ ----------
+ parameter : Parameter
+ The parameter to convert.
+ unit : str | sc.Unit
+ The unit to convert to.
+ """
+ if parameter.independent:
+ parameter.convert_unit(str(unit))
+ else:
+ parameter.set_desired_unit(str(unit))
+
+
+def energy_to_scipp(energy: np.ndarray, unit: str | sc.Unit) -> sc.Variable:
+ """
+ Convert a numpy energy array to a scipp Variable with dimension 'energy'.
+
+ Parameters
+ ----------
+ energy : np.ndarray
+ The energy array to be converted
+ unit : str | sc.Unit
+ The unit of the energy
+
+ Returns
+ -------
+ sc.Variable
+ Energy as sc.Variable.
+ """
+ return sc.array(dims=['energy'], values=energy, unit=unit)
+
+
def _validate_and_convert_Q(
Q: np.ndarray | Numeric | list | ArrayLike | sc.Variable | None,
-) -> np.ndarray | None:
+) -> sc.Variable | None:
"""
- Validate and convert Q to a numpy array.
+ Validate and convert Q to a scipp Variable in 1/angstrom.
+
+ Numbers, lists, and numpy arrays are assumed to be in 1/angstrom. Scipp Variables may be in any
+ unit convertible to 1/angstrom and are converted.
Parameters
----------
Q : np.ndarray | Numeric | list | ArrayLike | sc.Variable | None
- Scattering vector values in 1/angstrom.
+ Scattering vector values.
Raises
------
@@ -37,8 +121,8 @@ def _validate_and_convert_Q(
Returns
-------
- np.ndarray | None
- Q as a np.ndarray or None if Q is None.
+ sc.Variable | None
+ Q as a sc.Variable with dimension 'Q' and unit 1/angstrom, or None if Q is None.
"""
if Q is None:
return None
@@ -59,7 +143,7 @@ def _validate_and_convert_Q(
if Q.dims != ('Q',):
raise ValueError("Q must have a single dimension named 'Q'.")
Q = Q.to(unit='1/angstrom')
- return Q.values
+ return Q
def _validate_unit(unit: str | sc.Unit | None) -> sc.Unit | None:
@@ -92,6 +176,30 @@ def _validate_unit(unit: str | sc.Unit | None) -> sc.Unit | None:
return unit
+def _assert_valid_unit(unit: str | sc.Unit) -> None:
+ """
+ Assert that the given unit is recognised by scipp.
+
+ Parameters
+ ----------
+ unit : str | sc.Unit
+ The unit to validate.
+
+ Raises
+ ------
+ TypeError
+ If unit is not a string or scipp Unit.
+ ValueError
+ If the string is not a valid scipp unit.
+ """
+ if not isinstance(unit, (str, sc.Unit)):
+ raise TypeError(f'unit must be a string or sc.Unit, got {type(unit).__name__}')
+ try:
+ sc.Unit(str(unit))
+ except sc.UnitError as e:
+ raise ValueError(f"'{unit}' is not a valid scipp unit.") from e
+
+
def _in_notebook() -> bool:
"""
Check if the code is running in a Jupyter notebook.
diff --git a/tests/performance_tests/convolution/convolution_width_thresholds.ipynb b/tests/performance_tests/convolution/convolution_width_thresholds.ipynb
index fbb5df296..543002d54 100644
--- a/tests/performance_tests/convolution/convolution_width_thresholds.ipynb
+++ b/tests/performance_tests/convolution/convolution_width_thresholds.ipynb
@@ -48,7 +48,7 @@
")\n",
"numerical_convolver.upsample_factor = None\n",
"for gwidth in gaussian_widths:\n",
- " sample_components.components[0].width = gwidth\n",
+ " sample_components[0].width = gwidth\n",
" y_analytical = analytical_convolver.convolution()\n",
"\n",
" y_numerical = numerical_convolver.convolution()\n",
@@ -106,8 +106,8 @@
")\n",
"numerical_convolver.upsample_factor = None\n",
"for gwidth, gcenter in zip(gaussian_widths, gaussian_centers, strict=True):\n",
- " sample_components.components[0].width = gwidth\n",
- " sample_components.components[0].center = gcenter\n",
+ " sample_components[0].width = gwidth\n",
+ " sample_components[0].center = gcenter\n",
" y_analytical = analytical_convolver.convolution()\n",
"\n",
" y_numerical = numerical_convolver.convolution()\n",
@@ -135,7 +135,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "default",
"language": "python",
"name": "python3"
},
@@ -149,7 +149,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.14.4"
+ "version": "3.14.5"
}
},
"nbformat": 4,
diff --git a/tests/performance_tests/utils/detailed_balance_approximations.ipynb b/tests/performance_tests/utils/detailed_balance_approximations.ipynb
index f50de3cf5..7b5a75fa0 100644
--- a/tests/performance_tests/utils/detailed_balance_approximations.ipynb
+++ b/tests/performance_tests/utils/detailed_balance_approximations.ipynb
@@ -42,8 +42,6 @@
"metadata": {},
"outputs": [],
"source": [
- "import numpy as np\n",
- "\n",
"x = np.linspace(1e-10, 1e-5, 1000)\n",
"\n",
"y = x / (1 - np.exp(-x))\n",
@@ -62,7 +60,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "newdynamics",
+ "display_name": "default",
"language": "python",
"name": "python3"
},
@@ -76,7 +74,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.13"
+ "version": "3.14.5"
}
},
"nbformat": 4,
diff --git a/tests/unit/easydynamics/analysis/test_analysis.py b/tests/unit/easydynamics/analysis/test_analysis.py
index 516b718ef..2a67357ce 100644
--- a/tests/unit/easydynamics/analysis/test_analysis.py
+++ b/tests/unit/easydynamics/analysis/test_analysis.py
@@ -98,6 +98,24 @@ def test_analysis_list_contains_all_Q_indices(self, analysis):
for i in range(3):
assert analysis.analysis_list[i].Q_index == i
+ def test_analysis_list_shares_convolution_settings(self, analysis):
+ # WHEN THEN EXPECT: every Analysis1d holds the same ConvolutionSettings object, so
+ # user changes to analysis.convolution_settings reach all Q indices (regression:
+ # per-Q copies silently detached the settings)
+ for analysis1d in analysis.analysis_list:
+ assert analysis1d.convolution_settings is analysis.convolution_settings
+
+ def test_convolution_settings_changes_reach_all_Q(self, analysis):
+ # WHEN: the analysis list has been built
+ assert len(analysis.analysis_list) == 3
+
+ # THEN: mutate the settings on the Analysis after the fact
+ analysis.convolution_settings.upsample_factor = 20
+
+ # EXPECT: the per-Q analyses see the new value
+ for analysis1d in analysis.analysis_list:
+ assert analysis1d.convolution_settings.upsample_factor == 20
+
def test_analysis_list_setter_raises(self, analysis):
# WHEN / THEN / EXPECT
with pytest.raises(
@@ -152,7 +170,7 @@ def test_calculate_with_invalid_Q_index(self, analysis):
# WHEN / THEN / EXPECT
with pytest.raises(
IndexError,
- match='must be a valid index',
+ match='Q_index 3 is out of bounds',
):
analysis.calculate(Q_index=3)
@@ -512,7 +530,7 @@ def test_parameters_to_dataset_different_units(self, analysis):
)
# Convert the unit of a component to eV.
- analysis.sample_model.get_component_collection(Q_index=1)[0].convert_unit('eV')
+ analysis.sample_model.get_component_collection(Q_index=1)[0].convert_x_unit('eV')
# THEN
parameters_dataset = analysis.parameters_to_dataset()
@@ -533,12 +551,18 @@ def test_parameters_to_dataset_different_units(self, analysis):
assert 'Q' in parameters_dataset[parameter_name].dims
def test_parameters_to_dataset_raises_on_duplicate_names(self, analysis):
- # Add a second Gaussian with the same parameter names as the first
- analysis.sample_model.append_component(
- Gaussian(name='GaussianName', display_name='Gaussian2', area=0.5)
- )
+ # Add a second Gaussian with the same parameter names as the first. Appending and the
+ # later per-Q copy both warn about the duplicate names; capture them so the test suite
+ # stays warning-clean.
+ with pytest.warns(UserWarning, match='Duplicate component names'):
+ analysis.sample_model.append_component(
+ Gaussian(name='GaussianName', display_name='Gaussian2', area=0.5)
+ )
- with pytest.raises(ValueError, match='Duplicate parameter names'):
+ with (
+ pytest.warns(UserWarning, match='Duplicate component names'),
+ pytest.raises(ValueError, match='Duplicate parameter names'),
+ ):
analysis.parameters_to_dataset()
@pytest.mark.parametrize(
@@ -732,15 +756,18 @@ def test_on_instrument_model_changed(self, analysis):
def test_on_convolution_settings_changed(self, analysis):
# WHEN
- new_convolution_settings = ConvolutionSettings()
+ new_convolution_settings = ConvolutionSettings(upsample_factor=7, extension_factor=0.3)
# THEN (this calls _on_convolution_settings_changed internally)
analysis.convolution_settings = new_convolution_settings
- # EXPECT
+ # EXPECT: the parent holds the new settings object
assert analysis.convolution_settings is new_convolution_settings
+ # Each Analysis1d gets its own copy of the settings (so plan_is_valid is
+ # tracked independently per Q), but all values are propagated from the parent.
for analysis1d in analysis.analysis_list:
- assert analysis1d.convolution_settings is new_convolution_settings
+ assert analysis1d.convolution_settings.upsample_factor == 7
+ assert analysis1d.convolution_settings.extension_factor == pytest.approx(0.3)
def test_fit_single_Q_valid(self, analysis):
# WHEN
@@ -757,7 +784,7 @@ def test_fit_single_Q_invalid_Q_index(self, analysis):
# WHEN / THEN / EXPECT
with pytest.raises(
IndexError,
- match='must be a valid index',
+ match='Q_index 3 is out of bounds',
):
analysis.fit(Q_index=3)
@@ -824,6 +851,51 @@ def test_fit_all_Q_simultaneously(self, analysis):
# And that the result from the fit method is returned
assert result == fake_fit_result
+ def test_fit_all_Q_simultaneously_with_nan_data(self):
+ # Regression test: energy must be sliced via sc.array(dims=['energy'], values=mask),
+ # NOT via energy[numpy_bool_array]. The bug: numpy booleans are treated as integer
+ # indices by scipp (True→1, False→0), so energy[[True,False,True]] returns 3 elements
+ # with wrong values instead of filtering to the 2 finite points.
+
+ # WHEN: data with a NaN at index 1; finite energies are -1.0 and 1.0
+ Q = sc.array(dims=['Q'], values=[1.0], unit='1/Angstrom')
+ energy = sc.array(dims=['energy'], values=[-1.0, 0.0, 1.0], unit='meV')
+ values = np.array([[1.0, np.nan, 2.0]])
+ variances = np.array([[0.1, 0.1, 0.1]])
+ data = sc.array(dims=['Q', 'energy'], values=values, variances=variances)
+ data_array = sc.DataArray(data=data, coords={'Q': Q, 'energy': energy})
+ experiment = Experiment(data=data_array)
+
+ # Set up a sample model and analysis
+ sample_model = SampleModel(components=Gaussian(name='G'))
+ analysis = Analysis(experiment=experiment, sample_model=sample_model)
+
+ captured_energy = []
+ original_refresh = analysis.analysis_list[0].refresh_convolver
+
+ # Patch the refresh_convolver method to capture the energy passed to it
+ def capture_refresh(energy, **kwargs):
+ captured_energy.append(energy)
+ return original_refresh(energy=energy, **kwargs)
+
+ analysis.analysis_list[0].refresh_convolver = capture_refresh
+ analysis.get_fit_functions = MagicMock(return_value=['fit_fn'])
+
+ fake_fitter_instance = MagicMock()
+ fake_fitter_instance.fit.return_value = object()
+
+ # THEN
+ with patch(
+ 'easydynamics.analysis.analysis.MultiFitter',
+ return_value=fake_fitter_instance,
+ ):
+ analysis._fit_all_Q_simultaneously()
+
+ # EXPECT: only the 2 finite energy points (-1.0, 1.0) were passed, not all 3
+ assert len(captured_energy) == 1
+ assert len(captured_energy[0]) == 2
+ np.testing.assert_array_equal(captured_energy[0].values, [-1.0, 1.0])
+
def test_get_fit_functions(self, analysis):
# WHEN
@@ -923,9 +995,9 @@ def test_create_components_dataset_single_Q(self, analysis_single_Q):
assert components_dataset.sizes['Q'] == 1
assert components_dataset.coords['Q'].ndim == 1
- def test_ensure_analysis_list_current_clears_dirty_when_Q_is_none(self):
- # An Analysis with no experiment has Q=None; _ensure_analysis_list_current should still
- # clear the dirty flag without attempting to build the list.
+ def test_ensure_analysis_list_current_stays_dirty_when_Q_is_none(self):
+ # When Q is None, _ensure_analysis_list_current should NOT clear the dirty flag —
+ # it must stay dirty so the list is rebuilt as soon as Q becomes available.
# WHEN
analysis = Analysis(display_name='NoQ')
@@ -935,8 +1007,8 @@ def test_ensure_analysis_list_current_clears_dirty_when_Q_is_none(self):
# THEN
result = analysis.analysis_list
- # EXPECT - dirty flag cleared, list stays empty
- assert analysis._analysis_list_is_dirty is False
+ # EXPECT - dirty flag preserved, list stays empty
+ assert analysis._analysis_list_is_dirty is True
assert result == []
def test_rebin_marks_analysis_list_dirty(self, analysis):
@@ -1039,3 +1111,9 @@ def test_direct_experiment_rebin_does_not_update_analysis_list(self, analysis):
# EXPECT - analysis_list is NOT marked dirty (callers must use Analysis.rebin())
assert analysis._analysis_list_is_dirty is False
+
+ def test_repr(self, analysis):
+ repr_str = repr(analysis)
+ assert 'Analysis' in repr_str
+ assert 'display_name=' in repr_str
+ assert 'n_analyses=' in repr_str
diff --git a/tests/unit/easydynamics/analysis/test_analysis1d.py b/tests/unit/easydynamics/analysis/test_analysis1d.py
index 5bd571634..a44c3ed4b 100644
--- a/tests/unit/easydynamics/analysis/test_analysis1d.py
+++ b/tests/unit/easydynamics/analysis/test_analysis1d.py
@@ -75,8 +75,8 @@ def test_Q_index_setter(self, analysis1d):
@pytest.mark.parametrize(
'invalid_Q_index, expected_exception, expected_message',
[
- (-1, IndexError, 'Q_index must be'),
- (10, IndexError, 'Q_index must be'),
+ (-1, IndexError, 'Q_index -1 is out of bounds'),
+ (10, IndexError, 'Q_index 10 is out of bounds'),
('invalid', TypeError, 'Q_index must be '),
(np.nan, TypeError, 'Q_index must be '),
([1, 2], TypeError, 'Q_index must be '),
@@ -111,10 +111,10 @@ def test_calculate_updates_convolver_and_calls_calculate(self, analysis1d):
result = analysis1d.calculate()
# EXPECT
-
analysis1d._create_convolver.assert_called_once()
- assert analysis1d._convolver is fake_convolver
- analysis1d._calculate.assert_called_once()
+ # calculate() passes the convolver directly to _calculate without storing it on self
+ _, call_kwargs = analysis1d._calculate.call_args
+ assert call_kwargs['convolver'] is fake_convolver
np.testing.assert_array_equal(result, expected_result)
def test__calculate_adds_sample_and_background(self, analysis1d):
@@ -150,7 +150,7 @@ def test_fit_calls_fitter_with_correct_arguments(self, analysis1d):
fake_weights = np.array([0.1, 0.2, 0.3])
fake_mask = np.array([True, False, True])
- analysis1d.experiment._extract_x_y_weights_only_finite = MagicMock(
+ analysis1d.experiment.extract_x_y_weights_only_finite = MagicMock(
return_value=(fake_x, fake_y, fake_weights, fake_mask)
)
@@ -180,7 +180,7 @@ def test_fit_calls_fitter_with_correct_arguments(self, analysis1d):
fit_function='fit_func',
)
- analysis1d.experiment._extract_x_y_weights_only_finite.assert_called_once()
+ analysis1d.experiment.extract_x_y_weights_only_finite.assert_called_once()
fake_fitter_instance.fit.assert_called_once_with(
x=fake_x,
@@ -508,15 +508,19 @@ def test_calculate_energy_with_offset_different_units(self, analysis1d):
# WHEN
energy = analysis1d.experiment.energy
energy_offset = analysis1d.instrument_model.get_energy_offset(Q_index=analysis1d.Q_index)
- energy_offset.value = 1.0 # override with a simple value for testing
- energy_offset.convert_unit('eV')
+ energy_offset.value = 1.0 # set to 1.0 in original unit (meV)
+ energy_offset.convert_unit('eV') # now 0.001 eV, still represents 1 meV
# THEN
result = analysis1d._calculate_energy_with_offset(energy, energy_offset)
- # EXPECT
- expected = energy.values - energy_offset.value
- np.testing.assert_array_equal(result.values, expected)
+ # EXPECT: offset must be converted to energy's unit before subtraction
+ offset_in_energy_unit = sc.to_unit(
+ sc.scalar(energy_offset.value, unit=str(energy_offset.unit)),
+ str(energy.unit),
+ ).value
+ expected = energy.values - offset_in_energy_unit
+ np.testing.assert_array_almost_equal(result.values, expected)
def test_calculate_energy_with_offset_raises_if_incompatible_units(self, analysis1d):
# WHEN
@@ -524,9 +528,7 @@ def test_calculate_energy_with_offset_raises_if_incompatible_units(self, analysi
energy_offset = Parameter(name='energy_offset', value=1.0, unit='m') # incompatible unit
# THEN / EXPECT
- with pytest.raises(
- sc.UnitError, match='Energy and energy offset must have compatible units'
- ):
+ with pytest.raises(sc.UnitError):
analysis1d._calculate_energy_with_offset(energy, energy_offset)
#############
@@ -920,7 +922,7 @@ def test_fit_marks_convolver_dirty_when_sample_model_components_change(self, ana
'easydynamics.analysis.analysis1d.EasyScienceFitter',
return_value=MagicMock(fit=MagicMock(return_value=MagicMock())),
):
- analysis1d.experiment._extract_x_y_weights_only_finite = MagicMock(
+ analysis1d.experiment.extract_x_y_weights_only_finite = MagicMock(
return_value=(
np.array([1.0, 2.0, 3.0]),
np.array([1.0, 2.0, 3.0]),
@@ -946,7 +948,7 @@ def test_fit_does_not_rebuild_convolver_when_nothing_changed(self, analysis1d):
'easydynamics.analysis.analysis1d.EasyScienceFitter',
return_value=MagicMock(fit=MagicMock(return_value=MagicMock())),
):
- analysis1d.experiment._extract_x_y_weights_only_finite = MagicMock(
+ analysis1d.experiment.extract_x_y_weights_only_finite = MagicMock(
return_value=(
np.array([1.0, 2.0, 3.0]),
np.array([1.0, 2.0, 3.0]),
@@ -1030,7 +1032,7 @@ def test_fit_marks_convolver_dirty_when_resolution_model_components_change(self,
'easydynamics.analysis.analysis1d.EasyScienceFitter',
return_value=MagicMock(fit=MagicMock(return_value=MagicMock())),
):
- analysis1d.experiment._extract_x_y_weights_only_finite = MagicMock(
+ analysis1d.experiment.extract_x_y_weights_only_finite = MagicMock(
return_value=(
np.array([1.0, 2.0, 3.0]),
np.array([1.0, 2.0, 3.0]),
@@ -1042,3 +1044,67 @@ def test_fit_marks_convolver_dirty_when_resolution_model_components_change(self,
# EXPECT
analysis1d._create_convolver.assert_called_once()
+
+ # ───── Regression tests ─────
+
+ @pytest.fixture
+ def analysis1d_with_nan(self):
+ """analysis1d fixture whose data contains a NaN at the second energy point."""
+ Q = sc.array(dims=['Q'], values=[1.0], unit='1/Angstrom')
+ energy = sc.array(dims=['energy'], values=[10.0, 20.0, 30.0], unit='meV')
+ data = sc.array(
+ dims=['Q', 'energy'],
+ values=[[1.0, float('nan'), 3.0]],
+ variances=[[0.1, 0.2, 0.3]],
+ )
+ data_array = sc.DataArray(data=data, coords={'Q': Q, 'energy': energy})
+ experiment = Experiment(data=data_array)
+ sample_model = SampleModel(components=Gaussian())
+ instrument_model = InstrumentModel()
+ return Analysis1d(
+ display_name='TestNaN',
+ experiment=experiment,
+ sample_model=sample_model,
+ instrument_model=instrument_model,
+ Q_index=0,
+ )
+
+ def test_create_residuals_array_with_nan_data_does_not_crash(self, analysis1d_with_nan):
+ # Before the fix, residuals subtracted a 2-point model from 3-point data
+ # (including NaN) which caused a dimension mismatch crash.
+ # WHEN
+
+ # THEN
+ result = analysis1d_with_nan._create_residuals_array()
+
+ # EXPECT
+ assert isinstance(result, sc.DataArray)
+ # Only the 2 finite energy points survive the mask.
+ assert result.sizes['energy'] == 2
+
+ def test_data_and_model_to_datagroup_with_nan_excludes_nan_from_data(
+ self, analysis1d_with_nan
+ ):
+ # Before the fix, 'Data' contained the full 3-point grid (including NaN)
+ # and computing Residuals crashed on the dimension mismatch.
+ # WHEN
+ energy = sc.array(dims=['energy'], values=[20.0, 30.0, 40.0], unit='meV')
+
+ # THEN
+ datagroup = analysis1d_with_nan.data_and_model_to_datagroup(
+ energy=energy, include_residuals=True
+ )
+
+ # EXPECT
+ assert isinstance(datagroup, sc.DataGroup)
+ # 'Data' must contain only the 2 finite points.
+ assert datagroup['Data'].sizes['energy'] == 2
+ # Residuals must be present and have matching size.
+ assert 'Residuals' in datagroup
+ assert datagroup['Residuals'].sizes['energy'] == 2
+
+ def test_repr(self, analysis1d):
+ repr_str = repr(analysis1d)
+ assert 'Analysis1d' in repr_str
+ assert 'display_name=' in repr_str
+ assert 'Q_index=' in repr_str
diff --git a/tests/unit/easydynamics/analysis/test_analysis_base.py b/tests/unit/easydynamics/analysis/test_analysis_base.py
index 7507b0b35..b68210c61 100644
--- a/tests/unit/easydynamics/analysis/test_analysis_base.py
+++ b/tests/unit/easydynamics/analysis/test_analysis_base.py
@@ -92,19 +92,28 @@ def test_init_detailed_balance_settings(self):
assert analysis.detailed_balance_settings is detailed_balance_settings
def test_init_extra_parameter(self):
+ # WHEN
extra_parameter = Parameter(name='param1', value=1.0)
+
+ # THEN
analysis = AnalysisBase(extra_parameters=extra_parameter)
+ # EXPECT
assert analysis._extra_parameters == [extra_parameter]
def test_init_extra_parameters(self):
+ # WHEN
extra_parameters = [
Parameter(name='param1', value=1.0),
Parameter(name='param2', value=2.0),
]
+
+ # THEN
analysis = AnalysisBase(extra_parameters=extra_parameters)
+ # EXPECT
assert analysis._extra_parameters == extra_parameters
def test_init_calls_on_experiment_changed(self):
+ # WHEN THEN EXPECT
with patch.object(AnalysisBase, '_on_experiment_changed') as mock_on_experiment_changed:
AnalysisBase()
mock_on_experiment_changed.assert_called_once()
@@ -159,74 +168,99 @@ def test_init_calls_on_experiment_changed(self):
],
)
def test_init_invalid_inputs(self, kwargs, expected_exception, expected_message):
+ # WHEN THEN EXPECT
with pytest.raises(expected_exception, match=expected_message):
AnalysisBase(**kwargs)
def test_experiment_setter_calls_on_experiment_changed(self, analysis_base):
+ # WHEN THEN EXPECT
with patch.object(analysis_base, '_on_experiment_changed') as mock_on_experiment_changed:
new_experiment = Experiment()
analysis_base.experiment = new_experiment
mock_on_experiment_changed.assert_called_once()
def test_experiment_setter_invalid_type(self, analysis_base):
+ # WHEN THEN EXPECT
with pytest.raises(TypeError, match='experiment must be an instance of Experiment'):
analysis_base.experiment = 'not an experiment'
def test_experiment_setter_valid(self, analysis_base):
+ # WHEN
new_experiment = Experiment()
analysis_base.experiment = new_experiment
+ # THEN EXPECT
assert analysis_base.experiment == new_experiment
def test_sample_model_setter_invalid_type(self, analysis_base):
+ # WHEN THEN EXPECT
with pytest.raises(TypeError, match='sample_model must be an instance of SampleModel'):
analysis_base.sample_model = 'not a sample model'
def test_sample_model_setter_valid(self, analysis_base):
+ # WHEN
new_sample_model = SampleModel()
+
+ # THEN
analysis_base.sample_model = new_sample_model
+ # EXPECT
assert analysis_base.sample_model == new_sample_model
def test_sample_model_setter_calls_on_sample_model_changed(self, analysis_base):
+ # WHEN
+ new_sample_model = SampleModel()
with patch.object(
analysis_base, '_on_sample_model_changed'
) as mock_on_sample_model_changed:
- new_sample_model = SampleModel()
+ # THEN
analysis_base.sample_model = new_sample_model
+
+ # EXPECT
mock_on_sample_model_changed.assert_called_once()
def test_instrument_model_setter_invalid_type(self, analysis_base):
+ # WHEN THEN EXPECT
with pytest.raises(
TypeError, match='instrument_model must be an instance of InstrumentModel'
):
analysis_base.instrument_model = 'not an instrument model'
def test_instrument_model_setter_valid(self, analysis_base):
+ # WHEN
new_instrument_model = InstrumentModel()
+
+ # THEN
analysis_base.instrument_model = new_instrument_model
+
+ # EXPECT
assert analysis_base.instrument_model == new_instrument_model
def test_instrument_model_setter_calls_on_instrument_model_changed(self, analysis_base):
+ # WHEN
+ new_instrument_model = InstrumentModel()
with patch.object(
analysis_base, '_on_instrument_model_changed'
) as mock_on_instrument_model_changed:
- new_instrument_model = InstrumentModel()
+ # THEN
analysis_base.instrument_model = new_instrument_model
+
+ # EXPECT
mock_on_instrument_model_changed.assert_called_once()
def test_Q_property(self, analysis_base):
- # Create a mock Q value
+ # WHEN
fake_Q = [1, 2, 3]
-
- # Patch the 'experiment' attribute's Q property
with patch.object(
type(analysis_base.experiment), 'Q', new_callable=PropertyMock
) as mock_Q:
mock_Q.return_value = fake_Q
- result = analysis_base.Q # Access the property
+ # THEN
+ result = analysis_base.Q
+ # EXPECT
assert result == fake_Q
mock_Q.assert_called_once()
def test_Q_setter_raises(self, analysis_base):
+ # WHEN THEN EXPECT
with pytest.raises(
AttributeError,
match=r'Q is a read-only property derived from the Experiment.',
@@ -234,19 +268,20 @@ def test_Q_setter_raises(self, analysis_base):
analysis_base.Q = [1, 2, 3]
def test_energy_property(self, analysis_base):
- # Create a mock energy value
+ # WHEN
fake_energy = [10, 20, 30]
-
- # Patch the 'experiment' attribute's energy property
with patch.object(
type(analysis_base.experiment), 'energy', new_callable=PropertyMock
) as mock_energy:
mock_energy.return_value = fake_energy
- result = analysis_base.energy # Access the property
+ # THEN
+ result = analysis_base.energy
+ # EXPECT
assert result == fake_energy
mock_energy.assert_called_once()
def test_energy_setter_raises(self, analysis_base):
+ # WHEN THEN EXPECT
with pytest.raises(
AttributeError,
match=r'energy is a read-only property derived from the Experiment.',
@@ -254,30 +289,32 @@ def test_energy_setter_raises(self, analysis_base):
analysis_base.energy = [10, 20, 30]
def test_temperature_property_no_temperature(self, analysis_base):
- # Patch the 'experiment' attribute's temperature property to
- # return None
+ # WHEN
with patch.object(
type(analysis_base.sample_model), 'temperature', new_callable=PropertyMock
) as mock_temperature:
mock_temperature.return_value = None
- result = analysis_base.temperature # Access the property
+ # THEN
+ result = analysis_base.temperature
+ # EXPECT
assert result is None
mock_temperature.assert_called_once()
def test_temperature_property(self, analysis_base):
- # Create a mock temperature value
+ # WHEN
fake_temperature = 300
-
- # Patch the 'sample_model' attribute's temperature property
with patch.object(
type(analysis_base.sample_model), 'temperature', new_callable=PropertyMock
) as mock_temperature:
mock_temperature.return_value = fake_temperature
- result = analysis_base.temperature # Access the property
+ # THEN
+ result = analysis_base.temperature
+ # EXPECT
assert result == fake_temperature
mock_temperature.assert_called_once()
def test_temperature_setter_raises(self, analysis_base):
+ # WHEN THEN EXPECT
with pytest.raises(
AttributeError,
match='temperature is a read-only property',
@@ -381,6 +418,7 @@ def test_extra_parameters_property(self, analysis_base, extra_parameters):
],
)
def test_extra_parameters_setter_invalid_type(self, analysis_base, invalid_extra_parameters):
+ # WHEN THEN EXPECT
with pytest.raises(
TypeError,
match='extra_parameters must be',
@@ -392,6 +430,7 @@ def test_extra_parameters_setter_invalid_type(self, analysis_base, invalid_extra
#############
def test_normalize_resolution_calls_instrument_model(self, analysis_base):
+ # WHEN THEN EXPECT
with patch.object(
analysis_base.instrument_model, 'normalize_resolution'
) as mock_normalize_resolution:
@@ -458,12 +497,13 @@ def test_get_parameters_near_bounds_with_tolerances(self, analysis_base_with_com
def test_get_parameters_near_bounds_errors(
self, analysis_base_with_components, rtol, atol, expected_param, expected_error
):
+ # WHEN THEN
with pytest.raises(expected_error) as exc:
analysis_base_with_components.get_parameters_near_bounds(
rtol=rtol,
atol=atol,
)
-
+ # EXPECT
assert expected_param in str(exc.value)
def test_not_finite_parameters(self, analysis_base_with_components):
@@ -499,10 +539,9 @@ def test_on_experiment_changed_updates_Q(self, analysis_base):
analysis_base._on_experiment_changed()
# EXPECT
- # assert that the Q attribute was set
np.testing.assert_array_equal(analysis_base.Q, fake_Q)
- np.testing.assert_array_equal(analysis_base.sample_model.Q, fake_Q)
- np.testing.assert_array_equal(analysis_base.instrument_model.Q, fake_Q)
+ np.testing.assert_array_equal(analysis_base.sample_model.Q.values, fake_Q)
+ np.testing.assert_array_equal(analysis_base.instrument_model.Q.values, fake_Q)
def test_on_sample_model_changed_updates_Q(self, analysis_base):
# WHEN
@@ -518,37 +557,45 @@ def test_on_sample_model_changed_updates_Q(self, analysis_base):
analysis_base._on_sample_model_changed()
# EXPECT
- np.testing.assert_array_equal(analysis_base.sample_model.Q, fake_Q)
+ np.testing.assert_array_equal(analysis_base.sample_model.Q.values, fake_Q)
def test_on_instrument_model_changed_updates_Q(self, analysis_base):
+ # WHEN
fake_Q = [1, 2, 3]
-
- # Patch the Q property of analysis_base
with patch.object(
type(analysis_base.experiment), 'Q', new_callable=PropertyMock
) as mock_Q:
mock_Q.return_value = fake_Q
-
+ # THEN
analysis_base._on_instrument_model_changed()
- np.testing.assert_array_equal(analysis_base.instrument_model.Q, fake_Q)
+ # EXPECT
+ np.testing.assert_array_equal(analysis_base.instrument_model.Q.values, fake_Q)
- def test_verify_Q_index_valid(self, analysis_base):
+ def test_verify_Q_index_valid(self, analysis_base_with_components):
# WHEN
valid_Q_index = 0
# THEN
- result = analysis_base._verify_Q_index(valid_Q_index)
+ result = analysis_base_with_components._verify_Q_index(valid_Q_index)
# EXPECT
assert result == valid_Q_index
- def test_verify_Q_index_invalid(self, analysis_base):
+ def test_verify_Q_index_invalid(self, analysis_base_with_components):
# WHEN
invalid_Q_index = -1
# THEN / EXPECT
- with pytest.raises(IndexError, match='Q_index must be a valid index'):
- analysis_base._verify_Q_index(invalid_Q_index)
+ with pytest.raises(IndexError, match='Q_index -1 is out of bounds for Q of length 2'):
+ analysis_base_with_components._verify_Q_index(invalid_Q_index)
+
+ def test_verify_Q_index_invalid_when_Q_is_none(self, analysis_base):
+ # WHEN
+ positive_Q_index = 0
+
+ # THEN / EXPECT
+ with pytest.raises(ValueError, match='Q is None, cannot validate Q_index'):
+ analysis_base._verify_Q_index(positive_Q_index)
def test_repr(self, analysis_base):
# WHEN
diff --git a/tests/unit/easydynamics/analysis/test_fit_binding.py b/tests/unit/easydynamics/analysis/test_fit_binding.py
index 191d1ba1a..901ab9e50 100644
--- a/tests/unit/easydynamics/analysis/test_fit_binding.py
+++ b/tests/unit/easydynamics/analysis/test_fit_binding.py
@@ -2,288 +2,272 @@
# SPDX-License-Identifier: BSD-3-Clause
-from unittest.mock import Mock
-
+import numpy as np
import pytest
from easydynamics.analysis.fit_binding import FitBinding
from easydynamics.sample_model.components.gaussian import Gaussian
+from easydynamics.sample_model.components.polynomial import Polynomial
from easydynamics.sample_model.diffusion_model.brownian_translational_diffusion import (
BrownianTranslationalDiffusion,
)
+from easydynamics.sample_model.diffusion_model.delta_lorentz import DeltaLorentz
+from easydynamics.utils.fit_target import FitTarget
class TestFitBinding:
@pytest.fixture
- def fit_binding(self):
- model = Gaussian()
- return FitBinding(parameter_name='parameter1', model=model)
+ def component_binding(self):
+ model = Gaussian(display_name='GaussianModel')
+ return FitBinding(model=model, targets='parameter1')
@pytest.fixture
- def diffusion_fit_binding(self):
- model = BrownianTranslationalDiffusion()
- return FitBinding(parameter_name='parameter3', model=model)
-
- def test_initialization(self, fit_binding):
- # WHEN THEN EXPECT
- assert isinstance(fit_binding, FitBinding)
- assert fit_binding.parameter_name == 'parameter1'
- assert isinstance(fit_binding.model, Gaussian)
- assert fit_binding.modes is None
-
- @pytest.mark.parametrize(
- 'parameter_name, model, modes, error_msg',
- [
- # parameter_name errors
- (123, Gaussian(), None, 'parameter_name must be a string'),
- (None, Gaussian(), None, 'parameter_name must be a string'),
- # model errors
- (
- 'param',
- 123,
- None,
- 'model must be a ModelComponent, ComponentCollection, or DiffusionModelBase',
- ),
- (
- 'param',
- 'not_a_model',
- None,
- 'model must be a ModelComponent, ComponentCollection, or DiffusionModelBase',
- ),
- # modes type errors
- (
- 'param',
- Gaussian(),
- 123,
- 'modes must be a string, list of strings, or None',
- ),
- (
- 'param',
- Gaussian(),
- {'mode': 'area'},
- 'modes must be a string, list of strings, or None',
- ),
- # modes list content errors
- (
- 'param',
- Gaussian(),
- ['area', 123],
- 'All modes in the list must be strings',
- ),
- ('param', Gaussian(), [None], 'All modes in the list must be strings'),
- ],
- )
- def test_fitbinding_init_errors(self, parameter_name, model, modes, error_msg):
- with pytest.raises(TypeError, match=error_msg):
- FitBinding(
- parameter_name=parameter_name,
- model=model,
- modes=modes,
- )
+ def diffusion_binding(self):
+ model = BrownianTranslationalDiffusion(lorentzian_name='Lorentzian')
+ return FitBinding(model=model)
# ------------------------------------------------------------------
- # Properties
+ # Initialization and validation
# ------------------------------------------------------------------
- def test_parameter_name_setter(self, fit_binding):
- # WHEN
- fit_binding.parameter_name = 'new_parameter'
+ def test_initialization(self, component_binding):
+ # WHEN THEN EXPECT
+ assert isinstance(component_binding, FitBinding)
+ assert component_binding.targets == 'parameter1'
+ assert isinstance(component_binding.model, Gaussian)
- # THEN EXPECT
- assert fit_binding.parameter_name == 'new_parameter'
+ def test_initialization_invalid_model_raises(self):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match='model must be a ModelComponent'):
+ FitBinding(model='not_a_model', targets='parameter1')
- def test_parameter_name_setter_errors(self, fit_binding):
- with pytest.raises(TypeError, match='parameter_name must be a string'):
- fit_binding.parameter_name = 123
+ @pytest.mark.parametrize(
+ 'targets',
+ [None, ['parameter1'], {'value': 'parameter1'}, 123],
+ ids=['none', 'list', 'dict', 'int'],
+ )
+ def test_component_model_requires_string_targets(self, targets):
+ # WHEN THEN EXPECT: component models have a single prediction, so targets must be
+ # the dataset key
+ with pytest.raises(TypeError, match='targets must be the dataset key'):
+ FitBinding(model=Gaussian(), targets=targets)
- def test_model_setter(self, fit_binding):
- # WHEN
+ def test_diffusion_unknown_prediction_raises(self):
+ # WHEN THEN EXPECT
model = BrownianTranslationalDiffusion()
+ with pytest.raises(ValueError, match=r'Unknown prediction.*Available predictions'):
+ FitBinding(model=model, targets=['nonsense'])
- # THEN
- fit_binding.model = model
-
- # EXPECT
- assert fit_binding.model is model
+ def test_diffusion_unknown_prediction_in_dict_raises(self):
+ # WHEN THEN EXPECT
+ model = BrownianTranslationalDiffusion()
+ with pytest.raises(ValueError, match='Unknown prediction'):
+ FitBinding(model=model, targets={'delta_area': 'Elastic area'})
- def test_model_setter_errors(self, fit_binding):
- with pytest.raises(
- TypeError,
- match='model must be a ModelComponent, ComponentCollection, or DiffusionModelBase',
- ):
- fit_binding.model = 'not_a_model'
+ def test_diffusion_invalid_targets_type_raises(self):
+ # WHEN THEN EXPECT
+ model = BrownianTranslationalDiffusion()
+ with pytest.raises(TypeError, match='targets must be None'):
+ FitBinding(model=model, targets=123)
- def test_modes_setter(self, fit_binding):
- # WHEN
- fit_binding.modes = 'area'
+ def test_diffusion_non_string_dataset_key_raises(self):
+ # WHEN THEN EXPECT
+ model = BrownianTranslationalDiffusion()
+ with pytest.raises(TypeError, match='dataset keys'):
+ FitBinding(model=model, targets={'width': 123})
- # THEN EXPECT
- assert fit_binding.modes == ['area']
+ # ------------------------------------------------------------------
+ # Properties
+ # ------------------------------------------------------------------
- # WHEN
- fit_binding.modes = ['area', 'width']
+ def test_model_setter_revalidates_targets(self):
+ # WHEN: a binding using DeltaLorentz's delta_area prediction
+ binding = FitBinding(model=DeltaLorentz(), targets=['delta_area'])
- # THEN EXPECT
- assert fit_binding.modes == ['area', 'width']
+ # THEN EXPECT: switching to a model without that prediction fails
+ with pytest.raises(ValueError, match='Unknown prediction'):
+ binding.model = BrownianTranslationalDiffusion()
- def test_modes_setter_errors(self, fit_binding):
- with pytest.raises(TypeError, match='modes must be a string, list of strings, or None'):
- fit_binding.modes = 123
+ def test_model_setter_invalid_type_raises(self, component_binding):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match='model must be a ModelComponent'):
+ component_binding.model = 'not_a_model'
- with pytest.raises(TypeError, match='modes must be a string, list of strings, or None'):
- fit_binding.modes = {'mode': 'area'}
+ def test_targets_setter(self, diffusion_binding):
+ # WHEN THEN
+ diffusion_binding.targets = ['width']
- with pytest.raises(TypeError, match='All modes in the list must be strings'):
- fit_binding.modes = ['area', 123]
+ # EXPECT
+ assert [t.name for t in diffusion_binding.get_targets()] == ['width']
- with pytest.raises(TypeError, match='All modes in the list must be strings'):
- fit_binding.modes = [None]
+ def test_targets_setter_invalid_raises(self, diffusion_binding):
+ # WHEN THEN EXPECT
+ with pytest.raises(ValueError, match='Unknown prediction'):
+ diffusion_binding.targets = ['nonsense']
# ------------------------------------------------------------------
- # Other methods
+ # get_targets
# ------------------------------------------------------------------
- def test_build_callables_component(self, fit_binding):
+ def test_component_target(self, component_binding):
# WHEN
- mock_model = Mock()
- mock_model.evaluate.return_value = 1.0
- fit_binding._model = mock_model
-
- # THEN
- callables = fit_binding.build_callables()
+ targets = component_binding.get_targets()
+
+ # THEN EXPECT: a single target evaluating the component, with its units
+ assert len(targets) == 1
+ target = targets[0]
+ assert isinstance(target, FitTarget)
+ assert target.name == 'value'
+ assert target.dataset_key == 'parameter1'
+ assert target.label == 'GaussianModel'
+ assert target.x_unit == component_binding.model.x_unit
+ assert target.y_unit == component_binding.model.y_unit
+
+ def test_component_target_function_evaluates_model(self, component_binding):
+ # WHEN
+ target = component_binding.get_targets()[0]
+ x = np.linspace(-1, 1, 5)
- # EXPECT
- assert len(callables) == 1
- assert callable(callables[0])
- assert callables[0](0) == pytest.approx(1.0)
- mock_model.evaluate.assert_called_once_with(0)
+ # THEN EXPECT
+ np.testing.assert_allclose(target.function(x), component_binding.model.evaluate(x))
- def test_build_callables_diffusion(self, diffusion_fit_binding):
+ def test_diffusion_default_targets(self, diffusion_binding):
# WHEN
- mock_model = Mock(spec=BrownianTranslationalDiffusion)
- mock_model.calculate_QISF.return_value = 2.0
- mock_model.scale.value = 3.0
- mock_model.calculate_width.return_value = 0.5
- diffusion_fit_binding._model = mock_model
+ targets = diffusion_binding.get_targets()
- # THEN
- callables = diffusion_fit_binding.build_callables()
+ # THEN EXPECT: all predictions with default dataset keys from the Lorentzian name
+ assert [t.name for t in targets] == ['area', 'width']
+ assert [t.dataset_key for t in targets] == ['Lorentzian area', 'Lorentzian width']
- # EXPECT
- assert len(callables) == 2
- assert callable(callables[0])
- assert callable(callables[1])
- assert callables[0](0) == pytest.approx(6.0) # 2.0 * 3.0
- assert callables[1](0) == pytest.approx(0.5)
- mock_model.calculate_QISF.assert_called_once_with(0)
- mock_model.calculate_width.assert_called_once_with(0)
-
- def test_build_callables_diffusion_with_modes(self, diffusion_fit_binding):
+ def test_diffusion_string_target(self):
# WHEN
- diffusion_fit_binding.modes = 'area'
- mock_model = Mock(spec=BrownianTranslationalDiffusion)
- mock_model.calculate_QISF.return_value = 2.0
- mock_model.scale.value = 3.0
- diffusion_fit_binding._model = mock_model
+ model = BrownianTranslationalDiffusion(lorentzian_name='QENS')
+ binding = FitBinding(model=model, targets='width')
# THEN
- callables = diffusion_fit_binding.build_callables()
+ targets = binding.get_targets()
# EXPECT
- assert len(callables) == 1
- assert callable(callables[0])
- assert callables[0](0) == pytest.approx(6.0) # 2.0 * 3.0
- mock_model.calculate_QISF.assert_called_once_with(0)
+ assert len(targets) == 1
+ assert targets[0].name == 'width'
+ assert targets[0].dataset_key == 'QENS width'
- def test_get_model_names(self, fit_binding):
- # WHEN THEN
- model_names = fit_binding.get_model_names()
-
- # EXPECT
- assert model_names == ['Gaussian']
+ def test_diffusion_list_targets(self):
+ # WHEN
+ model = BrownianTranslationalDiffusion(lorentzian_name='QENS')
+ binding = FitBinding(model=model, targets=['width', 'area'])
+
+ # THEN EXPECT: order follows the user's list
+ assert [t.name for t in binding.get_targets()] == ['width', 'area']
+
+ def test_diffusion_dict_targets_map_dataset_keys(self):
+ # WHEN: mapping DeltaLorentz predictions to custom dataset keys
+ model = DeltaLorentz()
+ binding = FitBinding(
+ model=model,
+ targets={
+ 'width': 'L width',
+ 'area': 'L area',
+ 'delta_area': 'Elastic area',
+ },
+ )
- def test_get_model_names_diffusion(self, diffusion_fit_binding):
- # WHEN THEN
- model_names = diffusion_fit_binding.get_model_names()
+ # THEN
+ targets = {t.name: t for t in binding.get_targets()}
# EXPECT
- assert model_names == [
- 'BrownianTranslationalDiffusion area',
- 'BrownianTranslationalDiffusion width',
- ]
+ assert targets['width'].dataset_key == 'L width'
+ assert targets['area'].dataset_key == 'L area'
+ assert targets['delta_area'].dataset_key == 'Elastic area'
- def test_get_parameter_names(self, fit_binding):
- # WHEN THEN
- parameter_names = fit_binding.get_parameter_names()
+ def test_delta_lorentz_default_targets(self):
+ # WHEN: default keys derive from the component names
+ model = DeltaLorentz(lorentzian_name='Lorentzian', delta_name='Delta function')
+ binding = FitBinding(model=model)
- # EXPECT
- assert parameter_names == ['parameter1']
-
- def test_get_parameter_names_diffusion(self, diffusion_fit_binding):
- # WHEN THEN
- parameter_names = diffusion_fit_binding.get_parameter_names()
+ # THEN
+ targets = {t.name: t for t in binding.get_targets()}
# EXPECT
- assert parameter_names == ['parameter3 area', 'parameter3 width']
+ assert set(targets) == {'area', 'width', 'delta_area'}
+ assert targets['area'].dataset_key == 'Lorentzian area'
+ assert targets['width'].dataset_key == 'Lorentzian width'
+ assert targets['delta_area'].dataset_key == 'Delta function area'
- # ------------------------------------------------------------------
- # Private methods
- # ------------------------------------------------------------------
+ def test_diffusion_target_units(self, diffusion_binding):
+ # WHEN
+ targets = {t.name: t for t in diffusion_binding.get_targets()}
- def test_build_diffusion_callable(self, diffusion_fit_binding):
+ # THEN EXPECT: predictions take Q in 1/angstrom; widths are in the model's x_unit and
+ # areas in the scale unit
+ assert targets['width'].x_unit == '1/angstrom'
+ assert targets['width'].y_unit == diffusion_binding.model.x_unit
+ assert targets['area'].y_unit == str(diffusion_binding.model.scale.unit)
- # WHEN
- mock_model = Mock()
- mock_model.calculate_QISF.return_value = 2.0
- mock_model.scale.value = 3.0
- mock_model.calculate_width.return_value = 0.5
- diffusion_fit_binding._model = mock_model
+ def test_diffusion_target_units_track_conversions(self, diffusion_binding):
+ # WHEN: targets are snapshots of the live model, so unit conversions are reflected
+ diffusion_binding.model.convert_x_unit('ueV')
# THEN
- area_callable = diffusion_fit_binding._build_diffusion_callable(mode='area')
- width_callable = diffusion_fit_binding._build_diffusion_callable(mode='width')
+ targets = {t.name: t for t in diffusion_binding.get_targets()}
# EXPECT
- assert area_callable(0) == pytest.approx(6.0) # 2.0 * 3.0
- mock_model.calculate_QISF.assert_called_once_with(0)
+ assert targets['width'].y_unit == 'ueV'
- assert width_callable(0) == pytest.approx(0.5)
- mock_model.calculate_width.assert_called_once_with(0)
+ def test_diffusion_target_functions(self):
+ # WHEN
+ model = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, scale=0.5)
+ binding = FitBinding(model=model)
+ Q = np.array([0.5, 1.0])
# THEN
- result_area = area_callable(0, unused_arg=123)
- result_width = width_callable(0, unused_arg=123)
+ targets = {t.name: t for t in binding.get_targets()}
# EXPECT
- assert result_area == pytest.approx(6.0) # Should ignore unused_arg
- assert result_width == pytest.approx(0.5) # Should ignore unused_arg
-
- def test_build_diffusion_callable_errors(self, diffusion_fit_binding):
- with pytest.raises(ValueError, match='Unknown diffusion mode: invalid_mode'):
- diffusion_fit_binding._build_diffusion_callable(mode='invalid_mode')
+ np.testing.assert_allclose(targets['width'].function(Q), model.calculate_width(Q))
+ np.testing.assert_allclose(
+ targets['area'].function(Q), model.calculate_QISF(Q) * model.scale.value
+ )
- def test_get_modes(self, diffusion_fit_binding):
+ def test_delta_lorentz_delta_area_function(self):
# WHEN
- modes = diffusion_fit_binding._get_modes()
-
- # EXPECT
- assert modes == ['area', 'width']
+ model = DeltaLorentz(A_0=0.5, mean_u_squared=0.3, scale=2.0, lorentzian_width=0.1)
+ binding = FitBinding(model=model, targets=['delta_area'])
+ Q = np.array([0.5, 1.0])
# THEN
- diffusion_fit_binding.modes = 'area'
- modes = diffusion_fit_binding._get_modes()
+ target = binding.get_targets()[0]
+
# EXPECT
- assert modes == ['area']
+ np.testing.assert_allclose(target.function(Q), model.calculate_EISF(Q) * model.scale.value)
# ------------------------------------------------------------------
# dunder methods
# ------------------------------------------------------------------
- def test_repr(self, fit_binding):
- # WHEN
- repr_str = repr(fit_binding)
- # THEN EXPECT
+ def test_repr(self, diffusion_binding):
+ # WHEN THEN
+ repr_str = repr(diffusion_binding)
+
+ # EXPECT
assert 'FitBinding' in repr_str
- assert "parameter_name='parameter1'" in repr_str
- assert "model='Gaussian'" in repr_str
- assert 'modes=None' in repr_str
+ assert 'model=' in repr_str
+ assert 'targets=' in repr_str
+
+
+class TestFitBindingWorkflows:
+ """End-to-end regression tests for the standard ParameterAnalysis workflows."""
+
+ def test_polynomial_targets_gaussian_area(self):
+ # WHEN: fitting a Polynomial to a 'Gaussian area' dataset key
+ polynomial = Polynomial(
+ coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV', display_name='Line'
+ )
+ binding = FitBinding(model=polynomial, targets='Gaussian area')
+
+ # THEN
+ target = binding.get_targets()[0]
+
+ # EXPECT
+ assert target.dataset_key == 'Gaussian area'
+ assert target.label == 'Line'
diff --git a/tests/unit/easydynamics/analysis/test_parameter_analysis.py b/tests/unit/easydynamics/analysis/test_parameter_analysis.py
index 14d9c58c6..031f813cf 100644
--- a/tests/unit/easydynamics/analysis/test_parameter_analysis.py
+++ b/tests/unit/easydynamics/analysis/test_parameter_analysis.py
@@ -17,12 +17,25 @@
from easydynamics.sample_model.diffusion_model.brownian_translational_diffusion import (
BrownianTranslationalDiffusion,
)
+from easydynamics.utils.fit_target import FitTarget
+
+
+def make_target(dataset_key, function, label, x_unit=None, y_unit=None, name='value'):
+ """Build a FitTarget for mocking FitBinding.get_targets in tests."""
+ return FitTarget(
+ name=name,
+ dataset_key=dataset_key,
+ function=function,
+ label=label,
+ x_unit=x_unit,
+ y_unit=y_unit,
+ )
class TestParameterAnalysis:
@pytest.fixture
def dataset(self):
- Q = sc.array(dims=['Q'], values=[0.1, 0.2])
+ Q = sc.array(dims=['Q'], values=[0.1, 0.2], unit='1/angstrom')
return sc.Dataset(
data={
'parameter1': sc.DataArray(
@@ -63,11 +76,14 @@ def mock_model_dataset(self):
@pytest.fixture
def parameter_analysis(self, dataset):
- model = Polynomial(coefficients=[1.0, 0.5])
+ model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV')
diffusion_model = BrownianTranslationalDiffusion()
- fit_binding1 = FitBinding(parameter_name='parameter1', model=model)
- fit_binding2 = FitBinding(parameter_name='parameter3', model=diffusion_model)
+ fit_binding1 = FitBinding(model=model, targets='parameter1')
+ fit_binding2 = FitBinding(
+ model=diffusion_model,
+ targets={'area': 'parameter3 area', 'width': 'parameter3 width'},
+ )
return ParameterAnalysis(parameters=dataset, bindings=[fit_binding1, fit_binding2])
@@ -75,8 +91,11 @@ def test_initialization(self, parameter_analysis):
# WHEN THEN EXPECT
assert isinstance(parameter_analysis, ParameterAnalysis)
assert len(parameter_analysis.bindings) == 2
- assert parameter_analysis.bindings[0].parameter_name == 'parameter1'
- assert parameter_analysis.bindings[1].parameter_name == 'parameter3'
+ assert parameter_analysis.bindings[0].targets == 'parameter1'
+ assert parameter_analysis.bindings[1].targets == {
+ 'area': 'parameter3 area',
+ 'width': 'parameter3 width',
+ }
def test_parameter_property(self, parameter_analysis):
# WHEN
@@ -133,7 +152,7 @@ def test_bindings_property(self, parameter_analysis):
# WHEN
model = Polynomial(coefficients=[2.0, 1.0])
- new_binding = FitBinding(parameter_name='parameter2', model=model)
+ new_binding = FitBinding(model=model, targets='parameter2')
parameter_analysis.bindings = new_binding
# THEN EXPECT
@@ -162,7 +181,7 @@ def test_fit_no_parameters_raises(self, parameter_analysis):
def test_fit_wrong_parameter_name_raises(self, parameter_analysis):
# WHEN
model = Polynomial(coefficients=[2.0, 1.0])
- incorrect_binding = FitBinding(parameter_name='nonexistent_parameter', model=model)
+ incorrect_binding = FitBinding(model=model, targets='nonexistent_parameter')
parameter_analysis.bindings = incorrect_binding
# THEN EXPECT
@@ -172,6 +191,87 @@ def test_fit_wrong_parameter_name_raises(self, parameter_analysis):
):
parameter_analysis.fit()
+ def test_fit_incompatible_x_unit_raises(self, dataset):
+ # WHEN: Polynomial has x_unit='meV' but Q coordinate has unit '1/angstrom'
+ model = Polynomial(coefficients=[1.0, 0.5], x_unit='meV', y_unit='meV')
+ binding = FitBinding(model=model, targets='parameter1')
+ pa = ParameterAnalysis(parameters=dataset, bindings=binding)
+
+ # THEN EXPECT
+ with pytest.raises(Exception, match=r'Q coordinate unit .* is incompatible'):
+ pa.fit()
+
+ def test_fit_incompatible_y_unit_raises(self, dataset):
+ # WHEN: Polynomial has y_unit='1/angstrom' but parameter1 has unit 'meV'
+ model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='1/angstrom')
+ binding = FitBinding(model=model, targets='parameter1')
+ pa = ParameterAnalysis(parameters=dataset, bindings=binding)
+
+ # THEN EXPECT
+ with pytest.raises(Exception, match="Parameter 'parameter1' unit 'meV' is incompatible"):
+ pa.fit()
+
+ def test_fit_converts_x_unit(self):
+ # WHEN: Q is in '1/m' but the model declares x_unit='1/angstrom'.
+ # 1 1/m = 1e-10 1/angstrom, so values [1e10, 2e10] 1/m → [1.0, 2.0] 1/angstrom.
+ Q = sc.array(dims=['Q'], values=[1e10, 2e10], unit='1/m')
+ dataset = sc.Dataset(
+ data={
+ 'param': sc.DataArray(
+ data=sc.array(dims=['Q'], values=[1.0, 2.0], variances=[0.1, 0.2], unit='meV'),
+ coords={'Q': Q},
+ )
+ }
+ )
+ model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV')
+ pa = ParameterAnalysis(
+ parameters=dataset, bindings=FitBinding(model=model, targets='param')
+ )
+
+ with patch('easydynamics.analysis.parameter_analysis.MultiFitter') as MockMultiFitter:
+ MockMultiFitter.return_value.fit.return_value = MagicMock()
+ # THEN
+ pa.fit()
+ x_passed = MockMultiFitter.return_value.fit.call_args.kwargs['x'][0]
+
+ # EXPECT: x values converted from 1/m to 1/angstrom
+ np.testing.assert_allclose(x_passed, [1.0, 2.0])
+
+ def test_fit_converts_y_unit(self):
+ # WHEN: parameter is in 'eV' but model declares y_unit='meV'.
+ # 1 eV = 1000 meV, so values [0.001, 0.002] eV → [1.0, 2.0] meV.
+ # Weights (= 1/std) scale by 1/y_factor: sqrt(1e-6)=1e-3 eV^-1 → 1.0 meV^-1.
+ Q = sc.array(dims=['Q'], values=[0.1, 0.2], unit='1/angstrom')
+ dataset = sc.Dataset(
+ data={
+ 'param': sc.DataArray(
+ data=sc.array(
+ dims=['Q'],
+ values=[0.001, 0.002],
+ variances=[1e-6, 4e-6],
+ unit='eV',
+ ),
+ coords={'Q': Q},
+ )
+ }
+ )
+ model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV')
+ pa = ParameterAnalysis(
+ parameters=dataset, bindings=FitBinding(model=model, targets='param')
+ )
+
+ with patch('easydynamics.analysis.parameter_analysis.MultiFitter') as MockMultiFitter:
+ MockMultiFitter.return_value.fit.return_value = MagicMock()
+ # THEN
+ pa.fit()
+ kwargs = MockMultiFitter.return_value.fit.call_args.kwargs
+ y_passed = kwargs['y'][0]
+ w_passed = kwargs['weights'][0]
+
+ # EXPECT: y values converted from eV to meV; weights scaled inversely
+ np.testing.assert_allclose(y_passed, [1.0, 2.0])
+ np.testing.assert_allclose(w_passed, [1.0, 0.5])
+
def test_fit_success(self, parameter_analysis):
# WHEN
mock_result = MagicMock()
@@ -469,9 +569,11 @@ def test_calculate_model_dataset_missing_parameter_raises(self, parameter_analys
binding = parameter_analysis.bindings[0]
with (
- patch.object(binding, 'get_parameter_names', return_value=['missing_name']),
- patch.object(binding, 'get_model_names', return_value=['model']),
- patch.object(binding, 'build_callables', return_value=[lambda x: x]),
+ patch.object(
+ binding,
+ 'get_targets',
+ return_value=[make_target('missing_name', lambda x: x, 'model')],
+ ),
pytest.raises(ValueError, match='not found in parameters Dataset'),
):
parameter_analysis.calculate_model_dataset([binding])
@@ -483,10 +585,10 @@ def test_calculate_model_dataset_single_binding(self, parameter_analysis):
# Mock callable to return predictable output
mock_callable = MagicMock(return_value=np.array([10.0, 20.0]))
- with (
- patch.object(binding, 'build_callables', return_value=[mock_callable]),
- patch.object(binding, 'get_model_names', return_value=['model1']),
- patch.object(binding, 'get_parameter_names', return_value=['parameter1']),
+ with patch.object(
+ binding,
+ 'get_targets',
+ return_value=[make_target('parameter1', mock_callable, 'model1')],
):
# THEN
result = parameter_analysis.calculate_model_dataset([binding])
@@ -516,12 +618,16 @@ def test_calculate_model_dataset_multiple_bindings(self, parameter_analysis):
mock_callable2 = MagicMock(return_value=np.array([30.0, 40.0]))
with (
- patch.object(binding1, 'build_callables', return_value=[mock_callable1]),
- patch.object(binding1, 'get_model_names', return_value=['model1']),
- patch.object(binding1, 'get_parameter_names', return_value=['parameter1']),
- patch.object(binding2, 'build_callables', return_value=[mock_callable2]),
- patch.object(binding2, 'get_model_names', return_value=['model2']),
- patch.object(binding2, 'get_parameter_names', return_value=['parameter2']),
+ patch.object(
+ binding1,
+ 'get_targets',
+ return_value=[make_target('parameter1', mock_callable1, 'model1')],
+ ),
+ patch.object(
+ binding2,
+ 'get_targets',
+ return_value=[make_target('parameter2', mock_callable2, 'model2')],
+ ),
):
# THEN
result = parameter_analysis.calculate_model_dataset([binding1, binding2])
@@ -562,19 +668,18 @@ def test_calculate_model_dataset_multiple_bindings_diffusion(self, parameter_ana
mock_callable2w = MagicMock(return_value=np.array([50.0, 60.0]))
with (
- patch.object(binding1, 'build_callables', return_value=[mock_callable1]),
- patch.object(binding1, 'get_model_names', return_value=['model1']),
- patch.object(binding1, 'get_parameter_names', return_value=['parameter1']),
patch.object(
- binding2,
- 'build_callables',
- return_value=[mock_callable2a, mock_callable2w],
+ binding1,
+ 'get_targets',
+ return_value=[make_target('parameter1', mock_callable1, 'model1')],
),
- patch.object(binding2, 'get_model_names', return_value=['model2a', 'model2w']),
patch.object(
binding2,
- 'get_parameter_names',
- return_value=['parameter3 area', 'parameter3 width'],
+ 'get_targets',
+ return_value=[
+ make_target('parameter3 area', mock_callable2a, 'model2a', name='area'),
+ make_target('parameter3 width', mock_callable2w, 'model2w', name='width'),
+ ],
),
):
# THEN
@@ -614,10 +719,76 @@ def test_calculate_model_dataset_multiple_bindings_diffusion(self, parameter_ana
np.testing.assert_allclose(args[0], [0.1, 0.2])
assert kwargs == {}
+ def test_calculate_model_dataset_converts_x_unit(self):
+ # WHEN: Q is in '1/m' but model declares x_unit='1/angstrom'.
+ # Values [1e10, 2e10] 1/m → [1.0, 2.0] 1/angstrom passed to callable.
+ Q = sc.array(dims=['Q'], values=[1e10, 2e10], unit='1/m')
+ dataset = sc.Dataset(
+ data={
+ 'param': sc.DataArray(
+ data=sc.array(dims=['Q'], values=[1.0, 2.0], unit='meV'),
+ coords={'Q': Q},
+ )
+ }
+ )
+ model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV')
+ binding = FitBinding(model=model, targets='param')
+ pa = ParameterAnalysis(parameters=dataset, bindings=binding)
+
+ mock_callable = MagicMock(return_value=np.array([10.0, 20.0]))
+ # THEN
+ with patch.object(
+ binding,
+ 'get_targets',
+ return_value=[
+ make_target(
+ 'param', mock_callable, 'Polynomial', x_unit='1/angstrom', y_unit='meV'
+ )
+ ],
+ ):
+ pa.calculate_model_dataset([binding])
+
+ # EXPECT: callable received x in 1/angstrom, not raw 1/m values
+ args, _ = mock_callable.call_args
+ np.testing.assert_allclose(args[0], [1.0, 2.0])
+
+ def test_calculate_model_dataset_converts_y_unit(self):
+ # WHEN: parameter is in 'eV' but model declares y_unit='meV'.
+ # Callable returns [1.0, 2.0] meV → stored as [0.001, 0.002] eV in the DataArray.
+ Q = sc.array(dims=['Q'], values=[0.1, 0.2], unit='1/angstrom')
+ dataset = sc.Dataset(
+ data={
+ 'param': sc.DataArray(
+ data=sc.array(dims=['Q'], values=[0.001, 0.002], unit='eV'),
+ coords={'Q': Q},
+ )
+ }
+ )
+ model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV')
+ binding = FitBinding(model=model, targets='param')
+ pa = ParameterAnalysis(parameters=dataset, bindings=binding)
+
+ mock_callable = MagicMock(return_value=np.array([1.0, 2.0])) # meV
+ # THEN
+ with patch.object(
+ binding,
+ 'get_targets',
+ return_value=[
+ make_target(
+ 'param', mock_callable, 'Polynomial', x_unit='1/angstrom', y_unit='meV'
+ )
+ ],
+ ):
+ result = pa.calculate_model_dataset([binding])
+
+ # EXPECT: model output converted from meV back to eV
+ np.testing.assert_allclose(result['Polynomial'].values, [0.001, 0.002])
+ assert result['Polynomial'].unit == sc.Unit('eV')
+
def test_append_binding(self, parameter_analysis):
# WHEN
model = Polynomial(coefficients=[2.0, 1.0])
- new_binding = FitBinding(parameter_name='parameter2', model=model)
+ new_binding = FitBinding(model=model, targets='parameter2')
# THEN
parameter_analysis.append_binding(new_binding)
@@ -666,8 +837,8 @@ def test_get_all_variables_overlapping_variables(self, parameter_analysis):
# Create two bindings with overlapping variables
model1 = Gaussian(display_name='model1')
- binding1 = FitBinding(parameter_name='parameter1', model=model1)
- binding2 = FitBinding(parameter_name='parameter2', model=model1)
+ binding1 = FitBinding(model=model1, targets='parameter1')
+ binding2 = FitBinding(model=model1, targets='parameter2')
parameter_analysis.bindings = [binding1, binding2]
@@ -832,14 +1003,14 @@ def test_get_xyweight_from_dataset_wrong_parameter_name_raises(self, parameter_a
parameter_analysis._get_xyweight_from_dataset('nonexistent_parameter')
@pytest.mark.parametrize(
- 'non_finite_variance',
- [np.inf, -np.inf, np.nan, -1.0, 0.0],
- ids=['inf', '-inf', 'nan', 'negative', 'zero'],
+ 'bad_variance',
+ [np.inf, -np.inf, -1.0, 0.0],
+ ids=['inf', '-inf', 'negative', 'zero'],
)
def test_get_xyweight_from_dataset_non_finite_weights_raises(
- self, parameter_analysis, non_finite_variance
+ self, parameter_analysis, bad_variance
):
- # WHEN
+ # Non-NaN invalid variances (inf, negative, zero) should still raise.
Q = sc.array(dims=['Q'], values=[0.1, 0.2])
parameter_analysis.parameters = sc.Dataset(
data={
@@ -847,7 +1018,7 @@ def test_get_xyweight_from_dataset_non_finite_weights_raises(
data=sc.array(
dims=['Q'],
values=[1.0, 2.0],
- variances=[1.0, non_finite_variance],
+ variances=[1.0, bad_variance],
unit='meV',
),
coords={'Q': Q},
@@ -861,6 +1032,33 @@ def test_get_xyweight_from_dataset_non_finite_weights_raises(
):
parameter_analysis._get_xyweight_from_dataset('parameter1')
+ def test_get_xyweight_from_dataset_nan_variance_filters_row(self, parameter_analysis):
+ # NaN variances arise when a parameter is absent for a given Q; those rows are filtered.
+
+ # WHEN
+ Q = sc.array(dims=['Q'], values=[0.1, 0.2])
+ parameter_analysis.parameters = sc.Dataset(
+ data={
+ 'parameter1': sc.DataArray(
+ data=sc.array(
+ dims=['Q'],
+ values=[1.0, np.nan],
+ variances=[0.25, np.nan],
+ unit='meV',
+ ),
+ coords={'Q': Q},
+ ),
+ }
+ )
+
+ # THEN
+ x, y, w = parameter_analysis._get_xyweight_from_dataset('parameter1')
+
+ # EXPECT
+ np.testing.assert_allclose(x, [0.1])
+ np.testing.assert_allclose(y, [1.0])
+ np.testing.assert_allclose(w, [1 / np.sqrt(0.25)])
+
def test_get_xyweight_from_dataset_valid(self, parameter_analysis):
# WHEN THEN
x, y, w = parameter_analysis._get_xyweight_from_dataset('parameter1')
@@ -871,6 +1069,74 @@ def test_get_xyweight_from_dataset_valid(self, parameter_analysis):
expected_w = 1 / np.sqrt([0.1, 0.2])
np.testing.assert_allclose(w, expected_w)
+ def test_get_xyweight_from_dataset_no_variances(self, parameter_analysis):
+ # WHEN
+ Q = sc.array(dims=['Q'], values=[0.1, 0.2], unit='1/angstrom')
+ parameter_analysis.parameters = sc.Dataset(
+ data={
+ 'parameter1': sc.DataArray(
+ data=sc.array(dims=['Q'], values=[1.0, 2.0], unit='meV'),
+ coords={'Q': Q},
+ ),
+ }
+ )
+
+ # THEN
+ x, y, w = parameter_analysis._get_xyweight_from_dataset('parameter1')
+
+ # EXPECT
+ np.testing.assert_allclose(x, [0.1, 0.2])
+ np.testing.assert_allclose(y, [1.0, 2.0])
+ np.testing.assert_allclose(w, [1.0, 1.0])
+
+ def test_get_xyweight_from_dataset_all_nan_variances_raises(self, parameter_analysis):
+ # WHEN
+ Q = sc.array(dims=['Q'], values=[0.1, 0.2], unit='1/angstrom')
+ parameter_analysis.parameters = sc.Dataset(
+ data={
+ 'parameter1': sc.DataArray(
+ data=sc.array(
+ dims=['Q'],
+ values=[np.nan, np.nan],
+ variances=[np.nan, np.nan],
+ unit='meV',
+ ),
+ coords={'Q': Q},
+ ),
+ }
+ )
+
+ # THEN EXPECT
+ with pytest.raises(
+ ValueError,
+ match="No finite positive variances found for parameter 'parameter1'",
+ ):
+ parameter_analysis._get_xyweight_from_dataset('parameter1')
+
+ def test_get_unit_conversions_none_x_unit(self, parameter_analysis):
+ # WHEN
+ model = Polynomial(coefficients=[1.0], x_unit=None, y_unit='meV')
+ binding = FitBinding(model=model, targets='parameter1')
+
+ # THEN
+ x_factor, y_factor = parameter_analysis._get_unit_conversions(binding.get_targets()[0])
+
+ # EXPECT
+ assert x_factor == pytest.approx(1.0)
+ assert y_factor == pytest.approx(1.0)
+
+ def test_get_unit_conversions_none_y_unit(self, parameter_analysis):
+ # WHEN
+ model = Polynomial(coefficients=[1.0], x_unit='1/angstrom', y_unit=None)
+ binding = FitBinding(model=model, targets='parameter1')
+
+ # THEN
+ x_factor, y_factor = parameter_analysis._get_unit_conversions(binding.get_targets()[0])
+
+ # EXPECT
+ assert x_factor == pytest.approx(1.0)
+ assert y_factor == pytest.approx(1.0)
+
def test_repr(self, parameter_analysis):
# WHEN
repr_str = repr(parameter_analysis)
@@ -883,3 +1149,102 @@ def test_repr(self, parameter_analysis):
assert f'n_parameters={len(parameter_analysis.parameters)}' in repr_str
assert 'parameter_names=' in repr_str
assert 'bindings=' in repr_str
+
+
+class TestParameterAnalysisWorkflows:
+ """End-to-end fits for the standard workflows on synthetic data."""
+
+ @staticmethod
+ def _dataset_from_targets(model, Q, unit_overrides=None):
+ """Build a parameters Dataset from a model's own predictions (no noise)."""
+ unit_overrides = unit_overrides or {}
+ Q_coord = sc.array(dims=['Q'], values=Q, unit='1/angstrom')
+ data = {}
+ for target in model.get_fit_targets():
+ values = target.function(Q)
+ unit = unit_overrides.get(target.name, target.y_unit)
+ factor = sc.to_unit(sc.scalar(1.0, unit=target.y_unit), unit).value
+ data[target.dataset_key] = sc.DataArray(
+ data=sc.array(
+ dims=['Q'],
+ values=values * factor,
+ variances=(0.01 * values * factor) ** 2,
+ unit=unit,
+ ),
+ coords={'Q': Q_coord},
+ )
+ return sc.Dataset(data)
+
+ def test_delta_lorentz_three_target_simultaneous_fit(self):
+ # WHEN: synthetic width, area, and delta area curves from a known DeltaLorentz
+ from easydynamics.sample_model.diffusion_model.delta_lorentz import DeltaLorentz
+
+ Q = np.linspace(0.4, 2.0, 9)
+ truth = DeltaLorentz(scale=2.0, mean_u_squared=0.3, A_0=0.6, lorentzian_width=0.12)
+ dataset = self._dataset_from_targets(truth, Q)
+
+ # THEN: fit a fresh DeltaLorentz to all three dataset keys simultaneously
+ fit_model = DeltaLorentz(scale=1.5, mean_u_squared=0.2, A_0=0.5, lorentzian_width=0.1)
+ pa = ParameterAnalysis(parameters=dataset, bindings=FitBinding(model=fit_model))
+ pa.fit()
+
+ # EXPECT: the shared parameters are recovered across all three curves
+ assert fit_model.scale.value == pytest.approx(2.0, rel=1e-3)
+ assert fit_model.mean_u_squared.value == pytest.approx(0.3, rel=1e-3)
+ assert fit_model.A_0.value == pytest.approx(0.6, rel=1e-3)
+ assert fit_model.lorentzian_width.value == pytest.approx(0.12, rel=1e-3)
+
+ def test_jump_diffusion_width_only_fit(self):
+ # WHEN: synthetic widths from a known jump diffusion model
+ from easydynamics.sample_model.diffusion_model.jump_translational_diffusion import (
+ JumpTranslationalDiffusion,
+ )
+
+ Q = np.linspace(0.4, 2.0, 9)
+ truth = JumpTranslationalDiffusion(diffusion_coefficient=2.4e-9, relaxation_time=2.0)
+ dataset = self._dataset_from_targets(truth, Q)
+
+ # THEN: fit only the width prediction
+ fit_model = JumpTranslationalDiffusion(diffusion_coefficient=1.0e-9, relaxation_time=1.0)
+ fit_model.scale.fixed = True
+ pa = ParameterAnalysis(
+ parameters=dataset, bindings=FitBinding(model=fit_model, targets=['width'])
+ )
+ pa.fit()
+
+ # EXPECT
+ assert fit_model.diffusion_coefficient.value == pytest.approx(2.4e-9, rel=1e-3)
+ assert fit_model.relaxation_time.value == pytest.approx(2.0, rel=1e-3)
+
+ def test_diffusion_fit_converts_dataset_units(self):
+ # WHEN: the dataset stores widths in ueV and Q in 1/nm, while the model works in meV
+ # and 1/angstrom (regression: diffusion fits used to ignore units entirely, silently
+ # misfitting scaled data)
+ Q = np.linspace(0.4, 2.0, 9)
+ truth = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9)
+ width_target = {t.name: t for t in truth.get_fit_targets()}['width']
+ widths_mev = width_target.function(Q)
+
+ Q_coord = sc.array(dims=['Q'], values=Q * 10, unit='1/nm') # 1 1/angstrom = 10 1/nm
+ dataset = sc.Dataset({
+ width_target.dataset_key: sc.DataArray(
+ data=sc.array(
+ dims=['Q'],
+ values=widths_mev * 1000, # meV -> ueV
+ variances=(0.01 * widths_mev * 1000) ** 2,
+ unit='ueV',
+ ),
+ coords={'Q': Q_coord},
+ )
+ })
+
+ # THEN
+ fit_model = BrownianTranslationalDiffusion(diffusion_coefficient=1.0e-9)
+ fit_model.scale.fixed = True
+ pa = ParameterAnalysis(
+ parameters=dataset, bindings=FitBinding(model=fit_model, targets=['width'])
+ )
+ pa.fit()
+
+ # EXPECT: the diffusion coefficient is recovered despite the unit differences
+ assert fit_model.diffusion_coefficient.value == pytest.approx(2.4e-9, rel=1e-3)
diff --git a/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py b/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py
index 6767640a4..c205c9959 100644
--- a/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py
+++ b/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py
@@ -13,7 +13,7 @@ class TestEasyDynamicsModelBase:
def easy_dynamics_modelbase(self):
"""Fixture for creating an instance of EasyDynamicsModelBase."""
- return EasyDynamicsModelBase(name='TestModel', unit='meV')
+ return EasyDynamicsModelBase(name='TestModel', x_unit='meV')
def test_initialization(self, easy_dynamics_modelbase):
"""Test that the EasyDynamicsModelBase is initialized correctly."""
@@ -22,6 +22,8 @@ def test_initialization(self, easy_dynamics_modelbase):
assert easy_dynamics_modelbase.name == 'TestModel'
assert easy_dynamics_modelbase.display_name == 'TestModel'
assert easy_dynamics_modelbase.unique_name is not None
+ assert easy_dynamics_modelbase.x_unit == 'meV'
+ assert easy_dynamics_modelbase.y_unit == 'dimensionless'
def test_init_raises_type_error_for_invalid_name(self):
"""Test that initializing with an invalid name raises a TypeError."""
@@ -68,9 +70,15 @@ def test_name_setter_invalid_type(self, easy_dynamics_modelbase, invalid_name):
def test_unit_property(self, easy_dynamics_modelbase):
# WHEN THEN EXPECT
- assert easy_dynamics_modelbase.unit == 'meV'
+ assert easy_dynamics_modelbase.x_unit == 'meV'
+ assert easy_dynamics_modelbase.y_unit == 'dimensionless'
- def test_unit_setter_raises(self, easy_dynamics_modelbase):
+ def test_x_unit_setter_raises(self, easy_dynamics_modelbase):
# WHEN / THEN / EXPECT
- with pytest.raises(AttributeError, match='Use convert_unit to change '):
- easy_dynamics_modelbase.unit = 'K'
+ with pytest.raises(AttributeError, match='read-only'):
+ easy_dynamics_modelbase.x_unit = 'K'
+
+ def test_y_unit_setter_raises(self, easy_dynamics_modelbase):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError, match='read-only'):
+ easy_dynamics_modelbase.y_unit = '1/meV'
diff --git a/tests/unit/easydynamics/convolution/test_analytical_convolution.py b/tests/unit/easydynamics/convolution/test_analytical_convolution.py
index c9d3f46e2..55c6cc297 100644
--- a/tests/unit/easydynamics/convolution/test_analytical_convolution.py
+++ b/tests/unit/easydynamics/convolution/test_analytical_convolution.py
@@ -289,7 +289,6 @@ def test_convolute_analytic_pair_delta(self, default_analytical_convolution, fun
"""
# WHEN THEN
delta_function = DeltaFunction(area=2.0, center=0.5)
- function1 = Gaussian(area=3.0, center=-1.0, width=1.0)
convoluted = default_analytical_convolution._convolute_analytic_pair(
delta_function, function1
diff --git a/tests/unit/easydynamics/convolution/test_convolution.py b/tests/unit/easydynamics/convolution/test_convolution.py
index dbb4290d4..2c4c5da3c 100644
--- a/tests/unit/easydynamics/convolution/test_convolution.py
+++ b/tests/unit/easydynamics/convolution/test_convolution.py
@@ -73,7 +73,8 @@ def test_init(self, default_convolution):
assert default_convolution.upsample_factor == 5
assert default_convolution.extension_factor == pytest.approx(0.2)
assert default_convolution.temperature is None
- assert default_convolution.unit == 'meV'
+ assert default_convolution.x_unit == 'meV'
+ assert default_convolution.y_unit == 'dimensionless'
assert default_convolution.detailed_balance_settings.normalize_detailed_balance is True
assert isinstance(default_convolution._energy_grid, EnergyGrid)
@@ -107,7 +108,8 @@ def test_init_components(self, convolution_with_components):
assert convolution_with_components.upsample_factor == 5
assert convolution_with_components.extension_factor == pytest.approx(0.2)
assert convolution_with_components.temperature is None
- assert convolution_with_components.unit == 'meV'
+ assert convolution_with_components.x_unit == 'meV'
+ assert convolution_with_components.y_unit == 'dimensionless'
assert (
convolution_with_components.detailed_balance_settings.normalize_detailed_balance
is True
@@ -403,7 +405,7 @@ def test_check_if_pair_is_analytic_raises_with_delta_in_resolution(self, default
)
@pytest.mark.parametrize('delta_component', [True, False], ids=['with_delta', 'without_delta'])
@pytest.mark.parametrize(
- 'temperature', [None, 100], ids=['with_temperature', 'without_temperature']
+ 'temperature', [None, 100], ids=['without_temperature', 'with_temperature']
)
def test_build_convolution_plan(
self,
@@ -577,3 +579,54 @@ def test_setattr_invalidates_plan_for_tracked_attribute(
# EXPECT
assert conv.convolution_settings.convolution_plan_is_valid is False
+
+ def test_convert_y_unit_propagates_to_sub_convolvers(self):
+ # WHEN: Convolution with Gaussian sample and resolution components,
+ # both with y_unit='1/meV'
+ energy = np.linspace(-10, 10, 5001)
+ sample_components = ComponentCollection()
+ sample_components.append_component(
+ Gaussian(name='G', area=1.0, center=0.0, width=0.5, x_unit='meV', y_unit='1/meV')
+ )
+ resolution_components = ComponentCollection()
+ resolution_components.append_component(
+ Gaussian(name='R', area=1.0, center=0.0, width=0.3, x_unit='meV', y_unit='1/meV')
+ )
+ conv = Convolution(
+ energy=energy,
+ sample_components=sample_components,
+ resolution_components=resolution_components,
+ )
+
+ # THEN: convert y_unit to '1/eV'
+ conv.convert_y_unit('1/eV')
+
+ # EXPECT: y_unit updated and propagated to sub-convolvers via Convolution.convert_y_unit
+ assert conv.y_unit == '1/eV'
+ assert conv._analytical_convolver._y_unit == '1/eV'
+
+ def test_convert_y_unit_propagates_to_numerical_convolver(self):
+ # WHEN: a DHO sample component forces a numerical convolver
+ energy = np.linspace(-10, 10, 5001)
+ sample_components = ComponentCollection()
+ sample_components.append_component(
+ DampedHarmonicOscillator(
+ name='DHO', area=1.0, center=1.0, width=0.5, x_unit='meV', y_unit='1/meV'
+ )
+ )
+ resolution_components = ComponentCollection()
+ resolution_components.append_component(
+ Gaussian(name='R', area=1.0, center=0.0, width=0.3, x_unit='meV', y_unit='1/meV')
+ )
+ conv = Convolution(
+ energy=energy,
+ sample_components=sample_components,
+ resolution_components=resolution_components,
+ )
+
+ # THEN
+ conv.convert_y_unit('1/eV')
+
+ # EXPECT
+ assert conv.y_unit == '1/eV'
+ assert conv._numerical_convolver._y_unit == '1/eV'
diff --git a/tests/unit/easydynamics/convolution/test_convolution_base.py b/tests/unit/easydynamics/convolution/test_convolution_base.py
index fe8ceef62..7bc26fff7 100644
--- a/tests/unit/easydynamics/convolution/test_convolution_base.py
+++ b/tests/unit/easydynamics/convolution/test_convolution_base.py
@@ -32,6 +32,8 @@ def test_init(self, convolution_base):
assert np.allclose(convolution_base.energy.values, np.linspace(-10, 10, 100))
assert isinstance(convolution_base._sample_components, ComponentCollection)
assert isinstance(convolution_base._resolution_components, ComponentCollection)
+ assert convolution_base.x_unit == 'meV'
+ assert convolution_base.y_unit == 'dimensionless'
def test_init_with_model_component(self):
# WHEN
@@ -78,7 +80,7 @@ def test_init_energy_numerical_none_offset(self):
'energy': 'invalid',
'sample_components': ComponentCollection(),
'resolution_components': ComponentCollection(),
- 'unit': 'meV',
+ 'x_unit': 'meV',
'energy_offset': 0,
},
'Energy must be',
@@ -88,7 +90,7 @@ def test_init_energy_numerical_none_offset(self):
'energy': np.linspace(-10, 10, 100),
'sample_components': 'invalid',
'resolution_components': ComponentCollection(),
- 'unit': 'meV',
+ 'x_unit': 'meV',
'energy_offset': 0,
},
(
@@ -101,7 +103,7 @@ def test_init_energy_numerical_none_offset(self):
'energy': np.linspace(-10, 10, 100),
'sample_components': ComponentCollection(),
'resolution_components': 'invalid',
- 'unit': 'meV',
+ 'x_unit': 'meV',
'energy_offset': 0,
},
(
@@ -114,7 +116,7 @@ def test_init_energy_numerical_none_offset(self):
'energy': np.linspace(-10, 10, 100),
'sample_components': ComponentCollection(),
'resolution_components': ComponentCollection(),
- 'unit': 123,
+ 'x_unit': 123,
'energy_offset': 0,
},
'unit must be ',
@@ -124,7 +126,17 @@ def test_init_energy_numerical_none_offset(self):
'energy': np.linspace(-10, 10, 100),
'sample_components': ComponentCollection(),
'resolution_components': ComponentCollection(),
- 'unit': 'meV',
+ 'y_unit': 123,
+ 'energy_offset': 0,
+ },
+ 'unit must be ',
+ ),
+ (
+ {
+ 'energy': np.linspace(-10, 10, 100),
+ 'sample_components': ComponentCollection(),
+ 'resolution_components': ComponentCollection(),
+ 'x_unit': 'meV',
'energy_offset': 'invalid',
},
'Energy_offset must be ',
@@ -181,17 +193,17 @@ def test_unit_setter_raises(self, convolution_base):
# WHEN THEN EXPECT
with pytest.raises(
AttributeError,
- match=r'Use convert_unit to change the unit between allowed types ',
+ match=r'read-only',
):
- convolution_base.unit = 'K'
+ convolution_base.x_unit = 'K'
def test_convert_unit(self, convolution_base):
# WHEN THEN
- convolution_base.convert_unit('eV')
+ convolution_base.convert_x_unit('eV')
# EXPECT
assert convolution_base.energy.unit == 'eV'
- assert convolution_base.unit == 'eV'
+ assert convolution_base.x_unit == 'eV'
assert np.allclose(convolution_base.energy.values, np.linspace(-0.01, 0.01, 100))
def test_convert_unit_invalid_type_raises(self, convolution_base):
@@ -200,7 +212,7 @@ def test_convert_unit_invalid_type_raises(self, convolution_base):
TypeError,
match=r'Energy unit must be a string or scipp unit.',
):
- convolution_base.convert_unit(123)
+ convolution_base.convert_x_unit(123)
def test_convert_unit_invalid_unit_rollback(self, convolution_base):
# WHEN THEN
@@ -208,10 +220,10 @@ def test_convert_unit_invalid_unit_rollback(self, convolution_base):
UnitError,
match=r'Conversion from `meV` to `s` is not valid.',
):
- convolution_base.convert_unit('s')
+ convolution_base.convert_x_unit('s')
# EXPECT
- assert convolution_base.unit == 'meV'
+ assert convolution_base.x_unit == 'meV'
assert np.allclose(convolution_base.energy.values, np.linspace(-10, 10, 100))
def test_convert_unit_invalid_offset_unit_rollback(self, convolution_base):
@@ -223,10 +235,10 @@ def test_convert_unit_invalid_offset_unit_rollback(self, convolution_base):
UnitError,
match=r'Conversion from `s` to `meV` is not valid.',
):
- convolution_base.convert_unit('meV')
+ convolution_base.convert_x_unit('meV')
# EXPECT
- assert convolution_base.unit == 'meV'
+ assert convolution_base.x_unit == 'meV'
assert convolution_base.energy_offset.unit == 's'
def test_energy_offset_property(self, convolution_base):
@@ -255,6 +267,21 @@ def test_energy_with_offset_setter_raises(self, convolution_base):
with pytest.raises(AttributeError):
convolution_base.energy_with_offset = 5
+ def test_energy_with_offset_unit_conversion(self, convolution_base):
+ # WHEN: energy is in meV and energy_offset given in eV (0.001 eV = 1 meV)
+ convolution_base.energy_offset = Parameter(
+ name='energy_offset',
+ value=0.001,
+ unit='eV', # 0.001 eV = 1 meV
+ )
+
+ # THEN
+ result = convolution_base.energy_with_offset
+
+ # EXPECT: offset is converted to meV before subtracting → shifted by -1 meV
+ expected = convolution_base.energy.values - 1.0
+ np.testing.assert_allclose(result.values, expected, rtol=1e-5)
+
def test_sample_components_property(self, convolution_base):
# WHEN THEN EXPECT
assert isinstance(convolution_base.sample_components, ComponentCollection)
@@ -303,3 +330,76 @@ def test_resolution_components_setter_invalid_type_raises(self, convolution_base
),
):
convolution_base.resolution_components = 'invalid'
+
+ def test_y_unit_setter_raises(self, convolution_base):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError, match='read-only'):
+ convolution_base.y_unit = '1/meV'
+
+ def test_convert_y_unit(self, convolution_base):
+ # WHEN: sample component with y_unit='1/meV'
+ convolution_base.sample_components.append_component(
+ Gaussian(area=1.0, x_unit='meV', y_unit='1/meV')
+ )
+
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ convolution_base.convert_y_unit('1/eV')
+
+ # EXPECT: y_unit updated and propagated to sample components
+ assert convolution_base.y_unit == '1/eV'
+ for component in convolution_base.sample_components:
+ assert component.y_unit == '1/eV'
+
+ def test_convert_y_unit_invalid_type_raises(self, convolution_base):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError):
+ convolution_base.convert_y_unit(123)
+
+ def test_convert_y_unit_rollback_on_failure(self, convolution_base):
+ # WHEN: sample component with default y_unit='dimensionless'
+ convolution_base.sample_components.append_component(Gaussian(area=1.0, x_unit='meV'))
+
+ # THEN EXPECT: 'K' is dimensionally incompatible → triggers rollback
+ with pytest.raises(UnitError):
+ convolution_base.convert_y_unit('K')
+
+ assert convolution_base.y_unit == 'dimensionless'
+
+ def test_convert_x_unit_without_sample_components(self):
+ # WHEN
+ cb = ConvolutionBase(energy=np.linspace(-10, 10, 100), sample_components=None)
+
+ # THEN
+ cb.convert_x_unit('eV')
+
+ # EXPECT
+ assert cb.energy.unit == 'eV'
+ assert np.allclose(cb.energy.values, np.linspace(-0.01, 0.01, 100))
+ assert cb.x_unit == 'eV'
+
+ def test_convert_x_unit_without_resolution_components(self):
+ # WHEN
+ sample = ComponentCollection(display_name='Sample')
+ cb = ConvolutionBase(
+ energy=np.linspace(-10, 10, 100),
+ sample_components=sample,
+ resolution_components=None,
+ )
+
+ # THEN
+ cb.convert_x_unit('eV')
+
+ # EXPECT
+ assert cb.energy.unit == 'eV'
+ assert np.allclose(cb.energy.values, np.linspace(-0.01, 0.01, 100))
+ assert cb.x_unit == 'eV'
+
+ def test_convert_y_unit_without_sample_components(self):
+ # WHEN
+ cb = ConvolutionBase(energy=np.linspace(-10, 10, 100), sample_components=None)
+
+ # THEN
+ cb.convert_y_unit('1/meV')
+
+ # EXPECT
+ assert cb.y_unit == '1/meV'
diff --git a/tests/unit/easydynamics/convolution/test_energy_grid.py b/tests/unit/easydynamics/convolution/test_energy_grid.py
index 01cf19d40..e0390b3ee 100644
--- a/tests/unit/easydynamics/convolution/test_energy_grid.py
+++ b/tests/unit/easydynamics/convolution/test_energy_grid.py
@@ -8,12 +8,14 @@
class TestEnergyGrid:
def test_energy_grid_attributes(self):
+ # WHEN
energy_dense = np.array([-1.0, 0.0, 1.0])
energy_dense_centered = np.array([-1.0, 0.0, 1.0])
energy_dense_step = 1.0
energy_span_dense = 2.0
energy_even_length_offset = 0.0
+ # THEN
energy_grid = EnergyGrid(
energy_dense=energy_dense,
energy_span_dense=energy_span_dense,
@@ -22,6 +24,7 @@ def test_energy_grid_attributes(self):
energy_dense_step=energy_dense_step,
)
+ # EXPECT
assert np.array_equal(energy_grid.energy_dense, energy_dense)
assert np.array_equal(energy_grid.energy_dense_centered, energy_dense_centered)
assert energy_grid.energy_dense_step == energy_dense_step
diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution.py b/tests/unit/easydynamics/convolution/test_numerical_convolution.py
index 755381c24..6e43ce017 100644
--- a/tests/unit/easydynamics/convolution/test_numerical_convolution.py
+++ b/tests/unit/easydynamics/convolution/test_numerical_convolution.py
@@ -4,33 +4,86 @@
import numpy as np
import pytest
import scipp as sc
+from easyscience.variable import Parameter
from scipy.signal import fftconvolve
from easydynamics.convolution.energy_grid import EnergyGrid
from easydynamics.convolution.numerical_convolution import NumericalConvolution
from easydynamics.sample_model import Gaussian
from easydynamics.sample_model.component_collection import ComponentCollection
+from easydynamics.settings.convolution_settings import ConvolutionSettings
from easydynamics.utils.detailed_balance import detailed_balance_factor
+def _make_numerical_convolution(convolution_settings=None):
+ energy = np.linspace(-10, 10, 5001)
+ sample_components = ComponentCollection(display_name='ComponentCollection')
+ sample_components.append_component(Gaussian(name='Gaussian1', area=2.0, center=0.1, width=0.4))
+ resolution_components = ComponentCollection(display_name='ResolutionModel')
+ resolution_components.append_component(
+ Gaussian(name='GaussianRes', area=3.0, center=0.2, width=0.5)
+ )
+
+ return NumericalConvolution(
+ energy=energy,
+ sample_components=sample_components,
+ resolution_components=resolution_components,
+ convolution_settings=convolution_settings,
+ )
+
+
class TestNumericalConvolution:
@pytest.fixture
def default_numerical_convolution(self):
- energy = np.linspace(-10, 10, 5001)
- sample_components = ComponentCollection(display_name='ComponentCollection')
- sample_components.append_component(
- Gaussian(name='Gaussian1', area=2.0, center=0.1, width=0.4)
- )
- resolution_components = ComponentCollection(display_name='ResolutionModel')
- resolution_components.append_component(
- Gaussian(name='GaussianRes', area=3.0, center=0.2, width=0.5)
- )
+ return _make_numerical_convolution()
- return NumericalConvolution(
- energy=energy,
- sample_components=sample_components,
- resolution_components=resolution_components,
- )
+ def test_settings_change_invalidates_all_sharing_convolvers(self):
+ # WHEN: two convolvers sharing one ConvolutionSettings, and an accuracy knob changes
+ settings = ConvolutionSettings()
+ convolver_a = _make_numerical_convolution(settings)
+ convolver_b = _make_numerical_convolution(settings)
+ settings.upsample_factor = 10
+
+ # THEN: A consumes the change by rebuilding its plan
+ convolver_a.convolution()
+
+ # EXPECT: A is current, but B still sees the stale plan (regression: A's rebuild
+ # setting the shared flag used to mask the settings change from B)
+ assert convolver_a._convolution_plan_is_current() is True
+ assert convolver_b._convolution_plan_is_current() is False
+
+ # and B becomes current after its own rebuild
+ convolver_b.convolution()
+ assert convolver_b._convolution_plan_is_current() is True
+
+ def test_manual_plan_valid_true_blesses_all_sharing_convolvers(self):
+ # WHEN: a settings change followed by an explicit convolution_plan_is_valid = True
+ settings = ConvolutionSettings()
+ convolver_a = _make_numerical_convolution(settings)
+ convolver_b = _make_numerical_convolution(settings)
+ settings.upsample_factor = 10
+
+ # THEN: the escape hatch suppresses rebuilds for every convolver sharing the settings
+ settings.convolution_plan_is_valid = True
+
+ # EXPECT
+ assert convolver_a._convolution_plan_is_current() is True
+ assert convolver_b._convolution_plan_is_current() is True
+
+ def test_manual_plan_valid_false_invalidates_all_sharing_convolvers(self):
+ # WHEN: two current convolvers sharing one ConvolutionSettings
+ settings = ConvolutionSettings()
+ convolver_a = _make_numerical_convolution(settings)
+ convolver_b = _make_numerical_convolution(settings)
+
+ # THEN: an explicit False invalidates the plan for every convolver, and one
+ # convolver rebuilding does not mask the invalidation from the other
+ settings.convolution_plan_is_valid = False
+ convolver_a.convolution()
+
+ # EXPECT
+ assert convolver_a._convolution_plan_is_current() is True
+ assert convolver_b._convolution_plan_is_current() is False
def test_init(self, default_numerical_convolution):
"""
@@ -48,7 +101,8 @@ def test_init(self, default_numerical_convolution):
assert default_numerical_convolution.upsample_factor == 5
assert default_numerical_convolution.extension_factor == pytest.approx(0.2)
assert default_numerical_convolution.temperature is None
- assert default_numerical_convolution.unit == 'meV'
+ assert default_numerical_convolution.x_unit == 'meV'
+ assert default_numerical_convolution.y_unit == 'dimensionless'
assert (
default_numerical_convolution.detailed_balance_settings.normalize_detailed_balance
is True
@@ -233,3 +287,53 @@ def fake_interp(*args, **kwargs): # noqa: ARG001
assert result.shape == conv.energy.values.shape
else:
assert result.shape == dense.shape
+
+ def test_convolution_with_energy_offset_in_different_unit(self, default_numerical_convolution):
+ # WHEN: energy grid is in meV, offset given as a Parameter in eV (0.001 eV = 1.0 meV)
+ conv = default_numerical_convolution
+ conv.energy_offset = Parameter(name='energy_offset', value=0.001, unit='eV')
+ result_ev = conv.convolution()
+
+ # THEN: run again with the numerically equivalent offset as an explicit meV Parameter
+ conv.energy_offset = Parameter(name='energy_offset', value=1.0, unit='meV')
+ result_mev = conv.convolution()
+
+ # EXPECT: unit conversion is transparent — results are numerically identical
+ np.testing.assert_allclose(result_ev, result_mev, rtol=1e-6)
+
+ def test_repr(self, default_numerical_convolution):
+ r = repr(default_numerical_convolution)
+ assert 'NumericalConvolution' in r
+ assert 'energy_len=' in r
+
+ # ───── Regression tests ─────
+
+ def test_detailed_balance_energy_includes_even_length_offset(
+ self, default_numerical_convolution, monkeypatch
+ ):
+ # WHEN: even-length energy grid with temperature
+ nc = default_numerical_convolution
+ nc.energy = np.linspace(-10, 10, 100) # even-length → non-zero energy_even_length_offset
+ nc.temperature = 300.0
+
+ captured = {}
+
+ # spy_dbf wraps detailed_balance_factor: captures the energy argument passed to it,
+ # then delegates to the real function so the convolution still produces a valid result.
+ def spy_dbf(**kwargs):
+ captured['energy'] = kwargs['energy']
+ return detailed_balance_factor(**kwargs)
+
+ monkeypatch.setattr(
+ 'easydynamics.convolution.numerical_convolution.detailed_balance_factor', spy_dbf
+ )
+
+ # THEN: run convolution
+ nc.convolution()
+
+ # EXPECT: DBF receives energy_dense - energy_even_length_offset
+ # Before the fix, energy_even_length_offset was omitted, causing a half-bin error.
+ grid = nc._energy_grid
+ np.testing.assert_allclose(
+ captured['energy'], grid.energy_dense - grid.energy_even_length_offset, atol=1e-12
+ )
diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py
index 73bcee0bd..2d64e0da1 100644
--- a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py
+++ b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py
@@ -48,7 +48,8 @@ def test_init(self, default_numerical_convolution_base):
assert default_numerical_convolution_base.upsample_factor == 5
assert default_numerical_convolution_base.extension_factor == pytest.approx(0.2)
assert default_numerical_convolution_base.temperature is None
- assert default_numerical_convolution_base.unit == 'meV'
+ assert default_numerical_convolution_base.x_unit == 'meV'
+ assert default_numerical_convolution_base.y_unit == 'dimensionless'
assert (
default_numerical_convolution_base.detailed_balance_settings.normalize_detailed_balance
is True
@@ -79,7 +80,7 @@ def test_init_with_custom_parameters(self):
detailed_balance_settings=detailed_balance_settings,
temperature=temperature,
temperature_unit=temperature_unit,
- unit=unit,
+ x_unit=unit,
)
# EXPECT
@@ -87,7 +88,8 @@ def test_init_with_custom_parameters(self):
assert numerical_convolution_base.extension_factor == pytest.approx(0.5)
assert numerical_convolution_base.temperature.value == temperature
assert numerical_convolution_base.temperature.unit == temperature_unit
- assert numerical_convolution_base.unit == unit
+ assert numerical_convolution_base.x_unit == unit
+ assert numerical_convolution_base.y_unit == 'dimensionless'
assert (
numerical_convolution_base.detailed_balance_settings.normalize_detailed_balance
is False
@@ -489,9 +491,11 @@ def test_create_energy_grid_upsample_none_non_uniform_raises(
and non-uniform energy raises ValueError (the energy grid must
always be uniform).
"""
- # WHEN
+ # WHEN: non-uniform energy with upsample_factor=None
default_numerical_convolution_base.energy = np.array([0, 1, 3, 6, 10])
default_numerical_convolution_base.upsample_factor = None
+
+ # THEN EXPECT
with pytest.raises(
ValueError,
match='Input array `energy` must be uniformly spaced if upsample_factor is not given',
@@ -638,12 +642,12 @@ def test_repr(self, default_numerical_convolution_base):
# Sample and resolution models
assert 'ComponentCollection' in repr_str
- assert 'components=[No components]' in repr_str
+ assert 'Components: No components' in repr_str
assert 'sample_components=' in repr_str
assert 'resolution_components=' in repr_str
# Important parameters
- assert 'unit=meV' in repr_str
+ assert 'x_unit=meV' in repr_str
assert 'upsample_factor=5' in repr_str
assert 'extension_factor=0.2' in repr_str
assert 'temperature=None' in repr_str
diff --git a/tests/unit/easydynamics/experiment/test_experiment.py b/tests/unit/easydynamics/experiment/test_experiment.py
index 7ffc83c87..ed360d94e 100644
--- a/tests/unit/easydynamics/experiment/test_experiment.py
+++ b/tests/unit/easydynamics/experiment/test_experiment.py
@@ -142,6 +142,18 @@ def test_load_hdf5_invalid_file_raises(self, experiment):
with pytest.raises(OSError):
experiment.load_hdf5('non_existent_file.h5')
+ def test_load_hdf5_invalid_data_type_raises(self, experiment):
+ "Test that loading a file that returns a non-DataArray raises TypeError"
+ # WHEN THEN EXPECT
+ with (
+ patch(
+ 'easydynamics.experiment.experiment.sc_load_hdf5',
+ return_value=sc.scalar(1.0),
+ ),
+ pytest.raises(TypeError, match=r'sc\.DataArray'),
+ ):
+ experiment.load_hdf5('fake_file.h5')
+
def test_save_hdf5(self, tmp_path, experiment):
"Test saving data to an HDF5 file. Load the saved file"
'using scipp and compare to the original data.'
@@ -319,15 +331,55 @@ def test_get_masked_energy_no_data_returns_None(self):
@pytest.mark.parametrize(
'Q_index',
- [-1, 100, 'not an index'],
- ids=['negative_index', 'out_of_bounds_index', 'invalid_type'],
+ [-1, 100],
+ ids=['negative_index', 'out_of_bounds_index'],
)
def test_get_masked_energy_invalid_Q_index_raises(self, experiment_with_data, Q_index):
- "Test getting masked energy raises IndexError when Q index is invalid"
+ "Test getting masked energy raises IndexError when Q index is out of range"
# WHEN THEN EXPECT
with pytest.raises(IndexError):
experiment_with_data.get_masked_energy(Q_index=Q_index)
+ def test_get_masked_energy_invalid_type_raises(self, experiment_with_data):
+ "Test getting masked energy raises TypeError when Q index is not an integer"
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError):
+ experiment_with_data.get_masked_energy(Q_index='not an index')
+
+ def test_get_finite_energy_mask_returns_none_without_data(self):
+ "Test get_finite_energy_mask returns None when no data is loaded"
+ # WHEN THEN EXPECT
+ assert Experiment().get_finite_energy_mask(Q_index=0) is None
+
+ def test_get_masked_binned_data_returns_none_without_data(self):
+ "Test get_masked_binned_data returns None when no data is loaded"
+ # WHEN THEN EXPECT
+ assert Experiment().get_masked_binned_data(Q_index=0) is None
+
+ def test_get_finite_energy_mask_with_data(self, experiment_with_data):
+ "Test get_finite_energy_mask returns a boolean scipp variable when data is loaded"
+ # WHEN
+
+ # THEN
+ mask = experiment_with_data.get_finite_energy_mask(Q_index=0)
+
+ # EXPECT
+ assert mask is not None
+ assert mask.dims == ('energy',)
+ assert len(mask) == 3
+
+ def test_get_masked_binned_data_with_data(self, experiment_with_data):
+ "Test get_masked_binned_data returns a DataArray with all-finite data loaded"
+ # WHEN
+
+ # THEN
+ result = experiment_with_data.get_masked_binned_data(Q_index=0)
+
+ # EXPECT
+ assert result is not None
+ assert isinstance(result, sc.DataArray)
+ assert result.dims == ('energy',)
+
##############
# test plotting
##############
@@ -496,8 +548,8 @@ def test_extract_x_y_var(self, experiment_with_data):
experiment_with_data.data.variances[Q_index],
)
- def test_extract_x_y_weights_only_finite_zero_variances(self, experiment_with_data):
- "Test that _extract_x_y_weights_only_finite raises ValueError when variances contain zeros"
+ def testextract_x_y_weights_only_finite_zero_variances(self, experiment_with_data):
+ "Test that extract_x_y_weights_only_finite raises ValueError when variances contain zeros"
# WHEN
Q_index = 0
invalid_data = experiment_with_data._data.copy()
@@ -507,10 +559,10 @@ def test_extract_x_y_weights_only_finite_zero_variances(self, experiment_with_da
# THEN EXPECT
with pytest.raises(ValueError, match='Cannot compute weights: some variances are zero'):
- Experiment(data=invalid_data)._extract_x_y_weights_only_finite(Q_index=Q_index)
+ Experiment(data=invalid_data).extract_x_y_weights_only_finite(Q_index=Q_index)
- def test_extract_x_y_weights_only_finite(self, experiment_with_data):
- "Test that _extract_x_y_weights_only_finite only returns finite values"
+ def testextract_x_y_weights_only_finite(self, experiment_with_data):
+ "Test that extract_x_y_weights_only_finite only returns finite values"
# WHEN
Q_index = 0
invalid_data = experiment_with_data._data.copy()
@@ -518,7 +570,7 @@ def test_extract_x_y_weights_only_finite(self, experiment_with_data):
invalid_data.data.variances[Q_index][1] = np.nan
# THEN
- x, y, weights, mask = Experiment(data=invalid_data)._extract_x_y_weights_only_finite(
+ x, y, weights, mask = Experiment(data=invalid_data).extract_x_y_weights_only_finite(
Q_index=Q_index
)
@@ -533,7 +585,7 @@ def test_extract_x_y_weights_only_finite(self, experiment_with_data):
# Mask should indicate which values were removed
assert np.array_equal(mask, [False, False, True])
- def test_extract_x_y_weights_only_finite_zero_variance(self, experiment_with_data):
+ def testextract_x_y_weights_only_finite_zero_variance(self, experiment_with_data):
"Test getting x y and weights when variances are None"
# WHEN
Q_index = 0
@@ -543,9 +595,7 @@ def test_extract_x_y_weights_only_finite_zero_variance(self, experiment_with_dat
experiment_with_data.data = data
# THEN
- x, y, weights, mask = experiment_with_data._extract_x_y_weights_only_finite(
- Q_index=Q_index
- )
+ x, y, weights, mask = experiment_with_data.extract_x_y_weights_only_finite(Q_index=Q_index)
# EXPECT
assert np.array_equal(x, experiment_with_data.energy.values)
diff --git a/tests/unit/easydynamics/sample_model/components/test_damped_harmonic_oscillator.py b/tests/unit/easydynamics/sample_model/components/test_damped_harmonic_oscillator.py
index 02fc31dfa..12f73235f 100644
--- a/tests/unit/easydynamics/sample_model/components/test_damped_harmonic_oscillator.py
+++ b/tests/unit/easydynamics/sample_model/components/test_damped_harmonic_oscillator.py
@@ -5,7 +5,9 @@
import numpy as np
import pytest
+import scipp as sc
from easyscience.variable import Parameter
+from scipp import UnitError
from scipy.integrate import simpson
from easydynamics.sample_model import DampedHarmonicOscillator
@@ -20,7 +22,7 @@ def dho(self):
area=2.0,
center=1.5,
width=0.3,
- unit='meV',
+ x_unit='meV',
)
def test_init_no_inputs(self):
@@ -32,7 +34,8 @@ def test_init_no_inputs(self):
assert dho.area.value == pytest.approx(1.0)
assert dho.center.value == pytest.approx(1.0)
assert dho.width.value == pytest.approx(1.0)
- assert dho.unit == 'meV'
+ assert dho.x_unit == 'meV'
+ assert dho.y_unit == 'dimensionless'
def test_initialization(self, dho: DampedHarmonicOscillator):
# WHEN THEN EXPECT
@@ -40,26 +43,36 @@ def test_initialization(self, dho: DampedHarmonicOscillator):
assert dho.area.value == pytest.approx(2.0)
assert dho.center.value == pytest.approx(1.5)
assert dho.width.value == pytest.approx(0.3)
- assert dho.unit == 'meV'
+ assert dho.x_unit == 'meV'
@pytest.mark.parametrize(
'kwargs, expected_message',
[
(
- {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'unit': 'meV'},
+ {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'x_unit': 'meV'},
'area must be a number',
),
(
- {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'unit': 'meV'},
+ {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'x_unit': 'meV'},
'center must be ',
),
(
- {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'unit': 'meV'},
+ {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'x_unit': 'meV'},
'width must be a number',
),
(
- {'area': 2.0, 'center': 0.5, 'width': 0.6, 'unit': 123},
- 'unit must be None',
+ {'area': 2.0, 'center': 0.5, 'width': 0.6, 'x_unit': 123},
+ 'unit must be None, a string',
+ ),
+ (
+ {
+ 'area': 2.0,
+ 'center': 0.5,
+ 'width': 0.6,
+ 'x_unit': 'meV',
+ 'y_unit': 123,
+ },
+ 'unit must be None, a string',
),
],
)
@@ -78,7 +91,7 @@ def test_negative_width_raises(self):
area=2.0,
center=0.5,
width=-0.6,
- unit='meV',
+ x_unit='meV',
)
def test_negative_area_warns(self):
@@ -89,7 +102,7 @@ def test_negative_area_warns(self):
area=-2.0,
center=0.5,
width=0.6,
- unit='meV',
+ x_unit='meV',
)
@pytest.mark.parametrize(
@@ -108,11 +121,15 @@ def test_property_setters(
invalid_value,
invalid_message,
):
- # set valid
+ # WHEN
+
+ # THEN : set a valid value
setattr(dho, prop, valid_value)
+
+ # EXPECT
assert getattr(dho, prop).value == valid_value
- # invalid
+ # WHEN: set an invalid value — THEN EXPECT
with pytest.raises(TypeError, match=invalid_message):
setattr(dho, prop, invalid_value)
@@ -167,12 +184,12 @@ def test_area_matches_parameter(self, dho: DampedHarmonicOscillator):
# EXPECT
assert numerical_area == pytest.approx(dho.area.value, rel=2e-3)
- def test_convert_unit(self, dho: DampedHarmonicOscillator):
+ def test_convert_x_unit(self, dho: DampedHarmonicOscillator):
# WHEN THEN
- dho.convert_unit('microeV')
+ dho.convert_x_unit('microeV')
# EXPECT
- assert dho.unit == 'microeV'
+ assert dho.x_unit == 'microeV'
assert dho.area.value == pytest.approx(2 * 1e3)
assert dho.center.value == pytest.approx(1.5 * 1e3)
assert dho.width.value == pytest.approx(0.3 * 1e3)
@@ -194,7 +211,7 @@ def test_copy(self, dho: DampedHarmonicOscillator):
assert dho_copy.width.value == dho.width.value
assert dho_copy.width.fixed == dho.width.fixed
- assert dho_copy.unit == dho.unit
+ assert dho_copy.x_unit == dho.x_unit
def test_repr(self, dho: DampedHarmonicOscillator):
# WHEN THEN
@@ -202,8 +219,91 @@ def test_repr(self, dho: DampedHarmonicOscillator):
# EXPECT
assert 'DampedHarmonicOscillator' in repr_str
- assert "name='TestDHOName'" in repr_str
- assert 'unit=meV' in repr_str
- assert 'area=' in repr_str
- assert 'center=' in repr_str
- assert 'width=' in repr_str
+ assert 'name = TestDHOName' in repr_str
+ assert 'x_unit = meV' in repr_str
+ assert 'area =' in repr_str
+ assert 'center =' in repr_str
+ assert 'width =' in repr_str
+
+ def test_y_unit_custom(self):
+ # WHEN THEN
+ dho = DampedHarmonicOscillator(
+ area=1.0, center=1.0, width=0.3, x_unit='meV', y_unit='1/meV'
+ )
+ # EXPECT
+ assert dho.y_unit == '1/meV'
+
+ def test_y_unit_setter_raises(self, dho: DampedHarmonicOscillator):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError):
+ dho.y_unit = '1/meV'
+
+ def test_convert_y_unit(self):
+ # WHEN: x_unit='meV', y_unit='1/meV' → area_unit='dimensionless'
+ dho = DampedHarmonicOscillator(
+ area=1.0, center=1.0, width=0.3, x_unit='meV', y_unit='1/meV'
+ )
+
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ dho.convert_y_unit('1/eV')
+
+ # EXPECT: y_unit updated and area value rescaled (1e3 factor)
+ assert dho.y_unit == '1/eV'
+ assert dho.area.value == pytest.approx(1e3)
+
+ def test_convert_y_unit_invalid_type_raises(self, dho: DampedHarmonicOscillator):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError):
+ dho.convert_y_unit(123)
+
+ def test_evaluate_scipp_output(self, dho: DampedHarmonicOscillator):
+ # WHEN
+ x = np.linspace(0.5, 5.0, 50)
+
+ # THEN
+ result = dho.evaluate(x, output='scipp')
+
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('dimensionless')
+ assert len(result.values) == 50
+ np.testing.assert_allclose(result.values, dho.evaluate(x, output='numpy'))
+
+ def test_evaluate_scipp_output_with_y_unit(self):
+ # WHEN
+ dho = DampedHarmonicOscillator(
+ area=1.0, center=1.0, width=0.3, x_unit='meV', y_unit='1/meV'
+ )
+ x = np.linspace(0.5, 5.0, 50)
+
+ # THEN
+ result = dho.evaluate(x, output='scipp')
+
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('1/meV')
+
+ def test_convert_x_unit_invalid_type_raises(self, dho: DampedHarmonicOscillator):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match=r'x_unit must be a string or sc\.Unit'):
+ dho.convert_x_unit(123)
+
+ def test_convert_x_unit_rollback_on_failure(self, dho: DampedHarmonicOscillator):
+ # WHEN THEN
+ with pytest.raises(UnitError):
+ dho.convert_x_unit('m')
+ # EXPECT: state rolled back
+ assert dho.x_unit == 'meV'
+ assert dho.area.value == pytest.approx(2.0)
+ assert dho.center.value == pytest.approx(1.5)
+ assert dho.width.value == pytest.approx(0.3)
+
+ def test_convert_y_unit_rollback_on_failure(self):
+ # WHEN
+ dho = DampedHarmonicOscillator(area=1.0, center=1.0, width=0.3, x_unit='meV')
+ # THEN
+ with pytest.raises(UnitError):
+ dho.convert_y_unit('K')
+ # EXPECT: state rolled back
+ assert dho.y_unit == 'dimensionless'
+ assert dho.area.value == pytest.approx(1.0)
diff --git a/tests/unit/easydynamics/sample_model/components/test_delta_function.py b/tests/unit/easydynamics/sample_model/components/test_delta_function.py
index de3148b7d..50baf2317 100644
--- a/tests/unit/easydynamics/sample_model/components/test_delta_function.py
+++ b/tests/unit/easydynamics/sample_model/components/test_delta_function.py
@@ -20,7 +20,7 @@ def delta_function(self):
display_name='TestDeltaFunction',
area=2.0,
center=0.5,
- unit='meV',
+ x_unit='meV',
)
def test_init_no_inputs(self):
@@ -31,7 +31,8 @@ def test_init_no_inputs(self):
assert delta_function.display_name == 'DeltaFunction'
assert delta_function.area.value == pytest.approx(1.0)
assert delta_function.center.value == pytest.approx(0.0)
- assert delta_function.unit == 'meV'
+ assert delta_function.x_unit == 'meV'
+ assert delta_function.y_unit == 'dimensionless'
assert delta_function.center.fixed is True
def test_initialization(self, delta_function: DeltaFunction):
@@ -39,21 +40,25 @@ def test_initialization(self, delta_function: DeltaFunction):
assert delta_function.display_name == 'TestDeltaFunction'
assert delta_function.area.value == pytest.approx(2.0)
assert delta_function.center.value == pytest.approx(0.5)
- assert delta_function.unit == 'meV'
+ assert delta_function.x_unit == 'meV'
@pytest.mark.parametrize(
'kwargs, expected_message',
[
(
- {'area': 'invalid', 'center': 0.5, 'unit': 'meV'},
+ {'area': 'invalid', 'center': 0.5, 'x_unit': 'meV'},
'area must be a number',
),
(
- {'area': 2.0, 'center': 'invalid', 'unit': 'meV'},
+ {'area': 2.0, 'center': 'invalid', 'x_unit': 'meV'},
'center must be ',
),
(
- {'area': 2.0, 'center': 0.5, 'unit': 123},
+ {'area': 2.0, 'center': 0.5, 'x_unit': 123},
+ 'unit must be ',
+ ),
+ (
+ {'area': 2.0, 'center': 0.5, 'x_unit': 'meV', 'y_unit': 123},
'unit must be ',
),
],
@@ -65,7 +70,7 @@ def test_input_type_validation_raises(self, kwargs, expected_message):
def test_negative_area_warns(self):
# WHEN THEN EXPECT
with pytest.warns(UserWarning, match='may not be physically meaningful'):
- DeltaFunction(display_name='TestDeltaFunction', area=-2.0, center=0.5, unit='meV')
+ DeltaFunction(display_name='TestDeltaFunction', area=-2.0, center=0.5, x_unit='meV')
@pytest.mark.parametrize(
'prop, valid_value, invalid_value, invalid_message',
@@ -82,11 +87,13 @@ def test_property_setters(
invalid_value,
invalid_message,
):
- # set valid
+ # WHEN: set a valid value
setattr(delta_function, prop, valid_value)
+
+ # THEN EXPECT
assert getattr(delta_function, prop).value == valid_value
- # invalid
+ # WHEN: set an invalid value — THEN EXPECT
with pytest.raises(TypeError, match=invalid_message):
setattr(delta_function, prop, invalid_value)
@@ -147,7 +154,7 @@ def test_evaluate_with_incompatible_unit_raises(self, delta_function: DeltaFunct
# THEN EXPECT
with pytest.raises(
UnitError,
- match='Input x has unit nm, but DeltaFunction component ',
+ match='Input x has unit nm',
):
delta_function.evaluate(x)
@@ -190,12 +197,12 @@ def test_get_all_parameters(self, delta_function: DeltaFunction):
actual_names = {param.name for param in params}
assert actual_names == expected_names
- def test_convert_unit(self, delta_function: DeltaFunction):
+ def test_convert_x_unit(self, delta_function: DeltaFunction):
# WHEN THEN
- delta_function.convert_unit('microeV')
+ delta_function.convert_x_unit('microeV')
# EXPECT
- assert delta_function.unit == 'microeV'
+ assert delta_function.x_unit == 'microeV'
assert delta_function.area.value == pytest.approx(2 * 1e3)
assert delta_function.center.value == pytest.approx(0.5 * 1e3)
@@ -213,7 +220,7 @@ def test_copy(self, delta_function: DeltaFunction):
assert delta_copy.center.value == delta_function.center.value
assert delta_copy.center.fixed == delta_function.center.fixed
- assert delta_copy.unit == delta_function.unit
+ assert delta_copy.x_unit == delta_function.x_unit
def test_repr(self, delta_function: DeltaFunction):
# WHEN THEN
@@ -221,7 +228,103 @@ def test_repr(self, delta_function: DeltaFunction):
# EXPECT
assert 'DeltaFunction' in repr_str
- assert "name='DeltaFunctionName'" in repr_str
- assert 'unit=meV' in repr_str
- assert 'area=' in repr_str
- assert 'center=' in repr_str
+ assert 'name = DeltaFunctionName' in repr_str
+ assert 'x_unit = meV' in repr_str
+ assert 'area =' in repr_str
+ assert 'center =' in repr_str
+
+ def test_y_unit_custom(self):
+ # WHEN THEN
+ delta = DeltaFunction(area=1.0, x_unit='meV', y_unit='1/meV')
+
+ # EXPECT
+ assert delta.y_unit == '1/meV'
+
+ def test_y_unit_setter_raises(self, delta_function: DeltaFunction):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError):
+ delta_function.y_unit = '1/meV'
+
+ def test_convert_y_unit(self):
+ # WHEN: x_unit='meV', y_unit='1/meV' → area_unit='dimensionless'
+ delta = DeltaFunction(area=1.0, x_unit='meV', y_unit='1/meV')
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ delta.convert_y_unit('1/eV')
+ # EXPECT: y_unit updated and area value rescaled (1e3 factor)
+ assert delta.y_unit == '1/eV'
+ assert delta.area.value == pytest.approx(1e3)
+
+ def test_convert_y_unit_invalid_type_raises(self, delta_function: DeltaFunction):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError):
+ delta_function.convert_y_unit(123)
+
+ def test_evaluate_scipp_output(self, delta_function: DeltaFunction):
+ # WHEN
+ x = np.linspace(-5, 5, 50)
+ # THEN
+ result = delta_function.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('dimensionless')
+ assert len(result.values) == 50
+ np.testing.assert_allclose(result.values, delta_function.evaluate(x, output='numpy'))
+
+ def test_evaluate_scipp_output_with_y_unit(self):
+ # WHEN
+ delta = DeltaFunction(area=1.0, x_unit='meV', y_unit='1/meV')
+ x = np.linspace(-5, 5, 50)
+ # THEN
+ result = delta.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('1/meV')
+
+ @pytest.mark.parametrize(
+ 'x, center, expected_idx',
+ [
+ (np.array([0.5, 1.0, 2.0]), 0.5, 0), # center at first element (line 202)
+ (np.array([0.0, 1.0, 1.5]), 1.5, 2), # center at last element (line 207)
+ ],
+ ids=['center_at_first', 'center_at_last'],
+ )
+ def test_evaluate_center_at_boundary(self, x, center, expected_idx):
+ # WHEN
+ area = 1.0
+ delta = DeltaFunction(area=area, center=center, x_unit='meV')
+
+ # THEN
+ result = delta.evaluate(x)
+
+ # EXPECT
+ # All elements except the boundary one should be zero
+ assert result[expected_idx] > 0.0
+ other_indices = [i for i in range(len(x)) if i != expected_idx]
+ assert all(result[i] == pytest.approx(0.0) for i in other_indices)
+ # Boundary bin width: both left and right are set to the single adjacent spacing
+ bin_width = x[1] - x[0] if expected_idx == 0 else x[-1] - x[-2]
+ assert np.isclose(result[expected_idx], area / bin_width, rtol=1e-10)
+
+ def test_convert_x_unit_invalid_type_raises(self, delta_function: DeltaFunction):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match=r'x_unit must be a string or sc\.Unit'):
+ delta_function.convert_x_unit(123)
+
+ def test_convert_x_unit_rollback_on_failure(self, delta_function: DeltaFunction):
+ # WHEN THEN
+ with pytest.raises(UnitError):
+ delta_function.convert_x_unit('m')
+ # EXPECT: state rolled back
+ assert delta_function.x_unit == 'meV'
+ assert delta_function.area.value == pytest.approx(2.0)
+ assert delta_function.center.value == pytest.approx(0.5)
+
+ def test_convert_y_unit_rollback_on_failure(self):
+ # WHEN
+ delta = DeltaFunction(area=1.0, center=0.0, x_unit='meV', y_unit='dimensionless')
+ # THEN
+ with pytest.raises(UnitError):
+ delta.convert_y_unit('K')
+ # EXPECT: state rolled back
+ assert delta.y_unit == 'dimensionless'
+ assert delta.area.value == pytest.approx(1.0)
diff --git a/tests/unit/easydynamics/sample_model/components/test_exponential.py b/tests/unit/easydynamics/sample_model/components/test_exponential.py
index b7fa55031..701bf73a9 100644
--- a/tests/unit/easydynamics/sample_model/components/test_exponential.py
+++ b/tests/unit/easydynamics/sample_model/components/test_exponential.py
@@ -5,6 +5,7 @@
import numpy as np
import pytest
+import scipp as sc
from easyscience.variable import Parameter
from scipp import UnitError
@@ -20,7 +21,7 @@ def exponential(self):
amplitude=2.0,
center=0.5,
rate=1.2,
- unit='meV',
+ x_unit='meV',
)
def test_init_no_inputs(self):
@@ -32,7 +33,8 @@ def test_init_no_inputs(self):
assert exponential.amplitude.value == pytest.approx(1.0)
assert exponential.center.value == pytest.approx(0.0)
assert exponential.rate.value == pytest.approx(1.0)
- assert exponential.unit == 'meV'
+ assert exponential.x_unit == 'meV'
+ assert exponential.y_unit == 'dimensionless'
def test_initialization(self, exponential: Exponential):
# WHEN THEN EXPECT
@@ -40,23 +42,33 @@ def test_initialization(self, exponential: Exponential):
assert exponential.amplitude.value == pytest.approx(2.0)
assert exponential.center.value == pytest.approx(0.5)
assert exponential.rate.value == pytest.approx(1.2)
- assert exponential.unit == 'meV'
+ assert exponential.x_unit == 'meV'
@pytest.mark.parametrize(
'kwargs, expected_message',
[
(
- {'amplitude': 'invalid', 'center': 0.5, 'rate': 1.0, 'unit': 'meV'},
+ {'amplitude': 'invalid', 'center': 0.5, 'rate': 1.0, 'x_unit': 'meV'},
'amplitude must be a number',
),
(
- {'amplitude': 2.0, 'center': 'invalid', 'rate': 1.0, 'unit': 'meV'},
+ {'amplitude': 2.0, 'center': 'invalid', 'rate': 1.0, 'x_unit': 'meV'},
'center must be None or a number',
),
(
- {'amplitude': 2.0, 'center': 0.5, 'rate': 'invalid', 'unit': 'meV'},
+ {'amplitude': 2.0, 'center': 0.5, 'rate': 'invalid', 'x_unit': 'meV'},
'rate must be a number',
),
+ (
+ {
+ 'amplitude': 2.0,
+ 'center': 0.5,
+ 'rate': 1.0,
+ 'x_unit': 'meV',
+ 'y_unit': 123,
+ },
+ 'unit must be None, a string',
+ ),
],
)
def test_input_type_validation_raises(self, kwargs, expected_message):
@@ -67,12 +79,12 @@ def test_input_type_validation_raises(self, kwargs, expected_message):
'kwargs, expected_message',
[
(
- {'amplitude': np.nan, 'center': 0.5, 'rate': 1.0, 'unit': 'meV'},
- 'amplitude must be a finite number or a Parameter',
+ {'amplitude': np.nan, 'center': 0.5, 'rate': 1.0, 'x_unit': 'meV'},
+ 'amplitude must be finite',
),
(
- {'amplitude': 2.0, 'center': 0.5, 'rate': np.nan, 'unit': 'meV'},
- 'rate must be a finite number or a Parameter',
+ {'amplitude': 2.0, 'center': 0.5, 'rate': np.nan, 'x_unit': 'meV'},
+ 'rate must be finite',
),
],
)
@@ -96,11 +108,12 @@ def test_property_setters(
invalid_value,
invalid_message,
):
- # set valid
+ # WHEN: set a valid value
setattr(exponential, prop, valid_value)
+ # THEN EXPECT
assert getattr(exponential, prop).value == valid_value
- # invalid
+ # WHEN: set an invalid value — THEN EXPECT
with pytest.raises(TypeError, match=invalid_message):
setattr(exponential, prop, invalid_value)
@@ -144,12 +157,14 @@ def test_get_all_parameters(self, exponential: Exponential):
assert actual_names == expected_names
- def test_convert_unit(self, exponential: Exponential):
+ def test_convert_x_unit(self, exponential: Exponential):
# WHEN
- exponential.convert_unit('microeV')
- # THEN EXPECT
- assert exponential.unit == 'microeV'
+ # THEN
+ exponential.convert_x_unit('microeV')
+
+ # EXPECT
+ assert exponential.x_unit == 'microeV'
assert exponential.amplitude.value == pytest.approx(2.0 * 1e3)
assert exponential.center.value == pytest.approx(0.5 * 1e3)
@@ -158,21 +173,21 @@ def test_convert_unit(self, exponential: Exponential):
assert exponential.rate.value == pytest.approx(1.2 / 1e3)
assert str(exponential.rate.unit) == '1/ueV'
- def test_convert_unit_incorrect_unit_raises(self, exponential: Exponential):
+ def test_convert_x_unit_incorrect_unit_raises(self, exponential: Exponential):
# WHEN THEN EXPECT
with pytest.raises(TypeError, match=r'unit must be a string or sc.Unit'):
- exponential.convert_unit(123)
+ exponential.convert_x_unit(123)
- def test_convert_unit_rollback(self, exponential: Exponential):
- # WHEN
+ def test_convert_x_unit_rollback(self, exponential: Exponential):
+ # WHEN THEN
with pytest.raises(
UnitError,
match=r'Failed to convert unit: Conversion from `meV` to `m` is not valid.',
):
- exponential.convert_unit('m')
+ exponential.convert_x_unit('m')
- # THEN EXPECT - values should be unchanged
- assert exponential.unit == 'meV'
+ # EXPECT - values should be unchanged
+ assert exponential.x_unit == 'meV'
assert exponential.amplitude.value == pytest.approx(2.0)
assert exponential.amplitude.unit == 'meV'
assert exponential.center.value == pytest.approx(0.5)
@@ -182,9 +197,11 @@ def test_convert_unit_rollback(self, exponential: Exponential):
def test_copy(self, exponential: Exponential):
# WHEN
+
+ # THEN
exponential_copy = copy(exponential)
- # THEN EXPECT
+ # EXPECT
assert exponential_copy is not exponential
assert exponential_copy.display_name == exponential.display_name
@@ -197,7 +214,8 @@ def test_copy(self, exponential: Exponential):
assert exponential_copy.rate.value == exponential.rate.value
assert exponential_copy.rate.fixed == exponential.rate.fixed
- assert exponential_copy.unit == exponential.unit
+ assert exponential_copy.x_unit == exponential.x_unit
+ assert exponential_copy.y_unit == exponential.y_unit
def test_repr(self, exponential: Exponential):
# WHEN
@@ -205,8 +223,76 @@ def test_repr(self, exponential: Exponential):
# THEN EXPECT
assert 'Exponential' in repr_str
- assert "name='ExponentialName'" in repr_str
- assert 'unit=meV' in repr_str
- assert 'amplitude=' in repr_str
- assert 'center=' in repr_str
- assert 'rate=' in repr_str
+ assert 'name = ExponentialName' in repr_str
+ assert 'x_unit = meV' in repr_str
+ assert 'amplitude =' in repr_str
+ assert 'center =' in repr_str
+ assert 'rate =' in repr_str
+
+ def test_y_unit_custom(self):
+ # WHEN THEN
+ exp = Exponential(amplitude=1.0, center=0.0, rate=1.0, x_unit='meV', y_unit='1/meV')
+ # EXPECT
+ assert exp.y_unit == '1/meV'
+
+ def test_y_unit_setter_raises(self, exponential: Exponential):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError):
+ exponential.y_unit = '1/meV'
+
+ def test_convert_y_unit(self):
+ # WHEN: x_unit='meV', y_unit='1/meV' → amplitude_unit='dimensionless'
+ exp = Exponential(amplitude=1.0, center=0.0, rate=1.0, x_unit='meV', y_unit='1/meV')
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ exp.convert_y_unit('1/eV')
+ # EXPECT: y_unit updated and amplitude value rescaled (1e3 factor)
+ assert exp.y_unit == '1/eV'
+ assert exp.amplitude.value == pytest.approx(1e3)
+
+ def test_convert_y_unit_invalid_type_raises(self, exponential: Exponential):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError):
+ exponential.convert_y_unit(123)
+
+ def test_evaluate_scipp_output(self, exponential: Exponential):
+ # WHEN
+ x = np.linspace(-5, 5, 50)
+ # THEN
+ result = exponential.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('dimensionless')
+ assert len(result.values) == 50
+ np.testing.assert_allclose(result.values, exponential.evaluate(x, output='numpy'))
+
+ def test_evaluate_scipp_output_with_y_unit(self):
+ # WHEN
+ exp = Exponential(amplitude=1.0, center=0.0, rate=1.0, x_unit='meV', y_unit='1/meV')
+ x = np.linspace(-5, 5, 50)
+ # THEN
+ result = exp.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('1/meV')
+
+ def test_init_rejects_parameter_amplitude(self):
+ # WHEN THEN EXPECT
+ amplitude_param = Parameter(name='amp', value=3.0, unit='meV')
+ with pytest.raises(TypeError, match='amplitude must be a number'):
+ Exponential(amplitude=amplitude_param, x_unit='meV')
+
+ def test_init_rejects_parameter_rate(self):
+ # WHEN THEN EXPECT
+ rate_param = Parameter(name='rate', value=0.5, unit='1/meV')
+ with pytest.raises(TypeError, match='rate must be a number'):
+ Exponential(rate=rate_param, x_unit='meV')
+
+ def test_convert_y_unit_rollback_on_failure(self):
+ # WHEN
+ exp = Exponential(amplitude=1.0, center=0.0, rate=1.0, x_unit='meV')
+ # THEN
+ with pytest.raises(UnitError):
+ exp.convert_y_unit('K')
+ # EXPECT: state rolled back
+ assert exp.y_unit == 'dimensionless'
+ assert exp.amplitude.value == pytest.approx(1.0)
diff --git a/tests/unit/easydynamics/sample_model/components/test_expression_component.py b/tests/unit/easydynamics/sample_model/components/test_expression_component.py
index 48fbafe26..b509f61b4 100644
--- a/tests/unit/easydynamics/sample_model/components/test_expression_component.py
+++ b/tests/unit/easydynamics/sample_model/components/test_expression_component.py
@@ -1,13 +1,18 @@
# SPDX-FileCopyrightText: 2026 EasyScience contributors
# SPDX-License-Identifier: BSD-3-Clause
+import warnings
from copy import copy
import numpy as np
import pytest
+import scipp as sc
+from easyscience.variable import DescriptorNumber
from easyscience.variable import Parameter
from easydynamics.sample_model import ExpressionComponent
+from easydynamics.sample_model import Gaussian
+from easydynamics.sample_model import Lorentzian
class TestExpressionComponent:
@@ -16,14 +21,14 @@ def expr(self):
return ExpressionComponent(
'A * exp(-(x - x0)**2 / (2*sigma**2))',
parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6},
- unit='meV',
+ x_unit='meV',
display_name='TestExpression',
)
def test_init_valid(self, expr: ExpressionComponent):
# WHEN THEN EXPECT
assert expr.display_name == 'TestExpression'
- assert expr.unit == 'meV'
+ assert expr.x_unit == 'meV'
assert expr.A.value == pytest.approx(2.0)
assert expr.x0.value == pytest.approx(0.5)
@@ -44,6 +49,50 @@ def test_init_with_parameter(self):
# EXPECT
assert expr.A.value == pytest.approx(3.0)
+ def test_parameters_are_dimensionless_by_default(self, expr: ExpressionComponent):
+ # WHEN THEN EXPECT
+ assert str(expr.A.unit) == 'dimensionless'
+ assert str(expr.x0.unit) == 'dimensionless'
+ assert str(expr.sigma.unit) == 'dimensionless'
+
+ def test_init_with_parameter_units(self):
+ # WHEN THEN
+ expr = ExpressionComponent(
+ 'A * exp(-(x - x0)**2 / (2*sigma**2))',
+ parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6},
+ parameter_units={'x0': 'meV', 'sigma': 'meV'},
+ )
+
+ # EXPECT: the named parameters get their units, values unchanged; others stay
+ # dimensionless
+ assert str(expr.x0.unit) == 'meV'
+ assert str(expr.sigma.unit) == 'meV'
+ assert str(expr.A.unit) == 'dimensionless'
+ assert expr.x0.value == pytest.approx(0.5)
+ assert expr.sigma.value == pytest.approx(0.6)
+
+ def test_init_parameter_units_overrides_parameter_instance_unit(self):
+ # WHEN: a Parameter instance in eV, but parameter_units requests meV
+ A = Parameter('A', 3.0, unit='eV')
+
+ # THEN: A*x evaluates to meV**2, which mismatches the default dimensionless y_unit
+ with pytest.warns(UserWarning, match='does not match'):
+ expr = ExpressionComponent('A * x', parameters={'A': A}, parameter_units={'A': 'meV'})
+
+ # EXPECT: the unit is relabelled without rescaling the value
+ assert str(expr.A.unit) == 'meV'
+ assert expr.A.value == pytest.approx(3.0)
+
+ def test_init_parameter_units_unknown_name_raises(self):
+ # WHEN THEN EXPECT
+ with pytest.raises(ValueError, match='unknown parameter'):
+ ExpressionComponent('A * x', parameter_units={'B': 'meV'})
+
+ def test_init_parameter_units_not_a_dict_raises(self):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match='parameter_units must be None or a dictionary'):
+ ExpressionComponent('A * x', parameter_units='meV')
+
def test_invalid_expression_raises(self):
# WHEN THEN EXPECT
with pytest.raises(ValueError, match='Invalid expression'):
@@ -131,10 +180,110 @@ def test_expression_is_read_only(self, expr: ExpressionComponent):
with pytest.raises(AttributeError, match='cannot be changed'):
expr.expression = 'x'
- def test_convert_unit_not_implemented(self, expr: ExpressionComponent):
+ def test_set_unit_relabels_without_rescaling(self, expr: ExpressionComponent):
+ # WHEN
+ expr.sigma.min = 0.1
+ expr.sigma.max = 10.0
+
+ # THEN: relabelling only sigma leaves the expression unit-inconsistent, which warns
+ with pytest.warns(UserWarning, match='not unit-consistent'):
+ expr.set_unit('sigma', 'meV')
+
+ # EXPECT: unit relabelled; value and bounds keep their numeric values
+ assert str(expr.sigma.unit) == 'meV'
+ assert expr.sigma.value == pytest.approx(0.6)
+ assert expr.sigma.min == pytest.approx(0.1)
+ assert expr.sigma.max == pytest.approx(10.0)
+
+ def test_set_unit_accepts_scipp_unit(self, expr: ExpressionComponent):
+ # WHEN THEN: relabelling only x0 leaves the expression unit-inconsistent, which warns
+ with pytest.warns(UserWarning, match='not unit-consistent'):
+ expr.set_unit('x0', sc.Unit('meV'))
+
+ # EXPECT
+ assert str(expr.x0.unit) == 'meV'
+
+ def test_set_unit_leaves_other_parameters_untouched(self, expr: ExpressionComponent):
+ # WHEN THEN
+ with pytest.warns(UserWarning, match='not unit-consistent'):
+ expr.set_unit('x0', 'meV')
+
+ # EXPECT
+ assert str(expr.A.unit) == 'dimensionless'
+ assert str(expr.sigma.unit) == 'dimensionless'
+
+ def test_set_unit_unknown_parameter_raises(self, expr: ExpressionComponent):
+ # WHEN THEN EXPECT
+ with pytest.raises(KeyError, match="No parameter named 'B'"):
+ expr.set_unit('B', 'meV')
+
+ def test_set_unit_invalid_unit_raises(self, expr: ExpressionComponent):
+ # WHEN THEN EXPECT
+ with pytest.raises(ValueError, match='not a valid scipp unit'):
+ expr.set_unit('A', 'not_a_unit')
+
+ @pytest.mark.parametrize(
+ 'name, unit, expected_message',
+ [
+ (123, 'meV', 'name must be a string'),
+ ('A', 123, 'unit must be a string or sc.Unit'),
+ ],
+ ids=['non_string_name', 'non_unit_unit'],
+ )
+ def test_set_unit_type_validation(
+ self, expr: ExpressionComponent, name, unit, expected_message
+ ):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match=expected_message):
+ expr.set_unit(name, unit)
+
+ def test_set_unit_dependent_parameter_raises(self, expr: ExpressionComponent):
+ # WHEN: make A dependent on sigma
+ expr.A.make_dependent_on(
+ dependency_expression='2 * sigma',
+ dependency_map={'sigma': expr.sigma},
+ )
+
+ # THEN EXPECT
+ with pytest.raises(AttributeError, match='dependent parameter'):
+ expr.set_unit('A', 'meV')
+
+ def test_parameter_units_survive_copy(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'A * x + B',
+ parameters={'A': 2.0, 'B': 3.0},
+ parameter_units={'A': '1/meV'},
+ )
+
+ # THEN
+ expr_copy = copy(expr)
+
+ # EXPECT: per-parameter units round-trip through to_dict/from_dict
+ assert str(expr_copy.A.unit) == '1/meV'
+ assert str(expr_copy.B.unit) == 'dimensionless'
+ assert expr_copy.A.value == pytest.approx(2.0)
+
+ def test_convert_x_unit_not_implemented(self, expr: ExpressionComponent):
+ # WHEN THEN EXPECT
+ with pytest.raises(NotImplementedError, match='not implemented'):
+ expr.convert_x_unit('microeV')
+
+ def test_convert_y_unit_not_implemented(self, expr: ExpressionComponent):
# WHEN THEN EXPECT
with pytest.raises(NotImplementedError, match='not implemented'):
- expr.convert_unit('microeV')
+ expr.convert_y_unit('1/meV')
+
+ def test_evaluate_scipp_output(self, expr: ExpressionComponent):
+ # WHEN
+ x = np.linspace(-2, 2, 30)
+ # THEN
+ result = expr.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('dimensionless')
+ assert len(result.values) == 30
+ np.testing.assert_allclose(result.values, expr.evaluate(x, output='numpy'))
def test_missing_parameter_defaults(self):
# WHEN THEN
@@ -166,20 +315,19 @@ def test_repr(self, expr: ExpressionComponent):
def test_evaluate_scalar_input(self, expr: ExpressionComponent):
# WHEN
x = 0.5
- result = expr.evaluate(x)
-
# THEN
+ result = expr.evaluate(x)
+ # EXPECT
expected = 2.0 * np.exp(-((x - 0.5) ** 2) / (2 * 0.6**2))
assert np.isclose(result, expected)
def test_reserved_name_not_parameter(self):
# WHEN
expr = ExpressionComponent('x + A', parameters={'A': 2.0})
-
# THEN
params = expr.get_all_variables()
names = {p.name for p in params}
-
+ # EXPECT
assert 'A' in names
assert 'x' not in names # x is reserved
@@ -191,7 +339,8 @@ def test_copy(self, expr: ExpressionComponent):
assert expr_copy is not expr
assert isinstance(expr_copy, ExpressionComponent)
assert expr_copy.expression == expr.expression
- assert expr_copy.unit == expr.unit
+ assert expr_copy.x_unit == expr.x_unit
+ assert expr_copy.y_unit == expr.y_unit
assert expr_copy.display_name == expr.display_name
assert expr_copy.A.value == pytest.approx(expr.A.value)
@@ -209,3 +358,416 @@ def test_erf(self):
# EXPECT
expected = np.array([-0.84270079, 0.0, 0.84270079]) # erf(-1), erf(0), erf(1)
np.testing.assert_allclose(result, expected, rtol=1e-5)
+
+
+GAUSSIAN_EXPRESSION = 'A / (sigma*sqrt(2*pi)) * exp(-(x - x0)**2 / (2*sigma**2))'
+GAUSSIAN_UNITS = {'A': 'meV', 'x0': 'meV', 'sigma': 'meV'}
+
+
+class TestExpressionComponentUnitCorrectness:
+ """Compare unit-aware expressions against the built-in components."""
+
+ @pytest.fixture
+ def gaussian_expr(self):
+ return ExpressionComponent(
+ GAUSSIAN_EXPRESSION,
+ parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6},
+ parameter_units=GAUSSIAN_UNITS,
+ x_unit='meV',
+ )
+
+ def test_matches_gaussian_component(self, gaussian_expr):
+ # WHEN
+ gaussian = Gaussian(area=2.0, center=0.5, width=0.6, x_unit='meV')
+ x = np.linspace(-3, 3, 101)
+
+ # THEN EXPECT: same values on a plain numpy grid
+ np.testing.assert_allclose(gaussian_expr.evaluate(x), gaussian.evaluate(x))
+
+ def test_matches_gaussian_component_scipp_input_and_output(self, gaussian_expr):
+ # WHEN
+ gaussian = Gaussian(area=2.0, center=0.5, width=0.6, x_unit='meV')
+ x = sc.linspace('energy', -3.0, 3.0, 101, unit='meV')
+
+ # THEN
+ expr_result = gaussian_expr.evaluate(x, output='scipp')
+ gaussian_result = gaussian.evaluate(x, output='scipp')
+
+ # EXPECT: same values and the same output unit
+ np.testing.assert_allclose(expr_result.values, gaussian_result.values)
+ assert expr_result.unit == gaussian_result.unit
+
+ def test_matches_gaussian_component_in_other_unit(self):
+ # WHEN: the built-in Gaussian auto-converts its meV parameters to a ueV grid; the
+ # expression is given the physically identical parameters expressed in ueV directly.
+ gaussian = Gaussian(area=2.0, center=0.5, width=0.6, x_unit='meV')
+ expr = ExpressionComponent(
+ GAUSSIAN_EXPRESSION,
+ parameters={'A': 2000.0, 'x0': 500.0, 'sigma': 600.0},
+ parameter_units={'A': 'ueV', 'x0': 'ueV', 'sigma': 'ueV'},
+ x_unit='ueV',
+ )
+ x = sc.linspace('energy', -3000.0, 3000.0, 101, unit='ueV')
+
+ # THEN EXPECT: identical physical results (the Gaussian output is dimensionless, so no
+ # y-scale factor is involved)
+ np.testing.assert_allclose(expr.evaluate(x), gaussian.evaluate(x))
+
+ def test_matches_lorentzian_component(self):
+ # WHEN
+ lorentzian = Lorentzian(area=2.0, center=0.5, width=0.3, x_unit='meV')
+ # Note: 'gamma' would collide with sympy's gamma function, so use 'hwhm'
+ expr = ExpressionComponent(
+ 'A / pi * hwhm / ((x - x0)**2 + hwhm**2)',
+ parameters={'A': 2.0, 'x0': 0.5, 'hwhm': 0.3},
+ parameter_units={'A': 'meV', 'x0': 'meV', 'hwhm': 'meV'},
+ x_unit='meV',
+ )
+ x = np.linspace(-3, 3, 101)
+
+ # THEN EXPECT
+ np.testing.assert_allclose(expr.evaluate(x), lorentzian.evaluate(x))
+
+ def test_consistent_units_do_not_warn(self):
+ # WHEN THEN EXPECT: a fully unit-consistent expression constructs without warnings
+ with warnings.catch_warnings():
+ warnings.simplefilter('error')
+ ExpressionComponent(
+ GAUSSIAN_EXPRESSION,
+ parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6},
+ parameter_units=GAUSSIAN_UNITS,
+ x_unit='meV',
+ )
+
+ def test_unit_agnostic_expression_does_not_warn(self):
+ # WHEN THEN EXPECT: without any parameter units the consistency check stays silent,
+ # even though x_unit is meV and the parameters are dimensionless
+ with warnings.catch_warnings():
+ warnings.simplefilter('error')
+ ExpressionComponent(
+ GAUSSIAN_EXPRESSION,
+ parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6},
+ x_unit='meV',
+ )
+
+
+class TestExpressionComponentOutputUnit:
+ def test_output_unit_gaussian_is_dimensionless(self):
+ # WHEN: area in meV divided by sigma in meV
+ expr = ExpressionComponent(
+ GAUSSIAN_EXPRESSION,
+ parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6},
+ parameter_units=GAUSSIAN_UNITS,
+ x_unit='meV',
+ )
+
+ # THEN EXPECT
+ assert expr.output_unit == 'dimensionless'
+
+ def test_output_unit_with_dimensionless_area_is_one_over_x(self):
+ # WHEN: dimensionless amplitude over a meV width gives 1/meV
+ with pytest.warns(UserWarning, match='does not match'):
+ expr = ExpressionComponent(
+ GAUSSIAN_EXPRESSION,
+ parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6},
+ parameter_units={'x0': 'meV', 'sigma': 'meV'},
+ x_unit='meV',
+ )
+
+ # THEN EXPECT
+ assert sc.Unit(expr.output_unit) == sc.Unit('1/meV')
+
+ def test_output_unit_matching_y_unit_does_not_warn(self):
+ # WHEN THEN EXPECT: declaring the matching y_unit silences the mismatch warning
+ with warnings.catch_warnings():
+ warnings.simplefilter('error')
+ expr = ExpressionComponent(
+ GAUSSIAN_EXPRESSION,
+ parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6},
+ parameter_units={'x0': 'meV', 'sigma': 'meV'},
+ x_unit='meV',
+ y_unit='1/meV',
+ )
+ assert sc.Unit(expr.output_unit) == sc.Unit(expr.y_unit)
+
+ def test_output_unit_preserved_by_abs(self):
+ # WHEN
+ with pytest.warns(UserWarning, match='does not match'):
+ expr = ExpressionComponent(
+ 'abs(x - x0)', parameters={'x0': 0.5}, parameter_units={'x0': 'meV'}
+ )
+
+ # THEN EXPECT
+ assert sc.Unit(expr.output_unit) == sc.Unit('meV')
+
+ def test_output_unit_sign_is_dimensionless(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'A * sign(x - x0)', parameters={'A': 1.0, 'x0': 0.5}, parameter_units={'x0': 'meV'}
+ )
+
+ # THEN EXPECT
+ assert expr.output_unit == 'dimensionless'
+
+ def test_output_unit_sqrt_of_squared_quantity(self):
+ # WHEN: sqrt(sigma**2) is a half-integer power of a quantity with units
+ with warnings.catch_warnings():
+ warnings.simplefilter('error')
+ expr = ExpressionComponent(
+ 'x / sqrt(2*pi*sigma**2) + offset',
+ parameters={'sigma': 0.5, 'offset': 1.0},
+ parameter_units={'sigma': 'meV'},
+ x_unit='meV',
+ )
+
+ # THEN EXPECT: meV / sqrt(meV**2) + dimensionless = dimensionless
+ assert expr.output_unit == 'dimensionless'
+
+ def test_output_unit_incompatible_addition_raises(self):
+ # WHEN
+ with pytest.warns(UserWarning, match='not unit-consistent'):
+ expr = ExpressionComponent(
+ 'x + A', parameters={'A': 1.0}, parameter_units={'A': 'K'}, x_unit='meV'
+ )
+
+ # THEN EXPECT
+ with pytest.raises(sc.UnitError, match='incompatible units'):
+ _ = expr.output_unit
+
+ def test_output_unit_exp_of_dimensional_quantity_raises(self):
+ # WHEN
+ with pytest.warns(UserWarning, match='not unit-consistent'):
+ expr = ExpressionComponent(
+ 'exp(x * A)', parameters={'A': 1.0}, parameter_units={'A': 'meV'}, x_unit='meV'
+ )
+
+ # THEN EXPECT
+ with pytest.raises(sc.UnitError, match='dimensionless argument'):
+ _ = expr.output_unit
+
+ def test_output_unit_symbolic_power_of_dimensional_base_raises(self):
+ # WHEN
+ with pytest.warns(UserWarning, match='not unit-consistent'):
+ expr = ExpressionComponent(
+ 'sigma**n * x',
+ parameters={'sigma': 0.5, 'n': 2.0},
+ parameter_units={'sigma': 'meV'},
+ )
+
+ # THEN EXPECT
+ with pytest.raises(sc.UnitError, match='symbolic power'):
+ _ = expr.output_unit
+
+ def test_output_unit_exponent_with_unit_raises(self):
+ # WHEN: the exponent itself carries a unit
+ with pytest.warns(UserWarning, match='not unit-consistent'):
+ expr = ExpressionComponent(
+ 'x**A', parameters={'A': 1.0}, parameter_units={'A': 'meV'}, x_unit='meV'
+ )
+
+ # THEN EXPECT
+ with pytest.raises(sc.UnitError, match='exponent must be dimensionless'):
+ _ = expr.output_unit
+
+ def test_output_unit_quarter_power_of_dimensional_base_raises(self):
+ # WHEN: only integer and half-integer powers of quantities with units are supported
+ with pytest.warns(UserWarning, match='not unit-consistent'):
+ expr = ExpressionComponent(
+ 'sigma**0.25 * x', parameters={'sigma': 0.5}, parameter_units={'sigma': 'meV'}
+ )
+
+ # THEN EXPECT
+ with pytest.raises(sc.UnitError, match='raised to the power'):
+ _ = expr.output_unit
+
+ def test_output_unit_with_x_unit_none(self):
+ # WHEN: without an x_unit, x is treated as dimensionless in the propagation
+ with pytest.warns(UserWarning, match='does not match'):
+ expr = ExpressionComponent(
+ 'A * x', parameters={'A': 1.0}, parameter_units={'A': 'meV'}, x_unit=None
+ )
+
+ # THEN EXPECT
+ assert sc.Unit(expr.output_unit) == sc.Unit('meV')
+
+ def test_evaluate_warns_for_genuinely_different_x_unit(self):
+ # WHEN: x carries a different (but compatible) unit than x_unit
+ expr = ExpressionComponent(
+ 'A * x', parameters={'A': 1.0}, parameter_units={'A': '1/meV'}, x_unit='meV'
+ )
+ x = sc.linspace('energy', -1.0, 1.0, 11, unit='ueV')
+
+ # THEN EXPECT: ExpressionComponent cannot auto-convert, so it warns and uses raw values
+ with pytest.warns(UserWarning, match='cannot auto-convert'):
+ result = expr.evaluate(x)
+ np.testing.assert_allclose(result, x.values)
+
+ def test_evaluate_does_not_warn_for_equivalent_unit_spelling(self):
+ # WHEN: 'ueV' and scipp's canonical micro-sign spelling are the same unit
+ expr = ExpressionComponent(
+ 'A * x', parameters={'A': 1.0}, parameter_units={'A': '1/ueV'}, x_unit='ueV'
+ )
+ x = sc.linspace('energy', -1.0, 1.0, 11, unit='\u00b5eV')
+
+ # THEN EXPECT: no spurious warning from the differing spellings
+ with warnings.catch_warnings():
+ warnings.simplefilter('error')
+ expr.evaluate(x)
+
+ def test_set_unit_warns_when_breaking_consistency(self):
+ # WHEN: a consistent expression
+ expr = ExpressionComponent(
+ 'A * (x - x0)',
+ parameters={'A': 1.0, 'x0': 0.5},
+ parameter_units={'A': '1/meV', 'x0': 'meV'},
+ x_unit='meV',
+ )
+
+ # THEN EXPECT: relabelling A breaks the output unit
+ with pytest.warns(UserWarning, match='does not match'):
+ expr.set_unit('A', '1/ueV')
+
+
+class TestExpressionComponentPhysicalConstants:
+ def test_kb_constant_value_and_unit(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'}
+ )
+
+ # THEN EXPECT: Boltzmann constant in meV/K, a read-only DescriptorNumber
+ assert expr.kb.value == pytest.approx(0.08617333, rel=1e-6)
+ assert sc.Unit(str(expr.kb.unit)) == sc.Unit('meV/K')
+ assert isinstance(expr.kb, DescriptorNumber)
+ assert not isinstance(expr.kb, Parameter)
+
+ def test_hbar_constant_value_and_unit(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'exp(-(x * tau / hbar)**2)', parameters={'tau': 1.0}, parameter_units={'tau': 'ps'}
+ )
+
+ # THEN EXPECT: hbar in meV*ps, a read-only DescriptorNumber
+ assert expr.hbar.value == pytest.approx(0.6582120, rel=1e-6)
+ assert sc.Unit(str(expr.hbar.unit)) == sc.Unit('meV*ps')
+ assert isinstance(expr.hbar, DescriptorNumber)
+ assert not isinstance(expr.hbar, Parameter)
+
+ def test_boltzmann_factor_value(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'}
+ )
+ x = np.array([0.0, 10.0, 25.0])
+
+ # THEN
+ result = expr.evaluate(x)
+
+ # EXPECT
+ expected = np.exp(-x / (0.08617333262145177 * 300.0))
+ np.testing.assert_allclose(result, expected, rtol=1e-9)
+
+ def test_constants_are_unit_consistent(self):
+ # WHEN THEN EXPECT: x/(kb*T) is dimensionless, so no warning is raised
+ with warnings.catch_warnings():
+ warnings.simplefilter('error')
+ expr = ExpressionComponent(
+ 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'}
+ )
+ assert expr.output_unit == 'dimensionless'
+
+ def test_constant_with_dimensionless_temperature_warns(self):
+ # WHEN THEN EXPECT: forgetting the unit on T makes x/(kb*T) carry kelvins
+ with pytest.warns(UserWarning, match='not unit-consistent'):
+ ExpressionComponent('exp(-x / (kb * T))', parameters={'T': 300.0})
+
+ def test_constants_property(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'}
+ )
+
+ # THEN EXPECT
+ assert set(expr.constants) == {'kb'}
+
+ def test_constant_not_in_get_all_variables(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'}
+ )
+
+ # THEN EXPECT: constants are not fittable variables
+ names = {p.name for p in expr.get_all_variables()}
+ assert names == {'T'}
+
+ def test_constant_cannot_be_set(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'}
+ )
+
+ # THEN EXPECT
+ with pytest.raises(AttributeError, match='physical constant'):
+ expr.kb = 5.0
+
+ def test_set_unit_on_constant_raises(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'}
+ )
+
+ # THEN EXPECT
+ with pytest.raises(AttributeError, match='physical constant'):
+ expr.set_unit('kb', 'meV/K')
+
+ def test_parameter_units_for_constant_raises(self):
+ # WHEN THEN EXPECT
+ with pytest.raises(ValueError, match='physical constant'):
+ ExpressionComponent(
+ 'exp(-x / (kb * T))',
+ parameters={'T': 300.0},
+ parameter_units={'T': 'K', 'kb': 'meV/K'},
+ )
+
+ def test_constant_convert_unit_rescales(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'}
+ )
+
+ # THEN
+ expr.kb.convert_unit('eV/K')
+
+ # EXPECT: the value is rescaled with the unit
+ assert expr.kb.value == pytest.approx(8.617333e-05, rel=1e-6)
+
+ def test_user_parameter_overrides_constant(self):
+ # WHEN: the user explicitly provides 'kb' as a parameter
+ with warnings.catch_warnings():
+ warnings.simplefilter('error')
+ expr = ExpressionComponent('exp(-x / (kb * T))', parameters={'kb': 2.0, 'T': 300.0})
+
+ # THEN EXPECT: it is an ordinary (fittable, dimensionless) parameter
+ assert not expr.constants
+ assert expr.kb.value == pytest.approx(2.0)
+ assert expr.kb.fixed is False
+ names = {p.name for p in expr.get_all_variables()}
+ assert names == {'kb', 'T'}
+
+ def test_constant_in_dir(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'}
+ )
+
+ # THEN EXPECT
+ assert 'kb' in dir(expr)
+
+ def test_constant_in_repr(self):
+ # WHEN
+ expr = ExpressionComponent(
+ 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'}
+ )
+
+ # THEN EXPECT
+ assert 'kb' in repr(expr)
diff --git a/tests/unit/easydynamics/sample_model/components/test_gaussian.py b/tests/unit/easydynamics/sample_model/components/test_gaussian.py
index a370e74fb..3ef9b01b7 100644
--- a/tests/unit/easydynamics/sample_model/components/test_gaussian.py
+++ b/tests/unit/easydynamics/sample_model/components/test_gaussian.py
@@ -5,7 +5,9 @@
import numpy as np
import pytest
+import scipp as sc
from easyscience.variable import Parameter
+from scipp import UnitError
from scipy.integrate import simpson
from easydynamics.sample_model import Gaussian
@@ -20,7 +22,7 @@ def gaussian(self):
area=2.0,
center=0.5,
width=0.6,
- unit='meV',
+ x_unit='meV',
)
def test_init_no_inputs(self):
@@ -32,7 +34,8 @@ def test_init_no_inputs(self):
assert gaussian.area.value == pytest.approx(1.0)
assert gaussian.center.value == pytest.approx(0.0)
assert gaussian.width.value == pytest.approx(1.0)
- assert gaussian.unit == 'meV'
+ assert gaussian.x_unit == 'meV'
+ assert gaussian.y_unit == 'dimensionless'
assert gaussian.center.fixed is True
def test_initialization(self, gaussian: Gaussian):
@@ -41,33 +44,38 @@ def test_initialization(self, gaussian: Gaussian):
assert gaussian.area.value == pytest.approx(2.0)
assert gaussian.center.value == pytest.approx(0.5)
assert gaussian.width.value == pytest.approx(0.6)
- assert gaussian.unit == 'meV'
+ assert gaussian.x_unit == 'meV'
@pytest.mark.parametrize(
'kwargs, expected_message',
[
(
- {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'unit': 'meV'},
+ {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'x_unit': 'meV'},
'area must be a number',
),
(
- {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'unit': 'meV'},
+ {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'x_unit': 'meV'},
'center must be None or a number',
),
(
- {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'unit': 'meV'},
+ {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'x_unit': 'meV'},
'width must be a number',
),
(
- {'area': 2.0, 'center': 0.5, 'width': 0.6, 'unit': 123},
- 'unit must be None',
+ {'area': 2.0, 'center': 0.5, 'width': 0.6, 'x_unit': 123},
+ 'unit must be None, a string',
+ ),
+ (
+ {'area': 2.0, 'center': 0.5, 'width': 0.6, 'x_unit': 'meV', 'y_unit': 123},
+ 'unit must be None, a string',
),
],
ids=[
'invalid area',
'invalid center',
'invalid width',
- 'invalid unit',
+ 'invalid x_unit',
+ 'invalid y_unit',
],
)
def test_input_type_validation_raises(self, kwargs, expected_message):
@@ -84,7 +92,7 @@ def test_negative_width_raises(self):
area=2.0,
center=0.5,
width=-0.6,
- unit='meV',
+ x_unit='meV',
)
def test_negative_area_warns(self):
@@ -95,7 +103,7 @@ def test_negative_area_warns(self):
area=-2.0,
center=0.5,
width=0.6,
- unit='meV',
+ x_unit='meV',
)
@pytest.mark.parametrize(
@@ -109,11 +117,12 @@ def test_negative_area_warns(self):
def test_property_setters(
self, gaussian: Gaussian, prop, valid_value, invalid_value, invalid_message
):
- # set valid
+ # WHEN: set a valid value
setattr(gaussian, prop, valid_value)
+ # THEN EXPECT
assert getattr(gaussian, prop).value == valid_value
- # invalid
+ # WHEN: set an invalid value — THEN EXPECT
with pytest.raises(TypeError, match=invalid_message):
setattr(gaussian, prop, invalid_value)
@@ -177,12 +186,12 @@ def test_area_matches_parameter(self, gaussian: Gaussian):
numerical_area = simpson(y, x)
assert np.isclose(numerical_area, gaussian.area.value, rtol=1e-3)
- def test_convert_unit(self, gaussian: Gaussian):
+ def test_convert_x_unit(self, gaussian: Gaussian):
# WHEN THEN
- gaussian.convert_unit('microeV')
+ gaussian.convert_x_unit('microeV')
# EXPECT
- assert gaussian.unit == 'microeV'
+ assert gaussian.x_unit == 'microeV'
assert gaussian.area.value == pytest.approx(2 * 1e3)
assert gaussian.center.value == pytest.approx(0.5 * 1e3)
assert gaussian.width.value == pytest.approx(0.6 * 1e3)
@@ -216,15 +225,91 @@ def test_copy(self, gaussian: Gaussian):
assert gaussian_copy.width.min == gaussian.width.min
assert gaussian_copy.width.max == gaussian.width.max
- assert gaussian_copy.unit == gaussian.unit
+ assert gaussian_copy.x_unit == gaussian.x_unit
def test_repr(self, gaussian: Gaussian):
# WHEN THEN
repr_str = repr(gaussian)
# EXPECT
assert 'Gaussian' in repr_str
- assert "name='GaussianName'" in repr_str
- assert 'unit=meV' in repr_str
- assert 'area=' in repr_str
- assert 'center=' in repr_str
- assert 'width=' in repr_str
+ assert 'name = GaussianName' in repr_str
+ assert 'x_unit = meV' in repr_str
+ assert 'area =' in repr_str
+ assert 'center =' in repr_str
+ assert 'width =' in repr_str
+
+ def test_y_unit_custom(self):
+ # WHEN THEN
+ gaussian = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV')
+ # EXPECT
+ assert gaussian.y_unit == '1/meV'
+
+ def test_y_unit_setter_raises(self, gaussian: Gaussian):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError):
+ gaussian.y_unit = '1/meV'
+
+ def test_convert_y_unit(self):
+ # WHEN: x_unit='meV', y_unit='1/meV' → area_unit ≈ dimensionless
+ gaussian = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV')
+
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ gaussian.convert_y_unit('1/eV')
+
+ # EXPECT: unit updated and area value rescaled (1/eV = 1e-3/meV, so value x 1e3)
+ assert gaussian.y_unit == '1/eV'
+ assert gaussian.area.value == pytest.approx(1e3)
+
+ def test_convert_y_unit_invalid_type_raises(self, gaussian: Gaussian):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError):
+ gaussian.convert_y_unit(123)
+
+ def test_evaluate_scipp_output(self, gaussian: Gaussian):
+ # WHEN
+ x = np.linspace(-5, 5, 100)
+
+ # THEN
+ result = gaussian.evaluate(x, output='scipp')
+
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('dimensionless')
+ assert len(result.values) == 100
+ np.testing.assert_allclose(result.values, gaussian.evaluate(x, output='numpy'))
+
+ def test_evaluate_scipp_output_with_y_unit(self):
+ gaussian = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV')
+ x = np.linspace(-5, 5, 100)
+
+ # WHEN
+ result = gaussian.evaluate(x, output='scipp')
+
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('1/meV')
+
+ def test_convert_x_unit_invalid_type_raises(self, gaussian: Gaussian):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match=r'x_unit must be a string or sc\.Unit'):
+ gaussian.convert_x_unit(123)
+
+ def test_convert_x_unit_rollback_on_failure(self, gaussian: Gaussian):
+ # WHEN THEN
+ with pytest.raises(UnitError):
+ gaussian.convert_x_unit('m')
+ # EXPECT: state rolled back
+ assert gaussian.x_unit == 'meV'
+ assert gaussian.area.value == pytest.approx(2.0)
+ assert gaussian.center.value == pytest.approx(0.5)
+ assert gaussian.width.value == pytest.approx(0.6)
+
+ def test_convert_y_unit_rollback_on_failure(self):
+ # WHEN
+ gaussian = Gaussian(area=1.0, center=0.0, width=0.5, x_unit='meV')
+ # THEN
+ with pytest.raises(UnitError):
+ gaussian.convert_y_unit('K')
+ # EXPECT: state rolled back
+ assert gaussian.y_unit == 'dimensionless'
+ assert gaussian.area.value == pytest.approx(1.0)
diff --git a/tests/unit/easydynamics/sample_model/components/test_lorentzian.py b/tests/unit/easydynamics/sample_model/components/test_lorentzian.py
index 1aefee6a9..97e02aad9 100644
--- a/tests/unit/easydynamics/sample_model/components/test_lorentzian.py
+++ b/tests/unit/easydynamics/sample_model/components/test_lorentzian.py
@@ -5,7 +5,9 @@
import numpy as np
import pytest
+import scipp as sc
from easyscience.variable import Parameter
+from scipp import UnitError
from scipy.integrate import simpson
from easydynamics.sample_model import Lorentzian
@@ -20,7 +22,7 @@ def lorentzian(self):
area=2.0,
center=0.5,
width=0.6,
- unit='meV',
+ x_unit='meV',
)
def test_init_no_inputs(self):
@@ -32,7 +34,8 @@ def test_init_no_inputs(self):
assert lorentzian.area.value == pytest.approx(1.0)
assert lorentzian.center.value == pytest.approx(0.0)
assert lorentzian.width.value == pytest.approx(1.0)
- assert lorentzian.unit == 'meV'
+ assert lorentzian.x_unit == 'meV'
+ assert lorentzian.y_unit == 'dimensionless'
assert lorentzian.center.fixed is True
def test_initialization(self, lorentzian: Lorentzian):
@@ -41,26 +44,30 @@ def test_initialization(self, lorentzian: Lorentzian):
assert lorentzian.area.value == pytest.approx(2.0)
assert lorentzian.center.value == pytest.approx(0.5)
assert lorentzian.width.value == pytest.approx(0.6)
- assert lorentzian.unit == 'meV'
+ assert lorentzian.x_unit == 'meV'
@pytest.mark.parametrize(
'kwargs, expected_message',
[
(
- {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'unit': 'meV'},
+ {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'x_unit': 'meV'},
'area must be a number',
),
(
- {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'unit': 'meV'},
- 'center must be None',
+ {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'x_unit': 'meV'},
+ 'center must be None or a number',
),
(
- {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'unit': 'meV'},
+ {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'x_unit': 'meV'},
'width must be a number',
),
(
- {'area': 2.0, 'center': 0.5, 'width': 0.6, 'unit': 123},
- 'unit must be None',
+ {'area': 2.0, 'center': 0.5, 'width': 0.6, 'x_unit': 123},
+ 'unit must be None, a string',
+ ),
+ (
+ {'area': 2.0, 'center': 0.5, 'width': 0.6, 'x_unit': 'meV', 'y_unit': 123},
+ 'unit must be None, a string',
),
],
)
@@ -78,7 +85,7 @@ def test_negative_width_raises(self):
area=2.0,
center=0.5,
width=-0.6,
- unit='meV',
+ x_unit='meV',
)
def test_negative_area_warns(self):
@@ -89,7 +96,7 @@ def test_negative_area_warns(self):
area=-2.0,
center=0.5,
width=0.6,
- unit='meV',
+ x_unit='meV',
)
@pytest.mark.parametrize(
@@ -103,11 +110,12 @@ def test_negative_area_warns(self):
def test_property_setters(
self, lorentzian: Lorentzian, prop, valid_value, invalid_value, invalid_message
):
- # set valid
+ # WHEN: set a valid value
setattr(lorentzian, prop, valid_value)
+ # THEN EXPECT
assert getattr(lorentzian, prop).value == valid_value
- # invalid
+ # WHEN: set an invalid value — THEN EXPECT
with pytest.raises(TypeError, match=invalid_message):
setattr(lorentzian, prop, invalid_value)
@@ -166,12 +174,12 @@ def test_area_matches_parameter(self, lorentzian: Lorentzian):
# EXPECT
assert numerical_area == pytest.approx(lorentzian.area.value, rel=2e-3)
- def test_convert_unit(self, lorentzian: Lorentzian):
+ def test_convert_x_unit(self, lorentzian: Lorentzian):
# WHEN THEN
- lorentzian.convert_unit('microeV')
+ lorentzian.convert_x_unit('microeV')
# EXPECT
- assert lorentzian.unit == 'microeV'
+ assert lorentzian.x_unit == 'microeV'
assert lorentzian.area.value == pytest.approx(2 * 1e3)
assert lorentzian.center.value == pytest.approx(0.5 * 1e3)
assert lorentzian.width.value == pytest.approx(0.6 * 1e3)
@@ -193,7 +201,7 @@ def test_copy(self, lorentzian: Lorentzian):
assert lorentzian_copy.width.value == lorentzian.width.value
assert lorentzian_copy.width.fixed == lorentzian.width.fixed
- assert lorentzian_copy.unit == lorentzian.unit
+ assert lorentzian_copy.x_unit == lorentzian.x_unit
def test_repr(self, lorentzian: Lorentzian):
# WHEN THEN
@@ -201,8 +209,79 @@ def test_repr(self, lorentzian: Lorentzian):
# EXPECT
assert 'Lorentzian' in repr_str
- assert "name='LorentzianName'" in repr_str
- assert 'unit=meV' in repr_str
- assert 'area=' in repr_str
- assert 'center=' in repr_str
- assert 'width=' in repr_str
+ assert 'name = LorentzianName' in repr_str
+ assert 'x_unit = meV' in repr_str
+ assert 'area =' in repr_str
+ assert 'center =' in repr_str
+ assert 'width =' in repr_str
+
+ def test_y_unit_custom(self):
+ # WHEN THEN
+ lor = Lorentzian(area=1.0, x_unit='meV', y_unit='1/meV')
+ # EXPECT
+ assert lor.y_unit == '1/meV'
+
+ def test_y_unit_setter_raises(self, lorentzian: Lorentzian):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError):
+ lorentzian.y_unit = '1/meV'
+
+ def test_convert_y_unit(self):
+ # WHEN: x_unit='meV', y_unit='1/meV' → area_unit='dimensionless'
+ lor = Lorentzian(area=1.0, x_unit='meV', y_unit='1/meV')
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ lor.convert_y_unit('1/eV')
+ # EXPECT: y_unit updated and area value rescaled (1e3 factor)
+ assert lor.y_unit == '1/eV'
+ assert lor.area.value == pytest.approx(1e3)
+
+ def test_convert_y_unit_invalid_type_raises(self, lorentzian: Lorentzian):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError):
+ lorentzian.convert_y_unit(123)
+
+ def test_evaluate_scipp_output(self, lorentzian: Lorentzian):
+ # WHEN
+ x = np.linspace(-5, 5, 50)
+ # THEN
+ result = lorentzian.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('dimensionless')
+ assert len(result.values) == 50
+ np.testing.assert_allclose(result.values, lorentzian.evaluate(x, output='numpy'))
+
+ def test_evaluate_scipp_output_with_y_unit(self):
+ # WHEN
+ lor = Lorentzian(area=1.0, x_unit='meV', y_unit='1/meV')
+ x = np.linspace(-5, 5, 50)
+ # THEN
+ result = lor.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('1/meV')
+
+ def test_convert_x_unit_invalid_type_raises(self, lorentzian: Lorentzian):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match=r'x_unit must be a string or sc\.Unit'):
+ lorentzian.convert_x_unit(123)
+
+ def test_convert_x_unit_rollback_on_failure(self, lorentzian: Lorentzian):
+ # WHEN THEN
+ with pytest.raises(UnitError):
+ lorentzian.convert_x_unit('m')
+ # EXPECT: state rolled back
+ assert lorentzian.x_unit == 'meV'
+ assert lorentzian.area.value == pytest.approx(2.0)
+ assert lorentzian.center.value == pytest.approx(0.5)
+ assert lorentzian.width.value == pytest.approx(0.6)
+
+ def test_convert_y_unit_rollback_on_failure(self):
+ # WHEN
+ lor = Lorentzian(area=1.0, center=0.0, width=0.5, x_unit='meV')
+ # THEN
+ with pytest.raises(UnitError):
+ lor.convert_y_unit('K')
+ # EXPECT: state rolled back
+ assert lor.y_unit == 'dimensionless'
+ assert lor.area.value == pytest.approx(1.0)
diff --git a/tests/unit/easydynamics/sample_model/components/test_mixins.py b/tests/unit/easydynamics/sample_model/components/test_mixins.py
index e30f4bdd1..bc7a312d1 100644
--- a/tests/unit/easydynamics/sample_model/components/test_mixins.py
+++ b/tests/unit/easydynamics/sample_model/components/test_mixins.py
@@ -18,7 +18,7 @@ def dummy_model(self):
@pytest.mark.parametrize('area_input', [2, 2.0])
def test_create_area_parameter_from_numeric(self, dummy_model, area_input, unit):
# WHEN THEN
- area_param = dummy_model._create_area_parameter(area_input, 'TestModel', unit=unit)
+ area_param = dummy_model._create_area_parameter(area_input, 'TestModel', x_unit=unit)
# EXPECT
assert isinstance(area_param, Parameter)
@@ -59,7 +59,7 @@ def test_negative_area_warns(self, dummy_model):
def test_create_center_parameter_from_numeric(self, dummy_model, center_input, unit):
# WHEN THEN
center_param = dummy_model._create_center_parameter(
- center_input, 'TestModel', fix_if_none=False, unit=unit
+ center_input, 'TestModel', fix_if_none=False, x_unit=unit
)
# EXPECT
assert isinstance(center_param, Parameter)
@@ -104,7 +104,7 @@ def test_create_center_parameter_invalid_numeric_raises(self, dummy_model, non_f
def test_create_width_parameter_from_numeric(self, dummy_model, width_input, unit):
# WHEN THEN
width_param = dummy_model._create_width_parameter(
- width_input, 'TestModel', param_name='width', unit=unit
+ width_input, 'TestModel', param_name='width', x_unit=unit
)
# EXPECT
assert isinstance(width_param, Parameter)
diff --git a/tests/unit/easydynamics/sample_model/components/test_model_component.py b/tests/unit/easydynamics/sample_model/components/test_model_component.py
index 514ced73b..8a1a08678 100644
--- a/tests/unit/easydynamics/sample_model/components/test_model_component.py
+++ b/tests/unit/easydynamics/sample_model/components/test_model_component.py
@@ -5,7 +5,9 @@
import pytest
import scipp as sc
from easyscience.variable import Parameter
+from scipp import UnitError
+from easydynamics.sample_model.components.gaussian import Gaussian
from easydynamics.sample_model.components.model_component import ModelComponent
@@ -15,13 +17,13 @@ def __init__(self):
self.area = Parameter(name='area', value=1.0, unit='meV', fixed=False)
self.center = Parameter(name='center', value=2.0, unit='meV', fixed=True)
self.width = Parameter(name='width', value=3.0, unit='meV', fixed=True)
- self._unit = 'meV'
+ self._x_unit = 'meV'
def get_all_parameters(self):
return [self.area, self.center, self.width]
- def evaluate(self, x):
- return np.zeros_like(x)
+ def _evaluate_values(self, x_vals, _eval_unit):
+ return np.zeros_like(x_vals)
class TestModelComponent:
@@ -31,23 +33,23 @@ def dummy(self):
def test_unit_cannot_be_set_directly(self, dummy: ModelComponent):
# WHEN THEN EXPECT
- with pytest.raises(AttributeError, match='Unit is read-only'):
- dummy.unit = 'K'
+ with pytest.raises(AttributeError, match='read-only'):
+ dummy.x_unit = 'K'
def test_convert_unit(self, dummy: DummyComponent):
# WHEN THEN
- dummy.convert_unit('microeV')
+ dummy.convert_x_unit('microeV')
# EXPECT
- assert dummy.unit == 'microeV'
+ assert dummy.x_unit == 'microeV'
assert dummy.area.value == pytest.approx(1 * 1e3)
assert dummy.center.value == pytest.approx(2 * 1e3)
assert dummy.width.value == pytest.approx(3 * 1e3)
def test_convert_unit_incorrect_unit_raises(self, dummy: DummyComponent):
# WHEN THEN EXPECT
- with pytest.raises(TypeError, match=r'Unit must be a string or sc.Unit'):
- dummy.convert_unit(123)
+ with pytest.raises(TypeError, match=r'unit must be a string or sc.Unit'):
+ dummy.convert_x_unit(123)
def test_free_and_fix_all_parameters(self, dummy):
# WHEN THEN EXPECT
@@ -92,8 +94,11 @@ def test_repr(self, dummy):
],
)
def test_prepare_x_for_evaluate_various_inputs(self, dummy, x_input, expected_array):
- x_prepared = dummy._prepare_x_for_evaluate(x_input)
+ # WHEN THEN
+ result = dummy._prepare_x_for_evaluate(x_input)
+ x_prepared, _detected_unit, _dim = result
+ # EXPECT
assert isinstance(x_prepared, np.ndarray)
assert x_prepared.shape == expected_array.shape
np.testing.assert_array_equal(x_prepared, expected_array)
@@ -137,26 +142,126 @@ def test_prepare_x_for_evaluate_with_incompatible_unit_raises(self, dummy):
# THEN EXPECT
with pytest.raises(
Exception,
- match='Input x has unit nm, but DummyComponent component ',
+ match='Input x has unit nm',
):
dummy._prepare_x_for_evaluate(x)
- def test_prepare_x_for_evaluate_with_different_unit_warns(self, dummy):
+ def test_prepare_x_for_evaluate_with_different_unit_no_warn(self, dummy):
# WHEN
x = sc.array(dims=['x'], values=[1.0, 2.0, 3.0], unit='microeV')
- # THEN EXPECT
- with pytest.warns(
- UserWarning,
- match='Input x has unit [µμ]eV, but DummyComponent component ',
- ):
- x_prepared = dummy._prepare_x_for_evaluate(x)
+ # THEN: compatible units are accepted without warning;
+ # the component's x_unit is NOT mutated and x values are returned as-is.
+ x_prepared, _detected_unit, _dim = dummy._prepare_x_for_evaluate(x)
# EXPECT
assert isinstance(x_prepared, np.ndarray)
assert x_prepared.shape == (3,)
np.testing.assert_array_equal(x_prepared, [1.0, 2.0, 3.0])
- assert dummy.unit == 'µeV' # noqa: RUF001
- assert dummy.area.value == pytest.approx(1.0 * 1e3)
- assert dummy.center.value == pytest.approx(2.0 * 1e3)
- assert dummy.width.value == pytest.approx(3.0 * 1e3)
+ assert dummy.x_unit == 'meV' # component unit unchanged
+ assert dummy.area.value == pytest.approx(1.0) # parameter values unchanged
+ assert dummy.center.value == pytest.approx(2.0)
+ assert dummy.width.value == pytest.approx(3.0)
+
+ def test_resolve_param_value_same_unit_returns_raw_value(self, dummy):
+ # WHEN: target unit matches parameter unit
+
+ # THEN
+ result = dummy._resolve_param_value(dummy.area, 'meV')
+ # EXPECT: raw value returned without conversion
+ assert result == pytest.approx(dummy.area.value)
+
+ def test_resolve_param_value_none_target_returns_raw_value(self, dummy):
+ # WHEN: target unit is None
+
+ # THEN
+ result = dummy._resolve_param_value(dummy.area, None)
+
+ # EXPECT: raw value returned without conversion
+ assert result == pytest.approx(dummy.area.value)
+
+ def test_resolve_param_value_converts_without_mutating(self, dummy):
+ # WHEN: target unit differs from parameter unit
+
+ # THEN
+ result = dummy._resolve_param_value(dummy.area, 'eV')
+ # EXPECT: converted value (1.0 meV → 0.001 eV)
+ assert result == pytest.approx(0.001)
+ # parameter itself is not mutated
+ assert dummy.area.value == pytest.approx(1.0)
+ assert str(dummy.area.unit) == 'meV'
+
+ def test_evaluate_with_compatible_unit_gives_correct_result(self):
+ # WHEN: Gaussian in meV and a physically equivalent Gaussian in eV
+ g_mev = Gaussian(area=1.0, center=0.0, width=0.5, x_unit='meV')
+ g_ev = Gaussian(area=0.001, center=0.0, width=0.0005, x_unit='eV')
+
+ x_ev = sc.array(
+ dims=['energy'],
+ values=np.array([-0.002, -0.001, 0.0, 0.001, 0.002]),
+ unit='eV',
+ )
+ x_ev_np = np.array([-0.002, -0.001, 0.0, 0.001, 0.002])
+
+ # THEN: evaluate meV-Gaussian with x in eV
+ result_mev = g_mev.evaluate(x_ev)
+ result_ev = g_ev.evaluate(x_ev_np)
+
+ # EXPECT: physically identical outputs
+ np.testing.assert_allclose(result_mev, result_ev, rtol=1e-10)
+ # EXPECT: model state is unchanged
+ assert g_mev.x_unit == 'meV'
+ assert g_mev.width.value == pytest.approx(0.5)
+ assert g_mev.area.value == pytest.approx(1.0)
+
+ # ───── Regression tests ─────
+
+ def test_convert_x_unit_rollback_on_failure(self, dummy: DummyComponent):
+ # Conversion to 'm' (length) is incompatible with 'meV' (energy) → triggers rollback
+ with pytest.raises(UnitError):
+ dummy.convert_x_unit('m')
+ # Parameters should be restored to original values after rollback
+ assert dummy.x_unit == 'meV'
+ assert dummy.area.value == pytest.approx(1.0)
+ assert dummy.center.value == pytest.approx(2.0)
+ assert dummy.width.value == pytest.approx(3.0)
+
+ def test_convert_y_unit_not_implemented(self, dummy: DummyComponent):
+ with pytest.raises(NotImplementedError, match='does not support convert_y_unit'):
+ dummy.convert_y_unit('1/meV')
+
+ def test_evaluate_invalid_output_raises(self, dummy: DummyComponent):
+ # WHEN THEN EXPECT
+ with pytest.raises(ValueError, match="output must be 'numpy' or 'scipp'"):
+ dummy.evaluate(np.linspace(-1, 1, 5), output='Scipp')
+
+ def test_eval_area_unit(self, dummy: DummyComponent):
+ # WHEN THEN EXPECT: y_unit is dimensionless, so the area unit equals the eval unit
+ assert sc.Unit(dummy._eval_area_unit('meV')) == sc.Unit('meV')
+
+ def test_eval_area_unit_none_eval_unit(self, dummy: DummyComponent):
+ # WHEN THEN EXPECT: without an eval unit there is no area unit
+ assert dummy._eval_area_unit(None) is None
+
+ def test_eval_area_unit_none_y_unit(self, dummy: DummyComponent):
+ # WHEN
+ dummy._y_unit = None
+
+ # THEN EXPECT
+ assert dummy._eval_area_unit('meV') is None
+
+ def test_evaluate_preserves_dataarray_coord_key_as_dim(self):
+ # WHEN: a Gaussian and a DataArray where the coord key ('energy') differs
+ # from the coord Variable's internal dim name ('x'). This is a valid scipp
+ # non-dimension coordinate: the data's dimension is 'x' and the coord is
+ # labelled 'energy' but lives on the same 'x' axis.
+ g = Gaussian(name='G', area=1.0, center=0.0, width=1.0, x_unit='meV')
+ coord = sc.Variable(dims=['x'], values=np.linspace(-5.0, 5.0, 10), unit='meV')
+ data = sc.Variable(dims=['x'], values=np.ones(10))
+ da = sc.DataArray(data=data, coords={'energy': coord})
+ # THEN: evaluate with scipp output
+ # Before the fix, dim was overwritten with coord.dims[0] = 'x', so the
+ # output Variable had dim 'x' instead of the coord key 'energy'.
+ result = g.evaluate(da, output='scipp')
+ # EXPECT: output dim must be the coord key 'energy', not the Variable dim 'x'.
+ assert result.dims == ('energy',)
diff --git a/tests/unit/easydynamics/sample_model/components/test_polynomial.py b/tests/unit/easydynamics/sample_model/components/test_polynomial.py
index 2ba3f5525..8f6550bee 100644
--- a/tests/unit/easydynamics/sample_model/components/test_polynomial.py
+++ b/tests/unit/easydynamics/sample_model/components/test_polynomial.py
@@ -5,6 +5,7 @@
import numpy as np
import pytest
+import scipp as sc
from easyscience.variable import Parameter
from scipp import UnitError
@@ -27,7 +28,8 @@ def test_init_no_inputs(self):
# EXPECT
assert polynomial.display_name == 'Polynomial'
assert polynomial.coefficients[0].value == pytest.approx(0.0)
- assert polynomial.unit == 'meV'
+ assert polynomial.x_unit == 'meV'
+ assert polynomial.y_unit == 'dimensionless'
def test_initialization(self, polynomial: Polynomial):
# WHEN THEN EXPECT
@@ -48,7 +50,11 @@ def test_initialization(self, polynomial: Polynomial):
'Each coefficient must be ',
),
(
- {'coefficients': [1.0, -2.0, 3.0], 'unit': 123},
+ {'coefficients': [1.0, -2.0, 3.0], 'x_unit': 123},
+ 'unit must be ',
+ ),
+ (
+ {'coefficients': [1.0, -2.0, 3.0], 'x_unit': 'meV', 'y_unit': 123},
'unit must be ',
),
(
@@ -59,6 +65,7 @@ def test_initialization(self, polynomial: Polynomial):
],
)
def test_input_type_validation_raises(self, kwargs, expected_message):
+ # WHEN THEN EXPECT
with pytest.raises(TypeError, match=expected_message):
Polynomial(display_name='TestPolynomial', **kwargs)
@@ -118,13 +125,12 @@ def test_set_coefficients(self, polynomial: Polynomial, values):
assert np.isclose(polynomial.coefficients[i].value, expected)
def test_set_coefficients_wrong_length_raises(self, polynomial: Polynomial):
- """Ensure that setting coefficients with mismatched length
- raises an error."""
+ # WHEN THEN EXPECT
with pytest.raises(ValueError, match='Number of coefficients'):
polynomial.coefficients = [1.0, 2.0] # shorter list
def test_set_coefficients_invalid_type_raises(self, polynomial: Polynomial):
- """Ensure that invalid coefficient types raise a TypeError."""
+ # WHEN THEN EXPECT
with pytest.raises(TypeError):
polynomial.coefficients = [1.0, 'invalid', 3.0]
@@ -137,8 +143,11 @@ def test_set_coefficients_invalid_type_raises(self, polynomial: Polynomial):
],
)
def test_set_coefficients_raises(self, invalid_coeffs, expected_message):
+ # WHEN
+ # THEN
+ polynomial = Polynomial(display_name='TestPolynomial', coefficients=[1.0, -2.0, 3.0])
with pytest.raises(TypeError, match=expected_message):
- polynomial = Polynomial(display_name='TestPolynomial', coefficients=[1.0, -2.0, 3.0])
+ # EXPECT
polynomial.coefficients = invalid_coeffs
def test_coefficient_values(self, polynomial: Polynomial):
@@ -161,20 +170,20 @@ def test_get_all_parameters(self, polynomial: Polynomial):
actual_names = {param.name for param in params}
assert actual_names == expected_names
- def test_convert_unit(self, polynomial: Polynomial):
+ def test_convert_x_unit(self, polynomial: Polynomial):
# WHEN
- polynomial.convert_unit('microeV')
+ polynomial.convert_x_unit('microeV')
# THEN EXPECT
- assert polynomial._unit == 'microeV'
+ assert polynomial._x_unit == 'microeV'
assert np.isclose(polynomial.coefficients[0].value, 1.0)
assert np.isclose(polynomial.coefficients[1].value, -2.0 * 1e-3)
assert np.isclose(polynomial.coefficients[2].value, 3.0 * 1e-6)
- def test_convert_unit_raises_invalid_unit(self, polynomial: Polynomial):
+ def test_convert_x_unit_raises_invalid_unit(self, polynomial: Polynomial):
# WHEN THEN EXPECT
- with pytest.raises(UnitError, match='unit must be '):
- polynomial.convert_unit(123)
+ with pytest.raises(Exception, match='unit must be '):
+ polynomial.convert_x_unit(123)
def test_copy(self, polynomial: Polynomial):
# WHEN THEN
@@ -192,10 +201,101 @@ def test_copy(self, polynomial: Polynomial):
assert copied_coeff.value == original_coeff.value
assert copied_coeff.fixed == original_coeff.fixed
+ def test_y_unit_custom(self):
+ # WHEN THEN
+ p = Polynomial(coefficients=[1.0, 2.0], x_unit='meV', y_unit='1/meV')
+ # EXPECT
+ assert p.y_unit == '1/meV'
+
+ def test_y_unit_setter_raises(self, polynomial: Polynomial):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError):
+ polynomial.y_unit = '1/meV'
+
+ def test_convert_y_unit_scales_all_coefficients(self):
+ # WHEN: polynomial with two non-zero coefficients and a physical y_unit
+ p = Polynomial(coefficients=[3.0, 1.0], x_unit='meV', y_unit='meV^-1')
+ x = np.array([2.0])
+ val_before = p.evaluate(x)[0] # 3.0 + 1.0*2.0 = 5.0 [meV^-1]
+
+ # THEN
+ p.convert_y_unit('eV^-1')
+
+ # EXPECT: both coefficients rescaled by 1000 (1 meV^-1 = 1000 eV^-1)
+ assert p.y_unit == 'eV^-1'
+ assert np.isclose(p.coefficients[0].value, 3000.0)
+ assert np.isclose(p.coefficients[1].value, 1000.0)
+ assert np.isclose(p.evaluate(x)[0], val_before * 1000.0)
+
+ def test_evaluate_scipp_output(self):
+ # WHEN
+ p = Polynomial(coefficients=[1.0, 2.0], x_unit='meV', suppress_warnings=True)
+ x = np.linspace(-3, 3, 40)
+ # THEN
+ result = p.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('dimensionless')
+ assert len(result.values) == 40
+ np.testing.assert_allclose(result.values, p.evaluate(x, output='numpy'))
+
+ def test_evaluate_scipp_output_with_y_unit(self):
+ # WHEN
+ p = Polynomial(
+ coefficients=[1.0, 2.0],
+ x_unit='meV',
+ y_unit='1/meV',
+ suppress_warnings=True,
+ )
+ x = np.linspace(-3, 3, 40)
+ # THEN
+ result = p.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('1/meV')
+
def test_repr(self, polynomial: Polynomial):
# WHEN THEN
repr_str = repr(polynomial)
# EXPECT
- assert "name='PolynomialName'" in repr_str
- assert 'coefficients=' in repr_str
+ assert 'name = PolynomialName' in repr_str
+ assert 'coefficients =' in repr_str
+
+ def test_evaluate_with_scipp_x_different_compatible_unit(self):
+ # WHEN: polynomial with x_unit='meV', coefficients [1.0, 1.0] → f(x) = 1 + x
+ p = Polynomial(coefficients=[1.0, 1.0], x_unit='meV')
+ # THEN: evaluate with x in eV (compatible unit) — triggers unit-rescaling branch
+ x_eV = sc.array(dims=['x'], values=np.array([0.001, 0.002]), unit='eV')
+ result = p.evaluate(x_eV)
+ # EXPECT: 0.001 eV = 1 meV → f(1)=2, 0.002 eV = 2 meV → f(2)=3; state not mutated
+ np.testing.assert_allclose(result, [2.0, 3.0], rtol=1e-5)
+ assert p.x_unit == 'meV'
+
+ def test_evaluate_with_equivalent_unit_spelling_does_not_rescale(self):
+ # WHEN: 'ueV' and scipp's canonical micro-sign spelling are the same unit
+ # (regression: a string comparison used to treat them as different and rescale)
+ p = Polynomial(coefficients=[1.0, 1.0], x_unit='ueV', suppress_warnings=True)
+ x = sc.array(dims=['x'], values=np.array([1.0, 2.0]), unit='\u00b5eV')
+
+ # THEN
+ result = p.evaluate(x)
+
+ # EXPECT: f(x) = 1 + x with the raw coefficient values, no rescaling applied
+ np.testing.assert_allclose(result, [2.0, 3.0])
+
+ def test_convert_y_unit_invalid_type_raises(self, polynomial: Polynomial):
+ # WHEN THEN EXPECT
+ with pytest.raises(UnitError, match='new_y_unit must be a string or a scipp unit'):
+ polynomial.convert_y_unit(123)
+
+ def test_convert_y_unit_rollback_on_failure(self):
+ # WHEN
+ p = Polynomial(coefficients=[1.0, 2.0], x_unit='meV')
+ # THEN
+ with pytest.raises(UnitError):
+ p.convert_y_unit('K')
+ # EXPECT: state rolled back
+ assert p.y_unit == 'dimensionless'
+ assert np.isclose(p.coefficients[0].value, 1.0)
+ assert np.isclose(p.coefficients[1].value, 2.0)
diff --git a/tests/unit/easydynamics/sample_model/components/test_voigt.py b/tests/unit/easydynamics/sample_model/components/test_voigt.py
index 8b8d44cf3..42eb5099c 100644
--- a/tests/unit/easydynamics/sample_model/components/test_voigt.py
+++ b/tests/unit/easydynamics/sample_model/components/test_voigt.py
@@ -5,7 +5,9 @@
import numpy as np
import pytest
+import scipp as sc
from easyscience.variable import Parameter
+from scipp import UnitError
from scipy.integrate import simpson
from scipy.special import voigt_profile
@@ -22,7 +24,7 @@ def voigt(self):
center=0.5,
gaussian_width=0.6,
lorentzian_width=0.7,
- unit='meV',
+ x_unit='meV',
)
def test_init_no_inputs(self):
@@ -35,7 +37,8 @@ def test_init_no_inputs(self):
assert voigt.center.value == pytest.approx(0.0)
assert voigt.gaussian_width.value == pytest.approx(1.0)
assert voigt.lorentzian_width.value == pytest.approx(1.0)
- assert voigt.unit == 'meV'
+ assert voigt.x_unit == 'meV'
+ assert voigt.y_unit == 'dimensionless'
assert voigt.center.fixed is True
def test_initialization(self, voigt: Voigt):
@@ -45,7 +48,7 @@ def test_initialization(self, voigt: Voigt):
assert voigt.center.value == pytest.approx(0.5)
assert voigt.gaussian_width.value == pytest.approx(0.6)
assert voigt.lorentzian_width.value == pytest.approx(0.7)
- assert voigt.unit == 'meV'
+ assert voigt.x_unit == 'meV'
@pytest.mark.parametrize(
'kwargs, expected_message',
@@ -56,7 +59,7 @@ def test_initialization(self, voigt: Voigt):
'center': 0.5,
'gaussian_width': 0.6,
'lorentzian_width': 0.7,
- 'unit': 'meV',
+ 'x_unit': 'meV',
},
'area must be a number',
),
@@ -66,7 +69,7 @@ def test_initialization(self, voigt: Voigt):
'center': 'invalid',
'gaussian_width': 0.6,
'lorentzian_width': 0.7,
- 'unit': 'meV',
+ 'x_unit': 'meV',
},
'center must be None',
),
@@ -76,7 +79,7 @@ def test_initialization(self, voigt: Voigt):
'center': 0.5,
'gaussian_width': 'invalid',
'lorentzian_width': 0.7,
- 'unit': 'meV',
+ 'x_unit': 'meV',
},
'gaussian_width must be a number',
),
@@ -86,7 +89,7 @@ def test_initialization(self, voigt: Voigt):
'center': 0.5,
'gaussian_width': 0.6,
'lorentzian_width': 'invalid',
- 'unit': 'meV',
+ 'x_unit': 'meV',
},
'lorentzian_width must be a number',
),
@@ -96,13 +99,25 @@ def test_initialization(self, voigt: Voigt):
'center': 0.5,
'gaussian_width': 0.6,
'lorentzian_width': 0.7,
- 'unit': 123,
+ 'x_unit': 123,
},
- 'unit must be None,',
+ 'unit must be None, a string',
+ ),
+ (
+ {
+ 'area': 2.0,
+ 'center': 0.5,
+ 'gaussian_width': 0.6,
+ 'lorentzian_width': 0.7,
+ 'x_unit': 'meV',
+ 'y_unit': 123,
+ },
+ 'unit must be None, a string',
),
],
)
def test_input_type_validation_raises(self, kwargs, expected_message):
+ # WHEN THEN EXPECT
with pytest.raises(TypeError, match=expected_message):
Voigt(display_name='TestVoigt', **kwargs)
@@ -117,7 +132,7 @@ def test_negative_gaussian_width_raises(self):
center=0.5,
gaussian_width=-0.6,
lorentzian_width=0.7,
- unit='meV',
+ x_unit='meV',
)
def test_negative_lorentzian_width_raises(self):
@@ -132,7 +147,7 @@ def test_negative_lorentzian_width_raises(self):
center=0.5,
gaussian_width=0.6,
lorentzian_width=-0.7,
- unit='meV',
+ x_unit='meV',
)
def test_negative_area_warns(self):
@@ -144,7 +159,7 @@ def test_negative_area_warns(self):
center=0.5,
gaussian_width=0.6,
lorentzian_width=0.7,
- unit='meV',
+ x_unit='meV',
)
@pytest.mark.parametrize(
@@ -164,11 +179,12 @@ def test_negative_area_warns(self):
def test_property_setters(
self, voigt: Voigt, prop, valid_value, invalid_value, invalid_message
):
- # set valid
+ # WHEN: set a valid value
setattr(voigt, prop, valid_value)
+ # THEN EXPECT
assert getattr(voigt, prop).value == valid_value
- # invalid
+ # WHEN: set an invalid value — THEN EXPECT
with pytest.raises(TypeError, match=invalid_message):
setattr(voigt, prop, invalid_value)
@@ -214,19 +230,19 @@ def test_center_is_fixed_if_init_to_None(self):
center=None,
gaussian_width=0.6,
lorentzian_width=0.7,
- unit='meV',
+ x_unit='meV',
)
# EXPECT
assert test_voigt.center.value == pytest.approx(0.0)
assert test_voigt.center.fixed is True
- def test_convert_unit(self, voigt: Voigt):
+ def test_convert_x_unit(self, voigt: Voigt):
# WHEN THEN
- voigt.convert_unit('microeV')
+ voigt.convert_x_unit('microeV')
# EXPECT
- assert voigt.unit == 'microeV'
+ assert voigt.x_unit == 'microeV'
assert voigt.area.value == pytest.approx(2 * 1e3)
assert voigt.center.value == pytest.approx(0.5 * 1e3)
assert voigt.gaussian_width.value == pytest.approx(0.6 * 1e3)
@@ -285,7 +301,7 @@ def test_copy(self, voigt: Voigt):
assert voigt_copy.lorentzian_width.value == voigt.lorentzian_width.value
assert voigt_copy.lorentzian_width.fixed == voigt.lorentzian_width.fixed
- assert voigt_copy.unit == voigt.unit
+ assert voigt_copy.x_unit == voigt.x_unit
def test_repr(self, voigt: Voigt):
# WHEN THEN
@@ -293,9 +309,102 @@ def test_repr(self, voigt: Voigt):
# EXPECT
assert 'Voigt' in repr_str
- assert "name='VoigtName'" in repr_str
- assert 'unit=meV' in repr_str
- assert 'area=' in repr_str
- assert 'center=' in repr_str
- assert 'gaussian_width=' in repr_str
- assert 'lorentzian_width=' in repr_str
+ assert 'name = VoigtName' in repr_str
+ assert 'x_unit = meV' in repr_str
+ assert 'area =' in repr_str
+ assert 'center =' in repr_str
+ assert 'gaussian_width =' in repr_str
+ assert 'lorentzian_width =' in repr_str
+
+ def test_y_unit_custom(self):
+ # WHEN THEN
+ voigt = Voigt(
+ area=1.0,
+ center=0.0,
+ gaussian_width=0.5,
+ lorentzian_width=0.3,
+ x_unit='meV',
+ y_unit='1/meV',
+ )
+ # EXPECT
+ assert voigt.y_unit == '1/meV'
+
+ def test_y_unit_setter_raises(self, voigt: Voigt):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError):
+ voigt.y_unit = '1/meV'
+
+ def test_convert_y_unit(self):
+ # WHEN: x_unit='meV', y_unit='1/meV' → area_unit='dimensionless'
+ voigt = Voigt(
+ area=1.0,
+ center=0.0,
+ gaussian_width=0.5,
+ lorentzian_width=0.3,
+ x_unit='meV',
+ y_unit='1/meV',
+ )
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ voigt.convert_y_unit('1/eV')
+ # EXPECT: y_unit updated and area value rescaled (1e3 factor)
+ assert voigt.y_unit == '1/eV'
+ assert voigt.area.value == pytest.approx(1e3)
+
+ def test_convert_y_unit_invalid_type_raises(self, voigt: Voigt):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError):
+ voigt.convert_y_unit(123)
+
+ def test_evaluate_scipp_output(self, voigt: Voigt):
+ # WHEN
+ x = np.linspace(-5, 5, 50)
+ # THEN
+ result = voigt.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('dimensionless')
+ assert len(result.values) == 50
+ np.testing.assert_allclose(result.values, voigt.evaluate(x, output='numpy'))
+
+ def test_evaluate_scipp_output_with_y_unit(self):
+ # WHEN
+ voigt = Voigt(
+ area=1.0,
+ center=0.0,
+ gaussian_width=0.5,
+ lorentzian_width=0.3,
+ x_unit='meV',
+ y_unit='1/meV',
+ )
+ x = np.linspace(-5, 5, 50)
+ # THEN
+ result = voigt.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('1/meV')
+
+ def test_convert_x_unit_invalid_type_raises(self, voigt: Voigt):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match=r'x_unit must be a string or sc\.Unit'):
+ voigt.convert_x_unit(123)
+
+ def test_convert_x_unit_rollback_on_failure(self, voigt: Voigt):
+ # WHEN THEN
+ with pytest.raises(UnitError):
+ voigt.convert_x_unit('m')
+ # EXPECT: state rolled back
+ assert voigt.x_unit == 'meV'
+ assert voigt.area.value == pytest.approx(2.0)
+ assert voigt.center.value == pytest.approx(0.5)
+ assert voigt.gaussian_width.value == pytest.approx(0.6)
+ assert voigt.lorentzian_width.value == pytest.approx(0.7)
+
+ def test_convert_y_unit_rollback_on_failure(self):
+ # WHEN
+ voigt = Voigt(area=1.0, center=0.0, gaussian_width=0.5, lorentzian_width=0.3, x_unit='meV')
+ # THEN
+ with pytest.raises(UnitError):
+ voigt.convert_y_unit('K')
+ # EXPECT: state rolled back
+ assert voigt.y_unit == 'dimensionless'
+ assert voigt.area.value == pytest.approx(1.0)
diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py
index 7941d4068..4b2137540 100644
--- a/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py
+++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py
@@ -25,16 +25,60 @@ def brownian_diffusion_model(self):
def test_init_default(self, brownian_diffusion_model):
# WHEN THEN EXPECT
assert brownian_diffusion_model.display_name == 'BrownianTranslationalDiffusion'
- assert brownian_diffusion_model.unit == 'meV'
+ assert brownian_diffusion_model.x_unit == 'meV'
+ assert brownian_diffusion_model.y_unit == 'dimensionless'
assert brownian_diffusion_model.scale.value == pytest.approx(1.0)
+ assert brownian_diffusion_model.scale.unit == 'meV'
assert brownian_diffusion_model.diffusion_coefficient.value == pytest.approx(1.0)
+ def test_convert_x_unit_rescales_widths(self):
+ # WHEN
+ model = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, Q=np.array([1.0]))
+ width_mev = model.calculate_width()[0]
+
+ # THEN
+ model.convert_x_unit('ueV')
+
+ # EXPECT: calculated width and the component width follow the new unit
+ assert model.calculate_width()[0] == pytest.approx(width_mev * 1000)
+ collection = model.get_component_collections(0)
+ assert sc.Unit(str(collection[0].width.unit)) == sc.Unit('ueV')
+ assert collection[0].width.value == pytest.approx(width_mev * 1000)
+ assert sc.Unit(str(model.scale.unit)) == sc.Unit('ueV')
+
+ def test_convert_x_unit_converts_collections_in_place(self):
+ # WHEN
+ model = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, Q=np.array([1.0]))
+ collection_before = model.get_component_collections(0)
+
+ # THEN
+ model.convert_x_unit('ueV')
+
+ # EXPECT: conversion does not regenerate the collections (regression: rebuilding
+ # replaced the component objects, breaking external references)
+ assert model.get_component_collections(0) is collection_before
+
+ def test_convert_x_unit_persists_after_dependency_update(self):
+ # WHEN
+ model = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, Q=np.array([1.0]))
+ width = model.get_component_collections(0)[0].width
+ model.convert_x_unit('ueV')
+ width_uev = width.value
+
+ # THEN: trigger a dependency-graph recompute
+ model.diffusion_coefficient = 4.8e-9
+
+ # EXPECT: the dependent width stays in the new unit (regression: a plain convert_unit
+ # was reverted to the old desired unit by the next graph update)
+ assert sc.Unit(str(width.unit)) == sc.Unit('ueV')
+ assert width.value == pytest.approx(2 * width_uev)
+
@pytest.mark.parametrize(
'kwargs,expected_exception, expected_message',
[
(
{
- 'unit': 123,
+ 'x_unit': 123,
'scale': 1.0,
'diffusion_coefficient': 1.0,
},
@@ -43,7 +87,16 @@ def test_init_default(self, brownian_diffusion_model):
),
(
{
- 'unit': 'meV',
+ 'y_unit': 123,
+ 'scale': 1.0,
+ 'diffusion_coefficient': 1.0,
+ },
+ TypeError,
+ None,
+ ),
+ (
+ {
+ 'x_unit': 'meV',
'scale': 'invalid',
'diffusion_coefficient': 1.0,
},
@@ -52,7 +105,7 @@ def test_init_default(self, brownian_diffusion_model):
),
(
{
- 'unit': 'meV',
+ 'x_unit': 'meV',
'scale': -123.4,
'diffusion_coefficient': 1.0,
},
@@ -61,7 +114,7 @@ def test_init_default(self, brownian_diffusion_model):
),
(
{
- 'unit': 'meV',
+ 'x_unit': 'meV',
'scale': 1.0,
'diffusion_coefficient': 'invalid',
},
@@ -70,7 +123,7 @@ def test_init_default(self, brownian_diffusion_model):
),
(
{
- 'unit': 'meV',
+ 'x_unit': 'meV',
'scale': 1.0,
'diffusion_coefficient': -123.4,
},
@@ -78,6 +131,14 @@ def test_init_default(self, brownian_diffusion_model):
'diffusion_coefficient must be non-negative',
),
],
+ ids=[
+ 'invalid_x_unit',
+ 'invalid_y_unit',
+ 'invalid_scale_type',
+ 'invalid_scale_negative',
+ 'invalid_diffusion_coefficient_type',
+ 'invalid_diffusion_coefficient_negative',
+ ],
)
def test_input_type_validation_raises(self, kwargs, expected_exception, expected_message):
with pytest.raises(expected_exception, match=expected_message):
@@ -123,6 +184,17 @@ def test_calculate_width(self, brownian_diffusion_model):
expected_widths = 1.0 * unit_conversion_factor.value * (Q_values**2)
np.testing.assert_allclose(widths, expected_widths, rtol=1e-5)
+ def test_calculate_width_scipp_Q_converts_unit(self, brownian_diffusion_model):
+ # WHEN: the same Q expressed in 1/angstrom (numpy, assumed) and in 1/nm (scipp)
+ Q_values = np.array([0.1, 0.2, 0.3])
+ Q_scipp = sc.Variable(dims=['Q'], values=Q_values * 10, unit='1/nm')
+
+ # THEN EXPECT: scipp input is converted to 1/angstrom before the width calculation
+ np.testing.assert_allclose(
+ brownian_diffusion_model.calculate_width(Q_scipp),
+ brownian_diffusion_model.calculate_width(Q_values),
+ )
+
def test_calculate_EISF(self, brownian_diffusion_model):
# WHEN
Q_values = np.array([0.1, 0.2, 0.3]) # Example Q values in Å^-1
@@ -180,9 +252,11 @@ def test_create_component_collections(self, brownian_diffusion_model, Q):
model = component_collections[model_index]
assert len(model) == 1
component = model[0]
- assert component.width.unit == brownian_diffusion_model.unit
+ assert component.width.unit == brownian_diffusion_model.x_unit
assert np.isclose(component.width.value, expected_widths[model_index])
assert component.width.independent is False
+ # area.unit = area_unit = x_unit * y_unit
+ assert component.area.unit == 'meV'
def test_write_width_dependency_expression(self, brownian_diffusion_model):
# WHEN THEN
@@ -213,6 +287,11 @@ def test_write_area_dependency_expression_raises(self, brownian_diffusion_model)
with pytest.raises(TypeError, match='QISF must be a float'):
brownian_diffusion_model._write_area_dependency_expression('invalid')
+ def test_y_unit_setter_raises(self, brownian_diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError, match=r'read-only'):
+ brownian_diffusion_model.y_unit = '1/meV'
+
def test_repr(self, brownian_diffusion_model):
# WHEN THEN
repr_str = repr(brownian_diffusion_model)
@@ -221,3 +300,5 @@ def test_repr(self, brownian_diffusion_model):
assert 'BrownianTranslationalDiffusion' in repr_str
assert 'diffusion_coefficient' in repr_str
assert 'scale=' in repr_str
+ # Regression: a stray ')' used to mangle this into 'x_unit=meV), y_unit=...'
+ assert 'x_unit=meV, y_unit=dimensionless' in repr_str
diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_delta_lorentz.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_delta_lorentz.py
index b3740db66..1e69bcec7 100644
--- a/tests/unit/easydynamics/sample_model/diffusion_model/test_delta_lorentz.py
+++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_delta_lorentz.py
@@ -3,6 +3,7 @@
import numpy as np
import pytest
+import scipp as sc
from easyscience.variable import Parameter
from easydynamics.sample_model.components.delta_function import DeltaFunction
@@ -15,6 +16,57 @@ class TestDeltaLorentz:
def delta_lorentz_model(self):
return DeltaLorentz()
+ def test_mean_u_squared_unit_conversion_leaves_physics_unchanged(self):
+ # WHEN: a model with a non-zero Debye-Waller factor
+ model = DeltaLorentz(A_0=0.5, mean_u_squared=1.0, lorentzian_width=0.1, Q=np.array([1.0]))
+ eisf_before = model.calculate_EISF()
+ qisf_before = model.calculate_QISF()
+ delta_area_before = model.get_component_collections(0)[1].area.value
+
+ # THEN: convert mean_u_squared to nm**2 (same physical value, different number)
+ model.mean_u_squared.convert_unit('nm**2')
+
+ # EXPECT: EISF, QISF, and the dependent delta area are unchanged
+ np.testing.assert_allclose(model.calculate_EISF(), eisf_before)
+ np.testing.assert_allclose(model.calculate_QISF(), qisf_before)
+ assert model.get_component_collections(0)[1].area.value == pytest.approx(delta_area_before)
+
+ def test_convert_x_unit_converts_lorentzian_width(self):
+ # WHEN
+ model = DeltaLorentz(lorentzian_width=0.1, Q=np.array([1.0]))
+
+ # THEN
+ model.convert_x_unit('ueV')
+
+ # EXPECT: the width template is rescaled and the regenerated component follows
+ assert sc.Unit(str(model.lorentzian_width.unit)) == sc.Unit('ueV')
+ assert model.lorentzian_width.value == pytest.approx(100.0)
+ collection = model.get_component_collections(0)
+ assert sc.Unit(str(collection[0].width.unit)) == sc.Unit('ueV')
+ assert collection[0].width.value == pytest.approx(100.0)
+
+ def test_convert_x_unit_with_q_varying_width_preserves_state(self):
+ # WHEN: a model with per-Q widths, one of which has been changed from the template
+ model = DeltaLorentz(
+ lorentzian_width=0.1,
+ Q=np.array([1.0, 2.0]),
+ allow_Q_variation={'lorentzian_width': True},
+ )
+ model._lorentzian_width_list[0].value = 0.2
+ collection_before = model.get_component_collections(0)
+
+ # THEN
+ model.convert_x_unit('ueV')
+
+ # EXPECT: conversion happens in place — per-Q values are converted, not reset to the
+ # template, and the collections are not regenerated (regression: conversion used to
+ # rebuild the collections, discarding per-Q state)
+ assert model.get_component_collections(0) is collection_before
+ assert model._lorentzian_width_list[0].value == pytest.approx(200.0)
+ assert model._lorentzian_width_list[1].value == pytest.approx(100.0)
+ for width in model._lorentzian_width_list:
+ assert sc.Unit(str(width.unit)) == sc.Unit('ueV')
+
@pytest.fixture
def delta_lorentz_model_with_Q(self):
Q = np.linspace(0.5, 2, 7)
@@ -38,16 +90,22 @@ def delta_lorentz_model_with_Q_no_variation(self):
def test_init_default(self, delta_lorentz_model):
# WHEN THEN EXPECT
assert delta_lorentz_model.display_name == 'DeltaLorentz'
- assert delta_lorentz_model.unit == 'meV'
+ assert delta_lorentz_model.x_unit == 'meV'
+ assert delta_lorentz_model.y_unit == 'dimensionless'
assert delta_lorentz_model.scale.value == pytest.approx(1.0)
assert delta_lorentz_model.mean_u_squared.value == pytest.approx(0.0)
assert delta_lorentz_model.A_0.value == pytest.approx(1.0)
assert delta_lorentz_model.lorentzian_width.value == pytest.approx(1.0)
+ def test_y_unit_setter_raises(self, delta_lorentz_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError, match=r'read-only'):
+ delta_lorentz_model.y_unit = '1/meV'
+
def test_init_with_Q(self, delta_lorentz_model_with_Q):
# WHEN THEN EXPECT
assert delta_lorentz_model_with_Q.display_name == 'DeltaLorentz'
- assert delta_lorentz_model_with_Q.unit == 'meV'
+ assert delta_lorentz_model_with_Q.x_unit == 'meV'
assert delta_lorentz_model_with_Q.scale.value == pytest.approx(1.0)
assert delta_lorentz_model_with_Q.mean_u_squared.value == pytest.approx(0.0)
assert delta_lorentz_model_with_Q.A_0.value == pytest.approx(0.5)
@@ -404,7 +462,7 @@ def test_calculate_EISF_with_Q(self, delta_lorentz_model_with_Q):
for i in range(len(eisf)):
expected = delta_lorentz_model_with_Q._A_0_list[i].value * np.exp(
-delta_lorentz_model_with_Q.mean_u_squared.value
- * delta_lorentz_model_with_Q.Q[i] ** 2
+ * delta_lorentz_model_with_Q.Q.values[i] ** 2
)
assert eisf[i] == pytest.approx(expected)
@@ -428,7 +486,7 @@ def test_calculate_QISF_with_Q(self, delta_lorentz_model_with_Q):
for i in range(len(qisf)):
expected = delta_lorentz_model_with_Q._A_1_list[i].value * np.exp(
-delta_lorentz_model_with_Q.mean_u_squared.value
- * delta_lorentz_model_with_Q.Q[i] ** 2
+ * delta_lorentz_model_with_Q.Q.values[i] ** 2
)
assert qisf[i] == pytest.approx(expected)
@@ -459,7 +517,8 @@ def test_create_component_collections_with_Q_variation(self, delta_lorentz_model
assert collection[0].width.independent is True
assert collection[0].area.independent is False
assert (
- 'scale * exp(-mean_u_squared.value *' in collection[0].area.dependency_expression
+ 'scale * exp(-mean_u_squared_ratio.value *'
+ in collection[0].area.dependency_expression
)
assert 'A_1' in collection[0].area.dependency_expression
@@ -467,7 +526,8 @@ def test_create_component_collections_with_Q_variation(self, delta_lorentz_model
assert isinstance(collection[1], DeltaFunction)
assert collection[1].area.independent is False
assert (
- 'scale * exp(-mean_u_squared.value *' in collection[1].area.dependency_expression
+ 'scale * exp(-mean_u_squared_ratio.value *'
+ in collection[1].area.dependency_expression
)
assert 'A_0' in collection[1].area.dependency_expression
@@ -491,7 +551,8 @@ def test_create_component_collections_with_no_Q_variation(
assert collection[0].width.dependency_expression == 'lorentzian_width'
assert collection[0].area.independent is False
assert (
- 'scale * exp(-mean_u_squared.value *' in collection[0].area.dependency_expression
+ 'scale * exp(-mean_u_squared_ratio.value *'
+ in collection[0].area.dependency_expression
)
assert 'A_1' in collection[0].area.dependency_expression
@@ -499,16 +560,17 @@ def test_create_component_collections_with_no_Q_variation(
assert isinstance(collection[1], DeltaFunction)
assert collection[1].area.independent is False
assert (
- 'scale * exp(-mean_u_squared.value *' in collection[1].area.dependency_expression
+ 'scale * exp(-mean_u_squared_ratio.value *'
+ in collection[1].area.dependency_expression
)
assert 'A_0' in collection[1].area.dependency_expression
@pytest.mark.parametrize(
- 'Q_index',
+ ('Q_index', 'expected_exception', 'expected_message'),
[
- -1,
- 100,
- 'string',
+ (-1, IndexError, r'Q_index -1 is out of bounds'),
+ (100, IndexError, r'Q_index 100 is out of bounds'),
+ ('string', TypeError, r'Q_index must be an int or None, got str'),
],
ids=[
'negative index',
@@ -516,9 +578,11 @@ def test_create_component_collections_with_no_Q_variation(
'non-integer index',
],
)
- def test_get_independent_variables_raises(self, delta_lorentz_model_with_Q, Q_index):
+ def test_get_independent_variables_raises(
+ self, delta_lorentz_model_with_Q, Q_index, expected_exception, expected_message
+ ):
# WHEN THEN EXPECT
- with pytest.raises(ValueError, match=r'Q_index must be an integer between 0 and'):
+ with pytest.raises(expected_exception, match=expected_message):
delta_lorentz_model_with_Q.get_independent_variables(Q_index=Q_index)
def test_get_all_variables_no_Q_index(self, delta_lorentz_model_with_Q):
@@ -578,10 +642,10 @@ def test_get_all_variables_with_Q_index_no_Q_variation(
def test_get_all_variables_invalid_Q_index(self, delta_lorentz_model_with_Q):
# WHEN THEN EXPECT
- with pytest.raises(ValueError, match='Q_index must be an integer between 0 and'):
+ with pytest.raises(IndexError, match='Q_index -1 is out of bounds'):
delta_lorentz_model_with_Q.get_all_variables(Q_index=-1)
- with pytest.raises(ValueError, match='Q_index must be an integer between 0 and'):
+ with pytest.raises(IndexError, match=r'Q_index \d+ is out of bounds'):
delta_lorentz_model_with_Q.get_all_variables(Q_index=len(delta_lorentz_model_with_Q.Q))
@pytest.mark.parametrize(
@@ -856,3 +920,22 @@ def test_repr(self, delta_lorentz_model):
assert 'mean_u_squared' in repr_str
assert 'A_0' in repr_str
assert 'lorentzian_width' in repr_str
+ # Regression: a stray ')' used to mangle this into 'x_unit=meV), y_unit=...'
+ assert 'x_unit=meV, y_unit=dimensionless' in repr_str
+
+ # ───── Regression tests ─────
+
+ def test_calculate_width_raises_after_clear_Q_when_allow_Q_variation(
+ self, delta_lorentz_model_with_Q
+ ):
+ # WHEN: model with Q-variation enabled for lorentzian_width
+ assert delta_lorentz_model_with_Q._allow_Q_variation['lorentzian_width'] is True
+ assert len(delta_lorentz_model_with_Q._lorentzian_width_list) > 0
+
+ # THEN: clear Q (empties _lorentzian_width_list)
+ delta_lorentz_model_with_Q.clear_Q(confirm=True)
+ assert len(delta_lorentz_model_with_Q._lorentzian_width_list) == 0
+
+ # THEN: before the fix, calculate_width() silently returned [] instead of raising.
+ with pytest.raises(ValueError, match='Lorentzian width Q-variation list is empty'):
+ delta_lorentz_model_with_Q.calculate_width()
diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model_base.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model_base.py
index 3364e5104..d3d56c3b6 100644
--- a/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model_base.py
+++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model_base.py
@@ -3,7 +3,9 @@
import numpy as np
import pytest
+import scipp as sc
from easyscience.variable.parameter import Parameter
+from scipp import UnitError
from easydynamics.sample_model.diffusion_model.diffusion_model_base import DiffusionModelBase
@@ -19,7 +21,10 @@ def test_init_default(self, diffusion_model):
assert diffusion_model.name == 'DiffusionModel'
assert diffusion_model.lorentzian_name == 'DiffusionModel'
assert diffusion_model.lorentzian_display_name == 'DiffusionModel'
- assert diffusion_model.unit == 'meV'
+ assert diffusion_model.x_unit == 'meV'
+ assert diffusion_model.y_unit == 'dimensionless'
+ # scale.unit = area_unit = x_unit * y_unit = meV * dimensionless = meV
+ assert diffusion_model.scale.unit == 'meV'
def test_init_raises(self):
# WHEN THEN EXPECT
@@ -29,13 +34,66 @@ def test_init_raises(self):
with pytest.raises(TypeError, match=r'lorentzian_display_name must be a string or None'):
DiffusionModelBase(lorentzian_display_name=123)
- def test_unit_setter_raises(self, diffusion_model):
+ def test_x_unit_setter_raises(self, diffusion_model):
# WHEN THEN EXPECT
with pytest.raises(
AttributeError,
- match=r'Unit is read-only. Use convert_unit to change the unit between allowed types',
+ match=r'read-only',
):
- diffusion_model.unit = 'eV'
+ diffusion_model.x_unit = 'eV'
+
+ def test_y_unit_setter_raises(self, diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError, match=r'read-only'):
+ diffusion_model.y_unit = '1/meV'
+
+ def test_init_accepts_scipp_units(self):
+ # WHEN THEN
+ model = DiffusionModelBase(x_unit=sc.Unit('meV'), y_unit=sc.Unit('1/meV'))
+
+ # EXPECT: scale.unit = x_unit * y_unit = dimensionless
+ assert str(model.x_unit) == 'meV'
+ assert sc.Unit(str(model.scale.unit)) == sc.Unit('dimensionless')
+
+ def test_convert_x_unit(self, diffusion_model):
+ # WHEN THEN
+ diffusion_model.convert_x_unit('ueV')
+
+ # EXPECT: x_unit and the scale unit (x_unit * y_unit) follow
+ assert sc.Unit(diffusion_model.x_unit) == sc.Unit('ueV')
+ assert sc.Unit(str(diffusion_model.scale.unit)) == sc.Unit('ueV')
+
+ def test_convert_x_unit_invalid_type_raises(self, diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match=r'x_unit must be a string or sc.Unit'):
+ diffusion_model.convert_x_unit(123)
+
+ @pytest.mark.parametrize('unit', ['m', '1/ps'], ids=['length', 'frequency'])
+ def test_convert_x_unit_non_energy_raises(self, diffusion_model, unit):
+ # WHEN THEN EXPECT: only energy axes are supported for now
+ with pytest.raises(UnitError, match=r'convertible to meV'):
+ diffusion_model.convert_x_unit(unit)
+
+ def test_convert_y_unit(self):
+ # WHEN
+ model = DiffusionModelBase(y_unit='1/meV')
+
+ # THEN
+ model.convert_y_unit('1/eV')
+
+ # EXPECT: y_unit and the scale unit follow (meV * 1/eV)
+ assert sc.Unit(model.y_unit) == sc.Unit('1/eV')
+ assert sc.Unit(str(model.scale.unit)) == sc.Unit('meV/eV')
+
+ def test_convert_y_unit_invalid_type_raises(self, diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match=r'y_unit must be a string or sc.Unit'):
+ diffusion_model.convert_y_unit(123)
+
+ def test_convert_y_unit_incompatible_raises(self, diffusion_model):
+ # WHEN THEN EXPECT: dimensionless -> K changes the dimension of the scale unit
+ with pytest.raises(UnitError):
+ diffusion_model.convert_y_unit('K')
@pytest.mark.parametrize(
('attribute', 'value', 'expected'),
@@ -131,7 +189,9 @@ def test_Q_property(self, diffusion_model):
diffusion_model.Q = [1.0, 2.0, 3.0]
# EXPECT
- np.testing.assert_allclose(diffusion_model.Q, [1.0, 2.0, 3.0])
+ assert isinstance(diffusion_model.Q, sc.Variable)
+ assert diffusion_model.Q.unit == sc.Unit('1/angstrom')
+ np.testing.assert_allclose(diffusion_model.Q.values, [1.0, 2.0, 3.0])
# THEN EXPECT
with pytest.raises(ValueError, match=r'New Q values are not similar to the old ones'):
@@ -147,13 +207,24 @@ def test_Q_property(self, diffusion_model):
# EXPECT
assert diffusion_model.Q is None
+ def test_Q_setter_accepts_equivalent_scipp_Q_in_other_unit(self, diffusion_model):
+ # WHEN
+ diffusion_model.Q = [1.0, 2.0, 3.0]
+
+ # THEN: the same Q in 1/nm is equivalent after conversion, so no error is raised
+ diffusion_model.Q = sc.Variable(dims=['Q'], values=[10.0, 20.0, 30.0], unit='1/nm')
+
+ # EXPECT
+ np.testing.assert_allclose(diffusion_model.Q.values, [1.0, 2.0, 3.0])
+
def test_repr(self, diffusion_model):
# WHEN THEN
repr_str = repr(diffusion_model)
# EXPECT
assert 'DiffusionModelBase' in repr_str
- assert 'unit=meV' in repr_str
+ # Regression: a stray ')' used to mangle this into 'x_unit=meV), y_unit=...'
+ assert 'x_unit=meV, y_unit=dimensionless' in repr_str
def test_get_independent_variables(self, diffusion_model):
# WHEN THEN EXPECT
@@ -163,11 +234,11 @@ def test_get_independent_variables(self, diffusion_model):
assert independent_vars == []
@pytest.mark.parametrize(
- 'Q_index',
+ ('Q_index', 'expected_exception', 'expected_message'),
[
- -1,
- 100,
- 'string',
+ (-1, ValueError, 'Q is None'),
+ (100, ValueError, 'Q is None'),
+ ('string', TypeError, 'Q_index must be an int or None, got str'),
],
ids=[
'negative index',
@@ -175,17 +246,19 @@ def test_get_independent_variables(self, diffusion_model):
'non-integer index',
],
)
- def test_get_independent_variables_raises(self, diffusion_model, Q_index):
+ def test_get_independent_variables_raises(
+ self, diffusion_model, Q_index, expected_exception, expected_message
+ ):
# WHEN THEN EXPECT
- with pytest.raises(ValueError, match=r'Q_index must be an integer between 0 and'):
+ with pytest.raises(expected_exception, match=expected_message):
diffusion_model.get_independent_variables(Q_index=Q_index)
@pytest.mark.parametrize(
- 'Q_index',
+ ('Q_index', 'expected_exception', 'expected_message'),
[
- -1,
- 100,
- 'string',
+ (-1, ValueError, 'Q is None'),
+ (100, ValueError, 'Q is None'),
+ ('string', TypeError, 'Q_index must be an int or None, got str'),
],
ids=[
'negative index',
@@ -193,9 +266,11 @@ def test_get_independent_variables_raises(self, diffusion_model, Q_index):
'non-integer index',
],
)
- def test_get_all_variables_raises(self, diffusion_model, Q_index):
+ def test_get_all_variables_raises(
+ self, diffusion_model, Q_index, expected_exception, expected_message
+ ):
# WHEN THEN EXPECT
- with pytest.raises(ValueError, match=r'Q_index must be an integer between 0 and'):
+ with pytest.raises(expected_exception, match=expected_message):
diffusion_model.get_all_variables(Q_index=Q_index)
def test_create_component_collections_no_Q(self, diffusion_model):
diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_jump_translational_diffusion.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_jump_translational_diffusion.py
index 014a221f7..7fcf4dacd 100644
--- a/tests/unit/easydynamics/sample_model/diffusion_model/test_jump_translational_diffusion.py
+++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_jump_translational_diffusion.py
@@ -25,8 +25,10 @@ def jump_diffusion_model(self):
def test_init_default(self, jump_diffusion_model):
# WHEN THEN EXPECT
assert jump_diffusion_model.display_name == 'JumpTranslationalDiffusion'
- assert jump_diffusion_model.unit == 'meV'
+ assert jump_diffusion_model.x_unit == 'meV'
+ assert jump_diffusion_model.y_unit == 'dimensionless'
assert jump_diffusion_model.scale.value == pytest.approx(1.0)
+ assert jump_diffusion_model.scale.unit == 'meV'
assert jump_diffusion_model.diffusion_coefficient.value == pytest.approx(1.0)
assert jump_diffusion_model.relaxation_time.value == pytest.approx(1.0)
@@ -35,7 +37,7 @@ def test_init_default(self, jump_diffusion_model):
[
(
{
- 'unit': 123,
+ 'x_unit': 123,
'scale': 1.0,
'diffusion_coefficient': 1.0,
'relaxation_time': 1.0,
@@ -45,7 +47,17 @@ def test_init_default(self, jump_diffusion_model):
),
(
{
- 'unit': 'meV',
+ 'y_unit': 123,
+ 'scale': 1.0,
+ 'diffusion_coefficient': 1.0,
+ 'relaxation_time': 1.0,
+ },
+ TypeError,
+ None,
+ ),
+ (
+ {
+ 'x_unit': 'meV',
'scale': 'invalid',
'diffusion_coefficient': 1.0,
'relaxation_time': 1.0,
@@ -55,7 +67,7 @@ def test_init_default(self, jump_diffusion_model):
),
(
{
- 'unit': 'meV',
+ 'x_unit': 'meV',
'scale': 1.0,
'diffusion_coefficient': 'invalid',
'relaxation_time': 1.0,
@@ -65,7 +77,7 @@ def test_init_default(self, jump_diffusion_model):
),
(
{
- 'unit': 'meV',
+ 'x_unit': 'meV',
'scale': 1.0,
'diffusion_coefficient': -1.0,
'relaxation_time': 1.0,
@@ -75,7 +87,7 @@ def test_init_default(self, jump_diffusion_model):
),
(
{
- 'unit': 'meV',
+ 'x_unit': 'meV',
'scale': 1.0,
'diffusion_coefficient': 1.0,
'relaxation_time': 'invalid',
@@ -85,7 +97,7 @@ def test_init_default(self, jump_diffusion_model):
),
(
{
- 'unit': 'meV',
+ 'x_unit': 'meV',
'scale': 1.0,
'diffusion_coefficient': 1.0,
'relaxation_time': -1.0,
@@ -94,6 +106,15 @@ def test_init_default(self, jump_diffusion_model):
'relaxation_time must be non-negative',
),
],
+ ids=[
+ 'invalid_x_unit',
+ 'invalid_y_unit',
+ 'invalid_scale_type',
+ 'invalid_diffusion_coefficient_type',
+ 'invalid_diffusion_coefficient_negative',
+ 'invalid_relaxation_time_type',
+ 'invalid_relaxation_time_negative',
+ ],
)
def test_input_type_validation_raises(self, kwargs, expected_exception, expected_message):
with pytest.raises(expected_exception, match=expected_message):
@@ -159,7 +180,7 @@ def test_calculate_width(self, jump_diffusion_model):
# EXPECT
expected_widths = scipp_hbar * diffusion_coefficient_sc * (Q_values**2) / (1 + denominator)
- expected_widths = expected_widths.to(unit=jump_diffusion_model.unit)
+ expected_widths = expected_widths.to(unit=jump_diffusion_model.x_unit)
np.testing.assert_allclose(widths, expected_widths.values, rtol=1e-5)
@@ -221,9 +242,11 @@ def test_create_component_collections(self, jump_diffusion_model, Q):
model = component_collections[model_index]
assert len(model) == 1
component = model[0]
- assert component.width.unit == jump_diffusion_model.unit
+ assert component.width.unit == jump_diffusion_model.x_unit
assert np.isclose(component.width.value, expected_widths[model_index])
assert component.width.independent is False
+ # area.unit = area_unit = x_unit * y_unit
+ assert component.area.unit == 'meV'
def test_write_width_dependency_expression(self, jump_diffusion_model):
# WHEN THEN
@@ -257,6 +280,11 @@ def test_write_area_dependency_expression_raises(self, jump_diffusion_model):
with pytest.raises(TypeError, match='QISF must be a float'):
jump_diffusion_model._write_area_dependency_expression('invalid')
+ def test_y_unit_setter_raises(self, jump_diffusion_model):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError, match=r'read-only'):
+ jump_diffusion_model.y_unit = '1/meV'
+
def test_repr(self, jump_diffusion_model):
# WHEN THEN
repr_str = repr(jump_diffusion_model)
@@ -265,3 +293,5 @@ def test_repr(self, jump_diffusion_model):
assert 'JumpTranslationalDiffusion' in repr_str
assert 'diffusion_coefficient' in repr_str
assert 'scale=' in repr_str
+ # Regression: a stray ')' used to mangle this into 'x_unit=meV), y_unit=...'
+ assert 'x_unit=meV, y_unit=dimensionless' in repr_str
diff --git a/tests/unit/easydynamics/sample_model/test_background_model.py b/tests/unit/easydynamics/sample_model/test_background_model.py
index a698a1bf0..ff1233d84 100644
--- a/tests/unit/easydynamics/sample_model/test_background_model.py
+++ b/tests/unit/easydynamics/sample_model/test_background_model.py
@@ -18,14 +18,14 @@ def background_model(self):
area=1.0,
center=0.0,
width=1.0,
- unit='meV',
+ x_unit='meV',
)
component2 = Lorentzian(
display_name='TestLorentzian1',
area=2.0,
center=1.0,
width=0.5,
- unit='meV',
+ x_unit='meV',
)
component_collection = ComponentCollection()
component_collection.append_component(component1)
@@ -33,19 +33,19 @@ def background_model(self):
return BackgroundModel(
display_name='InitModel',
components=component_collection,
- unit='meV',
+ x_unit='meV',
Q=np.array([1.0, 2.0, 3.0]),
)
def test_init(self, background_model):
# WHEN THEN
- model = background_model
# EXPECT
- assert model.display_name == 'InitModel'
- assert model.unit == 'meV'
- assert len(model.components) == 2
- np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0]))
+ assert background_model.display_name == 'InitModel'
+ assert background_model.x_unit == 'meV'
+ assert background_model.y_unit == 'dimensionless'
+ assert len(background_model.components) == 2
+ np.testing.assert_array_equal(background_model.Q.values, np.array([1.0, 2.0, 3.0]))
@pytest.mark.parametrize(
'invalid_component, expected_error_msg',
@@ -80,3 +80,19 @@ def test_init_raises_with_invalid_components(self, invalid_component, expected_e
collection = ComponentCollection()
collection.append_component(invalid_component)
BackgroundModel(components=collection)
+
+ def test_y_unit_setter_raises(self, background_model):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(AttributeError):
+ background_model.y_unit = '1/meV'
+
+ def test_convert_y_unit(self):
+ # WHEN
+ g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV')
+ model = BackgroundModel(components=g, x_unit='meV')
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ model.convert_y_unit('1/eV')
+ # EXPECT
+ assert model.y_unit == '1/eV'
+ assert model.components[0].y_unit == '1/eV'
+ assert g.area.value == pytest.approx(1e3)
diff --git a/tests/unit/easydynamics/sample_model/test_component_collection.py b/tests/unit/easydynamics/sample_model/test_component_collection.py
index 7ae15ed65..d72b84ce8 100644
--- a/tests/unit/easydynamics/sample_model/test_component_collection.py
+++ b/tests/unit/easydynamics/sample_model/test_component_collection.py
@@ -5,10 +5,12 @@
import numpy as np
import pytest
+import scipp as sc
from easyscience.variable import Parameter
from scipy.integrate import simpson
from easydynamics.sample_model import ComponentCollection
+from easydynamics.sample_model import ExpressionComponent
from easydynamics.sample_model import Gaussian
from easydynamics.sample_model import Lorentzian
from easydynamics.sample_model import Polynomial
@@ -24,7 +26,7 @@ def component_collection(self):
area=1.0,
center=0.0,
width=1.0,
- unit='meV',
+ x_unit='meV',
)
component2 = Lorentzian(
name='TestLorentzian1Name',
@@ -32,7 +34,7 @@ def component_collection(self):
area=2.0,
center=1.0,
width=0.5,
- unit='meV',
+ x_unit='meV',
)
model.append_component(component1)
model.append_component(component2)
@@ -45,10 +47,12 @@ def test_init(self):
# EXPECT
assert component_collection.display_name == 'InitModel'
assert not component_collection
+ assert component_collection.x_unit == 'meV'
+ assert component_collection.y_unit == 'dimensionless'
def test_init_with_component(self):
# WHEN THEN
- component1 = Gaussian(name='TestGaussian1', area=1.0, center=0.0, width=1.0, unit='meV')
+ component1 = Gaussian(name='TestGaussian1', area=1.0, center=0.0, width=1.0, x_unit='meV')
component_collection = ComponentCollection(display_name='InitModel', components=component1)
# EXPECT
@@ -58,9 +62,9 @@ def test_init_with_component(self):
def test_init_with_components(self):
# WHEN THEN
- component1 = Gaussian(name='TestGaussian1', area=1.0, center=0.0, width=1.0, unit='meV')
+ component1 = Gaussian(name='TestGaussian1', area=1.0, center=0.0, width=1.0, x_unit='meV')
component2 = Lorentzian(
- name='TestLorentzian1', area=2.0, center=1.0, width=0.5, unit='meV'
+ name='TestLorentzian1', area=2.0, center=1.0, width=0.5, x_unit='meV'
)
component_collection = ComponentCollection(
display_name='InitModel', components=[component1, component2]
@@ -91,13 +95,13 @@ def test_init_with_invalid_list_of_components_raises(self):
def test_init_with_invalid_unit_raises(self):
# WHEN THEN EXPECT
with pytest.raises(TypeError, match='unit must be'):
- ComponentCollection(unit=123)
+ ComponentCollection(x_unit=123)
# ───── Component Management ─────
def test_append_component(self, component_collection):
# WHEN
- component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, unit='meV')
+ component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, x_unit='meV')
# THEN
component_collection.append_component(component)
# EXPECT
@@ -105,7 +109,7 @@ def test_append_component(self, component_collection):
def test_append_component_collection(self, component_collection):
# WHEN
- component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, unit='meV')
+ component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, x_unit='meV')
component_collection2 = ComponentCollection()
component_collection2.append_component(component)
# THEN
@@ -127,7 +131,7 @@ def test_append_invalid_component_raises(self, component_collection):
def test_getitem(self, component_collection):
# WHEN
- component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, unit='meV')
+ component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, x_unit='meV')
# THEN
component_collection.append_component(component)
# EXPECT
@@ -140,7 +144,7 @@ def test_is_empty(self):
assert component_collection.is_empty is True
# WHEN THEN
- component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, unit='meV')
+ component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, x_unit='meV')
component_collection.append_component(component)
# EXPECT
assert component_collection.is_empty is False
@@ -158,68 +162,100 @@ def test_list_component_names(self, component_collection):
assert components[0] == 'TestGaussian1Name'
assert components[1] == 'TestLorentzian1Name'
- def test_convert_unit(self, component_collection):
+ def test_convert_x_unit(self, component_collection):
# WHEN THEN
- component_collection.convert_unit('eV')
+ component_collection.convert_x_unit('eV')
# EXPECT
for component in component_collection:
- assert component.unit == 'eV'
+ assert component.x_unit == 'eV'
- def test_convert_unit_incorrect_unit_raises(self, component_collection):
+ def test_convert_x_unit_incorrect_unit_raises(self, component_collection):
# WHEN THEN EXPECT
- with pytest.raises(TypeError, match=r'Unit must be a string or sc.Unit'):
- component_collection.convert_unit(123)
+ with pytest.raises(TypeError, match=r'unit must be a string or sc.Unit'):
+ component_collection.convert_x_unit(123)
- def test_convert_unit_failure_rolls_back(self, component_collection):
+ def test_convert_x_unit_failure_rolls_back(self, component_collection):
# WHEN THEN
# Introduce a faulty component that will fail conversion
class FaultyComponent(Gaussian):
- def convert_unit(self, _unit: str) -> None:
+ def convert_x_unit(self, _unit: str) -> None:
raise RuntimeError('Conversion failed.')
faulty_component = FaultyComponent(
- name='FaultyComponent', area=1.0, center=0.0, width=1.0, unit='meV'
+ name='FaultyComponent', area=1.0, center=0.0, width=1.0, x_unit='meV'
)
component_collection.append_component(faulty_component)
- original_units = {component.name: component.unit for component in component_collection}
+ original_units = {component.name: component.x_unit for component in component_collection}
# EXPECT
with pytest.raises(RuntimeError, match=r'Conversion failed.'):
- component_collection.convert_unit('eV')
+ component_collection.convert_x_unit('eV')
# Check that all components have their original units
for component in component_collection:
- assert component.unit == original_units[component.name]
+ assert component.x_unit == original_units[component.name]
- def test_set_unit(self, component_collection):
+ def test_set_x_unit(self, component_collection):
# WHEN THEN EXPECT
with pytest.raises(
AttributeError,
- match=r'Unit is read-only. Use convert_unit to change the unit',
+ match=r'read-only',
):
- component_collection.unit = 'eV'
+ component_collection.x_unit = 'eV'
def test_evaluate(self, component_collection):
# WHEN
x = np.linspace(-5, 5, 100)
+
+ # THEN
result = component_collection.evaluate(x)
# EXPECT
expected_result = component_collection[0].evaluate(x) + component_collection[1].evaluate(x)
np.testing.assert_allclose(result, expected_result, rtol=1e-5)
def test_evaluate_no_components_returns_zero(self):
- # WHEN THEN
+ # WHEN
component_collection = ComponentCollection(display_name='EmptyModel')
x = np.linspace(-5, 5, 100)
- # EXPECT
+ # THEN
result = component_collection.evaluate(x)
+
+ # EXPECT
assert np.all(result == pytest.approx(0.0))
assert result.shape == x.shape
+ def test_evaluate_no_components_scipp_output(self):
+ # WHEN
+ component_collection = ComponentCollection(display_name='EmptyModel', y_unit='1/meV')
+ x = np.linspace(-5, 5, 100)
+
+ # THEN
+ result = component_collection.evaluate(x, output='scipp')
+
+ # EXPECT: an sc.Variable of zeros carrying the collection's y_unit
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('1/meV')
+ assert np.all(result.values == pytest.approx(0.0))
+
+ def test_evaluate_no_components_scipp_input(self):
+ # WHEN
+ component_collection = ComponentCollection(display_name='EmptyModel')
+ x = sc.linspace('energy', -5.0, 5.0, 100, unit='meV')
+
+ # THEN
+ result = component_collection.evaluate(x, output='scipp')
+
+ # EXPECT: zeros on the input grid, keeping the input's dimension name
+ assert isinstance(result, sc.Variable)
+ assert result.dims == ('energy',)
+ assert np.all(result.values == pytest.approx(0.0))
+
def test_evaluate_component(self, component_collection):
- # WHEN THEN
+ # WHEN
x = np.linspace(-5, 5, 100)
+
+ # THEN
result1 = component_collection.evaluate_component(x, 'TestGaussian1Name')
result2 = component_collection.evaluate_component(x, 'TestLorentzian1Name')
@@ -290,13 +326,57 @@ def test_normalize_area_not_finite_area_raises(self, component_collection, area_
def test_normalize_area_non_area_component_warns(self, component_collection):
# WHEN
- component1 = Polynomial(display_name='TestPolynomial', coefficients=[1, 2, 3], unit='meV')
+ component1 = Polynomial(
+ display_name='TestPolynomial', coefficients=[1, 2, 3], x_unit='meV'
+ )
component_collection.append_component(component1)
# THEN EXPECT
with pytest.warns(UserWarning, match="does not have an 'area' "):
component_collection.normalize_area()
+ def test_convert_x_unit_rollback_skipped_when_old_unit_none(self):
+ # WHEN: a collection without an x_unit of its own
+ collection = ComponentCollection(
+ components=Gaussian(name='G', area=1.0, width=0.5, x_unit='meV'), x_unit=None
+ )
+
+ # THEN: an incompatible unit fails; the outer rollback is skipped (no old unit to
+ # restore) but the component's own atomic rollback keeps it consistent
+ with pytest.raises(sc.UnitError):
+ collection.convert_x_unit('m')
+
+ # EXPECT
+ assert collection[0].x_unit == 'meV'
+
+ def test_normalize_area_only_non_area_components_raises(self):
+ # WHEN: no component in the collection has an area attribute
+ collection = ComponentCollection(
+ components=Polynomial(display_name='OnlyPolynomial', coefficients=[1, 2])
+ )
+
+ # THEN EXPECT
+ with (
+ pytest.warns(UserWarning, match="does not have an 'area' "),
+ pytest.raises(ValueError, match='No components with an area attribute'),
+ ):
+ collection.normalize_area()
+
+ def test_normalize_area_mixed_units(self):
+ # WHEN: two Gaussians with compatible but different area units (meV and ueV)
+ gaussian_mev = Gaussian(name='G1', area=1.0, width=1.0, x_unit='meV')
+ gaussian_uev = Gaussian(name='G2', area=1000.0, width=500.0, x_unit='ueV')
+ collection = ComponentCollection(components=[gaussian_mev, gaussian_uev])
+
+ # THEN: both areas are physically 1 meV, so each should end up at half its value
+ collection.normalize_area()
+
+ # EXPECT: areas sum to 1 in the first component's unit (meV)
+ total_mev = gaussian_mev.area.value + gaussian_uev.area.value / 1000.0
+ assert total_mev == pytest.approx(1.0)
+ assert gaussian_mev.area.value == pytest.approx(0.5)
+ assert gaussian_uev.area.value == pytest.approx(500.0)
+
def test_get_all_parameters(self, component_collection):
# WHEN THEN
parameters = component_collection.get_all_parameters()
@@ -364,6 +444,7 @@ def test_fix_and_free_all_parameters(self, component_collection):
assert param.fixed is False
def test_contains(self, component_collection):
+ # WHEN THEN EXPECT: membership by name and by object
assert 'TestGaussian1Name' in component_collection
assert 'TestLorentzian1Name' in component_collection
assert 'NonExistentComponent' not in component_collection
@@ -373,9 +454,10 @@ def test_contains(self, component_collection):
assert gaussian_component in component_collection
assert lorentzian_component in component_collection
- # WHEN THEN
- fake_component = Gaussian(name='FakeGaussian', area=1.0, center=0.0, width=1.0, unit='meV')
- # EXPECT
+ # WHEN: a component not in the collection — THEN EXPECT
+ fake_component = Gaussian(
+ name='FakeGaussian', area=1.0, center=0.0, width=1.0, x_unit='meV'
+ )
assert fake_component not in component_collection
assert 123 not in component_collection # Invalid type
@@ -392,13 +474,15 @@ def test_to_dict(self, component_collection):
# EXPECT
assert model_dict['display_name'] == component_collection.display_name
- assert model_dict['unit'] == component_collection.unit
+ assert model_dict['x_unit'] == component_collection.x_unit
+ assert model_dict['y_unit'] == component_collection.y_unit
assert len(model_dict['components']) == len(component_collection)
for comp, comp_dict in zip(component_collection, model_dict['components'], strict=True):
assert comp_dict['@class'] == type(comp).__name__
assert comp_dict['display_name'] == comp.display_name
- assert comp_dict['unit'] == comp.unit
+ assert comp_dict['x_unit'] == comp.x_unit
+ assert comp_dict['y_unit'] == comp.y_unit
def test_from_dict(self, component_collection):
# WHEN
@@ -409,12 +493,15 @@ def test_from_dict(self, component_collection):
# EXPECT
assert new_model.display_name == component_collection.display_name
+ assert new_model.x_unit == component_collection.x_unit
+ assert new_model.y_unit == component_collection.y_unit
assert len(new_model) == len(component_collection)
for orig_comp, new_comp in zip(component_collection, new_model, strict=True):
assert type(new_comp) is type(orig_comp)
assert new_comp.display_name == orig_comp.display_name
- assert new_comp.unit == orig_comp.unit
+ assert new_comp.x_unit == orig_comp.x_unit
+ assert new_comp.y_unit == orig_comp.y_unit
orig_params = orig_comp.get_all_parameters()
new_params = new_comp.get_all_parameters()
@@ -426,6 +513,13 @@ def test_from_dict(self, component_collection):
assert param_new.value == param_orig.value
assert param_new.fixed == param_orig.fixed
+ @pytest.mark.parametrize('missing_key', ['x_unit', 'y_unit', 'components', 'name'])
+ def test_from_dict_requires_all_keys(self, component_collection, missing_key):
+ model_dict = component_collection.to_dict()
+ del model_dict[missing_key]
+ with pytest.raises(KeyError):
+ ComponentCollection.from_dict(model_dict)
+
def test_copy(self, component_collection):
# WHEN
component_collection[0].area.min = 0.5
@@ -451,7 +545,8 @@ def test_copy(self, component_collection):
# Same type and display name
assert type(copied_comp) is type(orig_comp)
assert copied_comp.display_name == orig_comp.display_name
- assert copied_comp.unit == orig_comp.unit
+ assert copied_comp.x_unit == orig_comp.x_unit
+ assert copied_comp.y_unit == orig_comp.y_unit
# Parameters are deep-copied and equivalent
orig_params = orig_comp.get_all_parameters()
@@ -487,3 +582,95 @@ def test_no_warning_with_unique_names(self, recwarn):
ComponentCollection(components=[g1, g2])
user_warnings = [w for w in recwarn.list if issubclass(w.category, UserWarning)]
assert not user_warnings
+
+ def test_y_unit_custom(self):
+ # WHEN THEN
+ cc = ComponentCollection(y_unit='1/meV')
+ # EXPECT
+ assert cc.y_unit == '1/meV'
+
+ def test_y_unit_setter_raises(self, component_collection):
+ # WHEN THEN EXPECT
+ with pytest.raises(AttributeError, match=r'read-only'):
+ component_collection.y_unit = '1/meV'
+
+ def test_convert_y_unit(self):
+ # WHEN: components with y_unit='1/meV' so area_unit ≈ dimensionless
+ g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV')
+ lor = Lorentzian(area=1.0, x_unit='meV', y_unit='1/meV')
+ cc = ComponentCollection(components=[g, lor])
+
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ cc.convert_y_unit('1/eV')
+
+ # EXPECT
+ assert cc.y_unit == '1/eV'
+ for component in cc:
+ assert component.y_unit == '1/eV'
+ assert g.area.value == pytest.approx(1e3)
+ assert lor.area.value == pytest.approx(1e3)
+
+ def test_convert_y_unit_invalid_type_raises(self, component_collection):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError):
+ component_collection.convert_y_unit(123)
+
+ def test_convert_x_unit_rollback_on_failure(self):
+ # WHEN: collection whose first Gaussian converts fine, but second has an
+ # ExpressionComponent that raises NotImplementedError for convert_x_unit.
+ g = Gaussian(area=1.0, x_unit='meV')
+ expr = ExpressionComponent('A * x', parameters={'A': 1.0}, x_unit='meV')
+ cc = ComponentCollection(components=[g, expr])
+ original_area = g.area.value
+
+ # THEN: attempt a unit conversion that will fail on the ExpressionComponent
+ with pytest.raises(NotImplementedError):
+ cc.convert_x_unit('microeV')
+
+ # EXPECT: Gaussian is rolled back to its original state
+ assert cc.x_unit == 'meV'
+ assert g.x_unit == 'meV'
+ assert g.area.value == pytest.approx(original_area)
+
+ def test_convert_y_unit_rollback_on_failure(self):
+ # WHEN: collection where first Gaussian converts successfully but second
+ # ExpressionComponent always raises NotImplementedError for convert_y_unit.
+ g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV')
+ expr = ExpressionComponent('A * x', parameters={'A': 1.0}, x_unit='meV')
+ cc = ComponentCollection(components=[g, expr], y_unit='1/meV')
+ original_area = g.area.value
+
+ # THEN: attempt y_unit conversion that will fail on the ExpressionComponent
+ with pytest.raises(NotImplementedError):
+ cc.convert_y_unit('1/eV')
+
+ # EXPECT: collection y_unit and Gaussian are both rolled back
+ assert cc.y_unit == '1/meV'
+ assert g.y_unit == '1/meV'
+ assert g.area.value == pytest.approx(original_area)
+
+ def test_evaluate_scipp_output_with_y_unit(self):
+ # WHEN
+ g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV')
+ cc = ComponentCollection(components=[g], y_unit='1/meV')
+ x = np.linspace(-5, 5, 50)
+ # THEN
+ result = cc.evaluate(x, output='scipp')
+ # EXPECT
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('1/meV')
+
+ # ───── Regression tests ─────
+
+ def test_evaluate_scipp_output_multi_component_does_not_raise(self, component_collection):
+ # WHEN: collection with two components (Gaussian + Lorentzian)
+ x = sc.Variable(dims=['energy'], values=np.linspace(-5.0, 5.0, 100), unit='meV')
+ # THEN: evaluate with scipp output
+ # Before the fix, sum() started from int 0 → '0 + sc.Variable' raised TypeError.
+ result = component_collection.evaluate(x, output='scipp')
+ # EXPECT: returns a Variable whose values are the sum of both components
+ assert isinstance(result, sc.Variable)
+ expected = component_collection[0].evaluate(x, output='scipp') + component_collection[
+ 1
+ ].evaluate(x, output='scipp')
+ assert sc.allclose(result, expected)
diff --git a/tests/unit/easydynamics/sample_model/test_instrument_model.py b/tests/unit/easydynamics/sample_model/test_instrument_model.py
index cb716fc42..a802b1961 100644
--- a/tests/unit/easydynamics/sample_model/test_instrument_model.py
+++ b/tests/unit/easydynamics/sample_model/test_instrument_model.py
@@ -68,12 +68,13 @@ def test_init(self, instrument_model):
# EXPECT
assert model.display_name == 'TestInstrumentModel'
- np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0]))
+ assert isinstance(model.Q, sc.Variable)
+ assert model.Q.unit == sc.Unit('1/angstrom')
+ np.testing.assert_array_equal(model.Q.values, np.array([1.0, 2.0, 3.0]))
assert isinstance(model.background_model, BackgroundModel)
assert isinstance(model.resolution_model, ResolutionModel)
- np.testing.assert_array_equal(model.background_model.Q, np.array([1.0, 2.0, 3.0]))
- np.testing.assert_array_equal(model.resolution_model.Q, np.array([1.0, 2.0, 3.0]))
- np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0]))
+ np.testing.assert_array_equal(model.background_model.Q.values, np.array([1.0, 2.0, 3.0]))
+ np.testing.assert_array_equal(model.resolution_model.Q.values, np.array([1.0, 2.0, 3.0]))
def test_init_defaults(self):
# WHEN THEN
@@ -113,7 +114,7 @@ def test_init_sample_model_as_resolution_model(self, sample_model):
'energy_offset must be a number',
),
(
- {'unit': 123},
+ {'x_unit': 123},
TypeError,
'unit must be',
),
@@ -122,7 +123,7 @@ def test_init_sample_model_as_resolution_model(self, sample_model):
'invalid resolution_model',
'invalid background_model',
'invalid energy_offset',
- 'invalid unit',
+ 'invalid x unit',
],
)
def test_instrument_model_init_invalid_inputs(
@@ -209,13 +210,10 @@ def test_clear_Q_raises_without_confirm(self, instrument_model):
with pytest.raises(ValueError, match='Clearing Q values requires confirmation'):
instrument_model.clear_Q()
- def test_unit_setter_raises(self, instrument_model):
+ def test_x_unit_setter_raises(self, instrument_model):
# WHEN / THEN / EXPECT
- with pytest.raises(
- AttributeError,
- match=r'Unit is read-only. Use convert_unit to change the unit between allowed types ',
- ):
- instrument_model.unit = 'meV'
+ with pytest.raises(AttributeError):
+ instrument_model.x_unit = 'meV'
def test_energy_offset_setter(self, instrument_model):
# WHEN
@@ -272,16 +270,16 @@ def test_get_energy_offset_no_Q_raises(self, instrument_model):
):
instrument_model.get_energy_offset(0)
- def test_convert_unit_calls_all_children(self, instrument_model):
+ def test_convert_x_unit_calls_all_children(self, instrument_model):
# WHEN
new_unit = 'eV'
# THEN
# Ensure energy offsets are built before mocking
instrument_model._ensure_energy_offsets_current()
- # Mock downstream convert_unit calls
- instrument_model._background_model.convert_unit = MagicMock()
- instrument_model._resolution_model.convert_unit = MagicMock()
+ # Mock downstream convert_x_unit calls
+ instrument_model._background_model.convert_x_unit = MagicMock()
+ instrument_model._resolution_model.convert_x_unit = MagicMock()
instrument_model._energy_offset.convert_unit = MagicMock()
for offset in instrument_model._energy_offsets:
offset.convert_unit = MagicMock()
@@ -290,28 +288,28 @@ def test_convert_unit_calls_all_children(self, instrument_model):
'easydynamics.sample_model.instrument_model._validate_unit',
return_value=new_unit,
) as mock_validate:
- instrument_model.convert_unit(new_unit)
+ instrument_model.convert_x_unit(new_unit)
# EXPECT
mock_validate.assert_called_once_with(new_unit)
- instrument_model._background_model.convert_unit.assert_called_once_with(new_unit)
- instrument_model._resolution_model.convert_unit.assert_called_once_with(new_unit)
+ instrument_model._background_model.convert_x_unit.assert_called_once_with(new_unit)
+ instrument_model._resolution_model.convert_x_unit.assert_called_once_with(new_unit)
instrument_model._energy_offset.convert_unit.assert_called_once_with(new_unit)
for offset in instrument_model._energy_offsets:
offset.convert_unit.assert_called_once_with(new_unit)
# final state
- assert instrument_model.unit == new_unit
+ assert instrument_model.x_unit == new_unit
- def test_convert_unit_None_raises(self, instrument_model):
+ def test_convert_x_unit_None_raises(self, instrument_model):
# WHEN / THEN / EXPECT
with pytest.raises(
ValueError,
match=' must be a valid unit',
):
- instrument_model.convert_unit(None)
+ instrument_model.convert_x_unit(None)
def test_fix_resolution_parameters(self, instrument_model):
# WHEN
@@ -420,7 +418,9 @@ def test_generate_energy_offsets_Q_none(self, instrument_model):
def test_generate_energy_offsets(self, instrument_model):
# WHEN
- instrument_model._Q = np.array([1.0, 2.0, 3.0, 4.0])
+ instrument_model._Q = sc.Variable(
+ dims=['Q'], values=[1.0, 2.0, 3.0, 4.0], unit='1/angstrom'
+ )
# THEN
instrument_model._generate_energy_offsets()
@@ -429,7 +429,7 @@ def test_generate_energy_offsets(self, instrument_model):
assert len(instrument_model._energy_offsets) == 4
for offset in instrument_model._energy_offsets:
assert offset.name == 'energy_offset'
- assert offset.unit == instrument_model.unit
+ assert offset.unit == instrument_model.x_unit
assert offset.value == instrument_model.energy_offset.value
def test_Q_setter(self, instrument_model_without_Q):
@@ -441,8 +441,12 @@ def test_Q_setter(self, instrument_model_without_Q):
# EXPECT
assert instrument_model_without_Q._energy_offsets_is_dirty is True
- np.testing.assert_array_equal(instrument_model_without_Q.background_model.Q, first_new_Q)
- np.testing.assert_array_equal(instrument_model_without_Q.resolution_model.Q, first_new_Q)
+ np.testing.assert_array_equal(
+ instrument_model_without_Q.background_model.Q.values, first_new_Q
+ )
+ np.testing.assert_array_equal(
+ instrument_model_without_Q.resolution_model.Q.values, first_new_Q
+ )
# THEN
new_Q = np.array([4.0, 5.0, 6.0])
@@ -457,8 +461,12 @@ def test_Q_setter(self, instrument_model_without_Q):
# EXPECT
# Q values remain unchanged
- np.testing.assert_array_equal(instrument_model_without_Q.background_model.Q, first_new_Q)
- np.testing.assert_array_equal(instrument_model_without_Q.resolution_model.Q, first_new_Q)
+ np.testing.assert_array_equal(
+ instrument_model_without_Q.background_model.Q.values, first_new_Q
+ )
+ np.testing.assert_array_equal(
+ instrument_model_without_Q.resolution_model.Q.values, first_new_Q
+ )
# THEN - set Q to an equivalent scipp Variable; values match so should be accepted
new_Q = sc.Variable(dims=['Q'], values=[1.0, 2.0, 3.0], unit='1/angstrom')
@@ -466,8 +474,12 @@ def test_Q_setter(self, instrument_model_without_Q):
# EXPECT - Q propagated to child models, offsets marked dirty again
assert instrument_model_without_Q._energy_offsets_is_dirty is True
- np.testing.assert_array_equal(instrument_model_without_Q.background_model.Q, first_new_Q)
- np.testing.assert_array_equal(instrument_model_without_Q.resolution_model.Q, first_new_Q)
+ np.testing.assert_array_equal(
+ instrument_model_without_Q.background_model.Q.values, first_new_Q
+ )
+ np.testing.assert_array_equal(
+ instrument_model_without_Q.resolution_model.Q.values, first_new_Q
+ )
def test_fix_and_free_offset(self, instrument_model):
# WHEN
@@ -516,13 +528,13 @@ def test_fix_or_free_energy_offset_nonint_Q_index_raises(self, instrument_model)
# WHEN / THEN / EXPECT
with pytest.raises(
TypeError,
- match='Q_index must be an int or None, got str',
+ match='Q_index must be an int',
):
instrument_model.fix_energy_offset(Q_index='invalid_index')
with pytest.raises(
TypeError,
- match='Q_index must be an int or None, got str',
+ match='Q_index must be an int',
):
instrument_model.free_energy_offset(Q_index='invalid_index')
@@ -567,7 +579,7 @@ def test_repr_contains_expected_fields(self, instrument_model):
# EXPECT
assert repr_str.startswith('InstrumentModel(')
assert f'unique_name={instrument_model.unique_name!r}' in repr_str
- assert f'unit={instrument_model.unit}' in repr_str
+ assert f'x_unit={instrument_model.x_unit}' in repr_str
assert 'Q_len=3' in repr_str
assert f'resolution_model={instrument_model._resolution_model!r}' in repr_str
assert f'background_model={instrument_model._background_model!r}' in repr_str
diff --git a/tests/unit/easydynamics/sample_model/test_model_base.py b/tests/unit/easydynamics/sample_model/test_model_base.py
index 5b58370d6..0b19950b8 100644
--- a/tests/unit/easydynamics/sample_model/test_model_base.py
+++ b/tests/unit/easydynamics/sample_model/test_model_base.py
@@ -23,7 +23,7 @@ def model_base(self):
area=1.0,
center=0.0,
width=1.0,
- unit='meV',
+ x_unit='meV',
)
component2 = Lorentzian(
name='TestLorentzian1Name',
@@ -31,7 +31,7 @@ def model_base(self):
area=2.0,
center=1.0,
width=0.5,
- unit='meV',
+ x_unit='meV',
)
component_collection = ComponentCollection()
component_collection.append_component(component1)
@@ -39,19 +39,22 @@ def model_base(self):
return ModelBase(
display_name='InitModel',
components=component_collection,
- unit='meV',
+ x_unit='meV',
Q=np.array([1.0, 2.0, 3.0]),
)
def test_init(self, model_base):
# WHEN THEN
- model = model_base
# EXPECT
- assert model.display_name == 'InitModel'
- assert model.unit == 'meV'
- assert len(model.components) == 2
- np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0]))
+ assert model_base.display_name == 'InitModel'
+ assert model_base.x_unit == 'meV'
+ assert model_base.y_unit == 'dimensionless'
+ assert len(model_base.components) == 2
+ assert isinstance(model_base.Q, sc.Variable)
+ assert model_base.Q.dims == ('Q',)
+ assert model_base.Q.unit == sc.Unit('1/angstrom')
+ np.testing.assert_array_equal(model_base.Q.values, np.array([1.0, 2.0, 3.0]))
def test_init_raises_with_invalid_components(self):
# WHEN / THEN / EXPECT
@@ -78,8 +81,8 @@ def test_evaluate_calls_all_component_collections(self, model_base):
result = model_base.evaluate(x)
# EXPECT
- collection1.evaluate.assert_called_once_with(x)
- collection2.evaluate.assert_called_once_with(x)
+ collection1.evaluate.assert_called_once_with(x, output='numpy')
+ collection2.evaluate.assert_called_once_with(x, output='numpy')
np.testing.assert_allclose(result[0], np.array([1.0, 2.0, 3.0]))
np.testing.assert_allclose(result[1], np.array([4.0, 5.0, 6.0]))
@@ -128,7 +131,7 @@ def test_get_all_variables(self, model_base):
# WHEN
all_vars = model_base.get_all_variables()
- # THEN
+ # EXPECT
expected_var_display_names = {
'TestGaussian1Name area',
'TestGaussian1Name center',
@@ -168,7 +171,7 @@ def test_get_all_variables_with_invalid_Q_index_raises(self, model_base):
# WHEN / THEN / EXPECT
with pytest.raises(
IndexError,
- match='Q_index 5 is out of bounds for component collections of length 3',
+ match='Q_index 5 is out of bounds for Q of length 3',
):
model_base.get_all_variables(Q_index=5)
@@ -198,7 +201,7 @@ def test_get_component_collection_invalid_index_raises(self, model_base):
# WHEN THEN EXPECT
with pytest.raises(
IndexError,
- match='Q_index 5 is out of bounds for ',
+ match='Q_index 5 is out of bounds for Q of length 3',
):
model_base.get_component_collection(Q_index=5)
@@ -246,36 +249,105 @@ def test_append_component_invalid_type_raises(self, model_base):
with pytest.raises(TypeError, match=' must be '):
model_base.append_component('invalid_component')
- def test_unit_property(self, model_base):
+ def test_x_unit_property(self, model_base):
# WHEN
- unit = model_base.unit
+ unit = model_base.x_unit
# THEN / EXPECT
assert unit == 'meV'
- def test_unit_setter_raises(self, model_base):
+ def test_x_unit_setter_raises(self, model_base):
# WHEN / THEN / EXPECT
- with pytest.raises(AttributeError, match='Use convert_unit to change '):
- model_base.unit = 'K'
+ with pytest.raises(AttributeError):
+ model_base.x_unit = 'K'
+
+ def test_convert_x_unit(self, model_base):
+ # Build collections before conversion so we can verify in-place update
+ _ = model_base.get_component_collection(0)
+ assert model_base._component_collections_is_dirty is False
+ collection_before = model_base._component_collections[0]
- def test_convert_unit(self, model_base):
# WHEN
- model_base.convert_unit('eV')
+ model_base.convert_x_unit('eV')
- # THEN / EXPECT
- assert model_base.unit == 'eV'
+ # THEN / EXPECT: dirty flag NOT set and same collections reused (not rebuilt)
+ assert model_base._component_collections_is_dirty is False
+ assert model_base._component_collections[0] is collection_before
+
+ assert model_base.x_unit == 'eV'
for component in model_base.components:
- assert component.unit == 'eV'
+ assert component.x_unit == 'eV'
+ for collection in model_base._component_collections:
+ for component in collection:
+ assert component.x_unit == 'eV'
- def test_convert_unit_invalid_raises(self, model_base):
+ def test_convert_x_unit_invalid_raises(self, model_base):
# WHEN / THEN / EXPECT
with pytest.raises(UnitError):
- model_base.convert_unit('invalid_unit')
+ model_base.convert_x_unit('invalid_unit')
- def test_convert_unit_incorrect_unit_raises(self, model_base):
+ def test_convert_x_unit_incorrect_unit_raises(self, model_base):
# WHEN THEN EXPECT
with pytest.raises(TypeError, match=r'Unit must be a string or sc.Unit'):
- model_base.convert_unit(123)
+ model_base.convert_x_unit(123)
+
+ def test_components_setter_none(self, model_base):
+ # WHEN THEN
+ model_base.components = None
+ # EXPECT
+ assert len(model_base.components) == 0
+
+ def test_convert_x_unit_rollback_when_old_unit_none(self):
+ # WHEN: model with _x_unit=None (rollback branch is skipped when old_unit is None)
+ component = Gaussian(name='G', area=1.0, center=0.0, width=0.5, x_unit='meV')
+ model = ModelBase(display_name='M', x_unit=None, components=component)
+ model._x_unit = None
+ # THEN
+ with pytest.raises(UnitError):
+ model.convert_x_unit('m') # incompatible unit triggers failure
+ # EXPECT: Gaussian's own atomic rollback keeps it at 'meV' even though
+ # ModelBase's outer rollback loop is skipped when old_unit is None
+ assert component.x_unit == 'meV'
+
+ def test_convert_x_unit_rollback_on_failure(self, model_base):
+ # WHEN THEN
+ with pytest.raises(UnitError):
+ model_base.convert_x_unit('m')
+ # EXPECT: state rolled back
+ assert model_base.x_unit == 'meV'
+ for component in model_base.components:
+ assert component.x_unit == 'meV'
+
+ def test_convert_y_unit_rollback_on_failure(self, model_base):
+ # WHEN THEN
+ with pytest.raises(UnitError):
+ model_base.convert_y_unit('K')
+ # EXPECT: state rolled back
+ assert model_base.y_unit == 'dimensionless'
+
+ def test_convert_x_unit_rollback_restores_collections(self):
+ # WHEN: a model with built per-Q collections
+ component = Gaussian(name='G', area=1.0, center=0.0, width=0.5, x_unit='meV')
+ model = ModelBase(display_name='M', components=component, Q=[1.0, 2.0])
+ collection = model.get_component_collection(0)
+
+ # THEN: an incompatible unit fails and triggers the rollback of components and
+ # collections
+ with pytest.raises(UnitError):
+ model.convert_x_unit('m')
+
+ # EXPECT
+ assert model.x_unit == 'meV'
+ assert component.x_unit == 'meV'
+ assert collection[0].x_unit == 'meV'
+
+ def test_component_collections_empty_without_Q(self):
+ # WHEN: a model without Q regenerates its collections
+ model = ModelBase(display_name='M', components=Gaussian(name='G'))
+
+ # THEN EXPECT: no per-Q collections and therefore no variables
+ assert model.get_all_variables() == []
+ assert model._component_collections == []
def test_components_setter(self, model_base):
# WHEN
@@ -320,8 +392,9 @@ def test_Q_setter_raises_if_Q_is_not_similar(self, model_base):
[1.0, 2.0, 3.0],
np.array([1.0, 2.0, 3.0]),
sc.Variable(dims=['Q'], values=[1.0, 2.0, 3.0], unit='1/angstrom'),
+ sc.Variable(dims=['Q'], values=[10.0, 20.0, 30.0], unit='1/nm'),
],
- ids=['list', 'numpy_array', 'scipp_variable'],
+ ids=['list', 'numpy_array', 'scipp_variable', 'scipp_variable_other_unit'],
)
def test_Q_setter_with_similar_Q(self, model_base, new_Q):
# WHEN
@@ -331,7 +404,7 @@ def test_Q_setter_with_similar_Q(self, model_base, new_Q):
model_base.Q = new_Q
# EXPECT
- np.testing.assert_array_equal(model_base.Q, old_Q)
+ np.testing.assert_array_equal(model_base.Q.values, old_Q.values)
def test_Q_setter_with_none(self, model_base):
# WHEN
@@ -352,7 +425,20 @@ def test_Q_setter_when_current_Q_is_none(self, model_base):
model_base.Q = new_Q
# EXPECT
- np.testing.assert_array_equal(model_base.Q, np.array(new_Q))
+ np.testing.assert_array_equal(model_base.Q.values, np.array(new_Q))
+
+ def test_Q_stored_as_scipp_in_inverse_angstrom(self, model_base):
+ # WHEN: a scipp Q in 1/nm
+ model_base._Q = None
+ new_Q = sc.Variable(dims=['Q'], values=[5.0, 10.0], unit='1/nm')
+
+ # THEN
+ model_base.Q = new_Q
+
+ # EXPECT: stored canonically in 1/angstrom
+ assert isinstance(model_base.Q, sc.Variable)
+ assert model_base.Q.unit == sc.Unit('1/angstrom')
+ np.testing.assert_allclose(model_base.Q.values, [0.5, 1.0])
def test_clear_Q(self, model_base):
# WHEN
@@ -388,3 +474,42 @@ def test_repr(self, model_base):
assert 'unit' in repr_str
assert 'Q=' in repr_str
assert 'components=' in repr_str
+
+ def test_y_unit_setter_raises(self, model_base):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(AttributeError):
+ model_base.y_unit = '1/meV'
+
+ def test_convert_y_unit(self):
+ # WHEN: model with components where y_unit='1/meV' so area_unit ≈ dimensionless
+ g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV')
+ lor = Lorentzian(area=1.0, x_unit='meV', y_unit='1/meV')
+ cc = ComponentCollection(components=[g, lor])
+ model = ModelBase(components=cc, x_unit='meV', Q=np.array([1.0]))
+
+ # Build collections before conversion so we can verify in-place update
+ _ = model.get_component_collection(0)
+ assert model._component_collections_is_dirty is False
+ collection_before = model._component_collections[0]
+
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ model.convert_y_unit('1/eV')
+
+ # EXPECT: dirty flag NOT set and same collections reused (not rebuilt)
+ assert model._component_collections_is_dirty is False
+ assert model._component_collections[0] is collection_before
+
+ # EXPECT: model y_unit and template components updated
+ assert model.y_unit == '1/eV'
+ for component in model.components:
+ assert component.y_unit == '1/eV'
+ assert g.area.value == pytest.approx(1e3)
+ assert lor.area.value == pytest.approx(1e3)
+ # EXPECT: component collections updated in-place (not rebuilt from templates)
+ for component in collection_before:
+ assert component.y_unit == '1/eV'
+
+ def test_convert_y_unit_invalid_raises(self, model_base):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError):
+ model_base.convert_y_unit(123)
diff --git a/tests/unit/easydynamics/sample_model/test_resolution_model.py b/tests/unit/easydynamics/sample_model/test_resolution_model.py
index f894daf20..102363f60 100644
--- a/tests/unit/easydynamics/sample_model/test_resolution_model.py
+++ b/tests/unit/easydynamics/sample_model/test_resolution_model.py
@@ -21,14 +21,14 @@ def resolution_model(self):
area=1.0,
center=0.0,
width=1.0,
- unit='meV',
+ x_unit='meV',
)
component2 = Lorentzian(
display_name='TestLorentzian1',
area=2.0,
center=1.0,
width=0.5,
- unit='meV',
+ x_unit='meV',
)
component_collection = ComponentCollection()
component_collection.append_component(component1)
@@ -36,7 +36,7 @@ def resolution_model(self):
return ResolutionModel(
display_name='InitModel',
components=component_collection,
- unit='meV',
+ x_unit='meV',
Q=np.array([1.0, 2.0, 3.0]),
)
@@ -48,7 +48,7 @@ def sample_model(self):
area=1.0,
center=0.0,
width=1.0,
- unit='meV',
+ x_unit='meV',
)
component2 = Lorentzian(
name='TestLorentzian1Name',
@@ -56,7 +56,7 @@ def sample_model(self):
area=2.0,
center=1.0,
width=0.5,
- unit='meV',
+ x_unit='meV',
)
component_collection = ComponentCollection()
component_collection.append_component(component1)
@@ -65,20 +65,20 @@ def sample_model(self):
return SampleModel(
display_name='InitModel',
components=component_collection,
- unit='meV',
+ x_unit='meV',
Q=np.array([1.0, 2.0, 3.0]),
temperature=10.0,
)
def test_init(self, resolution_model):
# WHEN THEN
- model = resolution_model
# EXPECT
- assert model.display_name == 'InitModel'
- assert model.unit == 'meV'
- assert len(model.components) == 2
- np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0]))
+ assert resolution_model.display_name == 'InitModel'
+ assert resolution_model.x_unit == 'meV'
+ assert resolution_model.y_unit == 'dimensionless'
+ assert len(resolution_model.components) == 2
+ np.testing.assert_array_equal(resolution_model.Q.values, np.array([1.0, 2.0, 3.0]))
@pytest.mark.parametrize(
'invalid_component, expected_error_msg',
@@ -215,10 +215,10 @@ def test_from_sample_model(
# EXPECT
assert resolution_model.display_name == 'InitModel'
- assert resolution_model.unit == 'meV'
+ assert resolution_model.x_unit == 'meV'
assert len(resolution_model.components) == 2
np.testing.assert_array_equal(
- resolution_model.Q,
+ resolution_model.Q.values,
np.array([1.0, 2.0, 3.0]),
)
@@ -242,6 +242,21 @@ def test_from_sample_model(
variables = resolution_model.get_all_variables()
assert all(var.fixed for var in variables) is all_fixed
+ def test_from_sample_model_installed_collections_are_stable(self, sample_model):
+ # WHEN
+ resolution_model = ResolutionModel.from_sample_model(sample_model)
+
+ # EXPECT: the installed collections are not scheduled for a rebuild (regression: init
+ # callbacks used to set the dirty flag, so the next access silently regenerated the
+ # collections from templates, discarding the normalization)
+ assert resolution_model._component_collections_is_dirty is False
+
+ # and repeated access returns the same installed, normalized collections
+ first_access = resolution_model.get_component_collection(0)
+ second_access = resolution_model.get_component_collection(0)
+ assert first_access is second_access
+ assert (first_access[0].area.value + first_access[1].area.value) == pytest.approx(1.0)
+
def test_from_sample_model_with_no_Q(self, sample_model):
# WHEN
sample_model_no_Q = SampleModel(
@@ -301,3 +316,19 @@ def test_from_sample_model_invalid_components(self, sample_model):
match='cannot be a DeltaFunction',
):
ResolutionModel.from_sample_model(sample_model)
+
+ def test_y_unit_setter_raises(self, resolution_model):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(AttributeError):
+ resolution_model.y_unit = '1/meV'
+
+ def test_convert_y_unit(self):
+ # WHEN
+ g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV')
+ model = ResolutionModel(components=g, x_unit='meV')
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ model.convert_y_unit('1/eV')
+ # EXPECT
+ assert model.y_unit == '1/eV'
+ assert model.components[0].y_unit == '1/eV'
+ assert g.area.value == pytest.approx(1e3)
diff --git a/tests/unit/easydynamics/sample_model/test_sample_model.py b/tests/unit/easydynamics/sample_model/test_sample_model.py
index 8086c92e6..424a41f0b 100644
--- a/tests/unit/easydynamics/sample_model/test_sample_model.py
+++ b/tests/unit/easydynamics/sample_model/test_sample_model.py
@@ -6,6 +6,7 @@
import numpy as np
import pytest
+import scipp as sc
from scipp import UnitError
from easydynamics.sample_model import ComponentCollection
@@ -28,7 +29,7 @@ def sample_model(self):
area=1.0,
center=0.0,
width=1.0,
- unit='meV',
+ x_unit='meV',
)
component2 = Lorentzian(
name='TestLorentzian1Name',
@@ -36,7 +37,7 @@ def sample_model(self):
area=2.0,
center=1.0,
width=0.5,
- unit='meV',
+ x_unit='meV',
)
component_collection = ComponentCollection()
component_collection.append_component(component1)
@@ -50,7 +51,7 @@ def sample_model(self):
display_name='InitModel',
components=component_collection,
diffusion_models=diffusion_model,
- unit='meV',
+ x_unit='meV',
Q=np.array([1.0, 2.0, 3.0]),
temperature=10.0,
)
@@ -62,7 +63,8 @@ def test_init(self, sample_model):
# EXPECT
assert model.display_name == 'InitModel'
- assert model.unit == 'meV'
+ assert model.x_unit == 'meV'
+ assert model.y_unit == 'dimensionless'
assert len(model.components) == 2
assert isinstance(model.diffusion_models, list)
assert len(model.diffusion_models) == 1
@@ -71,7 +73,7 @@ def test_init(self, sample_model):
assert model.normalize_detailed_balance is True
assert model.use_detailed_balance is True
assert isinstance(model.detailed_balance_settings, DetailedBalanceSettings)
- np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0]))
+ np.testing.assert_array_equal(model.Q.values, np.array([1.0, 2.0, 3.0]))
def test_init_custom_input(self):
# WHEN THEN
@@ -308,6 +310,62 @@ def test_convert_temperature_unit_raises_with_invalid_unit(self, sample_model):
):
model.convert_temperature_unit('invalid_unit')
+ def test_convert_x_unit_propagates_to_diffusion_models(self):
+ # WHEN: a SampleModel with an attached diffusion model
+ Q = np.array([1.0, 2.0])
+ brownian = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, Q=Q)
+ model = SampleModel(Q=Q, diffusion_models=brownian)
+ width_mev = model.get_component_collection(0)[0].width.value
+
+ # THEN
+ model.convert_x_unit('ueV')
+
+ # EXPECT: the diffusion model follows the SampleModel's new unit, and the merged
+ # collections carry rescaled widths in the new unit
+ assert sc.Unit(str(brownian.x_unit)) == sc.Unit('ueV')
+ assert sc.Unit(str(brownian.scale.unit)) == sc.Unit('ueV')
+ collection = model.get_component_collection(0)
+ assert sc.Unit(str(collection[0].width.unit)) == sc.Unit('ueV')
+ assert collection[0].width.value == pytest.approx(width_mev * 1000)
+
+ def test_convert_x_unit_does_not_mark_collections_dirty(self):
+ # WHEN: a SampleModel with an attached diffusion model and built collections
+ Q = np.array([1.0, 2.0])
+ brownian = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, Q=Q)
+ model = SampleModel(Q=Q, diffusion_models=brownian)
+ collection_before = model.get_component_collection(0)
+ width = collection_before[0].width
+
+ # THEN
+ model.convert_x_unit('ueV')
+
+ # EXPECT: conversion works in place — no dirty flag, no rebuild, and the converted
+ # dependent width stays in the new unit through later dependency-graph updates
+ # (regression: conversion used to mark the collections dirty, and the rebuild
+ # discarded per-Q state and object references)
+ assert model._component_collections_is_dirty is False
+ assert model.get_component_collection(0) is collection_before
+ width_uev = width.value
+ brownian.diffusion_coefficient = 4.8e-9
+ assert sc.Unit(str(width.unit)) == sc.Unit('ueV')
+ assert width.value == pytest.approx(2 * width_uev)
+
+ def test_convert_y_unit_propagates_to_diffusion_models(self):
+ # WHEN: SampleModel and diffusion model sharing y_unit='1/meV'
+ Q = np.array([1.0])
+ brownian = BrownianTranslationalDiffusion(
+ diffusion_coefficient=2.4e-9, Q=Q, y_unit='1/meV'
+ )
+ model = SampleModel(Q=Q, y_unit='1/meV', diffusion_models=brownian)
+
+ # THEN
+ model.convert_y_unit('1/eV')
+
+ # EXPECT
+ assert sc.Unit(str(brownian.y_unit)) == sc.Unit('1/eV')
+ assert sc.Unit(str(brownian.scale.unit)) == sc.Unit('meV/eV')
+ assert sc.Unit(str(model.y_unit)) == sc.Unit('1/eV')
+
def test_normalize_detailed_balance_setter(self, sample_model):
# WHEN
model = sample_model
@@ -402,12 +460,12 @@ def test_evaluate_calls_dbf(self, sample_model):
energy=x,
temperature=sample_model.temperature,
divide_by_temperature=sample_model.normalize_detailed_balance,
- energy_unit=sample_model.unit,
+ energy_unit=sample_model.x_unit,
)
# Check that evaluate was called on each component
- collection1.evaluate.assert_called_once_with(x)
- collection2.evaluate.assert_called_once_with(x)
+ collection1.evaluate.assert_called_once_with(x, output='numpy')
+ collection2.evaluate.assert_called_once_with(x, output='numpy')
# Check that DBF was applied elementwise
np.testing.assert_allclose(result[0], np.array([1.0, 2.0, 3.0]) * 10.0)
@@ -452,8 +510,8 @@ def test_evaluate_doesnt_call_dbf_when_disabled(
mock_dbf.assert_not_called()
# Check that evaluate was called on each component
- collection1.evaluate.assert_called_once_with(x)
- collection2.evaluate.assert_called_once_with(x)
+ collection1.evaluate.assert_called_once_with(x, output='numpy')
+ collection2.evaluate.assert_called_once_with(x, output='numpy')
# Check that results were not modified by DBF
np.testing.assert_allclose(result[0], np.array([1.0, 2.0, 3.0]))
@@ -531,9 +589,28 @@ def test_repr(self, sample_model):
# THEN / EXPECT
assert 'SampleModel' in repr_str
- assert 'unit=' in repr_str
- assert 'Q=' in repr_str
+ assert 'Q = ' in repr_str
assert 'components' in repr_str
assert 'diffusion_models' in repr_str
assert 'temperature' in repr_str
+ # Regression: a stray ')' and a missing separator used to mangle this into
+ # 'x_unit=meV), y_unit=dimensionlessQ = ...', and the closing ')' was missing
+ assert 'x_unit=meV, y_unit=dimensionless' in repr_str
+ assert repr_str.rstrip().endswith(')')
assert 'normalize_detailed_balance' in repr_str
+
+ def test_y_unit_setter_raises(self, sample_model):
+ # WHEN / THEN / EXPECT
+ with pytest.raises(AttributeError):
+ sample_model.y_unit = '1/meV'
+
+ def test_convert_y_unit(self):
+ # WHEN
+ g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV')
+ model = SampleModel(components=g, x_unit='meV')
+ # THEN: convert y_unit to '1/eV' (same dimension, different scale)
+ model.convert_y_unit('1/eV')
+ # EXPECT
+ assert model.y_unit == '1/eV'
+ assert model.components[0].y_unit == '1/eV'
+ assert g.area.value == pytest.approx(1e3)
diff --git a/tests/unit/easydynamics/settings/test_convolution_settings.py b/tests/unit/easydynamics/settings/test_convolution_settings.py
index 202f5baea..ccb900774 100644
--- a/tests/unit/easydynamics/settings/test_convolution_settings.py
+++ b/tests/unit/easydynamics/settings/test_convolution_settings.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: 2026 EasyScience contributors
# SPDX-License-Identifier: BSD-3-Clause
+from copy import copy
+
import pytest
from easydynamics.settings.convolution_settings import ConvolutionSettings
@@ -22,6 +24,21 @@ def test_init(self, default_convolution_settings):
assert default_convolution_settings.extension_factor == pytest.approx(0.2)
assert default_convolution_settings.convolution_plan_is_valid is False
+ def test_copy(self):
+ # WHEN
+ settings = ConvolutionSettings(
+ upsample_factor=10, extension_factor=0.5, suppress_warnings=True
+ )
+
+ # THEN
+ settings_copy = copy(settings)
+
+ # EXPECT: a distinct instance with the same knob values
+ assert settings_copy is not settings
+ assert settings_copy.upsample_factor == 10
+ assert settings_copy.extension_factor == pytest.approx(0.5)
+ assert settings_copy.suppress_warnings is True
+
def test_init_with_custom_parameters(self):
"""
Test initialization of ConvolutionSettings with custom
@@ -135,12 +152,13 @@ def test_upsample_factor_setter_invalid(
@pytest.mark.parametrize(
'value',
- [0.0, 0.2, 1, 5.5],
+ [0.0, 0.2, 1, 5.5, None],
ids=[
'zero',
'typical_fraction',
'integer',
'float',
+ 'none_valid',
],
)
def test_extension_factor_setter_valid(self, default_convolution_settings, value):
@@ -152,7 +170,8 @@ def test_extension_factor_setter_valid(self, default_convolution_settings, value
default_convolution_settings.extension_factor = value
# EXPECT
- assert default_convolution_settings.extension_factor == pytest.approx(float(value))
+ expected = pytest.approx(float(value)) if value is not None else None
+ assert default_convolution_settings.extension_factor == expected
assert default_convolution_settings.convolution_plan_is_valid is False
@pytest.mark.parametrize(
diff --git a/tests/unit/easydynamics/test_exceptions.py b/tests/unit/easydynamics/test_exceptions.py
index 494f6f7c9..bc7731a5e 100644
--- a/tests/unit/easydynamics/test_exceptions.py
+++ b/tests/unit/easydynamics/test_exceptions.py
@@ -7,11 +7,14 @@
class TestAmbiguousNameError:
def test_initialization(self):
+ # WHEN
name = 'test'
matches = ['test1', 'test2', 'test3']
+ # THEN
error = AmbiguousNameError(name, matches)
+ # EXPECT
assert error.name == name
assert error.matches == matches
assert str(error) == (
@@ -19,11 +22,14 @@ def test_initialization(self):
)
def test_empty_matches(self):
+ # WHEN
name = 'unknown'
matches = []
+ # THEN
error = AmbiguousNameError(name, matches)
+ # EXPECT
assert error.name == name
assert error.matches == matches
assert str(error) == ("Ambiguous name 'unknown' matches 0 elements: []")
diff --git a/tests/unit/easydynamics/test_import.py b/tests/unit/easydynamics/test_import.py
index e062efcbe..d2ffe77ba 100644
--- a/tests/unit/easydynamics/test_import.py
+++ b/tests/unit/easydynamics/test_import.py
@@ -3,6 +3,5 @@
def test_import_easydynamics():
- import easydynamics
-
- assert easydynamics is not None
+ # WHEN THEN EXPECT: importing raises no error
+ import easydynamics # noqa: F401
diff --git a/tests/unit/easydynamics/utils/test_detailed_balance.py b/tests/unit/easydynamics/utils/test_detailed_balance.py
index bebf5407a..2d2284d36 100644
--- a/tests/unit/easydynamics/utils/test_detailed_balance.py
+++ b/tests/unit/easydynamics/utils/test_detailed_balance.py
@@ -9,6 +9,7 @@
from scipp.constants import Boltzmann as kB
from easydynamics.utils import detailed_balance_factor
+from easydynamics.utils.detailed_balance import _convert_to_scipp_variable
kB_meV_per_K = sc.to_unit(kB, 'meV/K').value
@@ -305,3 +306,30 @@ def test_incompatible_temperature_unit_raises(self):
energy_unit=energy_unit,
temperature_unit=temperature_unit,
)
+
+
+class TestConvertToScippVariable:
+ """Tests for _convert_to_scipp_variable internal helper."""
+
+ @pytest.mark.parametrize(
+ 'name, expected_match',
+ [
+ ('energy', 'energy must be a number'),
+ ('temperature', 'temperature must be a number'),
+ ],
+ ids=['energy_name', 'other_name'],
+ )
+ def test_invalid_type_raises_type_error(self, name, expected_match):
+ # WHEN THEN EXPECT
+ with pytest.raises(TypeError, match=expected_match):
+ _convert_to_scipp_variable({'invalid': 'type'}, name=name, unit='meV')
+
+ def test_invalid_unit_scalar_raises_unit_error(self):
+ # WHEN THEN EXPECT
+ with pytest.raises(UnitError, match='Invalid unit string'):
+ _convert_to_scipp_variable(1.0, name='energy', unit='not_a_real_unit_xyz')
+
+ def test_invalid_unit_array_raises_unit_error(self):
+ # WHEN THEN EXPECT
+ with pytest.raises(UnitError, match='Invalid unit string'):
+ _convert_to_scipp_variable([1.0, 2.0], name='energy', unit='not_a_real_unit_xyz')
diff --git a/tests/unit/easydynamics/utils/test_utils.py b/tests/unit/easydynamics/utils/test_utils.py
index 680d2a28a..6cd13b998 100644
--- a/tests/unit/easydynamics/utils/test_utils.py
+++ b/tests/unit/easydynamics/utils/test_utils.py
@@ -21,12 +21,14 @@ class TestValidateAndConvertQ:
],
)
def test_validate_and_convert_Q_numeric_and_array(self, Q_input, expected):
- # WHEN THEN
+ # WHEN THEN: numbers, lists, and numpy arrays are assumed to be in 1/angstrom
result = _validate_and_convert_Q(Q_input)
# EXPECT
- assert isinstance(result, np.ndarray)
- np.testing.assert_allclose(result, expected)
+ assert isinstance(result, sc.Variable)
+ assert result.dims == ('Q',)
+ assert result.unit == sc.Unit('1/angstrom')
+ np.testing.assert_allclose(result.values, expected)
def test_validate_and_convert_Q_scipp_variable(self):
# WHEN
@@ -36,8 +38,20 @@ def test_validate_and_convert_Q_scipp_variable(self):
result = _validate_and_convert_Q(Q)
# EXPECT
- assert isinstance(result, np.ndarray)
- np.testing.assert_allclose(result, [1.0, 2.0])
+ assert isinstance(result, sc.Variable)
+ assert result.unit == sc.Unit('1/angstrom')
+ np.testing.assert_allclose(result.values, [1.0, 2.0])
+
+ def test_validate_and_convert_Q_scipp_variable_other_unit(self):
+ # WHEN: a scipp Q in 1/nm
+ Q = sc.array(dims=['Q'], values=[10.0, 20.0], unit='1/nm')
+
+ # THEN
+ result = _validate_and_convert_Q(Q)
+
+ # EXPECT: converted to 1/angstrom
+ assert result.unit == sc.Unit('1/angstrom')
+ np.testing.assert_allclose(result.values, [1.0, 2.0])
def test_validate_and_convert_Q_none(self):
# WHEN THEN EXPECT
@@ -87,16 +101,24 @@ class TestValidateUnit:
],
)
def test_validate_unit_valid(self, unit_input):
+ # WHEN
+
+ # THEN
unit = _validate_unit(unit_input)
+ # EXPECT
if unit_input is None:
assert unit is None
else:
assert isinstance(unit, str)
def test_validate_unit_string_conversion(self):
+ # WHEN
+
+ # THEN
unit = _validate_unit(sc.Unit('meV'))
+ # EXPECT
assert isinstance(unit, str)
assert unit == 'meV'
@@ -111,6 +133,7 @@ def test_validate_unit_string_conversion(self):
],
)
def test_validate_unit_invalid_type(self, unit_input):
+ # WHEN THEN EXPECT
with pytest.raises(TypeError, match='unit must be None, a string, or a scipp Unit'):
_validate_unit(unit_input)