diff --git a/docs/docs/tutorials/analysis.ipynb b/docs/docs/tutorials/analysis.ipynb index 41e6c7ed1..b00d430fb 100644 --- a/docs/docs/tutorials/analysis.ipynb +++ b/docs/docs/tutorials/analysis.ipynb @@ -338,9 +338,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3" }, "language_info": { "codemirror_mode": { @@ -351,10 +351,9 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.14.4" + "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/docs/tutorials/analysis1d.ipynb b/docs/docs/tutorials/analysis1d.ipynb index 784f2be50..8d921318f 100644 --- a/docs/docs/tutorials/analysis1d.ipynb +++ b/docs/docs/tutorials/analysis1d.ipynb @@ -81,9 +81,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3" }, "language_info": { "codemirror_mode": { @@ -94,10 +94,9 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.14.5" + "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/docs/tutorials/component_collection.ipynb b/docs/docs/tutorials/component_collection.ipynb index 54f763871..4138d5cb7 100644 --- a/docs/docs/tutorials/component_collection.ipynb +++ b/docs/docs/tutorials/component_collection.ipynb @@ -67,9 +67,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3" }, "language_info": { "codemirror_mode": { @@ -80,10 +80,9 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.14.4" + "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/docs/tutorials/components.ipynb b/docs/docs/tutorials/components.ipynb index 70070f1e0..3b072d315 100644 --- a/docs/docs/tutorials/components.ipynb +++ b/docs/docs/tutorials/components.ipynb @@ -39,7 +39,9 @@ "gaussian = sm.Gaussian(display_name='Gaussian', width=0.5, area=1)\n", "dho = sm.DampedHarmonicOscillator(display_name='DHO', center=1.0, width=0.3, area=2.0)\n", "lorentzian = sm.Lorentzian(display_name='Lorentzian', center=-1.0, width=0.2, area=1.0)\n", - "polynomial = sm.Polynomial(display_name='Polynomial', coefficients=[0.1, 0, 0.5]) # y=0.1+0.5*x^2\n", + "polynomial = sm.Polynomial(\n", + " display_name='Polynomial', coefficients=[-0.2, 0, 0.5]\n", + ") # y=-0.2+0.5*x^2\n", "exponential = sm.Exponential(display_name='Exponential', amplitude=1.0, rate=-0.5)\n", "\n", "x = np.linspace(-2, 2, 100)\n", @@ -59,6 +61,18 @@ "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "6db00a3e", + "metadata": {}, + "outputs": [], + "source": [ + "# Suppress the warning from the Polynomial:\n", + "polynomial.suppress_warnings = True\n", + "y = polynomial.evaluate(x)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -172,7 +186,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "default", "language": "python", "name": "python3" }, @@ -191,4 +205,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/docs/tutorials/convolution.ipynb b/docs/docs/tutorials/convolution.ipynb index d37c38c14..90129696e 100644 --- a/docs/docs/tutorials/convolution.ipynb +++ b/docs/docs/tutorials/convolution.ipynb @@ -210,9 +210,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3" }, "language_info": { "codemirror_mode": { @@ -223,10 +223,9 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.14.4" + "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/docs/tutorials/data/fake_advanced_data.hdf5 b/docs/docs/tutorials/data/fake_advanced_data.hdf5 index d83903fb0..2a21cb285 100644 Binary files a/docs/docs/tutorials/data/fake_advanced_data.hdf5 and b/docs/docs/tutorials/data/fake_advanced_data.hdf5 differ diff --git a/docs/docs/tutorials/data/fake_simple_data.hdf5 b/docs/docs/tutorials/data/fake_simple_data.hdf5 index 14afb8306..bef9761a8 100644 Binary files a/docs/docs/tutorials/data/fake_simple_data.hdf5 and b/docs/docs/tutorials/data/fake_simple_data.hdf5 differ diff --git a/docs/docs/tutorials/delta_lorentz.ipynb b/docs/docs/tutorials/delta_lorentz.ipynb index 351358124..9d41573d6 100644 --- a/docs/docs/tutorials/delta_lorentz.ipynb +++ b/docs/docs/tutorials/delta_lorentz.ipynb @@ -96,9 +96,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3" }, "language_info": { "codemirror_mode": { @@ -109,10 +109,9 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.14.4" + "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/docs/tutorials/detailed_balance.ipynb b/docs/docs/tutorials/detailed_balance.ipynb index 12d926a6d..69a22f71a 100644 --- a/docs/docs/tutorials/detailed_balance.ipynb +++ b/docs/docs/tutorials/detailed_balance.ipynb @@ -86,9 +86,9 @@ ], "metadata": { "kernelspec": { - "display_name": "easydynamics_newbase", "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3" }, "language_info": { "codemirror_mode": { @@ -99,8 +99,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" + "pygments_lexer": "ipython3" } }, "nbformat": 4, diff --git a/docs/docs/tutorials/diffusion_model.ipynb b/docs/docs/tutorials/diffusion_model.ipynb index ddf9e9a0e..7468efc59 100644 --- a/docs/docs/tutorials/diffusion_model.ipynb +++ b/docs/docs/tutorials/diffusion_model.ipynb @@ -85,9 +85,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3" }, "language_info": { "codemirror_mode": { @@ -98,10 +98,9 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.14.4" + "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/docs/tutorials/experiment.ipynb b/docs/docs/tutorials/experiment.ipynb index 0e6b7c5b4..124d83ff1 100644 --- a/docs/docs/tutorials/experiment.ipynb +++ b/docs/docs/tutorials/experiment.ipynb @@ -70,9 +70,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python (Pixi)", "language": "python", - "name": "pixi-kernel-python3" + "name": "python3", + "display_name": "Python 3" }, "language_info": { "codemirror_mode": { @@ -83,10 +83,9 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" + "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/docs/tutorials/instrument_model.ipynb b/docs/docs/tutorials/instrument_model.ipynb index 27ccbb93a..988f56ca5 100644 --- a/docs/docs/tutorials/instrument_model.ipynb +++ b/docs/docs/tutorials/instrument_model.ipynb @@ -75,9 +75,9 @@ ], "metadata": { "kernelspec": { - "display_name": "easydynamics_newbase", "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3" }, "language_info": { "codemirror_mode": { @@ -88,10 +88,9 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" + "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/docs/tutorials/sample_model.ipynb b/docs/docs/tutorials/sample_model.ipynb index fa34e8da1..be0d967dc 100644 --- a/docs/docs/tutorials/sample_model.ipynb +++ b/docs/docs/tutorials/sample_model.ipynb @@ -61,7 +61,7 @@ " diffusion_models=diffusion_model,\n", " components=component_collection,\n", " Q=Q,\n", - " unit='meV',\n", + " x_unit='meV',\n", " display_name='MySampleModel',\n", " temperature=10,\n", ")\n", @@ -127,9 +127,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3" }, "language_info": { "codemirror_mode": { @@ -140,10 +140,9 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.14.5" + "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/docs/tutorials/tutorial0_basics.ipynb b/docs/docs/tutorials/tutorial0_basics.ipynb index 548c6c714..176aaf170 100644 --- a/docs/docs/tutorials/tutorial0_basics.ipynb +++ b/docs/docs/tutorials/tutorial0_basics.ipynb @@ -386,7 +386,19 @@ "id": "842c1f01", "metadata": {}, "source": [ - "The final step in this tutorial is to fit the are of the `Gaussian` to a straight line. For this, we use the `ParameterAnalysis` class. We create a `Polynomial` with two coefficients for the fit function. We create a `FitBinding`, telling the class we want to fit the parameter named `Gaussian area` with the fit function that we define." + "The final step in this tutorial is to fit the area of the `Gaussian` as a function of Q using a straight line. For this, we use the `ParameterAnalysis` class.\n", + "\n", + "We create a `Polynomial` with two coefficients as the fit function. We then create a `FitBinding` to connect the parameter named `Gaussian area` to the fit function, and pass both to a `ParameterAnalysis` object:\n", + "\n", + "
\n", + " Note\n", + "
\n", + "Two units must be set on the fit function:\n", + "\n", + "- `x_unit='1/angstrom'` — because `ParameterAnalysis` always uses Q as its x-axis.\n", + "- `y_unit='meV'` — because we are fitting `Gaussian area`, which has unit `meV`.\n", + "
\n", + "
" ] }, { @@ -397,10 +409,13 @@ "outputs": [], "source": [ "fit_func = sm.Polynomial(\n", - " coefficients=[3.7, -0.5], name='Straight line', display_name='Straight line'\n", + " coefficients=[3.7, -0.5],\n", + " x_unit='1/angstrom',\n", + " y_unit='meV',\n", + " name='Straight line',\n", ")\n", "\n", - "binding = edyn.FitBinding(parameter_name='Gaussian area', model=fit_func)\n", + "binding = edyn.FitBinding(model=fit_func, targets='Gaussian area')\n", "\n", "parameter_analysis = edyn.ParameterAnalysis(\n", " parameters=analysis,\n", @@ -450,7 +465,7 @@ "id": "dc33728c", "metadata": {}, "source": [ - "To see the parameters we can use the `get_all_parameters()` method. We can also see only the parameters that can be fitted:" + "To see the parameters we can use the `get_all_parameters()` method. We can also see only the parameters that can be fitted using `get_fittable_parameters`:" ] }, { @@ -476,7 +491,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "default", "language": "python", "name": "python3" }, diff --git a/docs/docs/tutorials/tutorial0_more_advanced.ipynb b/docs/docs/tutorials/tutorial0_more_advanced.ipynb index f203dbdcb..ece987107 100644 --- a/docs/docs/tutorials/tutorial0_more_advanced.ipynb +++ b/docs/docs/tutorials/tutorial0_more_advanced.ipynb @@ -144,6 +144,13 @@ "
\n", " SampleModel, BackgroundModel and ResolutionModel (introduced later) all take components in several different ways: you can add a single component like in the previous tutorial, append components using the `append_component` method, or give a ComponentCollection like we do here. \n", "
\n", + "\n", + "\n", + "
\n", + " 💡 Tip\n", + "
\n", + " Polynomials will warn if they produce negative values, since negative backgrounds are unphysical. To suppress the warning, set suppress_warnings to True, either when constructing the Polynomial (polynomial=sm.Polynomial(coefficients=[0.0], suppress_warnings=True)), or afterwards (polynomial.suppress_warnings = True).\n", + "
\n", "
\n" ] }, @@ -300,7 +307,19 @@ "id": "0eadbd91", "metadata": {}, "source": [ - "With apologies for the lack of creativity, these all appear like straight lines. We can fit them individually or all together using `ParameterAnalysis`" + "With apologies for the lack of creativity, these all appear like straight lines. We can fit them individually or all together using `ParameterAnalysis`.\n", + "\n", + "For each parameter we want to fit, we create a fit function (here a `Polynomial`) and a `FitBinding` connecting the parameter name to the fit function. We then pass all bindings to a single `ParameterAnalysis` object:\n", + "\n", + "
\n", + " Note\n", + "
\n", + "Two units must be set on each fit function:\n", + "\n", + "- `x_unit='1/angstrom'` — because `ParameterAnalysis` always uses Q as its x-axis.\n", + "- `y_unit='meV'` — because all three parameters (`Gaussian area`, `DHO area`, `DHO center`) have unit `meV`.\n", + "
\n", + "
" ] }, { @@ -310,17 +329,21 @@ "metadata": {}, "outputs": [], "source": [ - "gauss_fit_func = sm.Polynomial(coefficients=[3.7, -0.5], unit='1/angstrom', name='Gauss area fit')\n", - "dho_area_fit_func = sm.Polynomial(coefficients=[2.0, 0.12], unit='1/angstrom', name='DHO area fit')\n", + "gauss_fit_func = sm.Polynomial(\n", + " coefficients=[3.7, -0.5], x_unit='1/angstrom', y_unit='meV', name='Gauss area fit'\n", + ")\n", + "dho_area_fit_func = sm.Polynomial(\n", + " coefficients=[2.0, 0.12], x_unit='1/angstrom', y_unit='meV', name='DHO area fit'\n", + ")\n", "dho_center_fit_func = sm.Polynomial(\n", - " coefficients=[1.1, 0.2], unit='1/angstrom', name='DHO center fit'\n", + " coefficients=[1.1, 0.2], x_unit='1/angstrom', y_unit='meV', name='DHO center fit'\n", ")\n", "\n", - "binding1 = edyn.FitBinding(parameter_name='Gaussian area', model=gauss_fit_func)\n", + "binding1 = edyn.FitBinding(model=gauss_fit_func, targets='Gaussian area')\n", "\n", - "binding2 = edyn.FitBinding(parameter_name='DHO area', model=dho_area_fit_func)\n", + "binding2 = edyn.FitBinding(model=dho_area_fit_func, targets='DHO area')\n", "\n", - "binding3 = edyn.FitBinding(parameter_name='DHO center', model=dho_center_fit_func)\n", + "binding3 = edyn.FitBinding(model=dho_center_fit_func, targets='DHO center')\n", "\n", "parameter_analysis = edyn.ParameterAnalysis(\n", " parameters=analysis,\n", @@ -335,7 +358,7 @@ "id": "32bc1efc", "metadata": {}, "source": [ - "The start guesses look reasonable, so we fit:" + "The start guesses look reasonable, so we fit and plot the result:" ] }, { @@ -352,7 +375,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "default", "language": "python", "name": "python3" }, diff --git a/docs/docs/tutorials/tutorial1_brownian.ipynb b/docs/docs/tutorials/tutorial1_brownian.ipynb index 269bb2599..0965bfaa7 100644 --- a/docs/docs/tutorials/tutorial1_brownian.ipynb +++ b/docs/docs/tutorials/tutorial1_brownian.ipynb @@ -451,7 +451,7 @@ "$$\n", "where $\\Gamma(Q) = D Q^2$ and $D$ is the diffusion coefficient. $S$ is an overall scale.\n", "\n", - "To fit the Brownian translational diffusion model to the data, we use the `ParameterAnalysis`. This time, we wish to fit the `Lorentzian`, and we wish to fit both its area (scale, $S$) and width to the diffusion model. We do this by creating a `FitBinding`, saying we want to fit both the area an width of the component called `Lorentzian`" + "To fit the Brownian translational diffusion model to the data, we use the `ParameterAnalysis`. This time, we wish to fit the `Lorentzian`, and we wish to fit both its area (scale, $S$) and width to the diffusion model. We do this by creating a `FitBinding`. Because the diffusion model's `lorentzian_name` matches the fitted component's name, the default targets automatically fit both the area and the width against the `Lorentzian area` and `Lorentzian width` parameters" ] }, { @@ -462,14 +462,13 @@ "outputs": [], "source": [ "brownian_diffusion_model = sm.BrownianTranslationalDiffusion(\n", - " name='Brownian Translational Diffusion', diffusion_coefficient=2.4e-9, scale=0.5\n", + " name='Brownian Translational Diffusion',\n", + " lorentzian_name='Lorentzian',\n", + " diffusion_coefficient=2.4e-9,\n", + " scale=0.5,\n", ")\n", "\n", - "binding = edyn.FitBinding(\n", - " parameter_name='Lorentzian',\n", - " model=brownian_diffusion_model,\n", - " modes=['area', 'width'],\n", - ")\n", + "binding = edyn.FitBinding(model=brownian_diffusion_model)\n", "\n", "parameter_analysis = edyn.ParameterAnalysis(\n", " parameters=diffusion_analysis,\n", @@ -691,7 +690,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "default", "language": "python", "name": "python3" }, diff --git a/docs/docs/tutorials/tutorial2_nanoparticles.ipynb b/docs/docs/tutorials/tutorial2_nanoparticles.ipynb index 22de24900..da2188bcf 100644 --- a/docs/docs/tutorials/tutorial2_nanoparticles.ipynb +++ b/docs/docs/tutorials/tutorial2_nanoparticles.ipynb @@ -593,7 +593,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "default", "language": "python", "name": "python3" }, diff --git a/src/easydynamics/analysis/analysis.py b/src/easydynamics/analysis/analysis.py index 7291ea917..25a600a14 100644 --- a/src/easydynamics/analysis/analysis.py +++ b/src/easydynamics/analysis/analysis.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +from copy import copy from typing import Any import numpy as np @@ -456,10 +457,7 @@ def data_and_model_to_datagroup( self._verify_bool(include_components, 'include_components') self._verify_bool(include_residuals, 'include_residuals') - energy = self._verify_energy(energy) - - if energy is None: - energy = self.energy + energy = self._verify_energy(energy) if energy is not None else self.energy data_and_model = { 'Data': self.experiment.binned_data, @@ -532,6 +530,7 @@ def parameters_to_dataset(self) -> sc.Dataset: units[name] = p.unit elif units[name] != p.unit: try: + p = copy(p) p.convert_unit(units[name]) except Exception as e: raise UnitError( @@ -587,7 +586,7 @@ def plot_parameters( ds = self.parameters_to_dataset() - if not names: + if names is None: names = list(ds.keys()) if isinstance(names, str): @@ -686,9 +685,8 @@ def _on_convolution_settings_changed(self) -> None: def _ensure_analysis_list_current(self) -> None: """Rebuild the analysis list if any dependency has changed since it was last built.""" - if self._analysis_list_is_dirty: - if self.Q is not None: - self._create_analysis_list() + if self._analysis_list_is_dirty and self.Q is not None: + self._create_analysis_list() self._analysis_list_is_dirty = False def _create_analysis_list(self) -> None: @@ -698,6 +696,8 @@ def _create_analysis_list(self) -> None: """ self._analysis_list = [] for Q_index in range(len(self.Q)): + # The ConvolutionSettings object is shared so user changes reach every Q index; + # plan validity is tracked per convolver, not on the settings object. analysis = Analysis1d( display_name=f'{self.display_name}_Q{Q_index}', experiment=self.experiment, @@ -756,16 +756,17 @@ def _fit_all_Q_simultaneously(self) -> FitResults: ys = [] ws = [] + # TODO: consider using scipp built-in masking instead of numpy boolean masks, # noqa: FIX002 TD002 TD003 + # once the EasyScience fitter accepts scipp Variables directly. for analysis1d in self.analysis_list: - x, y, weight, _ = self.experiment._extract_x_y_weights_only_finite( # noqa: SLF001 - analysis1d.Q_index - ) + x, y, weight, _ = self.experiment.extract_x_y_weights_only_finite(analysis1d.Q_index) xs.append(x) ys.append(y) ws.append(weight) - # Make sure the convolver is up to date for this Q index - analysis1d.refresh_convolver(energy=x) + analysis1d.refresh_convolver( + energy=self.experiment.get_masked_energy(Q_index=analysis1d.Q_index) + ) mf = MultiFitter( fit_objects=self.analysis_list, diff --git a/src/easydynamics/analysis/analysis1d.py b/src/easydynamics/analysis/analysis1d.py index 341103e16..b45a3ebc7 100644 --- a/src/easydynamics/analysis/analysis1d.py +++ b/src/easydynamics/analysis/analysis1d.py @@ -170,7 +170,7 @@ def calculate(self, energy: sc.Variable | None = None) -> np.ndarray: """ Calculate the model prediction for the chosen Q index. - Makes sure the convolver is up to date before calculating. + Creates a new convolver before calculating without touching the stored convolver. Parameters ---------- @@ -184,14 +184,12 @@ def calculate(self, energy: sc.Variable | None = None) -> np.ndarray: The calculated model prediction. """ energy = self._verify_energy(energy) - self._convolver = self._create_convolver(energy=energy) - # Mark dirty so the next fit() call rebuilds the convolver with the standard - # (unmasked) energy grid rather than reusing this plot-path grid. - self._convolver_is_dirty = True + convolver = self._create_convolver(energy=energy) + return self._calculate(energy=energy, convolver=convolver) - return self._calculate(energy=energy) - - def _calculate(self, energy: sc.Variable | None = None) -> np.ndarray: + def _calculate( + self, energy: sc.Variable | None = None, convolver: Convolution | None = None + ) -> np.ndarray: """ Calculate the model prediction for the chosen Q index. @@ -202,18 +200,21 @@ def _calculate(self, energy: sc.Variable | None = None) -> np.ndarray: energy : sc.Variable | None, default=None Optional energy grid to use for calculation. If None, the energy grid from the experiment is used. + convolver : Convolution | None, default=None + Optional convolver to use. If None, uses self._convolver. Returns ------- np.ndarray The calculated model prediction. """ - + if convolver is None: + convolver = self._convolver Q_index = self._require_Q_index() sample = self._evaluate_with_convolution( self.sample_model.get_component_collection(Q_index), energy, - convolver=self._convolver, + convolver=convolver, ) background = self._evaluate_direct( self.instrument_model.background_model.get_component_collection(Q_index), @@ -254,7 +255,7 @@ def fit(self) -> FitResults: fit_function=self.as_fit_function(), ) - x, y, weights, _ = self.experiment._extract_x_y_weights_only_finite( # noqa: SLF001 + x, y, weights, _ = self.experiment.extract_x_y_weights_only_finite( Q_index=self._require_Q_index() ) fit_result = fitter.fit(x=x, y=y, weights=weights) @@ -438,7 +439,7 @@ def data_and_model_to_datagroup( energy = self._masked_energy data_and_model = { - 'Data': self.experiment.binned_data['Q', self.Q_index], + 'Data': self.experiment.get_masked_binned_data(Q_index=self.Q_index), 'Model': self._create_model_array(energy=energy), } @@ -516,6 +517,10 @@ def _on_Q_index_changed(self) -> None: This method is called whenever the Q index is changed. It updates the masked energy from the experiment for the new Q index and marks the convolver as dirty. """ + if self._Q_index is None: + self._masked_energy = None + self._convolver_is_dirty = True + return masked_energy = self.experiment.get_masked_energy(Q_index=self._Q_index) self._masked_energy = masked_energy self._convolver_is_dirty = True @@ -523,6 +528,9 @@ def _on_Q_index_changed(self) -> None: def _on_experiment_changed(self) -> None: """Mark the convolver as dirty when the experiment changes.""" super()._on_experiment_changed() + # Refresh masked energy if Q_index is already set (i.e. post-init experiment swap). + if getattr(self, '_Q_index', None) is not None and self.experiment is not None: + self._masked_energy = self.experiment.get_masked_energy(Q_index=self._Q_index) self._convolver_is_dirty = True def _on_sample_model_changed(self) -> None: @@ -561,28 +569,15 @@ def _calculate_energy_with_offset( energy_offset : Parameter The energy offset to apply. - Raises - ------ - sc.UnitError - If the energy and energy offset have incompatible units. - Returns ------- sc.Variable The energy grid with the offset applied. """ - if energy.unit != energy_offset.unit: - try: - energy_offset.convert_unit(str(energy.unit)) - except Exception as e: - raise sc.UnitError( - f'Energy and energy offset must have compatible units. ' - f'Got {energy.unit} and {energy_offset.unit}.' - ) from e - - energy_with_offset = energy.copy(deep=True) - energy_with_offset.values -= energy_offset.value + offset_value = sc.to_unit(energy_offset.full_value, energy.unit).value + energy_with_offset = energy.copy() + energy_with_offset.values = energy.values - offset_value return energy_with_offset ############# @@ -640,18 +635,15 @@ def _evaluate_with_convolution( energy=energy_with_offset, temperature=self.temperature, divide_by_temperature=self.detailed_balance_settings.normalize_detailed_balance, - energy_unit=self.unit, + energy_unit=self.x_unit, ) return result - return Convolution( - energy=energy, + return self._build_convolution( sample_components=components, resolution_components=resolution, + energy=energy, energy_offset=energy_offset, - convolution_settings=self.convolution_settings, - temperature=self.temperature, - detailed_balance_settings=self.detailed_balance_settings, ).convolution() def _evaluate_direct( @@ -719,11 +711,25 @@ def _create_convolver( if resolution_components.is_empty: return None + return self._build_convolution( + sample_components=sample_components, + resolution_components=resolution_components, + energy=energy, + energy_offset=self.instrument_model.get_energy_offset(Q_index), + ) + + def _build_convolution( + self, + sample_components: ComponentCollection | ModelComponent, + resolution_components: ComponentCollection, + energy: sc.Variable, + energy_offset: Parameter, + ) -> Convolution: return Convolution( energy=energy, sample_components=sample_components, resolution_components=resolution_components, - energy_offset=self.instrument_model.get_energy_offset(Q_index), + energy_offset=energy_offset, convolution_settings=self.convolution_settings, temperature=self.temperature, detailed_balance_settings=self.detailed_balance_settings, @@ -769,7 +775,7 @@ def _create_residuals_array(self) -> sc.DataArray: if self.Q_index is None: raise ValueError('Q_index must be set to calculate residuals.') - data = self.experiment.binned_data['Q', self.Q_index] + data = self.experiment.get_masked_binned_data(Q_index=self.Q_index) model = self._create_model_array() return data.copy(deep=True) - model diff --git a/src/easydynamics/analysis/analysis_base.py b/src/easydynamics/analysis/analysis_base.py index 598daea67..80be2833f 100644 --- a/src/easydynamics/analysis/analysis_base.py +++ b/src/easydynamics/analysis/analysis_base.py @@ -13,6 +13,7 @@ from easydynamics.sample_model import SampleModel from easydynamics.settings.convolution_settings import ConvolutionSettings from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings +from easydynamics.utils.utils import verify_Q_index class AnalysisBase(EasyDynamicsModelBase): @@ -503,13 +504,6 @@ def _verify_Q_index(self, Q_index: int | None) -> int | None: Q_index : int | None The Q index to verify. - Raises - ------ - TypeError - If Q_index is not an integer or None. - IndexError - If the Q index is not valid. - Returns ------- int | None @@ -517,12 +511,7 @@ def _verify_Q_index(self, Q_index: int | None) -> int | None: """ if Q_index is None: return None - - if not isinstance(Q_index, int): - raise TypeError('Q_index must be an integer or None.') - - if Q_index < 0 or (self.Q is not None and Q_index >= len(self.Q)): - raise IndexError('Q_index must be a valid index for the Q values.') + verify_Q_index(Q_index, self.Q) return Q_index def _verify_energy(self, energy: sc.Variable | None) -> sc.Variable | None: diff --git a/src/easydynamics/analysis/fit_binding.py b/src/easydynamics/analysis/fit_binding.py index da74e383f..e43ced4ab 100644 --- a/src/easydynamics/analysis/fit_binding.py +++ b/src/easydynamics/analysis/fit_binding.py @@ -3,145 +3,110 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from dataclasses import replace from easydynamics.base_classes.easydynamics_base import EasyDynamicsBase from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent from easydynamics.sample_model.diffusion_model.diffusion_model_base import DiffusionModelBase - -if TYPE_CHECKING: - from collections.abc import Callable +from easydynamics.utils.fit_target import FitTarget class FitBinding(EasyDynamicsBase): """ - Contract between dataset, model, and fit function for ParameterAnalysis. This class - encapsulates the necessary information to bind a dataset key to a model and convert it into a - fit function callable. + Contract between dataset, model, and fit functions for ParameterAnalysis. A binding maps the + model's fittable predictions (its FitTargets) onto keys of the parameters Dataset they should + be fitted against. Examples -------- - **Basic usage with a ModelComponent** + **Fitting a component model to one parameter** - Bind a fit parameter name to any ``ModelComponent`` or ``ComponentCollection``: + Component models (e.g. a Polynomial) have a single prediction — their ``evaluate`` — so + ``targets`` is simply the dataset key to fit against: ```python import easydynamics as edyn import easydynamics.sample_model as sm fit_func = sm.Polynomial(coefficients=[3.7, -0.5], display_name='Straight line') - binding = edyn.FitBinding(parameter_name='Gaussian area', model=fit_func) + binding = edyn.FitBinding(model=fit_func, targets='Gaussian area') ``` - **Usage with a DiffusionModelBase and specific modes** + **Fitting a diffusion model with default dataset keys** - For diffusion models, use ``modes`` to select which parameter arrays are fitted: + Diffusion models declare their predictions (``'area'``, ``'width'``, and for DeltaLorentz also + ``'delta_area'``). With ``targets=None`` all predictions are fitted against default dataset + keys derived from the model's component names: ```python brownian = sm.BrownianTranslationalDiffusion( - display_name='Brownian Translational Diffusion', diffusion_coefficient=2.4e-9, scale=0.5, + lorentzian_name='Lorentzian', ) + binding = edyn.FitBinding(model=brownian) # fits 'Lorentzian area' and 'Lorentzian width' + ``` + + **Selecting predictions or mapping them to custom dataset keys** + + Pass a list of prediction names, or a dict mapping prediction names to dataset keys: + ```python + binding = edyn.FitBinding(model=brownian, targets=['width']) + + delta_lorentz = sm.DeltaLorentz(A_0=0.5, lorentzian_width=0.1) binding = edyn.FitBinding( - parameter_name='Lorentzian', - model=brownian, - modes=['area', 'width'], + model=delta_lorentz, + targets={ + 'width': 'Lorentzian width', + 'area': 'Lorentzian area', + 'delta_area': 'Elastic area', + }, ) ``` """ def __init__( self, - parameter_name: str, model: ModelComponent | ComponentCollection | DiffusionModelBase, - modes: str | list[str] | None = None, + targets: str | list[str] | dict[str, str] | None = None, display_name: str | None = None, unique_name: str | None = None, ) -> None: """ Initialize a FitBinding. + Validation raises ``TypeError`` if model or targets have an invalid type, and + ``ValueError`` if targets names a prediction the model does not declare. + Parameters ---------- - parameter_name : str - The name of the parameter to fit. This should correspond to a key in the dataset. model : ModelComponent | ComponentCollection | DiffusionModelBase The model to fit. This can be a single ModelComponent, a ComponentCollection, or a DiffusionModelBase. - modes : str | list[str] | None, default=None - The modes to fit for diffusion models. This can be a single string, a list of strings, - or None (which defaults to ["area", "width"]). Only applicable if the model is a - DiffusionModelBase. Default is None. + targets : str | list[str] | dict[str, str] | None, default=None + Which predictions of the model to fit, and against which dataset keys. For component + models this must be a string: the dataset key to fit the model's ``evaluate`` against. + For diffusion models: None fits all predictions against their default dataset keys; a + string or list of strings selects predictions by name (default keys); a dict maps + prediction names to custom dataset keys. display_name : str | None, default=None An optional display name for the FitBinding. If None, the unique_name will be used. Default is None. unique_name : str | None, default=None An optional unique name for the FitBinding. If None, a unique name will be generated. Default is None. - - Raises - ------ - TypeError - If parameter_name is not a string, if model is not a ModelComponent, - ComponentCollection or DiffusionModelBase, or if modes is not a string, list of - strings, or None. """ super().__init__(display_name=display_name, unique_name=unique_name) - if not isinstance(parameter_name, str): - raise TypeError('parameter_name must be a string') - - if not isinstance(model, (ModelComponent, ComponentCollection, DiffusionModelBase)): - raise TypeError( - 'model must be a ModelComponent, ComponentCollection, or DiffusionModelBase' - ) - - if modes is not None and not isinstance(modes, (str, list)): - raise TypeError('modes must be a string, list of strings, or None') - - if isinstance(modes, list) and not all(isinstance(mode, str) for mode in modes): - raise TypeError('All modes in the list must be strings') - - self._parameter_name = parameter_name + self._validate_model(model) + self._validate_targets(model, targets) self._model = model - self._modes = modes + self._targets = targets # ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ - @property - def parameter_name(self) -> str: - """ - The name of the parameter to fit. This should correspond to a key in the dataset. - - Returns - ------- - str - The name of the parameter to fit. - """ - return self._parameter_name - - @parameter_name.setter - def parameter_name(self, value: str) -> None: - """ - Set the name of the parameter to fit. - - Parameters - ---------- - value : str - The new name of the parameter to fit. - - Raises - ------ - TypeError - If the value is not a string. - """ - if not isinstance(value, str): - raise TypeError('parameter_name must be a string') - self._parameter_name = value - @property def model(self) -> ModelComponent | ComponentCollection | DiffusionModelBase: """ @@ -160,162 +125,168 @@ def model(self, value: ModelComponent | ComponentCollection | DiffusionModelBase """ Set the model to fit. + Validation raises ``TypeError`` if the value has an invalid type, and ``ValueError`` if the + current targets name a prediction the new model does not declare. + Parameters ---------- value : ModelComponent | ComponentCollection | DiffusionModelBase The new model to fit. - - Raises - ------ - TypeError - If the value is not a ModelComponent, ComponentCollection, or DiffusionModelBase. """ - if not isinstance(value, (ModelComponent, ComponentCollection, DiffusionModelBase)): - raise TypeError( - 'model must be a ModelComponent, ComponentCollection, or DiffusionModelBase.' - ) + self._validate_model(value) + self._validate_targets(value, self._targets) self._model = value @property - def modes(self) -> str | list[str] | None: + def targets(self) -> str | list[str] | dict[str, str] | None: """ - The modes to fit for diffusion models. This can be a single string, a list of strings, or - None (which defaults to ["area", "width"]). + Which predictions of the model to fit, and against which dataset keys. Returns ------- - str | list[str] | None - The modes to fit for diffusion models. + str | list[str] | dict[str, str] | None + The targets specification (see ``__init__``). """ - return self._modes + return self._targets - @modes.setter - def modes(self, value: str | list[str] | None) -> None: + @targets.setter + def targets(self, value: str | list[str] | dict[str, str] | None) -> None: """ - Set the modes to fit for diffusion models. + Set which predictions of the model to fit, and against which dataset keys. + + Validation raises ``TypeError`` if the value has an invalid type for the current model, and + ``ValueError`` if it names a prediction the model does not declare. Parameters ---------- - value : str | list[str] | None - The new modes to fit for diffusion models. - - Raises - ------ - TypeError - If the value is not a string, list of strings, or None. + value : str | list[str] | dict[str, str] | None + The new targets specification (see ``__init__``). """ - if value is not None and not isinstance(value, (str, list)): - raise TypeError('modes must be a string, list of strings, or None') - - if isinstance(value, str): - value = [value] - if isinstance(value, list) and not all(isinstance(mode, str) for mode in value): - raise TypeError('All modes in the list must be strings') - self._modes = value + self._validate_targets(self._model, value) + self._targets = value # ------------------------------------------------------------------ # Other methods # ------------------------------------------------------------------ - def build_callables(self) -> list[Callable]: + def get_targets(self) -> list[FitTarget]: """ - Build the callables for fitting based on the model and modes. + Get the FitTargets this binding fits, with dataset keys resolved. - Returns - ------- - list[Callable] - A list of callables for fitting. - """ - modes = self._get_modes() - - if isinstance(self.model, DiffusionModelBase): - return [self._build_diffusion_callable(mode) for mode in modes] - - return [lambda x, **_: self.model.evaluate(x)] - - def get_model_names(self) -> list[str]: - """ - Get the names of the models based on the current modes. + Targets are built from the model at call time, so their units and default dataset keys + reflect the model's current state. Returns ------- - list[str] - A list of model names. + list[FitTarget] + The resolved fit targets. """ - modes = self._get_modes() - if isinstance(self.model, DiffusionModelBase): - return [f'{self.model.display_name} {mode}' for mode in modes] - - return [self.model.display_name] - - def get_parameter_names(self) -> list[str]: - """ - Get the names of the parameters based on the current modes. - - Returns - ------- - list[str] - A list of parameter names. - """ - modes = self._get_modes() - - if isinstance(self.model, DiffusionModelBase): - # This needs to be generalised. - # TODO: Generalise this for different diffusion models and modes. # noqa TD002 TD003 - if 'delta' in modes: - return [f'{self.parameter_name} area' for mode in modes] - - return [f'{self.parameter_name} {mode}' for mode in modes] - - return [self.parameter_name] + available = {target.name: target for target in self.model.get_fit_targets()} + if self._targets is None: + return list(available.values()) + if isinstance(self._targets, str): + return [available[self._targets]] + if isinstance(self._targets, list): + return [available[name] for name in self._targets] + return [ + replace(available[name], dataset_key=dataset_key) + for name, dataset_key in self._targets.items() + ] + + return [ + FitTarget( + name='value', + dataset_key=self._targets, + function=lambda x, model=self.model, **_: model.evaluate(x), + label=self.model.display_name, + x_unit=self.model.x_unit, + y_unit=self.model.y_unit, + ) + ] # ------------------------------------------------------------------ # Private methods # ------------------------------------------------------------------ - def _build_diffusion_callable(self, mode: str) -> Callable: + @staticmethod + def _validate_model( + model: ModelComponent | ComponentCollection | DiffusionModelBase, + ) -> None: """ - Build a callable for a specific diffusion mode. + Validate the model type. Parameters ---------- - mode : str - The diffusion mode ("area" or "width"). - - Returns - ------- - Callable - A callable for the specified diffusion mode. + model : ModelComponent | ComponentCollection | DiffusionModelBase + The model to validate. Raises ------ - ValueError - If the mode is unknown. + TypeError + If model is not a ModelComponent, ComponentCollection, or DiffusionModelBase. """ - model = self.model - - if mode == 'area': - return lambda x, **_: model.calculate_QISF(x) * model.scale.value - - if mode == 'width': - return lambda x, **_: model.calculate_width(x) - - if mode == 'delta': - return lambda x, **_: model.calculate_EISF(x) * model.scale.value - - raise ValueError(f'Unknown diffusion mode: {mode}') + if not isinstance(model, (ModelComponent, ComponentCollection, DiffusionModelBase)): + raise TypeError( + 'model must be a ModelComponent, ComponentCollection, or DiffusionModelBase' + ) - def _get_modes(self) -> list[str]: + @staticmethod + def _validate_targets( + model: ModelComponent | ComponentCollection | DiffusionModelBase, + targets: str | list[str] | dict[str, str] | None, + ) -> None: """ - Get the modes to fit for diffusion models, defaulting to ["area", "width"] if not set. + Validate a targets specification against a model. - Returns - ------- - list[str] - The modes to fit for diffusion models. + Parameters + ---------- + model : ModelComponent | ComponentCollection | DiffusionModelBase + The model the targets apply to. + targets : str | list[str] | dict[str, str] | None + The targets specification to validate. + + Raises + ------ + TypeError + If targets has an invalid type for the given model. + ValueError + If targets names a prediction the model does not declare. """ - return ['area', 'width'] if self.modes is None else self.modes + if isinstance(model, DiffusionModelBase): + if targets is None: + return + if isinstance(targets, str): + requested = [targets] + elif isinstance(targets, list): + requested = targets + elif isinstance(targets, dict): + requested = list(targets.keys()) + if not all(isinstance(key, str) for key in targets.values()): + raise TypeError('targets dict values must be dataset keys (strings)') + else: + raise TypeError( + 'targets must be None, a prediction name, a list of prediction names, ' + 'or a dict mapping prediction names to dataset keys' + ) + if not all(isinstance(name, str) for name in requested): + raise TypeError('prediction names in targets must be strings') + + available = [target.name for target in model.get_fit_targets()] + unknown = sorted(set(requested) - set(available)) + if unknown: + raise ValueError( + f'Unknown prediction(s) {", ".join(unknown)} for ' + f'{model.__class__.__name__}. Available predictions: ' + f'{", ".join(available)}.' + ) + return + + if not isinstance(targets, str): + raise TypeError( + 'For component models, targets must be the dataset key (a string) to fit ' + "the model's evaluate against." + ) # ------------------------------------------------------------------ # dunder methods @@ -331,10 +302,8 @@ def __repr__(self) -> str: A string representation of the FitBinding. """ return ( - f'{self.__class__.__name__}(' - f'parameter_name={self.parameter_name!r},\n' - f' model={self.model.display_name!r},\n' - f' modes={self.modes},\n' - f' display_name={self.display_name!r},\n' - f' unique_name={self.unique_name!r})' + f'{self.__class__.__name__}(\n' + f' model={self.model.display_name},\n' + f' targets={self.targets},\n' + f')' ) diff --git a/src/easydynamics/analysis/parameter_analysis.py b/src/easydynamics/analysis/parameter_analysis.py index 142f0b1a4..7e022e008 100644 --- a/src/easydynamics/analysis/parameter_analysis.py +++ b/src/easydynamics/analysis/parameter_analysis.py @@ -15,6 +15,7 @@ from easydynamics.analysis.analysis import Analysis from easydynamics.analysis.fit_binding import FitBinding from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase +from easydynamics.utils.fit_target import FitTarget from easydynamics.utils.utils import _in_notebook @@ -25,15 +26,15 @@ class ParameterAnalysis(EasyDynamicsModelBase): Can be used to fit parameters to ModelComponents, ComponentCollections, or DiffusionModelBase objects, and to plot the parameters and fit results. The parameters to be analyzed can be provided as a sc.Dataset or directly as an Analysis object. Multiple parameters can be fitted - simultaneously, and the fit functions can be customized for each parameter. For diffusion - models, the area and width can be fitted separately (or not at all) by specifying fit settings. + simultaneously, and each binding maps its model's predictions onto the dataset keys they are + fitted against (for diffusion models e.g. 'area', 'width', or 'delta_area'). Examples -------- **Fitting Lorentzian widths to a diffusion model** - After a full Analysis fit, pass the Analysis directly and bind each parameter to a fit function - using a ``FitBinding``: + After a full Analysis fit, pass the Analysis directly and bind the model's predictions to + dataset keys using a ``FitBinding``: ```python import easydynamics as edyn import easydynamics.sample_model as sm @@ -41,9 +42,8 @@ class ParameterAnalysis(EasyDynamicsModelBase): # analysis is an edyn.Analysis object with previously fitted parameters diffusion_model = sm.BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, scale=0.5) binding = edyn.FitBinding( - parameter_name='Lorentzian width', model=diffusion_model, - modes=['width'], + targets={'width': 'Lorentzian width'}, ) param_analysis = edyn.ParameterAnalysis( @@ -58,8 +58,8 @@ class ParameterAnalysis(EasyDynamicsModelBase): ```python area_binding = edyn.FitBinding( - parameter_name='Lorentzian area', model=sm.Polynomial(coefficients=[0.5, 0.0]), + targets='Lorentzian area', ) param_analysis = edyn.ParameterAnalysis( parameters=analysis, @@ -183,23 +183,24 @@ def fit(self) -> FitResults: funcs, models = [], [] for binding in self.bindings: - param_names = binding.get_parameter_names() - callables = binding.build_callables() - - for pname, func in zip(param_names, callables, strict=True): - if pname not in self.parameters: + for target in binding.get_targets(): + if target.dataset_key not in self.parameters: raise ValueError( - f"Parameter '{pname}' from binding '{binding.unique_name}' " - f'not found in parameters Dataset.' + f"Parameter '{target.dataset_key}' from binding " + f"'{binding.unique_name}' not found in parameters Dataset." ) - x, y, weight = self._get_xyweight_from_dataset(pname) + x, y, weight = self._get_xyweight_from_dataset(target.dataset_key) + x_factor, y_factor = self._get_unit_conversions(target) + x = x * x_factor + y = y * y_factor + weight = weight / y_factor xs.append(x) ys.append(y) ws.append(weight) - funcs.append(func) + funcs.append(target.function) models.append(binding.model) mf = MultiFitter( @@ -258,7 +259,7 @@ def plot( names = list(self.parameters.keys()) else: for b in self.bindings: - names.extend(b.get_parameter_names()) + names.extend(target.dataset_key for target in b.get_targets()) names = self._normalize_names(names) @@ -282,14 +283,13 @@ def plot( data_arrays = {} model_arrays = {} - # map parameter names to model names + # map dataset keys to model labels param_to_model = {} if self.bindings is not None: for b in self.bindings: - param_names = b.get_parameter_names() - model_names = b.get_model_names() - - param_to_model.update(dict(zip(param_names, model_names, strict=True))) + param_to_model.update({ + target.dataset_key: target.label for target in b.get_targets() + }) for pname in names: data_arrays[pname] = self.parameters[pname] @@ -354,22 +354,20 @@ def calculate_model_dataset(self, bindings: list[FitBinding]) -> sc.Dataset: arrays = {} for b in bindings: - param_names = b.get_parameter_names() - model_names = b.get_model_names() - callables = b.build_callables() - - for pname, mname, func in zip(param_names, model_names, callables, strict=True): - if pname not in self.parameters: + for target in b.get_targets(): + if target.dataset_key not in self.parameters: raise ValueError( - f"Parameter '{pname}' from binding '{b.unique_name}' " + f"Parameter '{target.dataset_key}' from binding '{b.unique_name}' " f'not found in parameters Dataset.' ) - da = self.parameters[pname] + da = self.parameters[target.dataset_key] x = da.coords['Q'] + x_factor, y_factor = self._get_unit_conversions(target) - y_model = func(x.values) + y_model = target.function(x.values * x_factor) + y_model = y_model / y_factor - arrays[mname] = sc.DataArray( + arrays[target.label] = sc.DataArray( data=sc.array(dims=['Q'], values=y_model, unit=da.unit), coords={'Q': x}, ) @@ -521,6 +519,57 @@ def _normalize_names(self, names: str | list[str] | None) -> list[str] | None: # Private methods ############# + def _get_unit_conversions(self, target: FitTarget) -> tuple[float, float]: + """ + Return (x_factor, y_factor) to convert dataset values into the target's declared units. + + x_factor converts Q coordinate values from their stored unit to target.x_unit (e.g. + 1/angstrom for diffusion-model predictions). y_factor converts parameter values from their + stored unit to target.y_unit. A factor is 1.0 when the target declares the corresponding + unit as None (its function takes raw values). + + Parameters + ---------- + target : FitTarget + The fit target whose unit contract defines the target units. + + Returns + ------- + tuple[float, float] + ``(x_factor, y_factor)`` scale factors to apply before/after model evaluation. + + Raises + ------ + sc.UnitError + If x or y units are physically incompatible (e.g. meV vs 1/angstrom). + """ + da = self.parameters[target.dataset_key] + x_factor = 1.0 + y_factor = 1.0 + + if target.x_unit is not None: + q_unit = str(da.coords['Q'].unit) + try: + x_factor = sc.to_unit(sc.scalar(1.0, unit=q_unit), str(target.x_unit)).value + except Exception as e: + raise sc.UnitError( + f"Q coordinate unit '{q_unit}' is incompatible with " + f"the x_unit '{target.x_unit}' of fit target '{target.label}' " + f"for parameter '{target.dataset_key}'." + ) from e + + if target.y_unit is not None: + param_unit = str(da.unit) + try: + y_factor = sc.to_unit(sc.scalar(1.0, unit=param_unit), str(target.y_unit)).value + except Exception as e: + raise sc.UnitError( + f"Parameter '{target.dataset_key}' unit '{param_unit}' is incompatible " + f"with the y_unit '{target.y_unit}' of fit target '{target.label}'." + ) from e + + return x_factor, y_factor + def _get_xyweight_from_dataset( self, parameter_name: str ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: @@ -549,21 +598,29 @@ def _get_xyweight_from_dataset( raise ValueError(f"Parameter name '{parameter_name}' not found in parameters Dataset.") variances = self._parameters[parameter_name].variances + values = self._parameters[parameter_name].values + q_values = self._parameters[parameter_name].coords['Q'].values + if variances is None: - weight = np.ones_like(self._parameters[parameter_name].values) - elif np.any(~np.isfinite(variances)) or np.any(variances <= 0): + return q_values, values, np.ones_like(values) + + # NaN variances arise when a parameter is absent for a given Q (parameters_to_dataset + # fills np.nan for missing parameters). Filter those rows silently; other non-finite or + # non-positive variances are errors. + nan_mask = np.isnan(variances) + if np.any(~nan_mask & (~np.isfinite(variances) | (variances <= 0))): raise ValueError( f"Non-finite variances found for parameter '{parameter_name}', " f'cannot compute weights.' ) - else: - weight = 1 / np.sqrt(variances) + valid_mask = ~nan_mask + if not np.any(valid_mask): + raise ValueError( + f"No finite positive variances found for parameter '{parameter_name}', " + f'cannot compute weights.' + ) - return ( - self._parameters[parameter_name].coords['Q'].values, - self._parameters[parameter_name].values, - weight, - ) + return q_values[valid_mask], values[valid_mask], 1 / np.sqrt(variances[valid_mask]) ############# # Dunder methods @@ -587,9 +644,8 @@ def __repr__(self) -> str: binding_info = [ { - 'parameter': b.parameter_name, 'model': b.model.display_name, - 'modes': b.modes, + 'targets': b.targets, } for b in self._bindings ] diff --git a/src/easydynamics/base_classes/easydynamics_modelbase.py b/src/easydynamics/base_classes/easydynamics_modelbase.py index 6d002f71e..3e48fe82c 100644 --- a/src/easydynamics/base_classes/easydynamics_modelbase.py +++ b/src/easydynamics/base_classes/easydynamics_modelbase.py @@ -14,7 +14,8 @@ class EasyDynamicsModelBase(NameMixin, ModelBase): def __init__( self, *args: object, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'MyEasyDynamicsModel', display_name: str | None = None, unique_name: str | None = None, @@ -27,8 +28,10 @@ def __init__( ---------- *args : object Positional arguments to pass to the parent class. - unit : str | sc.Unit, default='meV' - Unit of the model. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the y-axis. name : str, default='MyEasyDynamicsModel' Name of the model. display_name : str | None, default=None @@ -58,37 +61,43 @@ def __init__( **kwargs, ) - self._unit = _validate_unit(unit) + self._x_unit = _validate_unit(x_unit) + self._y_unit = _validate_unit(y_unit) @property - def unit(self) -> str | sc.Unit | None: + def x_unit(self) -> str | sc.Unit | None: """ - Get the unit of the model. + Get the unit of the x-axis. Returns ------- str | sc.Unit | None - The unit of the model. + The unit of the x-axis. """ + return self._x_unit - return self._unit + @x_unit.setter + def x_unit(self, _: str) -> None: + raise AttributeError( + f'x_unit is read-only. Use convert_x_unit to change the unit ' + f'or create a new {self.__class__.__name__} with the desired unit.' + ) - @unit.setter - def unit(self, _unit_str: str) -> None: + @property + def y_unit(self) -> str | sc.Unit | None: """ - Unit is read-only and cannot be set directly. - - Parameters - ---------- - _unit_str : str - The new unit to set (ignored). + Get the unit of the model output. - Raises - ------ - AttributeError - Always raised to indicate that the unit is read-only. + Returns + ------- + str | sc.Unit | None + The unit of the y-axis. """ + return self._y_unit + + @y_unit.setter + def y_unit(self, _: str) -> None: raise AttributeError( - f'Unit is read-only. Use convert_unit to change the unit between allowed types ' + f'y_unit is read-only. Use convert_y_unit to change the unit ' f'or create a new {self.__class__.__name__} with the desired unit.' ) diff --git a/src/easydynamics/convolution/analytical_convolution.py b/src/easydynamics/convolution/analytical_convolution.py index 535bea989..efcdf6f52 100644 --- a/src/easydynamics/convolution/analytical_convolution.py +++ b/src/easydynamics/convolution/analytical_convolution.py @@ -491,6 +491,7 @@ def __repr__(self) -> str: f'{self.__class__.__name__}(' f'display_name={self.display_name!r}, ' f'unique_name={self.unique_name!r}, ' - f'unit={self._unit}, ' + f'x_unit={self.x_unit}, ' + f'y_unit={self.y_unit}, ' f'energy_len={len(self.energy)})' ) diff --git a/src/easydynamics/convolution/convolution.py b/src/easydynamics/convolution/convolution.py index d4e8aafee..89c7f2ca6 100644 --- a/src/easydynamics/convolution/convolution.py +++ b/src/easydynamics/convolution/convolution.py @@ -80,7 +80,7 @@ class Convolution(NumericalConvolutionBase): # needs to be rebuilt. # Note: the public 'energy' property setter always writes to '_energy', so '_energy' alone # is sufficient — listing 'energy' separately would cause a double invalidation. - _invalidate_plan_on_change: ClassVar[dict[str, object]] = { + _invalidate_plan_on_change: ClassVar[set[str]] = { '_energy', '_energy_grid', '_sample_components', @@ -101,7 +101,8 @@ def __init__( temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', detailed_balance_settings: DetailedBalanceSettings | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', display_name: str | None = 'MyConvolution', unique_name: str | None = None, ) -> None: @@ -113,7 +114,7 @@ def __init__( energy : np.ndarray | sc.Variable 1D array of energy values where the convolution is evaluated. sample_components : ComponentCollection | ModelComponent - The sample components to be convolved. + The sample components to be convolved. resolution_components : ComponentCollection | ModelComponent The resolution components to convolve with. energy_offset : Numeric | Parameter, default=0.0 @@ -126,8 +127,10 @@ def __init__( The unit of the temperature parameter. detailed_balance_settings : DetailedBalanceSettings | None, default=None The settings for detailed balance. If None, default settings will be used. - unit : str | sc.Unit, default='meV' - The unit of the energy. + x_unit : str | sc.Unit, default='meV' + The unit of the energy axis. + y_unit : str | sc.Unit, default='dimensionless' + The unit of the model output (intensity). display_name : str | None, default='MyConvolution' Display name of the model. unique_name : str | None, default=None @@ -144,7 +147,8 @@ def __init__( temperature=temperature, temperature_unit=temperature_unit, detailed_balance_settings=detailed_balance_settings, - unit=unit, + x_unit=x_unit, + y_unit=y_unit, display_name=display_name, unique_name=unique_name, ) @@ -168,7 +172,7 @@ def convolution( np.ndarray The convolved values evaluated at energy. """ - if not self.convolution_settings.convolution_plan_is_valid: + if not self._convolution_plan_is_current(): self._build_convolution_plan() total = np.zeros_like(self.energy.values, dtype=float) @@ -287,7 +291,7 @@ def _build_convolution_plan(self) -> None: # Update convolvers self._set_convolvers() - self.convolution_settings.convolution_plan_is_valid = True + self._mark_convolution_plan_current() def _set_convolvers(self) -> None: """ @@ -303,6 +307,8 @@ def _set_convolvers(self) -> None: energy_offset=self.energy_offset, sample_components=self._analytical_sample_components, resolution_components=self._resolution_components, + x_unit=self.x_unit, + y_unit=self.y_unit, ) else: self._analytical_convolver = None @@ -317,11 +323,27 @@ def _set_convolvers(self) -> None: temperature=self.temperature, temperature_unit=self._temperature_unit, detailed_balance_settings=self.detailed_balance_settings, - unit=self.unit, + x_unit=self.x_unit, + y_unit=self.y_unit, ) else: self._numerical_convolver = None + def convert_y_unit(self, unit: str) -> None: + """ + Convert the y-axis unit and propagate it to the analytical and numerical sub-convolvers. + + Parameters + ---------- + unit : str + The new y-axis unit. + """ + super().convert_y_unit(unit) + if getattr(self, '_analytical_convolver', None) is not None: + self._analytical_convolver._y_unit = self.y_unit # noqa: SLF001 + if getattr(self, '_numerical_convolver', None) is not None: + self._numerical_convolver._y_unit = self.y_unit # noqa: SLF001 + # Update some setters so the internal sample models are updated def __setattr__(self, name: str, value: any) -> None: """ @@ -343,6 +365,7 @@ def __setattr__(self, name: str, value: any) -> None: # Only rebuild the convolution plan if reactions are enabled, to # avoid issues during __init__ if getattr(self, '_reactions_enabled', False) and name in self._invalidate_plan_on_change: + self._plan_is_valid = False self.convolution_settings.convolution_plan_is_valid = False def __repr__(self) -> str: @@ -350,7 +373,8 @@ def __repr__(self) -> str: f'{self.__class__.__name__}(' f'display_name={self.display_name!r}, ' f'unique_name={self.unique_name!r}, ' - f'unit={self.unit}, ' + f'x_unit={self.x_unit}, ' + f'y_unit={self.y_unit}, ' f'energy_len={len(self.energy)}, ' f'temperature={self.temperature})' ) diff --git a/src/easydynamics/convolution/convolution_base.py b/src/easydynamics/convolution/convolution_base.py index c8516cb0b..2301b70ba 100644 --- a/src/easydynamics/convolution/convolution_base.py +++ b/src/easydynamics/convolution/convolution_base.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +import contextlib + import numpy as np import scipp as sc from easyscience.variable import Parameter @@ -9,6 +11,7 @@ from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent from easydynamics.utils.utils import Numeric +from easydynamics.utils.utils import energy_to_scipp class ConvolutionBase(EasyDynamicsModelBase): @@ -23,7 +26,8 @@ def __init__( energy: np.ndarray | sc.Variable, sample_components: ComponentCollection | ModelComponent | None = None, resolution_components: ComponentCollection | ModelComponent | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', energy_offset: Numeric | Parameter = 0.0, display_name: str | None = 'MyConvolution', unique_name: str | None = None, @@ -39,8 +43,10 @@ def __init__( The sample model to be convolved. resolution_components : ComponentCollection | ModelComponent | None, default=None The resolution model to convolve with. - unit : str | sc.Unit, default='meV' - The unit of the energy. + x_unit : str | sc.Unit, default='meV' + The unit of the energy axis. + y_unit : str | sc.Unit, default='dimensionless' + The unit of the model output (intensity). energy_offset : Numeric | Parameter, default=0.0 The energy offset applied to the convolution. display_name : str | None, default='MyConvolution' @@ -58,7 +64,8 @@ def __init__( """ super().__init__( - unit=unit, + x_unit=x_unit, + y_unit=y_unit, display_name=display_name, unique_name=unique_name, ) @@ -70,10 +77,12 @@ def __init__( raise TypeError(f'Energy must be a numpy ndarray or a scipp Variable. Got {energy}') if isinstance(energy, np.ndarray): - energy = sc.array(dims=['energy'], values=energy, unit=unit) + energy = energy_to_scipp(energy, x_unit) if isinstance(energy_offset, Numeric): - energy_offset = Parameter(name='energy_offset', value=float(energy_offset), unit=unit) + energy_offset = Parameter( + name='energy_offset', value=float(energy_offset), unit=x_unit + ) if not isinstance(energy_offset, Parameter): raise TypeError('Energy_offset must be a number or a Parameter.') @@ -147,8 +156,9 @@ def energy_with_offset(self) -> sc.Variable: sc.Variable The energy values with the offset applied. """ + offset_value = sc.to_unit(self.energy_offset.full_value, self._energy.unit).value energy_with_offset = self.energy.copy() - energy_with_offset.values = self.energy.values - self.energy_offset.value + energy_with_offset.values = self.energy.values - offset_value return energy_with_offset @property @@ -187,15 +197,15 @@ def energy(self, energy: np.ndarray | sc.Variable) -> None: raise TypeError('Energy must be a Number, a numpy ndarray or a scipp Variable.') if isinstance(energy, np.ndarray): - self._energy = sc.array(dims=['energy'], values=energy, unit=self._energy.unit) + self._energy = energy_to_scipp(energy, self._energy.unit) if isinstance(energy, sc.Variable): self._energy = energy - self._unit = energy.unit + self._x_unit = energy.unit - def convert_unit(self, unit: str | sc.Unit) -> None: + def convert_x_unit(self, unit: str | sc.Unit) -> None: """ - Convert the energy and energy_offset to the specified unit. + Convert the energy axis, energy_offset, and all components to the specified unit. Parameters ---------- @@ -213,20 +223,62 @@ def convert_unit(self, unit: str | sc.Unit) -> None: raise TypeError('Energy unit must be a string or scipp unit.') old_energy = self.energy.copy() + old_x_unit = str(self.x_unit) + old_offset_unit = str(self.energy_offset.unit) + try: self.energy = sc.to_unit(self.energy, unit) - except Exception as e: + self._energy_offset.convert_unit(unit) + if self.sample_components is not None: + self.sample_components.convert_x_unit(unit) + if self.resolution_components is not None: + self.resolution_components.convert_x_unit(unit) + except Exception: self.energy = old_energy - raise e + # Roll back energy_offset if it was already converted to the new unit. + if str(self._energy_offset.unit) != old_offset_unit: + self._energy_offset.convert_unit(old_offset_unit) + # Roll back component collections that may have been partially converted. + if self.sample_components is not None: + with contextlib.suppress(Exception): + self.sample_components.convert_x_unit(old_x_unit) + if self.resolution_components is not None: + with contextlib.suppress(Exception): + self.resolution_components.convert_x_unit(old_x_unit) + raise + + self._x_unit = unit + + def convert_y_unit(self, unit: str | sc.Unit) -> None: + """ + Convert the y-axis unit of the sample components. - old_energy_offset = self.energy_offset - try: - self.energy_offset.convert_unit(unit) - except Exception as e: - self.energy_offset = old_energy_offset - raise e + Only propagates to sample components (resolution is normalised and unit-independent). + + Parameters + ---------- + unit : str | sc.Unit + The new y-axis unit. - self._unit = unit + Raises + ------ + TypeError + If unit is not a string or scipp unit. + Exception + If any component raises during unit conversion. On failure, attempts to roll back. + """ + if not isinstance(unit, (str, sc.Unit)): + raise TypeError('y_unit must be a string or scipp unit.') + old_y_unit = self.y_unit + try: + if self.sample_components is not None: + self.sample_components.convert_y_unit(unit) + self._y_unit = str(unit) if isinstance(unit, sc.Unit) else unit + except Exception: + if self.sample_components is not None: + with contextlib.suppress(Exception): + self.sample_components.convert_y_unit(old_y_unit) + raise @property def sample_components(self) -> ComponentCollection | ModelComponent: diff --git a/src/easydynamics/convolution/numerical_convolution.py b/src/easydynamics/convolution/numerical_convolution.py index e0ddde05d..fbe36b42a 100644 --- a/src/easydynamics/convolution/numerical_convolution.py +++ b/src/easydynamics/convolution/numerical_convolution.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import numpy as np +import scipp as sc from scipy.signal import fftconvolve from easydynamics.convolution.numerical_convolution_base import NumericalConvolutionBase @@ -31,9 +32,9 @@ def convolution( """ # Make sure the convolver is updated with the latest convolution # settings before convolution. - if not self.convolution_settings.convolution_plan_is_valid: + if not self._convolution_plan_is_current(): self._energy_grid = self._create_energy_grid() - self.convolution_settings.convolution_plan_is_valid = True + self._mark_convolution_plan_current() # Give warnings if peaks are very wide or very narrow if not self.convolution_settings.suppress_warnings: @@ -46,18 +47,24 @@ def convolution( model_name='resolution model', ) + # Unit-convert the energy offset to match the energy grid unit. + # sc.to_unit returns a new scalar — self.energy_offset is never mutated. + offset_value = sc.to_unit(self.energy_offset.full_value, self.energy.unit).value + # Evaluate sample model. If called via the Convolution class, # delta functions are already filtered out. sample_vals = self.sample_components.evaluate( self._energy_grid.energy_dense - self._energy_grid.energy_even_length_offset - - self.energy_offset.value + - offset_value ) # Detailed balance correction if self.temperature is not None and self.detailed_balance_settings.use_detailed_balance: detailed_balance_factor_correction = detailed_balance_factor( - energy=self._energy_grid.energy_dense - self.energy_offset.value, + energy=self._energy_grid.energy_dense + - self._energy_grid.energy_even_length_offset + - offset_value, temperature=self.temperature, energy_unit=self.energy.unit, divide_by_temperature=self.detailed_balance_settings.normalize_detailed_balance, @@ -90,7 +97,8 @@ def __repr__(self) -> str: f'{self.__class__.__name__}(' f'display_name={self.display_name!r}, ' f'unique_name={self.unique_name!r}, ' - f'unit={self.unit}, ' + f'x_unit={self.x_unit}, ' + f'y_unit={self.y_unit}, ' f'energy_len={len(self.energy)}, ' f'temperature={self.temperature})' ) diff --git a/src/easydynamics/convolution/numerical_convolution_base.py b/src/easydynamics/convolution/numerical_convolution_base.py index 188a69dbc..e837a4497 100644 --- a/src/easydynamics/convolution/numerical_convolution_base.py +++ b/src/easydynamics/convolution/numerical_convolution_base.py @@ -43,7 +43,8 @@ def __init__( temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', detailed_balance_settings: DetailedBalanceSettings | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', display_name: str | None = 'MyConvolution', unique_name: str | None = None, ) -> None: @@ -68,8 +69,10 @@ def __init__( The unit of the temperature parameter. detailed_balance_settings : DetailedBalanceSettings | None, default=None The settings for detailed balance. If None, default settings will be used. - unit : str | sc.Unit, default='meV' - The unit of the energy. + x_unit : str | sc.Unit, default='meV' + The unit of the energy axis. + y_unit : str | sc.Unit, default='dimensionless' + The unit of the model output (intensity). display_name : str | None, default='MyConvolution' Display name of the model. unique_name : str | None, default=None @@ -85,7 +88,8 @@ def __init__( energy=energy, sample_components=sample_components, resolution_components=resolution_components, - unit=unit, + x_unit=x_unit, + y_unit=y_unit, energy_offset=energy_offset, display_name=display_name, unique_name=unique_name, @@ -116,7 +120,36 @@ def __init__( # When upsample_factor>1, we evaluate on this grid and # interpolate back to the original values at the end self._energy_grid = self._create_energy_grid() - self._convolution_settings.convolution_plan_is_valid = True + self._mark_convolution_plan_current() + + def _convolution_plan_is_current(self) -> bool: + """ + Check whether this convolver's plan is up to date. + + Plan validity is tracked per convolver so several convolvers can share one + ConvolutionSettings object: a local flag covers convolver-local invalidations (e.g. a new + energy grid), while the settings' plan version detects settings changes this convolver has + not consumed yet — even if a sibling convolver already rebuilt and set the shared flag. + Explicitly setting ``convolution_plan_is_valid = True`` on the settings remains an escape + hatch that suppresses rebuilds for all convolvers. + + Returns + ------- + bool + True if the plan does not need to be rebuilt. + """ + if not getattr(self, '_plan_is_valid', False): + return False + return self.convolution_settings._plan_valid_for( # noqa: SLF001 + self._plan_settings_version_seen + ) + + def _mark_convolution_plan_current(self) -> None: + """Record that this convolver's plan matches its current state and settings.""" + self._plan_is_valid = True + self._plan_settings_version_seen = ( + self.convolution_settings._mark_plan_built() # noqa: SLF001 + ) @property def convolution_settings(self) -> ConvolutionSettings: @@ -149,6 +182,7 @@ def convolution_settings(self, settings: ConvolutionSettings) -> None: if not isinstance(settings, ConvolutionSettings): raise TypeError('settings must be a ConvolutionSettings instance.') self._convolution_settings = settings + self._plan_is_valid = False self._convolution_settings.convolution_plan_is_valid = False @ConvolutionBase.energy.setter @@ -162,6 +196,8 @@ def energy(self, energy: np.ndarray) -> None: The new energy array. """ ConvolutionBase.energy.fset(self, energy) + self._energy_grid = self._create_energy_grid() + self._plan_is_valid = False self.convolution_settings.convolution_plan_is_valid = False @property @@ -437,12 +473,12 @@ def __repr__(self) -> str: """ return ( f'{self.__class__.__name__}(' - f' energy=array of shape {self.energy.values.shape},\n' - f' sample_components={self.sample_components!r},\n' - f' resolution_components={self.resolution_components!r},\n' - f' unit={self.unit}, ' - f' upsample_factor={self.upsample_factor}, ' - f' extension_factor={self.extension_factor}, ' - f' temperature={self.temperature}, ' - f' detailed_balance={self.detailed_balance_settings!r})' + f'energy=array of shape {self.energy.values.shape},\n ' + f'sample_components={self.sample_components!r}, \n' + f'resolution_components={self.resolution_components!r},\n ' + f'x_unit={self.x_unit}, y_unit={self.y_unit}, ' + f'upsample_factor={self.upsample_factor}, ' + f'extension_factor={self.extension_factor}, ' + f'temperature={self.temperature}, ' + f'detailed_balance={self.detailed_balance_settings!r})' ) diff --git a/src/easydynamics/experiment/experiment.py b/src/easydynamics/experiment/experiment.py index a2fb04fe4..d791ee5db 100644 --- a/src/easydynamics/experiment/experiment.py +++ b/src/easydynamics/experiment/experiment.py @@ -13,6 +13,7 @@ from easydynamics.base_classes.easydynamics_base import EasyDynamicsBase from easydynamics.utils.utils import _in_notebook +from easydynamics.utils.utils import verify_Q_index class Experiment(EasyDynamicsBase): @@ -239,11 +240,6 @@ def get_masked_energy(self, Q_index: int) -> sc.Variable | None: Q_index : int The Q index to get the masked energy values for. - Raises - ------ - IndexError - If Q_index is not a valid index for the Q values. - Returns ------- sc.Variable | None @@ -252,19 +248,58 @@ def get_masked_energy(self, Q_index: int) -> sc.Variable | None: if self.binned_data is None: return None - if ( - not isinstance(Q_index, int) - or Q_index < 0 - or (self.Q is not None and Q_index >= len(self.Q)) - ): - raise IndexError('Q_index must be a valid index for the Q values.') + verify_Q_index(Q_index, self.Q) energy = self.binned_data.coords['energy'] - _, _, _, mask = self._extract_x_y_weights_only_finite(Q_index=Q_index) - - mask_var = sc.array(dims=['energy'], values=mask) + mask_var = self.get_finite_energy_mask(Q_index=Q_index) return energy[mask_var] + def get_finite_energy_mask(self, Q_index: int) -> sc.Variable | None: + """ + Get a boolean scipp Variable selecting energy points with finite intensity at the given Q. + + Parameters + ---------- + Q_index : int + The Q index to get the mask for. + + Returns + ------- + sc.Variable | None + Boolean scipp Variable of length n_energy with dim ``'energy'``, or None if no data is + loaded. + """ + if self.binned_data is None: + return None + + verify_Q_index(Q_index, self.Q) + + _, _, _, mask = self.extract_x_y_weights_only_finite(Q_index=Q_index) + return sc.array(dims=['energy'], values=mask) + + def get_masked_binned_data(self, Q_index: int) -> sc.DataArray | None: + """ + Get the binned data for a single Q slice with non-finite points masked out. + + Parameters + ---------- + Q_index : int + The Q index to extract. + + Returns + ------- + sc.DataArray | None + The binned data for the given Q index with NaN/Inf points removed, or None if no data + is loaded. + """ + if self.binned_data is None: + return None + + verify_Q_index(Q_index, self.Q) + + mask_var = self.get_finite_energy_mask(Q_index=Q_index) + return self.binned_data['Q', Q_index][mask_var] + ########### # Handle data ########### @@ -546,7 +581,7 @@ def _extract_x_y_var(self, Q_index: int) -> tuple[np.ndarray, np.ndarray, np.nda var = data.variances return x, y, var - def _extract_x_y_weights_only_finite( + def extract_x_y_weights_only_finite( self, Q_index: int ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ diff --git a/src/easydynamics/sample_model/background_model.py b/src/easydynamics/sample_model/background_model.py index 34d313932..031d9493c 100644 --- a/src/easydynamics/sample_model/background_model.py +++ b/src/easydynamics/sample_model/background_model.py @@ -47,7 +47,8 @@ def __init__( self, display_name: str | None = 'MyBackgroundModel', unique_name: str | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', components: ModelComponent | ComponentCollection | None = None, Q: Q_type | None = None, ) -> None: @@ -60,8 +61,10 @@ def __init__( Display name of the model. unique_name : str | None, default=None Unique name of the model. If None, a unique name will be generated. - unit : str | sc.Unit, default='meV' - Unit of the model. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis (energy, Q, etc.). + y_unit : str | sc.Unit, default='dimensionless' + Unit of the model output (intensity). components : ModelComponent | ComponentCollection | None, default=None Template components of the model. If None, no components are added. These components are copied into ComponentCollections for each Q value. @@ -71,7 +74,8 @@ def __init__( super().__init__( display_name=display_name, unique_name=unique_name, - unit=unit, + x_unit=x_unit, + y_unit=y_unit, components=components, Q=Q, ) @@ -80,7 +84,7 @@ def __repr__(self) -> str: return ( f'{self.__class__.__name__}(' f'unique_name={self.unique_name!r}, ' - f'unit={self.unit}, ' + f'x_unit={self.x_unit}, y_unit={self.y_unit}, ' f'Q_len={None if self._Q is None else len(self._Q)}, ' f'components={self.components})' ) diff --git a/src/easydynamics/sample_model/component_collection.py b/src/easydynamics/sample_model/component_collection.py index 54a6b0010..840b18709 100644 --- a/src/easydynamics/sample_model/component_collection.py +++ b/src/easydynamics/sample_model/component_collection.py @@ -9,14 +9,14 @@ import numpy as np import scipp as sc -from easyscience.variable import DescriptorBase -from easyscience.variable import Parameter from easydynamics.base_classes.easydynamics_list import EasyDynamicsList from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase from easydynamics.sample_model.components.model_component import ModelComponent if TYPE_CHECKING: + from easyscience.variable import DescriptorBase + from easydynamics.utils.utils import Numeric @@ -54,7 +54,8 @@ class ComponentCollection(EasyDynamicsList, EasyDynamicsModelBase): def __init__( self, components: ModelComponent | list[ModelComponent] | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'ComponentCollection', display_name: str | None = None, unique_name: str | None = None, @@ -66,8 +67,10 @@ def __init__( ---------- components : ModelComponent | list[ModelComponent] | None, default=None Initial model components to add to the ComponentCollection. - unit : str | sc.Unit, default='meV' - Unit of the collection. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis (energy, Q, etc.). + y_unit : str | sc.Unit, default='dimensionless' + Unit of the model output (intensity). name : str, default='ComponentCollection' Name of the collection. display_name : str | None, default=None @@ -86,12 +89,14 @@ def __init__( components = [components] elif not isinstance(components, list): raise TypeError( - f'components must be a ModelComponent or a list of ModelComponent, got {type(components).__name__} instead.' # noqa: E501 + f'components must be a ModelComponent or a list of ModelComponent, ' + f'got {type(components).__name__} instead.' ) for comp in components: if not isinstance(comp, ModelComponent): raise TypeError( - f'All items in components must be instances of ModelComponent, got {type(comp).__name__} instead.' # noqa: E501 + f'All items in components must be instances of ModelComponent, ' + f'got {type(comp).__name__} instead.' ) EasyDynamicsList.__init__( @@ -102,7 +107,8 @@ def __init__( EasyDynamicsModelBase.__init__( self, - unit=unit, + x_unit=x_unit, + y_unit=y_unit, name=name, display_name=display_name, unique_name=unique_name, @@ -146,39 +152,74 @@ def is_empty(self, _value: bool) -> None: 'whether the collection has components.' ) - def convert_unit(self, unit: str | sc.Unit) -> None: + # ------------------------------------------------------------------ + # Unit conversion + # ------------------------------------------------------------------ + + def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None: + """ + Convert the x-axis unit of the ComponentCollection and all its components. + + Parameters + ---------- + new_x_unit : str | sc.Unit + The target x-axis unit to convert to. + """ + self._convert_axis_unit(new_x_unit, axis='x') + + def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None: """ - Convert the unit of the ComponentCollection and all its components. + Convert the y-axis unit of the ComponentCollection and all its components. + + Parameters + ---------- + new_y_unit : str | sc.Unit + The target y-axis unit to convert to. + """ + self._convert_axis_unit(new_y_unit, axis='y') + + def _convert_axis_unit(self, unit: str | sc.Unit, axis: str) -> None: + """ + Convert one axis unit on all components in the collection. + + Converts every component via its ``convert__unit`` method and updates the + collection's own unit attribute. On failure, attempts a best-effort rollback of all + components to the old unit before re-raising. Parameters ---------- unit : str | sc.Unit - The target unit to convert to. + The new unit to convert to. + axis : str + Which axis to convert: ``'x'`` or ``'y'``. Raises ------ TypeError - If unit is not a string or sc.Unit. + If the provided unit is not a string or sc.Unit. Exception If any component cannot be converted to the specified unit. """ - if not isinstance(unit, (str, sc.Unit)): - raise TypeError(f'Unit must be a string or sc.Unit, got {type(unit).__name__}') - - old_unit = self._unit + raise TypeError(f'{axis}_unit must be a string or sc.Unit, got {type(unit).__name__}') + method = f'convert_{axis}_unit' + old_unit = self.x_unit if axis == 'x' else self.y_unit try: for component in self: - component.convert_unit(unit) - self._unit = unit + getattr(component, method)(unit) + unit_str = str(unit) if isinstance(unit, sc.Unit) else unit + if axis == 'x': + self._x_unit = unit_str + else: + self._y_unit = unit_str except Exception as e: - # Attempt to rollback on failure - try: - for component in self: - component.convert_unit(old_unit) - except Exception: # noqa: S110 - pass # Best effort rollback + if old_unit is not None: + try: + for component in self: + getattr(component, method)(old_unit) + except Exception: # noqa: S110 + pass raise e # ------------------------------------------------------------------ @@ -211,7 +252,6 @@ def list_component_names(self) -> list[str]: list[str] List of names of the components in the collection. """ - return [component.name for component in self] def normalize_area(self) -> None: @@ -230,28 +270,38 @@ def normalize_area(self) -> None: raise ValueError('No components in the model to normalize.') area_params = [] - total_area = Parameter(name='total_area', value=0.0, unit=self._unit) for component in self: if hasattr(component, 'area'): area_params.append(component.area) - total_area += component.area else: warnings.warn( f"Component '{component.name}' does not have an 'area' attribute " - f'and will be skipped in normalization.', + 'and will be skipped in normalization.', UserWarning, stacklevel=2, ) - if total_area.value == 0: + if not area_params: + raise ValueError('No components with an area attribute; cannot normalize.') + + # Sum the areas in a common unit so components with different (but compatible) + # units normalize correctly. Dividing each value by the total expressed in the + # reference unit makes the areas sum to 1 in that unit. + reference_unit = str(area_params[0].unit) + total_area_value = sum( + sc.to_unit(sc.scalar(p.value, unit=str(p.unit)), reference_unit).value + for p in area_params + ) + + if total_area_value == 0: raise ValueError('Total area is zero; cannot normalize.') - if not np.isfinite(total_area.value): + if not np.isfinite(total_area_value): raise ValueError('Total area is not finite; cannot normalize.') for param in area_params: - param.value /= total_area.value + param.value /= total_area_value # ------------------------------------------------------------------ # Other methods @@ -259,17 +309,20 @@ def normalize_area(self) -> None: def get_all_variables(self) -> list[DescriptorBase]: """ - Get all parameters from the model component. + Get all parameters from all model components. Returns ------- list[DescriptorBase] - List of parameters in the component. + List of parameters in the collection. """ - return [var for component in self for var in component.get_all_variables()] - def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray: + def evaluate( + self, + x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray, + output: str = 'numpy', + ) -> np.ndarray | sc.Variable: """ Evaluate the sum of all components. @@ -277,22 +330,35 @@ def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) ---------- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray Energy axis. + output : str, default='numpy' + 'numpy' returns np.ndarray; 'scipp' returns sc.Variable with y_unit. Returns ------- - np.ndarray + np.ndarray | sc.Variable Evaluated model values. """ - if not self: - return np.zeros_like(x) - return sum(component.evaluate(x) for component in self) + if isinstance(x, (sc.Variable, sc.DataArray)): + values = np.zeros_like(x.values, dtype=float) + dim = x.dims[0] if x.dims else 'x' + else: + values = np.zeros_like(x, dtype=float) + dim = 'x' + if output == 'scipp': + return sc.array(dims=[dim], values=values, unit=self.y_unit) + return values + # This is needed to handle both scipp and numpy output - a normal call to sum does not work + gen = (component.evaluate(x, output=output) for component in self) + first = next(gen) + return sum(gen, first) def evaluate_component( self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray, name: str, - ) -> np.ndarray: + output: str = 'numpy', + ) -> np.ndarray | sc.Variable: """ Evaluate a single component by name. @@ -302,6 +368,8 @@ def evaluate_component( Energy axis. name : str Component name. + output : str, default='numpy' + 'numpy' returns np.ndarray; 'scipp' returns sc.Variable with y_unit. Raises ------ @@ -314,22 +382,17 @@ def evaluate_component( Returns ------- - np.ndarray + np.ndarray | sc.Variable Evaluated values for the specified component. """ if not self: raise ValueError('No components in the model to evaluate.') - if not isinstance(name, str): raise TypeError(f'Component name must be a string, got {type(name)} instead.') - matches = [comp for comp in self if comp.name == name] if not matches: raise KeyError(f"No component named '{name}' exists.") - - component = matches[0] - - return component.evaluate(x) + return matches[0].evaluate(x, output=output) def fix_all_parameters(self) -> None: """Fix all free parameters in the model.""" @@ -376,11 +439,10 @@ def __repr__(self) -> str: String representation of the ComponentCollection. """ comp_names = ', '.join(c.name for c in self) or 'No components' - return ( - f'{self.__class__.__name__}(' - f'name={self.name!r}, unit={self.unit},\n' - f' components=[{comp_names}])' + f"{self.__class__.__name__}(name='{self.name}', " + f"x_unit='{self.x_unit}', y_unit='{self.y_unit}',\n" + f'Components: {comp_names})' ) def to_dict(self) -> dict: @@ -395,7 +457,8 @@ def to_dict(self) -> dict: return { '@module': self.__class__.__module__, '@class': self.__class__.__name__, - 'unit': str(self.unit), + 'x_unit': str(self.x_unit), + 'y_unit': str(self.y_unit), 'name': self.name, 'display_name': self.display_name, 'components': [c.to_dict() for c in self._data], @@ -433,13 +496,14 @@ def deserialise_component(d: dict) -> ModelComponent: cls = getattr(module, d['@class']) return cls.from_dict(d) - components = [deserialise_component(c) for c in obj_dict.get('components', [])] + components = [deserialise_component(c) for c in obj_dict['components']] return cls( components=components, - unit=obj_dict.get('unit', 'meV'), - name=obj_dict.get('name', 'ComponentCollection'), - display_name=obj_dict.get('display_name'), + x_unit=obj_dict['x_unit'], + y_unit=obj_dict['y_unit'], + name=obj_dict['name'], + display_name=obj_dict['display_name'], ) def __copy__(self) -> ComponentCollection: @@ -451,5 +515,4 @@ def __copy__(self) -> ComponentCollection: ComponentCollection A deep copy of the ComponentCollection. """ - return self.from_dict(self.to_dict()) diff --git a/src/easydynamics/sample_model/components/damped_harmonic_oscillator.py b/src/easydynamics/sample_model/components/damped_harmonic_oscillator.py index 9d4d1ed75..90a63562f 100644 --- a/src/easydynamics/sample_model/components/damped_harmonic_oscillator.py +++ b/src/easydynamics/sample_model/components/damped_harmonic_oscillator.py @@ -20,8 +20,11 @@ class DampedHarmonicOscillator(CreateParametersMixin, ModelComponent): r""" Model of a Damped Harmonic Oscillator (DHO). - The intensity is given by $$ I(x) = \frac{2 A x_0^2 \gamma}{\pi \left( (x^2 - x_0^2)^2 + (2 - \gamma x)^2 \right)}, $$ where $A$ is the area, $x_0$ is the center, and $\gamma$ is the width. + $$ I(x) = \frac{2 A x_0^2 \gamma}{\pi \left( (x^2 - x_0^2)^2 + (2\gamma x)^2 \right)} $$ + + where $A$ is the area (``area``), $x_0$ is the center (``center``), and $\gamma$ is the half + width at half max (``width``). area has unit = x_unit * y_unit; center and width have unit = + x_unit. Examples -------- @@ -55,56 +58,57 @@ def __init__( area: Numeric = 1.0, center: Numeric = 1.0, width: Numeric = 1.0, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'DampedHarmonicOscillator', display_name: str | None = None, unique_name: str | None = None, ) -> None: """ - Initialize the Damped Harmonic Oscillator. + Initialize the Damped Harmonic Oscillator component. Parameters ---------- area : Numeric, default=1.0 - Area under the curve. + Integrated area under the DHO profile. Unit is ``x_unit * y_unit``. center : Numeric, default=1.0 - Resonance frequency, approximately the peak position. + Resonance frequency (x_0) in x_unit; approximately the peak position. Must be strictly + positive; a minimum of ``DHO_MINIMUM_CENTER`` (1e-10) is enforced. width : Numeric, default=1.0 - Damping constant, approximately the half width at half max (HWHM) of the peaks. By - default, 1.0. - unit : str | sc.Unit, default='meV' - Unit of the parameters. + Damping coefficient (gamma) in x_unit. Must be strictly positive. Approximately equal + to the HWHM of each peak. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis. center and width are stored in this unit. area_unit = x_unit * + y_unit. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the y-axis (output). name : str, default='DampedHarmonicOscillator' - Name of the component for indexing. + Name of the component. display_name : str | None, default=None - Display name of the component. + Display name shown when plotting. Falls back to *name* if None. unique_name : str | None, default=None - Unique name of the component. If None, a unique_name is automatically generated. By - default, None. + Globally unique identifier. Auto-generated if None. """ - super().__init__( name=name, display_name=display_name, unique_name=unique_name, - unit=unit, + x_unit=x_unit, + y_unit=y_unit, ) - # These methods live in ValidationMixin - area = self._create_area_parameter(area=area, name=name, unit=self._unit) - center = self._create_center_parameter( + # These methods live in CreateParametersMixin + self._area = self._create_area_parameter( + area=area, name=name, x_unit=self.x_unit, y_unit=self.y_unit + ) + self._center = self._create_center_parameter( center=center, name=name, fix_if_none=False, - unit=self._unit, + x_unit=self.x_unit, enforce_minimum_center=True, ) - - width = self._create_width_parameter(width=width, name=name, unit=self._unit) - - self._area = area - self._center = center - self._width = width + self._width = self._create_width_parameter(width=width, name=name, x_unit=self.x_unit) @property def area(self) -> Parameter: @@ -114,24 +118,22 @@ def area(self) -> Parameter: Returns ------- Parameter - The area parameter. + The area Parameter with unit ``x_unit * y_unit``. """ return self._area @area.setter def area(self, value: Numeric) -> None: """ - Set the value of the area parameter. - Parameters ---------- value : Numeric - The new value for the area parameter. + New area value (in current area unit = x_unit * y_unit). Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. """ if not isinstance(value, Numeric): raise TypeError('area must be a number') @@ -140,35 +142,32 @@ def area(self, value: Numeric) -> None: @property def center(self) -> Parameter: """ - Get the center parameter. + Get the center parameter (resonance frequency). Returns ------- Parameter - The center parameter. + The resonance frequency (x_0) Parameter with unit ``x_unit``. """ return self._center @center.setter def center(self, value: Numeric) -> None: """ - Set the value of the center parameter. - Parameters ---------- value : Numeric - The new value for the center parameter. + New resonance frequency in x_unit. Must be strictly positive. Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. ValueError - If the value is not positive. + If *value* is not positive. """ if not isinstance(value, Numeric): raise TypeError('center must be a number') - if float(value) <= 0: raise ValueError('center must be positive') self._center.value = value @@ -176,66 +175,94 @@ def center(self, value: Numeric) -> None: @property def width(self) -> Parameter: """ - Get the width parameter. + Get the width parameter (damping coefficient). Returns ------- Parameter - The width parameter. + The damping coefficient (gamma) Parameter with unit ``x_unit``. """ return self._width @width.setter def width(self, value: Numeric) -> None: """ - Set the value of the width parameter. - Parameters ---------- value : Numeric - The new value for the width parameter. + New damping coefficient in x_unit. Must be strictly positive. Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. ValueError - If the value is not positive. + If *value* is not positive. """ if not isinstance(value, Numeric): raise TypeError('width must be a number') - if float(value) <= 0: raise ValueError('width must be positive') - self._width.value = value - def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray: + def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray: r""" - Evaluate the Damped Harmonic Oscillator at the given x values. + Evaluate the DHO at x_vals. + + $$ I(x) = \frac{2 A x_0^2 \gamma}{\pi \left( (x^2 - x_0^2)^2 + (2\gamma x)^2 \right)} $$ - If x is a scipp Variable, the unit of the DHO will be converted to match x. The intensity - is given by $$ I(x) = \frac{2 A x_0^2 \gamma}{\pi \left( (x^2 - x_0^2)^2 + (2 \gamma x)^2 - \right)}, $$ where $A$ is the area, $x_0$ is the center, and $\gamma$ is the width. + where *A* is ``area``, *x*₀ is ``center`` (resonance frequency), and *gamma* is ``width`` + (damping coefficient). Here *I* is the scattered intensity. Parameters in the model's own + units are temporarily converted to eval_unit for the computation. Parameters ---------- - x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - The x values at which to evaluate the DHO. + x_vals : np.ndarray + Raw x values expressed in eval_unit. + eval_unit : str | None + The unit of x_vals. Returns ------- np.ndarray - The intensity of the DHO at the given x values. + Evaluated DHO values at x_vals. """ + center = self._resolve_param_value(self._center, eval_unit) + width = self._resolve_param_value(self._width, eval_unit) + area = self._resolve_param_value(self._area, self._eval_area_unit(eval_unit)) - x = self._prepare_x_for_evaluate(x) + normalization = 2 * center**2 * width / np.pi + denominator = (x_vals**2 - center**2) ** 2 + (2 * width * x_vals) ** 2 + # denominator cannot reach zero: center > 0 enforced by DHO_MINIMUM_CENTER + return area * normalization / denominator - normalization = 2 * self.center.value**2 * self.width.value / np.pi - # No division by zero here, width>0 enforced in setter - denominator = (x**2 - self.center.value**2) ** 2 + (2 * self.width.value * x) ** 2 + def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None: + """ + Convert x-axis parameters (center, width) and area to new_x_unit. + + Parameters + ---------- + new_x_unit : str | sc.Unit + Target x-axis unit. Must be dimensionally compatible with the current x_unit. + """ + self._convert_x_unit_area_based( + new_x_unit=new_x_unit, + x_params=[self._center, self._width], + area_param=self._area, + ) - return self.area.value * normalization / (denominator) + def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None: + """ + Convert the y-axis unit by rescaling the area parameter. + + The area is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``. + + Parameters + ---------- + new_y_unit : str | sc.Unit + Target y-axis unit. + """ + self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._area) def __repr__(self) -> str: """ @@ -247,10 +274,9 @@ def __repr__(self) -> str: A string representation of the Damped Harmonic Oscillator. """ return ( - f'{self.__class__.__name__}(' - f'name={self.name!r}, display_name={self.display_name!r}, ' - f'unit={self._unit},\n' - f' area={self.area},\n' - f' center={self.center},\n' - f' width={self.width})' + f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, ' + f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n ' + f' area = {self.area},\n ' + f' center = {self.center},\n ' + f' width = {self.width})' ) diff --git a/src/easydynamics/sample_model/components/delta_function.py b/src/easydynamics/sample_model/components/delta_function.py index bca8c4239..9bea615ab 100644 --- a/src/easydynamics/sample_model/components/delta_function.py +++ b/src/easydynamics/sample_model/components/delta_function.py @@ -11,8 +11,7 @@ from easydynamics.sample_model.components.model_component import ModelComponent from easydynamics.utils.utils import Numeric -EPSILON = 1e-8 # small number to avoid floating point issues - +EPSILON = 1e-8 # tolerance for bin-edge comparisons if TYPE_CHECKING: import scipp as sc @@ -23,9 +22,12 @@ class DeltaFunction(CreateParametersMixin, ModelComponent): """ Delta function. - Evaluates to zero everywhere, except in convolutions, where it acts as an identity. This is - handled by the Convolution method. If the center is not provided, it will be centered at 0 and - fixed, which is typically what you want in QENS. + When called directly, returns zero everywhere except at the bin nearest to ``center``, where it + returns ``area / bin_width``. In convolutions it acts as an identity element (handled by the + ``Convolution`` class). area has unit = x_unit * y_unit; center has unit = x_unit. + + If the center is not provided, it will be centered at 0 and fixed, which is typically what you + want in QENS. Examples -------- @@ -57,7 +59,8 @@ def __init__( self, center: Numeric | None = None, area: Numeric = 1.0, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'DeltaFunction', display_name: str | None = None, unique_name: str | None = None, @@ -68,35 +71,35 @@ def __init__( Parameters ---------- center : Numeric | None, default=None - Center of the delta function. If None, it will be centered at 0 and fixed. + Position of the delta function in x_unit. If None, defaults to 0 and the center + parameter is fixed. area : Numeric, default=1.0 - Total area under the curve. - unit : str | sc.Unit, default='meV' - Unit of the parameters. + Integrated area (weight) of the delta function. Unit is ``x_unit * y_unit``. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis. center is stored in this unit. area_unit = x_unit * y_unit. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the y-axis (output). name : str, default='DeltaFunction' - Name of the component for indexing. + Name of the component. display_name : str | None, default=None - Display name of the component. + Display name of the component, shown when plotting. Falls back to *name* if None. unique_name : str | None, default=None - Unique name of the component. If None, a unique_name is automatically generated. By - default, None. + Globally unique identifier. Auto-generated if None. """ - # Validate inputs and create Parameters if not given super().__init__( - unit=unit, + x_unit=x_unit, + y_unit=y_unit, name=name, display_name=display_name, unique_name=unique_name, ) - # These methods live in ValidationMixin - area = self._create_area_parameter(area=area, name=name, unit=self._unit) - center = self._create_center_parameter( - center=center, name=name, fix_if_none=True, unit=self._unit + self._area = self._create_area_parameter( + area=area, name=name, x_unit=self.x_unit, y_unit=self.y_unit + ) + self._center = self._create_center_parameter( + center=center, name=name, fix_if_none=True, x_unit=self.x_unit ) - - self._area = area - self._center = center @property def area(self) -> Parameter: @@ -106,27 +109,23 @@ def area(self) -> Parameter: Returns ------- Parameter - The area parameter. + The area Parameter with unit ``x_unit * y_unit``. """ - return self._area @area.setter def area(self, value: Numeric) -> None: """ - Set the value of the area parameter. - Parameters ---------- value : Numeric - The new value for the area parameter. + New area value (in current area unit = x_unit * y_unit). Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. """ - if not isinstance(value, Numeric): raise TypeError('area must be a number') self._area.value = value @@ -139,27 +138,24 @@ def center(self) -> Parameter: Returns ------- Parameter - The center parameter. + The center Parameter with unit ``x_unit``. """ - return self._center @center.setter def center(self, value: Numeric | None) -> None: """ - Set the center parameter value. - Parameters ---------- value : Numeric | None - The new value for the center parameter. If None, defaults to 0 and is fixed. + New center value in x_unit. If None, the center is set to 0 and the parameter is + fixed. Raises ------ TypeError - If the value is not a number or None. + If *value* is not None and not a numeric type. """ - if value is None: value = 0.0 self._center.fixed = True @@ -167,53 +163,87 @@ def center(self, value: Numeric | None) -> None: raise TypeError('center must be a number') self._center.value = value - def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray: + def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray: """ - Evaluate the Delta function at the given x values. + Evaluate the Delta function at x_vals. - The Delta function evaluates to zero everywhere, except at the center. Its numerical - integral is equal to the area. It acts as an identity in convolutions. + Parameters in the model's own units are temporarily converted to eval_unit for the + computation. Parameters ---------- - x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - The x values at which to evaluate the Delta function. + x_vals : np.ndarray + Raw x values expressed in eval_unit. Assumed sorted for the bin-width computation. + eval_unit : str | None + The unit of x_vals. Returns ------- np.ndarray - The evaluated Delta function at the given x values. + Zero everywhere, with a single non-zero bin nearest the center when center falls within + the x range. + + Notes + ----- + When ``center`` falls within the x range, the bin nearest to ``center`` receives ``area / + bin_width`` rather than zero. In convolutions, the DeltaFunction acts as an identity + element (handled by the Convolution class). """ + center = self._resolve_param_value(self._center, eval_unit) + area = self._resolve_param_value(self._area, self._eval_area_unit(eval_unit)) - # x assumed sorted, 1D numpy array - x = self._prepare_x_for_evaluate(x) - model = np.zeros_like(x, dtype=float) - center = self.center.value - area = self.area.value + model = np.zeros_like(x_vals, dtype=float) - if x.min() - EPSILON <= center <= x.max() + EPSILON: + if x_vals.min() - EPSILON <= center <= x_vals.max() + EPSILON: # nearest index - i = np.argmin(np.abs(x - center)) + i = np.argmin(np.abs(x_vals - center)) # left half-width - if i == 0: # noqa: SIM108 - left = x[1] - x[0] if x.size > 1 else 0.5 + if i == 0: + left = x_vals[1] - x_vals[0] if x_vals.size > 1 else 0.5 else: - left = x[i] - x[i - 1] + left = x_vals[i] - x_vals[i - 1] # right half-width - if i == x.size - 1: # noqa: SIM108 - right = x[-1] - x[-2] if x.size > 1 else 0.5 + if i == x_vals.size - 1: + right = x_vals[-1] - x_vals[-2] if x_vals.size > 1 else 0.5 else: - right = x[i + 1] - x[i] + right = x_vals[i + 1] - x_vals[i] # effective bin width: half left + half right bin_width = 0.5 * (left + right) - model[i] = area / bin_width return model + def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None: + """ + Convert x-axis parameters (center) and area to new_x_unit. + + Parameters + ---------- + new_x_unit : str | sc.Unit + Target x-axis unit. Must be dimensionally compatible with the current x_unit. + """ + self._convert_x_unit_area_based( + new_x_unit=new_x_unit, + x_params=[self._center], + area_param=self._area, + ) + + def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None: + """ + Convert the y-axis unit by rescaling the area parameter. + + The area is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``. + + Parameters + ---------- + new_y_unit : str | sc.Unit + Target y-axis unit. + """ + self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._area) + def __repr__(self) -> str: """ Return a string representation of the Delta function. @@ -223,11 +253,9 @@ def __repr__(self) -> str: str A string representation of the Delta function. """ - return ( - f'{self.__class__.__name__}(' - f'name={self.name!r}, display_name={self.display_name!r}, ' - f'unit={self.unit},\n' - f' area={self.area},\n' - f' center={self.center})' + f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, ' + f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n' + f' area = {self.area},\n' + f' center = {self.center})' ) diff --git a/src/easydynamics/sample_model/components/exponential.py b/src/easydynamics/sample_model/components/exponential.py index 8b7e3241a..25b1048e9 100644 --- a/src/easydynamics/sample_model/components/exponential.py +++ b/src/easydynamics/sample_model/components/exponential.py @@ -16,11 +16,10 @@ class Exponential(CreateParametersMixin, ModelComponent): r""" Model of an exponential function. - The intensity is given by + $$ I(x) = A e^{B (x-x_0)} $$ - $$ I(x) = A e^{B (x-x_0)}, $$ - - where $A$ is the amplitude, $x_0$ is the center, and $B$ describes the rate of decay or growth. + where $A$ is the amplitude, $x_0$ is the center, and $B$ is the rate. amplitude has unit = + x_unit * y_unit; center has unit = x_unit; rate has unit = 1/x_unit. Examples -------- @@ -53,7 +52,8 @@ def __init__( amplitude: Numeric = 1.0, center: Numeric | None = None, rate: Numeric = 1.0, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'Exponential', display_name: str | None = None, unique_name: str | None = None, @@ -64,61 +64,59 @@ def __init__( Parameters ---------- amplitude : Numeric, default=1.0 - Amplitude of the Exponential. + Pre-exponential factor A. Unit is ``x_unit * y_unit``. center : Numeric | None, default=None - Center of the Exponential. If None, the center is fixed at 0. + Reference point x_0 in x_unit. If None, defaults to 0 and the center parameter is + fixed. rate : Numeric, default=1.0 - Decay or growth constant of the Exponential. - unit : str | sc.Unit, default='meV' - Unit of the parameters. + Exponential rate B in units of ``1/x_unit``. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis. center is stored in this unit; rate is stored in ``1/x_unit``. + amplitude_unit = x_unit * y_unit. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the y-axis (output). name : str, default='Exponential' - Name of the component for indexing. + Name of the component. display_name : str | None, default=None - Display name of the component. + Display name shown when plotting. Falls back to *name* if None. unique_name : str | None, default=None - Unique name of the component. If None, a unique_name is automatically generated. By - default, None. + Globally unique identifier. Auto-generated if None. Raises ------ TypeError - If amplitude, center, or rate are not numbers or Parameters. + If *amplitude* or *rate* is not numeric. ValueError - If amplitude, center or rate are not finite numbers. + If *amplitude* or *rate* is not finite. """ - # Validate inputs and create Parameters if not given super().__init__( - unit=unit, + x_unit=x_unit, + y_unit=y_unit, name=name, display_name=display_name, unique_name=unique_name, ) - if not isinstance(amplitude, (Parameter, Numeric)): - raise TypeError('amplitude must be a number or a Parameter.') - - if isinstance(amplitude, Numeric): - if not np.isfinite(amplitude): - raise ValueError('amplitude must be a finite number or a Parameter') + x_unit_str = str(x_unit) if isinstance(x_unit, sc.Unit) else x_unit + amplitude_unit = str(sc.Unit(x_unit_str) * sc.Unit(self.y_unit)) - amplitude = Parameter(name=name + ' amplitude', value=float(amplitude), unit=unit) - - center = self._create_center_parameter( - center=center, name=name, fix_if_none=True, unit=self._unit + if not isinstance(amplitude, Numeric): + raise TypeError('amplitude must be a number.') + if not np.isfinite(amplitude): + raise ValueError('amplitude must be finite.') + self._amplitude = Parameter( + name=name + ' amplitude', value=float(amplitude), unit=amplitude_unit ) - if not isinstance(rate, (Parameter, Numeric)): - raise TypeError('rate must be a number or a Parameter.') - - if isinstance(rate, Numeric): - if not np.isfinite(rate): - raise ValueError('rate must be a finite number or a Parameter') - - rate = Parameter(name=name + ' rate', value=float(rate), unit='1/' + str(unit)) + self._center = self._create_center_parameter( + center=center, name=name, fix_if_none=True, x_unit=self.x_unit + ) - self._amplitude = amplitude - self._center = center - self._rate = rate + if not isinstance(rate, Numeric): + raise TypeError('rate must be a number.') + if not np.isfinite(rate): + raise ValueError('rate must be finite.') + self._rate = Parameter(name=name + ' rate', value=float(rate), unit='1/' + x_unit_str) @property def amplitude(self) -> Parameter: @@ -128,27 +126,23 @@ def amplitude(self) -> Parameter: Returns ------- Parameter - The amplitude parameter. + The amplitude Parameter with unit ``x_unit * y_unit``. """ - return self._amplitude @amplitude.setter def amplitude(self, value: Numeric) -> None: """ - Set the value of the amplitude parameter. - Parameters ---------- value : Numeric - The new value for the amplitude parameter. + New amplitude value (in current amplitude unit = x_unit * y_unit). Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. """ - if not isinstance(value, Numeric): raise TypeError('amplitude must be a number') self._amplitude.value = value @@ -161,31 +155,27 @@ def center(self) -> Parameter: Returns ------- Parameter - The center parameter. + The center (x_0) Parameter with unit ``x_unit``. """ - return self._center @center.setter def center(self, value: Numeric | None) -> None: """ - Set the center parameter value. - Parameters ---------- value : Numeric | None - The new value for the center parameter. + New center value in x_unit. If None, the center is set to 0 and the parameter is + fixed. Raises ------ TypeError - If the value is not a number. + If *value* is not None and not a numeric type. """ - if value is None: value = 0.0 self._center.fixed = True - if not isinstance(value, Numeric): raise TypeError('center must be a number') self._center.value = value @@ -198,94 +188,84 @@ def rate(self) -> Parameter: Returns ------- Parameter - The rate parameter. + The exponential rate (B) Parameter with unit ``1/x_unit``. """ return self._rate @rate.setter def rate(self, value: Numeric) -> None: """ - Set the rate parameter value. - Parameters ---------- value : Numeric - The new value for the rate parameter. + New exponential rate in ``1/x_unit``. Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. """ if not isinstance(value, Numeric): raise TypeError('rate must be a number') - self._rate.value = value - def evaluate( - self, - x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray, - ) -> np.ndarray: + def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray: r""" - Evaluate the Exponential at the given x values. + Evaluate the Exponential at x_vals. - If x is a scipp Variable, the unit of the Exponential will be converted to match x. The - intensity is given by $$ I(x) = A \exp\left( r (x - x_0) \right) $$ - - where $A$ is the amplitude, $x_0$ is the center, and $r$ is the rate. + Parameters in the model's own units are temporarily converted to eval_unit for the + computation. Parameters ---------- - x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - The x values at which to evaluate the Exponential. + x_vals : np.ndarray + Raw x values expressed in eval_unit. + eval_unit : str | None + The unit of x_vals. Returns ------- np.ndarray - The intensity of the Exponential at the given x values. + Evaluated exponential values at x_vals. """ + eval_rate_unit = None if eval_unit is None else '1/' + str(eval_unit) - x = self._prepare_x_for_evaluate(x) - exponent = self.rate.value * (x - self.center.value) + center = self._resolve_param_value(self._center, eval_unit) + rate = self._resolve_param_value(self._rate, eval_rate_unit) + amplitude = self._resolve_param_value(self._amplitude, self._eval_area_unit(eval_unit)) - return self.amplitude.value * np.exp(exponent) + exponent = rate * (x_vals - center) + return amplitude * np.exp(exponent) - def convert_unit(self, unit: str | sc.Unit) -> None: + def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None: """ - Convert the unit of the Parameters in the component. + Convert center and amplitude to new_x_unit, rate to 1/new_x_unit. Parameters ---------- - unit : str | sc.Unit - The new unit to convert to. + new_x_unit : str | sc.Unit + Target x-axis unit. Must be dimensionally compatible with the current x_unit. The + rate unit is set to ``1/new_x_unit``. + """ + self._convert_x_unit_area_based( + new_x_unit=new_x_unit, + x_params=[self._center], + area_param=self._amplitude, + inverse_params=[self._rate], + ) - Raises - ------ - TypeError - If unit is not a string or sc.Unit. - Exception - If conversion fails for any parameter. + def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None: """ + Convert the y-axis unit by rescaling the amplitude parameter. + + The amplitude is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``. - if not isinstance(unit, (str, sc.Unit)): - raise TypeError('unit must be a string or sc.Unit') - - old_unit = self._unit - pars = [self.amplitude, self.center] - try: - for p in pars: - p.convert_unit(unit) - self.rate.convert_unit('1/' + str(unit)) - self._unit = unit - except Exception as e: - # Attempt to rollback on failure - try: - for p in pars: - p.convert_unit(old_unit) - self.rate.convert_unit('1/' + str(old_unit)) - except Exception: # noqa: S110 - pass # Best effort rollback - raise e + Parameters + ---------- + new_y_unit : str | sc.Unit + Target y-axis unit. + """ + self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._amplitude) def __repr__(self) -> str: """ @@ -296,12 +276,10 @@ def __repr__(self) -> str: str A string representation of the Exponential. """ - return ( - f'{self.__class__.__name__}(' - f'name={self.name!r}, display_name={self.display_name!r}, ' - f'unit={self._unit},\n' - f' amplitude={self.amplitude},\n' - f' center={self.center},\n' - f' rate={self.rate})' + f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, ' + f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n ' + f' amplitude = {self.amplitude},\n ' + f' center = {self.center},\n ' + f' rate = {self.rate})' ) diff --git a/src/easydynamics/sample_model/components/expression_component.py b/src/easydynamics/sample_model/components/expression_component.py index feda61c6f..3c85e57e9 100644 --- a/src/easydynamics/sample_model/components/expression_component.py +++ b/src/easydynamics/sample_model/components/expression_component.py @@ -3,19 +3,24 @@ from __future__ import annotations +import warnings +from copy import copy from typing import TYPE_CHECKING from typing import ClassVar +import scipp as sc import sympy as sp +from easyscience.variable import DescriptorNumber from easyscience.variable import Parameter from scipy.special import erf -from easydynamics.sample_model.components.model_component import ModelComponent -from easydynamics.utils.utils import Numeric - if TYPE_CHECKING: import numpy as np - import scipp as sc + +from easydynamics.sample_model.components.model_component import ModelComponent +from easydynamics.utils.utils import Numeric +from easydynamics.utils.utils import hbar +from easydynamics.utils.utils import kb class ExpressionComponent(ModelComponent): @@ -40,7 +45,7 @@ class ExpressionComponent(ModelComponent): expr = sm.ExpressionComponent( 'A * exp(-(x - x0)**2 / (2*sigma**2))', parameters={'A': 10, 'x0': 0, 'sigma': 1}, - unit='meV', + x_unit='meV', display_name='Gaussian Peak', ) x = np.linspace(-3, 3, 100) @@ -54,18 +59,45 @@ class ExpressionComponent(ModelComponent): expr.A = 5 expr.sigma = 0.5 ``` + + **Giving parameters units** + + Parameters are dimensionless by default. Units can be given per parameter at construction, or + relabelled later with ``set_unit`` (the numeric value is kept as-is). When units are in use, + the unit of the evaluated expression is derived from the parameter units and x_unit (see + ``output_unit``), and a warning is issued if it does not match y_unit: + ```python + expr = sm.ExpressionComponent( + 'A * exp(-(x - x0)**2 / (2*sigma**2))', + parameters={'A': 10, 'x0': 0, 'sigma': 1}, + parameter_units={'A': '1/meV', 'x0': 'meV', 'sigma': 'meV'}, + y_unit='1/meV', + ) + expr.set_unit('A', '1/meV') + ``` + + **Physical constants** + + The symbols ``hbar`` (in meV*ps) and ``kb`` (in meV/K) are provided automatically as read-only + constants (DescriptorNumbers) when they appear in the expression: + ```python + boltzmann = sm.ExpressionComponent( + 'exp(-x / (kb * T))', + parameters={'T': 300.0}, + parameter_units={'T': 'K'}, + ) + ``` + Use e.g. ``boltzmann.kb.convert_unit('eV/K')`` to work in another unit (this rescales the + value, unlike ``set_unit``). """ - # ------------------------- - # Allowed symbolic functions - # ------------------------- _ALLOWED_FUNCS: ClassVar[dict[str, object]] = { - # Exponentials & logs + # exponential / logarithmic 'exp': sp.exp, 'log': sp.log, 'ln': sp.log, 'sqrt': sp.sqrt, - # Trigonometric + # trigonometric 'sin': sp.sin, 'cos': sp.cos, 'tan': sp.tan, @@ -73,25 +105,23 @@ class ExpressionComponent(ModelComponent): 'cot': sp.cot, 'sec': sp.sec, 'csc': sp.csc, + # inverse trigonometric 'asin': sp.asin, 'acos': sp.acos, 'atan': sp.atan, - # Hyperbolic + # hyperbolic 'sinh': sp.sinh, 'cosh': sp.cosh, 'tanh': sp.tanh, - # Misc + # rounding / sign 'abs': sp.Abs, 'sign': sp.sign, 'floor': sp.floor, 'ceil': sp.ceiling, - # Special functions + # special functions 'erf': sp.erf, } - # ------------------------- - # Allowed constants - # ------------------------- _ALLOWED_CONSTANTS: ClassVar[dict[str, object]] = { 'pi': sp.pi, 'E': sp.E, @@ -99,11 +129,30 @@ class ExpressionComponent(ModelComponent): _RESERVED_NAMES: ClassVar[dict[str, object]] = {'x'} + # Physical constants provided automatically as read-only DescriptorNumbers when their + # symbol appears in the expression: name -> (source constant from utils, default unit). + _PHYSICAL_CONSTANTS: ClassVar[dict[str, tuple[DescriptorNumber, str]]] = { + 'hbar': (hbar, 'meV*ps'), + 'kb': (kb, 'meV/K'), + } + + # Sympy functions that preserve the unit of their argument; all other allowed functions + # require a dimensionless argument and return a dimensionless result. + _UNIT_PRESERVING_FUNCS: ClassVar[tuple] = (sp.Abs, sp.floor, sp.ceiling) + # Functions that accept any unit but return a dimensionless result. + _DIMENSIONLESS_RESULT_FUNCS: ClassVar[tuple] = (sp.sign,) + + # parameter_units is applied at construction only; the resulting units are captured in + # the serialized Parameter dicts, so the argument is skipped during serialization. + _REDIRECT: ClassVar[dict[str, object]] = {'parameter_units': None} + def __init__( self, expression: str, parameters: dict[str, Numeric] | None = None, - unit: str | sc.Unit = 'meV', + parameter_units: dict[str, str | sc.Unit] | None = None, + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'Expression', display_name: str | None = None, unique_name: str | None = None, @@ -114,26 +163,43 @@ def __init__( Parameters ---------- expression : str - The symbolic expression as a string. Must contain 'x' as the independent variable. + The symbolic expression as a string. Must contain 'x' as the independent variable. The + symbols ``hbar`` and ``kb`` are provided automatically as read-only physical constants + (in meV*ps and meV/K respectively) unless overridden via *parameters*. parameters : dict[str, Numeric] | None, default=None - Dictionary of parameter names and their initial values. - unit : str | sc.Unit, default='meV' - Unit of the output. + Dictionary of parameter names and their initial values. Parameters that are not given a + unit are dimensionless. + parameter_units : dict[str, str | sc.Unit] | None, default=None + Optional units per parameter name. Each entry sets the unit of the named parameter + without rescaling its value (see :meth:`set_unit`), and takes precedence over the unit + of a Parameter instance given in *parameters*. When units are in use, a warning is + issued if the expression's output unit does not match y_unit. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the y-axis (output). name : str, default='Expression' - Name of the component for indexing. + Name of the component. display_name : str | None, default=None - Display name for the component. + Display name shown when plotting. Falls back to *name* if None. unique_name : str | None, default=None Unique name for the component. Raises ------ ValueError - If the expression is invalid or does not contain 'x'. + If the expression is invalid or does not contain 'x', or if parameter_units names a + parameter that is not in the expression. TypeError - If any parameter value is not numeric. + If any parameter value is not numeric, or if parameter_units is not a dictionary. """ - super().__init__(unit=unit, name=name, display_name=display_name, unique_name=unique_name) + super().__init__( + x_unit=x_unit, + y_unit=y_unit, + name=name, + display_name=display_name, + unique_name=unique_name, + ) if 'np.' in expression: raise ValueError( @@ -158,15 +224,14 @@ def __init__( if 'x' not in symbol_names: raise ValueError("Expression must contain 'x' as independent variable") - # Reject unknown functions early so invalid expressions fail at init, # not later during numerical evaluation. allowed_function_names = set(self._ALLOWED_FUNCS) | { func.__name__ for func in self._ALLOWED_FUNCS.values() } - # Walk all function-call nodes in the parsed expression (e.g. sin(x), foo(x)). # Keep only function names that are not in our allowlist. + unknown_function_names: set[str] = set() function_atoms = self._expr.atoms(sp.Function) for function_atom in function_atoms: @@ -175,12 +240,10 @@ def __init__( unknown_function_names.add(function_name) unknown_functions = sorted(unknown_function_names) - if unknown_functions: raise ValueError( f'Unsupported function(s) in expression: {", ".join(unknown_functions)}' ) - # Create parameters if parameters is not None and not isinstance(parameters, dict): raise TypeError( @@ -196,34 +259,65 @@ def __init__( ) parameters = parameters or {} self._parameters: dict[str, Parameter] = {} + self._constants: dict[str, DescriptorNumber] = {} self._symbol_names = symbol_names for name in self._symbol_names: if name in self._RESERVED_NAMES: continue + # Physical constants are provided automatically, unless the user explicitly + # supplies a parameter with the same name. + if name in self._PHYSICAL_CONSTANTS and name not in parameters: + self._constants[name] = self._create_physical_constant(name) + continue + value = parameters.get(name, 1.0) if isinstance(value, Parameter): self._parameters[name] = value - elif isinstance(value, dict) and value.get('@class') == 'Parameter': self._parameters[name] = Parameter.from_dict(value) else: self._parameters[name] = Parameter( name=name, value=value, - unit=self._unit, + unit='dimensionless', + ) + + if parameter_units is not None: + if not isinstance(parameter_units, dict): + raise TypeError( + f'parameter_units must be None or a dictionary, ' + f'got {type(parameter_units).__name__}' + ) + constant_names = sorted(set(parameter_units) & set(self._constants)) + if constant_names: + raise ValueError( + f'parameter_units given for physical constant(s): ' + f'{", ".join(constant_names)}. Use e.g. ' + f'expr.{constant_names[0]}.convert_unit(...) to change the unit of a ' + f'constant (this rescales its value).' ) + unknown_names = sorted(set(parameter_units) - set(self._parameters)) + if unknown_names: + raise ValueError( + f'parameter_units given for unknown parameter(s): {", ".join(unknown_names)}' + ) + for parameter_name, parameter_unit in parameter_units.items(): + # Relabel without the per-call output-unit warning; the consistency check + # runs once below, after all units are applied. + self._relabel_parameter_unit(parameter_name, parameter_unit) # Create numerical function ordered_symbols = [sp.Symbol(name) for name in self._symbol_names] - self._func = sp.lambdify( ordered_symbols, self._expr, modules=[{'erf': erf}, 'numpy'], ) + self._warn_if_output_unit_mismatch() + # ------------------------- # Properties # ------------------------- @@ -255,31 +349,48 @@ def expression(self, _new_expr: str) -> None: AttributeError Always raised to prevent changing the expression. """ + raise AttributeError('Expression cannot be changed after initialization') - def evaluate( - self, - x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray, - ) -> np.ndarray: + def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray: """ - Evaluate the expression for given x values. + Evaluate the expression for the given x values. + + Unit conversion of parameters is not supported for ExpressionComponent. If x_vals is + expressed in a different unit than x_unit, a warning is issued and the values are used + as-is. Parameters ---------- - x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - Input values for the independent variable. + x_vals : np.ndarray + Raw x values expressed in eval_unit. + eval_unit : str | None + The unit of x_vals. Returns ------- np.ndarray Evaluated results. """ - x = self._prepare_x_for_evaluate(x) + if ( + eval_unit is not None + and self.x_unit is not None + and sc.Unit(eval_unit) != sc.Unit(self.x_unit) + ): + warnings.warn( + f'Input x has unit {eval_unit} but {self.__class__.__name__} has ' + f'x_unit {self.x_unit}. ExpressionComponent cannot auto-convert parameters. ' + 'x values are used as-is.', + UserWarning, + stacklevel=3, + ) args = [] for name in self._symbol_names: if name == 'x': - args.append(x) + args.append(x_vals) + elif name in self._constants: + args.append(self._constants[name].value) else: args.append(self._parameters[name].value) @@ -296,11 +407,302 @@ def get_all_variables(self) -> list[Parameter]: """ return list(self._parameters.values()) - def convert_unit(self, _new_unit: str | sc.Unit) -> None: + def set_unit(self, name: str, unit: str | sc.Unit) -> None: + """ + Set the unit of a parameter without rescaling its value. + + This relabels the unit: the numeric value, bounds, and variance are kept as-is. Use + ``Parameter.convert_unit`` instead to rescale a value into a compatible unit. Issues a + warning if the resulting output unit of the expression no longer matches y_unit. Raises the + same exceptions as :meth:`_relabel_parameter_unit` on invalid input. + + Parameters + ---------- + name : str + Name of the parameter whose unit to set. + unit : str | sc.Unit + The new unit. """ - Convert the unit of the expression. + self._relabel_parameter_unit(name, unit) + self._warn_if_output_unit_mismatch() - Unit conversion is not implemented for ExpressionComponent. + def _relabel_parameter_unit(self, name: str, unit: str | sc.Unit) -> None: + """ + Validate and apply a unit relabel to a parameter (see set_unit). + + Parameters + ---------- + name : str + Name of the parameter whose unit to set. + unit : str | sc.Unit + The new unit. + + Raises + ------ + TypeError + If name is not a string or unit is not a string or sc.Unit. + KeyError + If no parameter with the given name exists. + ValueError + If unit is not a valid scipp unit. + AttributeError + If the parameter is a physical constant or a dependent parameter. + """ + if not isinstance(name, str): + raise TypeError(f'name must be a string, got {type(name).__name__}') + if '_constants' in self.__dict__ and name in self._constants: + raise AttributeError( + f"'{name}' is a physical constant. Use expr.{name}.convert_unit(...) to change " + f'its unit (this rescales its value).' + ) + if '_parameters' not in self.__dict__ or name not in self._parameters: + raise KeyError(f"No parameter named '{name}' in this {self.__class__.__name__}.") + if not isinstance(unit, (str, sc.Unit)): + raise TypeError(f'unit must be a string or sc.Unit, got {type(unit).__name__}') + try: + new_unit = sc.Unit(str(unit)) + except sc.UnitError as e: + raise ValueError(f"'{unit}' is not a valid scipp unit.") from e + + param = self._parameters[name] + if not param.independent: + raise AttributeError( + f"Cannot set the unit of dependent parameter '{name}'; its unit is controlled " + f'by the dependency expression.' + ) + + # Parameter.unit is read-only and convert_unit rescales the value, so a pure relabel + # has to swap the underlying scipp scalars (value and bounds) directly. + param._scalar = sc.scalar( # noqa: SLF001 + param.value, + unit=new_unit, + variance=param._scalar.variance, # noqa: SLF001 + ) + param._min = sc.scalar(param._min.value, unit=new_unit) # noqa: SLF001 + param._max = sc.scalar(param._max.value, unit=new_unit) # noqa: SLF001 + + @classmethod + def _create_physical_constant(cls, name: str) -> DescriptorNumber: + """ + Create a DescriptorNumber for a supported physical constant. + + The constant is a per-instance copy of the shared constant in ``easydynamics.utils``, + converted to its default unit, so converting its unit on one component does not affect + others. + + Parameters + ---------- + name : str + The constant's symbol name (a key of ``_PHYSICAL_CONSTANTS``). + + Returns + ------- + DescriptorNumber + The constant's value in its default unit. + """ + source, default_unit = cls._PHYSICAL_CONSTANTS[name] + constant = copy(source) + constant.convert_unit(default_unit) + return constant + + @property + def constants(self) -> dict[str, DescriptorNumber]: + """ + Get the physical constants used by the expression. + + Returns + ------- + dict[str, DescriptorNumber] + The automatically provided constants (e.g. ``hbar``, ``kb``) keyed by symbol name. + """ + return dict(self._constants) + + @property + def output_unit(self) -> str: + """ + Get the unit of the evaluated expression, derived from x_unit and the parameter units. + + The unit is propagated through the expression tree: addition requires compatible units, + multiplication and powers combine units, and functions like ``exp`` or ``sin`` require a + dimensionless argument. Propagation raises ``sc.UnitError`` if the expression is not + unit-consistent (e.g. adding meV to a dimensionless quantity, or taking ``exp`` of a + quantity with a unit). + + Returns + ------- + str + The unit of the evaluated expression. + """ + return str(self._canonical_unit(self._propagate_unit(self._expr))) + + def _canonical_unit(self, unit: sc.Unit) -> sc.Unit: + """ + Replace a unit by an equal, cleaner-printing candidate where possible. + + Unit algebra accumulates tiny floating-point scale factors (e.g. ``meV * 1/meV`` prints as + ``0.999999999999999889`` rather than ``dimensionless``); scipp's unit equality still + matches, so map such units onto the common candidates for readable output. + + Parameters + ---------- + unit : sc.Unit + The unit to canonicalize. + + Returns + ------- + sc.Unit + An equal unit with a cleaner string representation, or the input unit unchanged. + """ + candidate_strings = ['dimensionless', self.y_unit, self.x_unit] + for candidate_string in candidate_strings: + if candidate_string is None: + continue + candidate = sc.Unit(candidate_string) + if unit == candidate: + return candidate + return unit + + def _propagate_unit(self, node: sp.Basic) -> sc.Unit: + """ + Recursively derive the unit of a sympy expression node. + + Parameters + ---------- + node : sp.Basic + The sympy expression node to derive the unit of. + + Raises + ------ + sc.UnitError + If the node mixes incompatible units, applies a function that requires a dimensionless + argument to a quantity with a unit, or raises a quantity with a unit to a symbolic or + unsupported power. + + Returns + ------- + sc.Unit + The unit of the node. + """ + dimensionless = sc.Unit('dimensionless') + + if isinstance(node, sp.Symbol): + name = str(node) + if name == 'x': + return sc.Unit(self.x_unit) if self.x_unit is not None else dimensionless + if name in self._constants: + return sc.Unit(str(self._constants[name].unit)) + return sc.Unit(str(self._parameters[name].unit)) + + if node.is_number: + return dimensionless + + if isinstance(node, sp.Add): + units = [self._propagate_unit(argument) for argument in node.args] + for unit in units[1:]: + try: + sc.to_unit(sc.scalar(1.0, unit=unit), units[0]) + except sc.UnitError as e: + raise sc.UnitError( + f'The expression adds quantities with incompatible units ' + f'{units[0]} and {unit}.' + ) from e + return units[0] + + if isinstance(node, sp.Mul): + result = dimensionless + for argument in node.args: + result = result * self._propagate_unit(argument) + return result + + if isinstance(node, sp.Pow): + base_unit = self._propagate_unit(node.base) + exponent_unit = self._propagate_unit(node.exp) + if exponent_unit != dimensionless: + raise sc.UnitError(f'An exponent must be dimensionless, got {exponent_unit}.') + if base_unit == dimensionless: + return dimensionless + if not node.exp.is_number: + raise sc.UnitError( + f'Cannot determine units: quantity with unit {base_unit} is raised to a ' + f'symbolic power.' + ) + exponent = float(node.exp) + if exponent == int(exponent): + return base_unit ** int(exponent) + if 2 * exponent == int(2 * exponent): + # Half-integer power: the square root of an integer power. + return sc.sqrt(sc.scalar(1.0, unit=base_unit ** int(2 * exponent))).unit + raise sc.UnitError( + f'Cannot determine units: quantity with unit {base_unit} is raised to the ' + f'power {exponent}.' + ) + + if isinstance(node, self._UNIT_PRESERVING_FUNCS): + return self._propagate_unit(node.args[0]) + + if isinstance(node, self._DIMENSIONLESS_RESULT_FUNCS): + self._propagate_unit(node.args[0]) + return dimensionless + + if isinstance(node, sp.Function): + for argument in node.args: + argument_unit = self._propagate_unit(argument) + try: + sc.to_unit(sc.scalar(1.0, unit=argument_unit), dimensionless) + except sc.UnitError as e: + raise sc.UnitError( + f'{node.func.__name__} requires a dimensionless argument, ' + f'got {argument_unit}.' + ) from e + return dimensionless + + raise sc.UnitError( + f'Cannot determine units for expression node {node} of type {type(node).__name__}.' + ) + + def _warn_if_output_unit_mismatch(self) -> None: + """ + Warn if the expression's output unit does not match y_unit. + + The check only runs when units are in use, i.e. when the expression uses physical constants + or any parameter has a unit other than dimensionless. Unit-agnostic expressions (all + parameters dimensionless) stay silent. + """ + units_in_use = bool(self._constants) or any( + str(parameter.unit) != 'dimensionless' for parameter in self._parameters.values() + ) + if not units_in_use: + return + + try: + output_unit = sc.Unit(self.output_unit) + except sc.UnitError as e: + warnings.warn( + f'The expression is not unit-consistent: {e}', + UserWarning, + stacklevel=3, + ) + return + + y_unit = sc.Unit(self.y_unit) if self.y_unit is not None else sc.Unit('dimensionless') + if output_unit != y_unit: + warnings.warn( + f'The expression evaluates to unit {output_unit}, which does not match ' + f'y_unit {y_unit}. The evaluated values are labelled with y_unit; adjust the ' + f'parameter units or y_unit to make them consistent.', + UserWarning, + stacklevel=3, + ) + + def convert_x_unit(self, _new_unit: str | sc.Unit) -> None: + """ + Convert the x-axis unit of the expression. + + Unit conversion is not implemented for ExpressionComponent. Should it ever be needed, the + viable path is dimensional analysis on the parameter units: for each parameter, determine + the power n of the x-dimension in its unit and rescale its value by the x-unit conversion + factor to the power n (the generalization of Polynomial's power-law rescaling). This only + works when x_unit has a single unambiguous dimension. Parameters ---------- @@ -312,21 +714,39 @@ def convert_unit(self, _new_unit: str | sc.Unit) -> None: NotImplementedError Always raised to indicate unit conversion is not supported. """ + raise NotImplementedError('Unit conversion is not implemented for ExpressionComponent') + + def convert_y_unit(self, _new_unit: str | sc.Unit) -> None: + """ + Convert the y-axis unit of the expression. + + Unit conversion is not implemented for ExpressionComponent. See convert_x_unit for the + approach that would make it possible. + + Parameters + ---------- + _new_unit : str | sc.Unit + The new unit to convert to (ignored). + Raises + ------ + NotImplementedError + Always raised to indicate unit conversion is not supported. + """ raise NotImplementedError('Unit conversion is not implemented for ExpressionComponent') # ------------------------- # dunder methods # ------------------------- - def __getattr__(self, name: str) -> Parameter: + def __getattr__(self, name: str) -> Parameter | DescriptorNumber: """ - Allow access to parameters as attributes. + Allow access to parameters and physical constants as attributes. Parameters ---------- name : str - Name of the parameter to access. + Name of the parameter or constant to access. Raises ------ @@ -335,11 +755,13 @@ def __getattr__(self, name: str) -> Parameter: Returns ------- - Parameter - The parameter with the given name. + Parameter | DescriptorNumber + The parameter or constant with the given name. """ if '_parameters' in self.__dict__ and name in self._parameters: return self._parameters[name] + if '_constants' in self.__dict__ and name in self._constants: + return self._constants[name] raise AttributeError(f"{self.__class__.__name__} has no attribute '{name}'") def __setattr__(self, name: str, value: Numeric) -> None: @@ -355,30 +777,31 @@ def __setattr__(self, name: str, value: Numeric) -> None: Raises ------ + AttributeError + If the name refers to a physical constant. TypeError If the value is not numeric. """ + if '_constants' in self.__dict__ and name in self._constants: + raise AttributeError(f"'{name}' is a physical constant and cannot be set.") if '_parameters' in self.__dict__ and name in self._parameters: param = self._parameters[name] - if not isinstance(value, Numeric): raise TypeError(f'{name} must be numeric') - param.value = value else: - # For other attributes, use default behavior super().__setattr__(name, value) def __dir__(self) -> list[str]: """ - Include parameter names in dir() output for better IDE support. + Include parameter and constant names in dir() output for better IDE support. Returns ------- list[str] - List of attribute names, including parameters. + List of attribute names, including parameters and constants. """ - return super().__dir__() + list(self._parameters.keys()) + return super().__dir__() + list(self._parameters.keys()) + list(self._constants.keys()) def __repr__(self) -> str: """ @@ -390,10 +813,15 @@ def __repr__(self) -> str: String representation of the ExpressionComponent. """ param_str = ', '.join(f'{k}={v.value}' for k, v in self._parameters.items()) + constants_str = '' + if self._constants: + constants_str = ',\n constants={ ' + ', '.join( + f'{k}={v.value} {v.unit}' for k, v in self._constants.items() + ) + constants_str += ' }' return ( - f'{self.__class__.__name__}(' - f'name={self.name!r}, display_name={self.display_name!r}, ' - f'unit={self._unit},\n' - f' expr={self._expression_str!r},\n' - f' parameters={{{param_str}}})' + f'{self.__class__.__name__}(name={self.name}, display_name={self.display_name}, ' + f'x_unit={self.x_unit}, y_unit={self.y_unit},\n' + f" expr='{self._expression_str}',\n" + f' parameters={{ {param_str} }}{constants_str} )' ) diff --git a/src/easydynamics/sample_model/components/gaussian.py b/src/easydynamics/sample_model/components/gaussian.py index 6a1805e13..364e89aba 100644 --- a/src/easydynamics/sample_model/components/gaussian.py +++ b/src/easydynamics/sample_model/components/gaussian.py @@ -20,12 +20,11 @@ class Gaussian(CreateParametersMixin, ModelComponent): r""" Model of a Gaussian function. - The intensity is given by - $$ I(x) = \frac{A}{\sigma \sqrt{2\pi}} \exp\left( -\frac{1}{2} \left(\frac{x - x_0}{\sigma}\right)^2 \right) $$ - where $A$ is the area, $x_0$ is the center, and $\sigma$ is the width. + where $A$ is the area, $x_0$ is the center, and $\sigma$ is the width. area has unit = x_unit * + y_unit; center and width have unit = x_unit. If the center is not provided, it will be centered at 0 and fixed, which is typically what you want in QENS. @@ -62,7 +61,8 @@ def __init__( area: Numeric = 1.0, center: Numeric | None = None, width: Numeric = 1.0, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'Gaussian', display_name: str | None = None, unique_name: str | None = None, @@ -73,39 +73,38 @@ def __init__( Parameters ---------- area : Numeric, default=1.0 - Area of the Gaussian. + Integrated area under the Gaussian. Unit is ``x_unit * y_unit``. center : Numeric | None, default=None - Center of the Gaussian. If None, defaults to 0 and is fixed. + Peak position in x_unit. If None, defaults to 0 and the center parameter is fixed. width : Numeric, default=1.0 - Standard deviation. - unit : str | sc.Unit, default='meV' - Unit of the parameters. + Standard deviation (sigma) in x_unit. Must be strictly positive. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis. center and width are stored in this unit. area_unit = x_unit * + y_unit. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the y-axis (output). name : str, default='Gaussian' - Name of the component for indexing. - display_name : str | None, default=None Name of the component. + display_name : str | None, default=None + Display name shown when plotting. Falls back to *name* if None. unique_name : str | None, default=None - Unique name of the component. if None, a unique_name is automatically generated. By - default, None. + Globally unique identifier. Auto-generated if None. """ - # Validate inputs and create Parameters if not given super().__init__( - unit=unit, + x_unit=x_unit, + y_unit=y_unit, name=name, display_name=display_name, unique_name=unique_name, ) - # These methods live in ValidationMixin - area = self._create_area_parameter(area=area, name=name, unit=self._unit) - center = self._create_center_parameter( - center=center, name=name, fix_if_none=True, unit=self._unit + self._area = self._create_area_parameter( + area=area, name=name, x_unit=self.x_unit, y_unit=self.y_unit ) - width = self._create_width_parameter(width=width, name=name, unit=self._unit) - - self._area = area - self._center = center - self._width = width + self._center = self._create_center_parameter( + center=center, name=name, fix_if_none=True, x_unit=self.x_unit + ) + self._width = self._create_width_parameter(width=width, name=name, x_unit=self.x_unit) @property def area(self) -> Parameter: @@ -115,27 +114,23 @@ def area(self) -> Parameter: Returns ------- Parameter - The area parameter. + The area Parameter with unit ``x_unit * y_unit``. """ - return self._area @area.setter def area(self, value: Numeric) -> None: """ - Set the value of the area parameter. - Parameters ---------- value : Numeric - The new value for the area parameter. + New area value (in current area unit = x_unit * y_unit). Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. """ - if not isinstance(value, Numeric): raise TypeError('area must be a number') self._area.value = value @@ -148,27 +143,24 @@ def center(self) -> Parameter: Returns ------- Parameter - The center parameter. + The center Parameter with unit ``x_unit``. """ - return self._center @center.setter def center(self, value: Numeric | None) -> None: """ - Set the center parameter value. - Parameters ---------- value : Numeric | None - The new value for the center parameter. If None, defaults to 0 and is fixed. + New center value in x_unit. If None, the center is set to 0 and the parameter is + fixed. Raises ------ TypeError - If the value is not a number or None. + If *value* is not None and not a numeric type. """ - if value is None: value = 0.0 self._center.fixed = True @@ -179,48 +171,43 @@ def center(self, value: Numeric | None) -> None: @property def width(self) -> Parameter: """ - Get the width parameter (standard deviation). + Get the width parameter (sigma). Returns ------- Parameter - The width parameter. + The width (sigma) Parameter with unit ``x_unit``. """ return self._width @width.setter def width(self, value: Numeric) -> None: """ - Set the width parameter value. - Parameters ---------- value : Numeric - The new value for the width parameter. + New width value in x_unit. Must be strictly positive. Raises ------ TypeError - If the value is not a number or None. + If *value* is not a numeric type. ValueError - If the value is not positive. + If *value* is not positive. """ if not isinstance(value, Numeric): raise TypeError('width must be a number') - if float(value) <= 0: raise ValueError('width must be positive') - self._width.value = value - def evaluate( - self, - x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray, - ) -> np.ndarray: + def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray: r""" - Evaluate the Gaussian at the given x values. + Evaluate the Gaussian at x_vals. + + Parameters in the model's own units are temporarily converted to eval_unit for the + computation. - If x is a scipp Variable, the unit of the Gaussian will be converted to match x. The intensity is given by $$ I(x) = \frac{A}{\sigma \sqrt{2\pi}} \exp\left( -\frac{1}{2} \left(\frac{x - x_0}{\sigma}\right)^2 \right) $$ @@ -228,21 +215,51 @@ def evaluate( Parameters ---------- - x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - The x values at which to evaluate the Gaussian. + x_vals : np.ndarray + Raw x values expressed in eval_unit. + eval_unit : str | None + The unit of x_vals. Returns ------- np.ndarray - The intensity of the Gaussian at the given x values. + Evaluated Gaussian values at x_vals. + """ + center = self._resolve_param_value(self._center, eval_unit) + width = self._resolve_param_value(self._width, eval_unit) + area = self._resolve_param_value(self._area, self._eval_area_unit(eval_unit)) + + normalization = 1 / (np.sqrt(2 * np.pi) * width) + exponent = -0.5 * ((x_vals - center) / width) ** 2 + return area * normalization * np.exp(exponent) + + def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None: + """ + Convert x-axis parameters (center, width) and area to new_x_unit. + + Parameters + ---------- + new_x_unit : str | sc.Unit + Target x-axis unit. Must be dimensionally compatible with the current x_unit. """ + self._convert_x_unit_area_based( + new_x_unit=new_x_unit, + x_params=[self._center, self._width], + area_param=self._area, + ) - x = self._prepare_x_for_evaluate(x) + def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None: + """ + Convert the y-axis (output) unit by rescaling the area parameter. - normalization = 1 / (np.sqrt(2 * np.pi) * self.width.value) - exponent = -0.5 * ((x - self.center.value) / self.width.value) ** 2 + The area is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``. - return self.area.value * normalization * np.exp(exponent) + Parameters + ---------- + new_y_unit : str | sc.Unit + Target y-axis unit. + """ + self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._area) def __repr__(self) -> str: """ @@ -253,12 +270,10 @@ def __repr__(self) -> str: str A string representation of the Gaussian. """ - return ( - f'{self.__class__.__name__}(' - f'name={self.name!r}, display_name={self.display_name!r}, ' - f'unit={self._unit},\n' - f' area={self.area},\n' - f' center={self.center},\n' - f' width={self.width})' + f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, ' + f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n' + f' area = {self.area},\n' + f' center = {self.center},\n' + f' width = {self.width})' ) diff --git a/src/easydynamics/sample_model/components/lorentzian.py b/src/easydynamics/sample_model/components/lorentzian.py index f2da78c84..fe36340fc 100644 --- a/src/easydynamics/sample_model/components/lorentzian.py +++ b/src/easydynamics/sample_model/components/lorentzian.py @@ -20,9 +20,10 @@ class Lorentzian(CreateParametersMixin, ModelComponent): r""" Model of a Lorentzian function. - The intensity is given by $$ I(x) = \frac{A}{\pi} \frac{\Gamma}{(x - x_0)^2 + \Gamma^2}, $$ - where $A$ is the area, $x_0$ is the center, and $\Gamma$ is the half width at half maximum - (HWHM). + $$ I(x) = \frac{A}{\pi} \frac{\Gamma}{(x - x_0)^2 + \Gamma^2} $$ + + where $A$ is the area, $x_0$ is the center, and $\Gamma$ is the hald width at half max (HWHM). + area has unit = x_unit * y_unit; center and width have unit = x_unit. If the center is not provided, it will be centered at 0 and fixed, which is typically what you want in QENS. @@ -58,7 +59,8 @@ def __init__( area: Numeric = 1.0, center: Numeric | None = None, width: Numeric = 1.0, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'Lorentzian', display_name: str | None = None, unique_name: str | None = None, @@ -69,39 +71,38 @@ def __init__( Parameters ---------- area : Numeric, default=1.0 - Area of the Lorentzian. + Integrated area under the Lorentzian. Unit is ``x_unit * y_unit``. center : Numeric | None, default=None - Center of the Lorentzian. If None, defaults to 0 and is fixed. + Peak position in x_unit. If None, defaults to 0 and the center parameter is fixed. width : Numeric, default=1.0 - Half width at half maximum (HWHM). - unit : str | sc.Unit, default='meV' - Unit of the parameters. + Half-width at half-maximum (HWHM, gamma) in x_unit. Must be strictly positive. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis. center and width are stored in this unit. area_unit = x_unit * + y_unit. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the y-axis (output). name : str, default='Lorentzian' - Name of the component for indexing. + Name of the component. display_name : str | None, default=None - Display name for the component. + Display name shown when plotting. Falls back to *name* if None. unique_name : str | None, default=None - Unique name of the component. If None, a unique_name is automatically generated. By - default, None. + Globally unique identifier. Auto-generated if None. """ - super().__init__( - unit=unit, + x_unit=x_unit, + y_unit=y_unit, name=name, display_name=display_name, unique_name=unique_name, ) - # These methods live in ValidationMixin - area = self._create_area_parameter(area=area, name=name, unit=self._unit) - center = self._create_center_parameter( - center=center, name=name, fix_if_none=True, unit=self._unit + self._area = self._create_area_parameter( + area=area, name=name, x_unit=self.x_unit, y_unit=self.y_unit ) - width = self._create_width_parameter(width=width, name=name, unit=self._unit) - - self._area = area - self._center = center - self._width = width + self._center = self._create_center_parameter( + center=center, name=name, fix_if_none=True, x_unit=self.x_unit + ) + self._width = self._create_width_parameter(width=width, name=name, x_unit=self.x_unit) @property def area(self) -> Parameter: @@ -111,24 +112,22 @@ def area(self) -> Parameter: Returns ------- Parameter - The area parameter. + The area Parameter with unit ``x_unit * y_unit``. """ return self._area @area.setter def area(self, value: Numeric) -> None: """ - Set the value of the area parameter. - Parameters ---------- value : Numeric - The new value for the area parameter. + New area value (in current area unit = x_unit * y_unit). Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. """ if not isinstance(value, Numeric): raise TypeError('area must be a number') @@ -142,26 +141,24 @@ def center(self) -> Parameter: Returns ------- Parameter - The center parameter. + The center Parameter with unit ``x_unit``. """ return self._center @center.setter def center(self, value: Numeric | None) -> None: """ - Set the value of the center parameter. - Parameters ---------- value : Numeric | None - The new value for the center parameter. If None, defaults to 0 and is fixed. + New center value in x_unit. If None, the center is set to 0 and the parameter is + fixed. Raises ------ TypeError - If the value is not a number or None. + If *value* is not None and not a numeric type. """ - if value is None: value = 0.0 self._center.fixed = True @@ -177,63 +174,85 @@ def width(self) -> Parameter: Returns ------- Parameter - The width parameter. + The HWHM (gamma) Parameter with unit ``x_unit``. """ return self._width @width.setter def width(self, value: Numeric) -> None: """ - Set the width parameter value (HWHM). - Parameters ---------- value : Numeric - The new value for the width parameter. + New HWHM value in x_unit. Must be strictly positive. Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. ValueError - If the value is not positive. + If *value* is not positive. """ if not isinstance(value, Numeric): raise TypeError('width must be a number') - if float(value) <= 0: raise ValueError('width must be positive') self._width.value = value - def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray: + def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray: r""" - Evaluate the Lorentzian at the given x values. + Evaluate the Lorentzian at x_vals. - If x is a scipp Variable, the unit of the Lorentzian will be converted to match x. The - intensity is given by - - $$ I(x) = \frac{A}{\pi} \frac{\Gamma}{(x - x_0)^2 + \Gamma^2}, $$ - - where $A$ is the area, $x_0$ is the center, and $\Gamma$ is the half width at half maximum - (HWHM). + Parameters in the model's own units are temporarily converted to eval_unit for the + computation. Parameters ---------- - x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - The x values at which to evaluate the Lorentzian. + x_vals : np.ndarray + Raw x values expressed in eval_unit. + eval_unit : str | None + The unit of x_vals. Returns ------- np.ndarray - The intensity of the Lorentzian at the given x values. + Evaluated Lorentzian values at x_vals. """ + center = self._resolve_param_value(self._center, eval_unit) + width = self._resolve_param_value(self._width, eval_unit) + area = self._resolve_param_value(self._area, self._eval_area_unit(eval_unit)) - x = self._prepare_x_for_evaluate(x) + normalization = width / np.pi + denominator = (x_vals - center) ** 2 + width**2 + return area * normalization / denominator - normalization = self.width.value / np.pi - denominator = (x - self.center.value) ** 2 + self.width.value**2 + def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None: + """ + Convert x-axis parameters (center, width) and area to new_x_unit. - return self.area.value * normalization / denominator + Parameters + ---------- + new_x_unit : str | sc.Unit + Target x-axis unit. Must be dimensionally compatible with the current x_unit. + """ + self._convert_x_unit_area_based( + new_x_unit=new_x_unit, + x_params=[self._center, self._width], + area_param=self._area, + ) + + def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None: + """ + Convert the y-axis (output) unit by rescaling the area parameter. + + The area is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``. + + Parameters + ---------- + new_y_unit : str | sc.Unit + Target y-axis unit. + """ + self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._area) def __repr__(self) -> str: """ @@ -245,10 +264,9 @@ def __repr__(self) -> str: A string representation of the Lorentzian. """ return ( - f'{self.__class__.__name__}(' - f'name={self.name!r}, display_name={self.display_name!r}, ' - f'unit={self._unit},\n' - f' area={self.area},\n' - f' center={self.center},\n' - f' width={self.width})' + f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, ' + f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n' + f' area = {self.area},\n' + f' center = {self.center},\n' + f' width = {self.width})' ) diff --git a/src/easydynamics/sample_model/components/mixins.py b/src/easydynamics/sample_model/components/mixins.py index 7552d2e80..82be533f3 100644 --- a/src/easydynamics/sample_model/components/mixins.py +++ b/src/easydynamics/sample_model/components/mixins.py @@ -18,52 +18,54 @@ class CreateParametersMixin: """ Provides parameter creation and validation methods for model components. - This mixin provides methods to create and validate common physics parameters (area, center, - width) with appropriate bounds and type checking. + area_unit = x_unit * y_unit, so when y_unit='dimensionless', area_unit = x_unit. """ def _create_area_parameter( self, area: Numeric, name: str, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', minimum_area: float = MINIMUM_AREA, ) -> Parameter: """ - Validate and convert a number to a Parameter describing the area of a function. - - If the area is negative, a warning is raised. If the area is non-negative, its minimum is - set to 0 to avoid it accidentally becoming negative during fitting. + Create a Parameter for the area with unit = x_unit * y_unit. Parameters ---------- area : Numeric - The area value. + Initial area value. name : str - The name of the model component. - unit : str | sc.Unit, default='meV' - The unit of the area Parameter. + Base name used to label the Parameter (``name + ' area'``). + x_unit : str | sc.Unit, default='meV' + X-axis unit. The resulting area unit is ``x_unit * y_unit``. + y_unit : str | sc.Unit, default='dimensionless' + Y-axis unit. The resulting area unit is ``x_unit * y_unit``. minimum_area : float, default=MINIMUM_AREA - The minimum allowed area. + Lower bound applied to the Parameter when the area is non-negative. When *area* is + negative no lower bound is set and a :class:`UserWarning` is issued. + + Returns + ------- + Parameter + Configured area Parameter with ``unit = x_unit * y_unit``. Raises ------ TypeError - If area is not a number. + If *area* is not a numeric type. ValueError - If area is not a finite number. - - Returns - ------- - Parameter - The validated area Parameter. + If *area* is not finite. """ if not isinstance(area, Numeric): raise TypeError('area must be a number.') if not np.isfinite(area): raise ValueError('area must be a finite number.') - area_param = Parameter(name=name + ' area', value=float(area), unit=unit) + + area_unit = str(sc.Unit(x_unit) * sc.Unit(y_unit)) + area_param = Parameter(name=name + ' area', value=float(area), unit=area_unit) if area_param.value < 0: warnings.warn( @@ -81,36 +83,38 @@ def _create_center_parameter( center: Numeric | None, name: str, fix_if_none: bool, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', enforce_minimum_center: bool = False, ) -> Parameter: """ - Validate and convert a number to a Parameter describing the center of a function. + Create a Parameter for the center with unit = x_unit. Parameters ---------- center : Numeric | None - The center value. + Initial center value. If None, the center is set to 0.0 and ``fixed`` is controlled by + *fix_if_none*. name : str - The name of the model component. + Base name used to label the Parameter (``name + ' center'``). fix_if_none : bool - Whether to fix the center Parameter if center is None. - unit : str | sc.Unit, default='meV' - The unit of the center Parameter. + Whether to fix the Parameter when *center* is None. + x_unit : str | sc.Unit, default='meV' + X-axis unit, applied to the center Parameter. enforce_minimum_center : bool, default=False - Whether to enforce a minimum center value to avoid zero center in DHO. + If True, the Parameter's lower bound is raised to ``DHO_MINIMUM_CENTER`` (1e-10) to + prevent a zero center. + + Returns + ------- + Parameter + Configured center Parameter with ``unit = x_unit``. Raises ------ TypeError - If center is not None or a number. + If *center* is not None and not a numeric type. ValueError - If center is a number but not finite. - - Returns - ------- - Parameter - The validated center Parameter. + If *center* is not None and not finite. """ if center is not None and not isinstance(center, Numeric): raise TypeError('center must be None or a number.') @@ -119,16 +123,18 @@ def _create_center_parameter( center_param = Parameter( name=name + ' center', value=0.0, - unit=unit, + unit=x_unit, fixed=fix_if_none, ) else: if not np.isfinite(center): raise ValueError('center must be None or a finite number.') + center_param = Parameter(name=name + ' center', value=float(center), unit=x_unit) - center_param = Parameter(name=name + ' center', value=float(center), unit=unit) if enforce_minimum_center and center_param.min < DHO_MINIMUM_CENTER: center_param.min = DHO_MINIMUM_CENTER + if center_param.value < DHO_MINIMUM_CENTER: + center_param.value = DHO_MINIMUM_CENTER return center_param def _create_width_parameter( @@ -136,36 +142,37 @@ def _create_width_parameter( width: Numeric, name: str, param_name: str = 'width', - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', minimum_width: float = MINIMUM_WIDTH, ) -> Parameter: """ - Validate and convert a number to a Parameter describing the width of a function. + Create a Parameter for the width with unit = x_unit. Parameters ---------- width : Numeric - The width value. + Initial width value. Must be strictly positive (>= *minimum_width*). name : str - The name of the model component. + Base name used to label the Parameter (``name + ' ' + param_name``). param_name : str, default='width' - The name of the width parameter. - unit : str | sc.Unit, default='meV' - The unit of the width Parameter. + Logical name of the parameter used in the label and error messages (e.g. + ``'gaussian_width'``, ``'lorentzian_width'``). + x_unit : str | sc.Unit, default='meV' + X-axis unit, applied to the width Parameter. minimum_width : float, default=MINIMUM_WIDTH - The minimum allowed width. + Absolute lower bound for the width to prevent division-by-zero. + + Returns + ------- + Parameter + Configured width Parameter with ``unit = x_unit`` and ``min = minimum_width``. Raises ------ TypeError - If width is not a number. + If *width* is not a numeric type. ValueError - If width is non-positive. - - Returns - ------- - Parameter - The validated width Parameter. + If *width* is not finite or is smaller than *minimum_width*. """ if not isinstance(width, Numeric): raise TypeError(f'{param_name} must be a number.') @@ -180,6 +187,6 @@ def _create_width_parameter( return Parameter( name=name + ' ' + param_name, value=float(width), - unit=unit, + unit=x_unit, min=minimum_width, ) diff --git a/src/easydynamics/sample_model/components/model_component.py b/src/easydynamics/sample_model/components/model_component.py index 878f501fa..1d2c98d0f 100644 --- a/src/easydynamics/sample_model/components/model_component.py +++ b/src/easydynamics/sample_model/components/model_component.py @@ -3,8 +3,8 @@ from __future__ import annotations -import warnings from abc import abstractmethod +from typing import TYPE_CHECKING import numpy as np import scipp as sc @@ -12,6 +12,10 @@ from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase from easydynamics.utils.utils import Numeric +from easydynamics.utils.utils import convert_parameter_unit + +if TYPE_CHECKING: + from easyscience.variable import Parameter class ModelComponent(EasyDynamicsModelBase): @@ -19,107 +23,145 @@ class ModelComponent(EasyDynamicsModelBase): def __init__( self, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'ModelComponent', display_name: str | None = None, unique_name: str | None = None, ) -> None: """ - Initialize the ModelComponent. - Parameters ---------- - unit : str | sc.Unit, default='meV' - The unit of the model component. + x_unit : str | sc.Unit, default='meV' + Unit for the x-axis (independent variable) of this component. + y_unit : str | sc.Unit, default='dimensionless' + Unit for the y-axis (dependent variable / output) of this component. name : str, default='ModelComponent' - The name of the model component for indexing. + Internal name used for parameter labelling and logging. display_name : str | None, default=None - A human-readable name for the component. + Human-readable name shown in plots and reports. Falls back to *name* if None. unique_name : str | None, default=None - A unique identifier for the component. + Globally unique identifier. Auto-generated if None. """ super().__init__( - unit=unit, + x_unit=x_unit, + y_unit=y_unit, name=name, display_name=display_name, unique_name=unique_name, ) @property - def unit(self) -> str: + def x_unit(self) -> str | None: """ - Get the unit. - + Get the unit of the x axis. Returns ------- - str - The unit of the model component. + str | None + The current x-axis unit as a string, or None if no unit is set. """ - return str(self._unit) + return str(self._x_unit) if self._x_unit is not None else None - @unit.setter - def unit(self, _unit_str: str) -> None: + @x_unit.setter + def x_unit(self, _: str) -> None: """ - Unit is read-only. + Unit is read-only; raises AttributeError always. - Use convert_unit to change the unit between allowed types or create a new ModelComponent - with the desired unit. + Use :meth:`convert_x_unit` to change the unit, or create a new instance with the desired + unit. - Parameters - ---------- - _unit_str : str - The new unit to set. + Raises + ------ + AttributeError + Always raised when this setter is called. + """ + raise AttributeError( + f'x_unit is read-only. Use convert_x_unit to change the unit ' + f'or create a new {self.__class__.__name__} with the desired unit.' + ) + + @property + def y_unit(self) -> str | None: + """ + Get the unit of the y axis + + Returns + ------- + str | None + The current y-axis unit as a string, or None if no unit is set. + """ + return str(self._y_unit) if self._y_unit is not None else None + + @y_unit.setter + def y_unit(self, _: str) -> None: + """ + Unit is read-only; raises AttributeError always. + + Use :meth:`convert_y_unit` to change the unit, or create a new instance with the desired + unit. Raises ------ AttributeError - Always raised since unit is read-only. + Always raised when this setter is called. """ raise AttributeError( - f'Unit is read-only. Use convert_unit to change the unit between allowed types ' + f'y_unit is read-only. Use convert_y_unit to change the unit ' f'or create a new {self.__class__.__name__} with the desired unit.' ) def fix_all_parameters(self) -> None: - """Fix all parameters in the model component.""" + """ + Fix all parameters in the model component. - pars = self.get_fittable_parameters() - for p in pars: + Sets ``fixed=True`` on every fittable parameter returned by + :meth:`get_fittable_parameters`. + """ + for p in self.get_fittable_parameters(): p.fixed = True def free_all_parameters(self) -> None: - """Free all parameters in the model component.""" + """ + Free all parameters in the model component. + + Sets ``fixed=False`` on every fittable parameter returned by + :meth:`get_fittable_parameters`. + """ for p in self.get_fittable_parameters(): p.fixed = False def _prepare_x_for_evaluate( self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray - ) -> np.ndarray: + ) -> tuple[np.ndarray, str | None, str]: """ - Prepare the input x for evaluation by handling units and converting to a numpy array. + Validate x and extract its values, detected unit, and dimension name. + + x is never converted. When x carries a unit, the caller is responsible for resolving + parameter values to that unit via _resolve_param_value. Parameters ---------- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - The input data to prepare. + Input x values to validate and extract. + + Returns + ------- + tuple[np.ndarray, str | None, str] + x_values : np.ndarray of raw float values (no unit conversion) detected_unit : str unit + of x if scipp input, else None dim : scipp dimension name if scipp input, else 'x' Raises ------ - ValueError - If x contains NaN or infinite values, or if a sc.DataArray has more than one - coordinate. UnitError - If x has incompatible units that cannot be converted to the component's unit. - - Returns - ------- - np.ndarray - The prepared input data as a numpy array. + If x has a unit incompatible with the model's x_unit. + ValueError + If x contains NaN or infinite values, or if a DataArray has more than one coordinate. """ + detected_unit: str | None = None + dim: str = 'x' + dim_from_dataarray: bool = False - # Handle units if isinstance(x, sc.DataArray): - # Check that there's exactly one coordinate coords = dict(x.coords) ncoords = len(coords) if ncoords != 1: @@ -128,31 +170,25 @@ def _prepare_x_for_evaluate( f'scipp.DataArray must have exactly one coordinate to be used as input `x`. ' f'Found {ncoords} coordinates: {coord_names}.' ) - # get the coordinate, it's a sc.Variable - _, coord_obj = next(iter(coords.items())) + dim, coord_obj = next(iter(coords.items())) x = coord_obj + dim_from_dataarray = True + if isinstance(x, sc.Variable): - # Need to check if the units are consistent, - # and convert if not. + detected_unit = str(x.unit) + if not dim_from_dataarray: + dim = x.dims[0] if x.dims else 'x' x_in = x.value if x.sizes == {} else x.values - if self._unit is not None and x.unit != self._unit: - self_unit_for_warning = self._unit + + # Validate that x's unit is compatible with model's x_unit + if self.x_unit is not None and detected_unit != self.x_unit: try: - self.convert_unit(x.unit.name) + sc.to_unit(sc.scalar(1.0, unit=detected_unit), self.x_unit) except Exception as e: raise UnitError( - f'Input x has unit {x.unit}, but {self.__class__.__name__} component \ - has unit {self._unit}. \ - Failed to convert {self.__class__.__name__} to {x.unit}.' + f'Input x has unit {detected_unit}, which is incompatible with ' + f'{self.__class__.__name__} x_unit {self.x_unit}.' ) from e - - warnings.warn( - f'Input x has unit {x.unit}, but {self.__class__.__name__} component \ - has unit {self_unit_for_warning}. \ - Converting {self.__class__.__name__} to {x.unit}.', - UserWarning, - stacklevel=3, - ) else: x_in = x @@ -167,60 +203,262 @@ def _prepare_x_for_evaluate( if any(np.isinf(x_in)): raise ValueError('Input x contains infinite values.') - return np.sort(x_in) + return x_in, detected_unit, dim + + def _resolve_param_value(self, param: Parameter, target_unit: str | None) -> float: + """ + Return param's value converted to target_unit without mutating param. + + If target_unit is None or already matches param's unit, returns param.value directly. Uses + a temporary scipp scalar for the conversion. + + Parameters + ---------- + param : Parameter + The parameter whose value should be resolved. + target_unit : str | None + The unit to which the parameter value should be converted. When None (or equal to the + parameter's own unit) the raw value is returned without any conversion. + + Returns + ------- + float + The parameter value expressed in *target_unit*. + """ + if target_unit is None or str(param.unit) == str(target_unit): + return param.value + return sc.to_unit(sc.scalar(param.value, unit=str(param.unit)), target_unit).value - def convert_unit(self, unit: str | sc.Unit) -> None: + def _convert_x_unit_area_based( + self, + new_x_unit: str | sc.Unit, + x_params: list, + area_param: Parameter, + inverse_params: list | None = None, + ) -> None: """ - Convert the unit of the Parameters in the component. + Shared convert_x_unit logic for components with an area parameter (area = x_unit * y_unit). + + Validates the input type, converts all x-axis parameters, the area parameter, and any + reciprocal (``1/x_unit``) parameters to the new unit, and updates ``_x_unit``. Rolls back + all conversions if any step fails. Parameters ---------- - unit : str | sc.Unit - The new unit to convert to. + new_x_unit : str | sc.Unit + Target x-axis unit. + x_params : list + Parameters whose unit equals *x_unit* (e.g. center, width). + area_param : Parameter + The parameter whose unit equals ``x_unit * y_unit``. + inverse_params : list | None, default=None + Parameters whose unit equals ``1/x_unit`` (e.g. an exponential rate). Raises ------ TypeError - If the provided unit is not a str or sc.Unit. + If *new_x_unit* is not a ``str`` or ``sc.Unit``. Exception - If the provided unit is invalid or incompatible with the component's parameters. + If the conversion fails; all parameters are rolled back to their original units. + """ + if not isinstance(new_x_unit, (str, sc.Unit)): + raise TypeError(f'x_unit must be a string or sc.Unit, got {type(new_x_unit).__name__}') + inverse_params = inverse_params or [] + old_x_unit = self.x_unit + new_x_str = str(new_x_unit) if isinstance(new_x_unit, sc.Unit) else new_x_unit + new_area_unit = str(sc.Unit(new_x_str) * sc.Unit(self.y_unit)) + try: + for p in x_params: + convert_parameter_unit(p, new_x_unit) + convert_parameter_unit(area_param, new_area_unit) + for p in inverse_params: + convert_parameter_unit(p, '1/' + new_x_str) + self._x_unit = new_x_str + except Exception as e: + try: + old_area_unit = str(sc.Unit(old_x_unit) * sc.Unit(self.y_unit)) + for p in x_params: + convert_parameter_unit(p, old_x_unit) + convert_parameter_unit(area_param, old_area_unit) + for p in inverse_params: + convert_parameter_unit(p, '1/' + str(old_x_unit)) + except Exception: # noqa: S110 + pass + raise e + + def _convert_y_unit_area_based( + self, + new_y_unit: str | sc.Unit, + area_param: Parameter, + ) -> None: """ - if not isinstance(unit, (str, sc.Unit)): - raise TypeError(f'Unit must be a string or sc.Unit, got {type(unit).__name__}') + Shared convert_y_unit logic for components with an area parameter (area = x_unit * y_unit). - old_unit = self._unit + Validates the input type, rescales the area parameter from ``x_unit * old_y_unit`` to + ``x_unit * new_y_unit``, and updates ``_y_unit``. Rolls back on failure. + + Parameters + ---------- + new_y_unit : str | sc.Unit + Target y-axis unit. + area_param : Parameter + The parameter whose unit equals ``x_unit * y_unit``. + + Raises + ------ + TypeError + If *new_y_unit* is not a ``str`` or ``sc.Unit``. + Exception + If the conversion fails; the area parameter is rolled back to its original unit. + """ + if not isinstance(new_y_unit, (str, sc.Unit)): + raise TypeError(f'y_unit must be a string or sc.Unit, got {type(new_y_unit).__name__}') + old_y_unit = self.y_unit + new_area_unit = str(sc.Unit(self.x_unit) * sc.Unit(new_y_unit)) + try: + convert_parameter_unit(area_param, new_area_unit) + self._y_unit = str(new_y_unit) if isinstance(new_y_unit, sc.Unit) else new_y_unit + except Exception as e: + try: + old_area_unit = str(sc.Unit(self.x_unit) * sc.Unit(old_y_unit)) + convert_parameter_unit(area_param, old_area_unit) + except Exception: # noqa: S110 + pass + raise e + + def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None: + """ + Convert the x-axis unit of the component. + + The base implementation converts all parameters. Subclasses with mixed-unit parameters + (e.g. area ≠ x_unit) should override this method. + + Parameters + ---------- + new_x_unit : str | sc.Unit + Target x-axis unit. Must be dimensionally compatible with the current x_unit. + + Raises + ------ + TypeError + If *new_x_unit* is not a ``str`` or ``sc.Unit``. + Exception + If the conversion between the current unit and *new_x_unit* fails. On failure the + component is rolled back to its original unit. + """ + if not isinstance(new_x_unit, (str, sc.Unit)): + raise TypeError(f'x_unit must be a string or sc.Unit, got {type(new_x_unit).__name__}') + + old_unit = self.x_unit pars = self.get_all_parameters() try: for p in pars: - p.convert_unit(unit) - self._unit = unit + convert_parameter_unit(p, new_x_unit) + self._x_unit = str(new_x_unit) if isinstance(new_x_unit, sc.Unit) else new_x_unit except Exception as e: - # Attempt to rollback on failure try: for p in pars: - if hasattr(p, 'convert_unit'): - p.convert_unit(old_unit) + convert_parameter_unit(p, old_unit) except Exception: # noqa: S110 - pass # Best effort rollback + pass raise e - @abstractmethod - def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray: + def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None: + """ + Convert the y-axis (output) unit. Subclasses with an area parameter should override this. + + Parameters + ---------- + new_y_unit : str | sc.Unit + Target y-axis unit. + + Raises + ------ + NotImplementedError + Always raised in this base implementation. Subclasses that carry an area parameter + (area_unit = x_unit * y_unit) must override this method to rescale the area + appropriately. + """ + raise NotImplementedError(f'{self.__class__.__name__} does not support convert_y_unit.') + + def evaluate( + self, + x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray, + output: str = 'numpy', + ) -> np.ndarray | sc.Variable: """ - Abstract method to evaluate the model component at input x. + Evaluate the model component at input x. - Must be implemented by subclasses. + When x carries a unit (scipp input), parameter values are temporarily converted to that + unit for the computation without mutating the parameters. Parameters ---------- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - The x values at which to evaluate the component. + Input x values. + output : str, default='numpy' + 'numpy' returns np.ndarray; 'scipp' returns sc.Variable with y_unit. + + Raises + ------ + ValueError + If output is not 'numpy' or 'scipp'. + + Returns + ------- + np.ndarray | sc.Variable + Evaluated model values at x. + """ + if output not in ('numpy', 'scipp'): + raise ValueError(f"output must be 'numpy' or 'scipp', got {output!r}") + x_vals, detected_unit, dim = self._prepare_x_for_evaluate(x) + eval_unit = detected_unit or self.x_unit + values = self._evaluate_values(x_vals, eval_unit) + if output == 'scipp': + return sc.array(dims=[dim], values=values, unit=self.y_unit) + return values + + @abstractmethod + def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray: + """ + Compute the component's values for raw x values expressed in eval_unit. + + Implementations must resolve their parameters to *eval_unit* (via + :meth:`_resolve_param_value`) and return plain numpy values; the public :meth:`evaluate` + handles input validation and output wrapping. + + Parameters + ---------- + x_vals : np.ndarray + Raw x values, expressed in *eval_unit*. + eval_unit : str | None + The unit of *x_vals* (the detected input unit, falling back to the component's x_unit), + or None if no unit information is available. Returns ------- np.ndarray - Evaluated function values. + Evaluated model values at x_vals. + """ + + def _eval_area_unit(self, eval_unit: str | None) -> str | None: + """ + Get the area unit (``eval_unit * y_unit``) matching an evaluation unit. + + Parameters + ---------- + eval_unit : str | None + The unit x values are expressed in during evaluation. + + Returns + ------- + str | None + The corresponding area unit, or None when either unit is unknown (in which case + parameter values are used unconverted). """ + if eval_unit is None or self.y_unit is None: + return None + return str(sc.Unit(eval_unit) * sc.Unit(self.y_unit)) def __repr__(self) -> str: """ @@ -231,5 +469,7 @@ def __repr__(self) -> str: str A string representation of the ModelComponent. """ - - return f'{self.__class__.__name__}(unique_name={self.unique_name!r}, unit={self._unit})' + return ( + f'{self.__class__.__name__}(unique_name={self.unique_name}, ' + f'x_unit={self.x_unit}, y_unit={self.y_unit})' + ) diff --git a/src/easydynamics/sample_model/components/polynomial.py b/src/easydynamics/sample_model/components/polynomial.py index a36700dff..683af8d9f 100644 --- a/src/easydynamics/sample_model/components/polynomial.py +++ b/src/easydynamics/sample_model/components/polynomial.py @@ -23,8 +23,10 @@ class Polynomial(ModelComponent): r""" Polynomial function component. - The intensity is given by $$ I(x) = c_0 + c_1 x + c_2 x^2 + ... + c_N x^N, $$ where $C_i$ are - the coefficients. + $$ I(x) = c_0 + c_1 x + c_2 x^2 + ... + c_N x^N $$ + + Coefficients are stored as dimensionless Parameters. When x_unit changes, the coefficient + values are rescaled so the evaluated result stays the same. The output unit is y_unit. Examples -------- @@ -53,39 +55,46 @@ class Polynomial(ModelComponent): def __init__( self, coefficients: Sequence[Numeric | Parameter] = (0.0,), - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'Polynomial', display_name: str | None = None, unique_name: str | None = None, + suppress_warnings: bool = False, ) -> None: """ - Initialize the Polynomial component. - Parameters ---------- coefficients : Sequence[Numeric | Parameter], default=(0.0,) - Coefficients c0, c1, ..., cN. - unit : str | sc.Unit, default='meV' - Unit of the Polynomial component. + Ordered list of polynomial coefficients ``[c0, c1, ..., cN]`` where the polynomial is + ``c0 + c1*x + c2*x^2 + ... + cN*x^N``. Each element may be a plain numeric value + (wrapped into a dimensionless :class:`Parameter`) or an existing :class:`Parameter` + instance. Must contain at least one element. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis. When the x_unit is changed via :meth:`convert_x_unit`, coefficient + values are rescaled by power-law factors so the evaluated output remains unchanged. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the y-axis (output). name : str, default='Polynomial' - Name of the component for indexing. + Name of the component. display_name : str | None, default=None - Display name of the Polynomial component. + Display name shown when plotting. Falls back to *name* if None. unique_name : str | None, default=None - Unique name of the component. If None, a unique_name is automatically generated. By - default, None. + Globally unique identifier. Auto-generated if None. + suppress_warnings : bool, default=False + Whether to suppress warnings Raises ------ TypeError - If coefficients is not a sequence of numbers or Parameters or if any item in - coefficients is not a number or Parameter. + If *coefficients* is not a list, tuple, or ndarray, or if any element is neither + numeric nor a :class:`Parameter`. ValueError - If coefficients is an empty sequence. + If *coefficients* is empty. """ - super().__init__( - unit=unit, + x_unit=x_unit, + y_unit=y_unit, name=name, display_name=display_name, unique_name=unique_name, @@ -93,17 +102,15 @@ def __init__( if not isinstance(coefficients, (list, tuple, np.ndarray)): raise TypeError( - 'coefficients must be a sequence (list/tuple/ndarray) \ - of numbers or Parameter objects.' + 'coefficients must be a sequence (list/tuple/ndarray) ' + 'of numbers or Parameter objects.' ) if len(coefficients) == 0: raise ValueError('At least one coefficient must be provided.') - # Internal storage of Parameter objects self._coefficients: list[Parameter] = [] - # Coefficients are treated as dimensionless Parameters for i, coef in enumerate(coefficients): if isinstance(coef, Parameter): param = coef @@ -113,19 +120,20 @@ def __init__( raise TypeError('Each coefficient must be either a numeric value or a Parameter.') self._coefficients.append(param) - # Helper scipp scalar to track unit conversions - # (value initialized to 1 with provided unit) - self._unit_conversion_helper = sc.scalar(value=1.0, unit=unit) + # Tracks the current x_unit scale for convert_x_unit power-law rescaling + self._x_unit_helper = sc.scalar(value=1.0, unit=x_unit) + + self.suppress_warnings = suppress_warnings @property def coefficients(self) -> list[Parameter]: """ Get the coefficients of the polynomial as a list of Parameters. - Returns ------- list[Parameter] - The coefficients of the polynomial. + A shallow copy of the internal coefficient list ``[c0, c1, ..., cN]``. Modifying the + returned list does not affect the model; use the setter to replace values. """ return list(self._coefficients) @@ -133,25 +141,24 @@ def coefficients(self) -> list[Parameter]: def coefficients(self, coeffs: Sequence[Numeric | Parameter]) -> None: """ Set the coefficients of the polynomial. - - Length must match current number of coefficients. - Parameters ---------- coeffs : Sequence[Numeric | Parameter] - New coefficients as a sequence of numbers or Parameters. + New coefficient values. Must be a list, tuple, or ndarray and must have the same + length as the current number of coefficients. Numeric values update the existing + Parameter's ``.value``; a Parameter instance replaces the stored Parameter entirely. Raises ------ TypeError - If coeffs is not a sequence of numbers or Parameters or if any item in coeffs is not a - number or Parameter. + If *coeffs* is not a list, tuple, or ndarray, or if any element is neither numeric nor + a Parameter. ValueError - If the length of coeffs does not match the existing number of coefficients. + If the length of *coeffs* does not match the current number of coefficients. """ if not isinstance(coeffs, (list, tuple, np.ndarray)): raise TypeError( - 'coefficients must be a sequence (list/tuple/ndarray) of numbers or Parameter .' + 'coefficients must be a sequence (list/tuple/ndarray) of numbers or Parameter.' ) if len(coeffs) != len(self._coefficients): raise ValueError( @@ -159,7 +166,6 @@ def coefficients(self, coeffs: Sequence[Numeric | Parameter]) -> None: ) for i, coef in enumerate(coeffs): if isinstance(coef, Parameter): - # replace parameter self._coefficients[i] = coef elif isinstance(coef, Numeric): self._coefficients[i].value = float(coef) @@ -169,56 +175,20 @@ def coefficients(self, coeffs: Sequence[Numeric | Parameter]) -> None: def coefficient_values(self) -> list[float]: """ Get the coefficients of the polynomial as a list. - Returns ------- list[float] - The coefficient values of the polynomial. + Current numeric values of all coefficients ``[c0.value, c1.value, ..., cN.value]``. """ return [param.value for param in self._coefficients] - def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray: - r""" - Evaluate the Polynomial at the given x values. - - The intensity is given by $$ I(x) = c_0 + c_1 x + c_2 x^2 + ... - + c_N x^N, $$ where $C_i$ are the coefficients. - - Parameters - ---------- - x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - The x values at which to evaluate the Polynomial. - - Returns - ------- - np.ndarray - The evaluated Polynomial at the given x values. - """ - - x = self._prepare_x_for_evaluate(x) - - result = np.zeros_like(x, dtype=float) - for i, param in enumerate(self._coefficients): - result += param.value * np.power(x, i) - - if any(result < 0): - warnings.warn( - f'The Polynomial with unique_name {self.unique_name} has negative values, ' - 'which may not be physically meaningful.', - UserWarning, - stacklevel=2, - ) - return result - @property def degree(self) -> int: """ - Get the degree of the polynomial. - Returns ------- int - The degree of the polynomial. + Polynomial degree, equal to ``len(coefficients) - 1``. """ return len(self._coefficients) - 1 @@ -230,73 +200,167 @@ def degree(self, _value: int) -> None: Parameters ---------- _value : int - The new degree of the polynomial. + Ignored; this setter always raises :exc:`AttributeError`. Raises ------ AttributeError - Always raised since degree cannot be set directly. + Always raised when this setter is called. """ raise AttributeError( 'The degree of the polynomial is determined by the number of coefficients ' 'and cannot be set directly.' ) - def get_all_variables(self) -> list[DescriptorBase]: + @property + def suppress_warnings(self) -> bool: + """ + Get whether or not to suppress warnings. + """ + return self._suppress_warnings + + @suppress_warnings.setter + def suppress_warnings(self, value: bool) -> None: """ - Get all variables from the model component. + Choose whether or not to suppress warnings. + + Parameters + ---------- + value : bool + Whether or not to suppress warnings + + Raises + ------ + TypeError + If suppress_warnings is not True or False + """ + if not isinstance(value, bool): + raise TypeError('Suppress_warnings must be True or False') + self._suppress_warnings = value + + def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray: + r""" + Evaluate the Polynomial at x_vals. + + When x_vals is expressed in a different unit than the stored x_unit, coefficient values are + temporarily rescaled (same power-law logic as convert_x_unit) without mutation. + + Parameters + ---------- + x_vals : np.ndarray + Raw x values expressed in eval_unit. + eval_unit : str | None + The unit of x_vals. + Returns + ------- + np.ndarray + Evaluated polynomial values. + """ + if ( + eval_unit is not None + and self.x_unit is not None + and sc.Unit(eval_unit) != sc.Unit(self.x_unit) + ): + # Temporary coefficient rescaling — no mutation + helper = sc.scalar(1.0, unit=self.x_unit) + helper_in_x = sc.to_unit(helper, eval_unit) + scale = helper.value / helper_in_x.value + coeff_vals = [p.value * scale**i for i, p in enumerate(self._coefficients)] + else: + coeff_vals = [p.value for p in self._coefficients] + + result = np.zeros_like(x_vals, dtype=float) + for i, cv in enumerate(coeff_vals): + result += cv * np.power(x_vals, i) + + if not self._suppress_warnings and any(result < 0): + warnings.warn( + f'The Polynomial with unique_name {self.unique_name} has negative values, ' + 'which may not be physically meaningful.', + UserWarning, + stacklevel=3, + ) + + return result + + def get_all_variables(self) -> list[DescriptorBase]: + """ Returns ------- list[DescriptorBase] - List of variables in the component. + The coefficient Parameters that constitute the fittable variables of this polynomial + component. """ return list(self._coefficients) - def convert_unit(self, unit: str | sc.Unit) -> None: + def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None: """ - Convert the unit of the polynomial. + Convert the x-axis unit by rescaling coefficients with power-law factors. + + Each coefficient ``c_i`` is rescaled by ``(old_scale / new_scale) ** i`` so the evaluated + polynomial output is unchanged after the conversion. Parameters ---------- - unit : str | sc.Unit - The target unit to convert to. + new_x_unit : str | sc.Unit + Target x-axis unit. Must be dimensionally compatible with the current x_unit. Raises ------ UnitError - If the provided unit is not a string or sc.Unit. + If *new_x_unit* is not a valid unit string or ``sc.Unit``, or if the conversion between + the current unit and *new_x_unit* fails. """ + if not isinstance(new_x_unit, (str, sc.Unit)): + raise UnitError('new_x_unit must be a string or a scipp unit.') - if not isinstance(unit, (str, sc.Unit)): - raise UnitError('unit must be a string or a scipp unit.') - - # Find out how much the unit changes - # by converting a helper variable - conversion_value_before = self._unit_conversion_helper.value - self._unit_conversion_helper = sc.to_unit(self._unit_conversion_helper, unit=unit) - conversion_value_after = self._unit_conversion_helper.value + conversion_value_before = self._x_unit_helper.value + self._x_unit_helper = sc.to_unit(self._x_unit_helper, unit=new_x_unit) + conversion_value_after = self._x_unit_helper.value for i, param in enumerate(self._coefficients): - param.value *= ( - conversion_value_before / conversion_value_after - ) ** i # set the values directly to the appropriate power + param.value *= (conversion_value_before / conversion_value_after) ** i - self._unit = unit + self._x_unit = str(new_x_unit) if isinstance(new_x_unit, sc.Unit) else new_x_unit - def __repr__(self) -> str: + def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None: """ - Return a string representation of the Polynomial. + Rescale all coefficients so the evaluated output remains the same physical value. - Returns - ------- - str - A string representation of the Polynomial. + All coefficients are multiplied by the conversion factor from ``old_y_unit`` to + ``new_y_unit`` so that ``I(x) [new_y_unit]`` represents the same physical quantity as + ``I(x) [old_y_unit]``. + + Parameters + ---------- + new_y_unit : str | sc.Unit + Target y-axis unit. Must be dimensionally compatible with the current y_unit. + + Raises + ------ + UnitError + If *new_y_unit* is not a valid unit string or ``sc.Unit``, or if the conversion between + the current y_unit and *new_y_unit* fails. """ + if not isinstance(new_y_unit, (str, sc.Unit)): + raise UnitError('new_y_unit must be a string or a scipp unit.') + old_y_unit = self.y_unit or 'dimensionless' + new_y_str = str(new_y_unit) if isinstance(new_y_unit, sc.Unit) else new_y_unit + + # Compute conversion factor: 1 old_y_unit expressed in new_y_unit + y_helper = sc.scalar(1.0, unit=old_y_unit) + y_helper_new = sc.to_unit(y_helper, new_y_str) + scale = y_helper_new.value / y_helper.value + + for param in self._coefficients: + param.value *= scale + self._y_unit = new_y_str + + def __repr__(self) -> str: coeffs_str = ', '.join(f'{param.name}={param.value}' for param in self._coefficients) return ( - f'{self.__class__.__name__}(' - f'name={self.name!r}, display_name={self.display_name!r}, ' - f'unit={self._unit},\n' - f' coefficients=[{coeffs_str}])' + f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, ' + f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n' + f' coefficients = [{coeffs_str}])' ) diff --git a/src/easydynamics/sample_model/components/voigt.py b/src/easydynamics/sample_model/components/voigt.py index 5ca29105c..ee8c29046 100644 --- a/src/easydynamics/sample_model/components/voigt.py +++ b/src/easydynamics/sample_model/components/voigt.py @@ -19,13 +19,14 @@ class Voigt(CreateParametersMixin, ModelComponent): r""" - Voigt profile, a convolution of Gaussian and Lorentzian. + Voigt profile — convolution of Gaussian and Lorentzian. + + Uses ``scipy.special.voigt_profile`` to evaluate the profile. area has unit = x_unit * y_unit; + center, gaussian_width, and lorentzian_width have unit = x_unit. If the center is not provided, it will be centered at 0 and fixed, which is typically what you want in QENS. - Use scipy.special.voigt_profile to evaluate the Voigt profile. - Examples -------- **Creating a Voigt profile with a fixed center (typical QENS use)** @@ -60,7 +61,8 @@ def __init__( center: Numeric | Parameter | None = None, gaussian_width: Numeric | Parameter = 1.0, lorentzian_width: Numeric | Parameter = 1.0, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'Voigt', display_name: str | None = None, unique_name: str | None = None, @@ -71,54 +73,52 @@ def __init__( Parameters ---------- area : Numeric | Parameter, default=1.0 - Total area under the curve. + Integrated area under the Voigt profile. Unit is ``x_unit * y_unit``. center : Numeric | Parameter | None, default=None - Center of the Voigt profile. + Peak position in x_unit. If None, defaults to 0 and the center parameter is fixed. gaussian_width : Numeric | Parameter, default=1.0 - Standard deviation of the Gaussian part. + Gaussian component standard deviation (sigma) in x_unit. Must be strictly positive. lorentzian_width : Numeric | Parameter, default=1.0 - Half width at half max (HWHM) of the Lorentzian part. - unit : str | sc.Unit, default='meV' - Unit of the parameters. + Lorentzian component HWHM (gamma) in x_unit. Must be strictly positive. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis. center, gaussian_width, and lorentzian_width are stored in this + unit. area_unit = x_unit * y_unit. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the y-axis (output). name : str, default='Voigt' - Name of the component for indexing. + Name of the component. display_name : str | None, default=None - Display name of the component. + Display name shown when plotting. Falls back to *name* if None. unique_name : str | None, default=None - Unique name of the component. If None, a unique_name is automatically generated. By - default, None. + Globally unique identifier. Auto-generated if None. """ - super().__init__( - unit=unit, + x_unit=x_unit, + y_unit=y_unit, name=name, display_name=display_name, unique_name=unique_name, ) - # These methods live in ValidationMixin - area = self._create_area_parameter(area=area, name=name, unit=self._unit) - center = self._create_center_parameter( - center=center, name=name, fix_if_none=True, unit=self._unit + self._area = self._create_area_parameter( + area=area, name=name, x_unit=self.x_unit, y_unit=self.y_unit ) - gaussian_width = self._create_width_parameter( + self._center = self._create_center_parameter( + center=center, name=name, fix_if_none=True, x_unit=self.x_unit + ) + self._gaussian_width = self._create_width_parameter( width=gaussian_width, name=name, param_name='gaussian_width', - unit=self._unit, + x_unit=self.x_unit, ) - lorentzian_width = self._create_width_parameter( + self._lorentzian_width = self._create_width_parameter( width=lorentzian_width, name=name, param_name='lorentzian_width', - unit=self._unit, + x_unit=self.x_unit, ) - self._area = area - self._center = center - self._gaussian_width = gaussian_width - self._lorentzian_width = lorentzian_width - @property def area(self) -> Parameter: """ @@ -127,24 +127,22 @@ def area(self) -> Parameter: Returns ------- Parameter - The area parameter. + The area Parameter with unit ``x_unit * y_unit``. """ return self._area @area.setter def area(self, value: Numeric) -> None: """ - Set the value of the area parameter. - Parameters ---------- value : Numeric - The new value for the area parameter. + New area value (in current area unit = x_unit * y_unit). Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. """ if not isinstance(value, Numeric): raise TypeError('area must be a number') @@ -158,24 +156,23 @@ def center(self) -> Parameter: Returns ------- Parameter - The center parameter. + The center Parameter with unit ``x_unit``. """ return self._center @center.setter def center(self, value: Numeric | None) -> None: """ - Set the value of the center parameter. - Parameters ---------- value : Numeric | None - The new value for the center parameter. If None, defaults to 0 and is fixed. + New center value in x_unit. If None, the center is set to 0 and the parameter is + fixed. Raises ------ TypeError - If the value is not a number. + If *value* is not None and not a numeric type. """ if value is None: value = 0.0 @@ -187,31 +184,30 @@ def center(self, value: Numeric | None) -> None: @property def gaussian_width(self) -> Parameter: """ - Get the Gaussian width parameter. + Get the Gaussian width parameter (sigma). Returns ------- Parameter - The Gaussian width parameter. + The Gaussian component width (sigma) Parameter with unit ``x_unit``. """ return self._gaussian_width @gaussian_width.setter def gaussian_width(self, value: Numeric) -> None: """ - Set the width parameter value. - + Set the gaussian width parameter value. Parameters ---------- value : Numeric - The new value for the width parameter. + New Gaussian width (sigma) in x_unit. Must be strictly positive. Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. ValueError - If the value is not positive. + If *value* is not positive. """ if not isinstance(value, Numeric): raise TypeError('gaussian_width must be a number') @@ -227,7 +223,7 @@ def lorentzian_width(self) -> Parameter: Returns ------- Parameter - The Lorentzian width parameter. + The Lorentzian component HWHM (gamma) Parameter with unit ``x_unit``. """ return self._lorentzian_width @@ -235,18 +231,17 @@ def lorentzian_width(self) -> Parameter: def lorentzian_width(self, value: Numeric) -> None: """ Set the value of the Lorentzian width parameter. - Parameters ---------- value : Numeric - The new value for the Lorentzian width parameter. + New Lorentzian HWHM (gamma) in x_unit. Must be strictly positive. Raises ------ TypeError - If the value is not a number. + If *value* is not a numeric type. ValueError - If the value is not positive. + If *value* is not positive. """ if not isinstance(value, Numeric): raise TypeError('lorentzian_width must be a number') @@ -254,33 +249,60 @@ def lorentzian_width(self, value: Numeric) -> None: raise ValueError('lorentzian_width must be positive') self._lorentzian_width.value = value - def evaluate(self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray) -> np.ndarray: - r""" - Evaluate the Voigt at the given x values. + def _evaluate_values(self, x_vals: np.ndarray, eval_unit: str | None) -> np.ndarray: + """ + Evaluate the Voigt at x_vals. - If x is a scipp Variable, the unit of the Voigt will be converted to match x. The Voigt - evaluates to the convolution of a Gaussian with sigma gaussian_width and a Lorentzian with - half width at half max lorentzian_width, centered at center, with area equal to area. + Parameters in the model's own units are temporarily converted to eval_unit for the + computation. Parameters ---------- - x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - The x values at which to evaluate the Voigt. + x_vals : np.ndarray + Raw x values expressed in eval_unit. + eval_unit : str | None + The unit of x_vals. Returns ------- np.ndarray - The intensity of the Voigt at the given x values. + Evaluated Voigt profile values at x_vals. """ + center = self._resolve_param_value(self._center, eval_unit) + gw = self._resolve_param_value(self._gaussian_width, eval_unit) + lw = self._resolve_param_value(self._lorentzian_width, eval_unit) + area = self._resolve_param_value(self._area, self._eval_area_unit(eval_unit)) + + return area * voigt_profile(x_vals - center, gw, lw) - x = self._prepare_x_for_evaluate(x) + def convert_x_unit(self, new_x_unit: str | sc.Unit) -> None: + """ + Convert x-axis parameters (center, widths) and area to new_x_unit. - return self.area.value * voigt_profile( - x - self.center.value, - self.gaussian_width.value, - self.lorentzian_width.value, + Parameters + ---------- + new_x_unit : str | sc.Unit + Target x-axis unit. Must be dimensionally compatible with the current x_unit. + """ + self._convert_x_unit_area_based( + new_x_unit=new_x_unit, + x_params=[self._center, self._gaussian_width, self._lorentzian_width], + area_param=self._area, ) + def convert_y_unit(self, new_y_unit: str | sc.Unit) -> None: + """ + Convert the y-axis unit by rescaling the area parameter. + + The area is rescaled from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit``. + + Parameters + ---------- + new_y_unit : str | sc.Unit + Target y-axis unit. + """ + self._convert_y_unit_area_based(new_y_unit=new_y_unit, area_param=self._area) + def __repr__(self) -> str: """ Return a string representation of the Voigt. @@ -290,12 +312,11 @@ def __repr__(self) -> str: str A string representation of the Voigt. """ - return ( - f'{self.__class__.__name__}(' - f'name={self.name!r}, display_name={self.display_name!r}, unit={self._unit},\n' - f' area={self.area},\n' - f' center={self.center},\n' - f' gaussian_width={self.gaussian_width},\n' - f' lorentzian_width={self.lorentzian_width})' + f'{self.__class__.__name__}(name = {self.name}, display_name = {self.display_name}, ' + f'x_unit = {self.x_unit}, y_unit = {self.y_unit},\n' + f' area = {self.area},\n' + f' center = {self.center},\n' + f' gaussian_width = {self.gaussian_width},\n' + f' lorentzian_width = {self.lorentzian_width})' ) diff --git a/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py b/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py index 7255f88a0..b15d00120 100644 --- a/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py +++ b/src/easydynamics/sample_model/diffusion_model/brownian_translational_diffusion.py @@ -52,7 +52,8 @@ def __init__( scale: Numeric = 1.0, diffusion_coefficient: Numeric = 1.0, Q: Q_type | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'BrownianTranslationalDiffusion', display_name: str | None = 'BrownianTranslationalDiffusion', lorentzian_name: str | None = None, @@ -70,8 +71,10 @@ def __init__( Diffusion coefficient D in m^2/s. Q : Q_type | None, default=None Q values for the model. If None, Q is not set. - unit : str | sc.Unit, default='meV' - Unit of the diffusion model. Must be convertible to meV. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis (energy/frequency). Must be convertible to meV. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the model output (intensity). Determines scale.unit = x_unit * y_unit. name : str, default='BrownianTranslationalDiffusion' Name of the diffusion model. display_name : str | None, default='BrownianTranslationalDiffusion' @@ -96,7 +99,8 @@ def __init__( """ super().__init__( Q=Q, - unit=unit, + x_unit=x_unit, + y_unit=y_unit, scale=scale, name=name, display_name=display_name, @@ -186,7 +190,7 @@ def calculate_width(self, Q: Q_type | None = None) -> np.ndarray: Q = self._ensure_Q(Q) unit_conversion_factor = self._hbar * self.diffusion_coefficient / (self._angstrom**2) - unit_conversion_factor.convert_unit(self.unit) + unit_conversion_factor.convert_unit(self.x_unit) return Q**2 * unit_conversion_factor.value def calculate_EISF(self, Q: Q_type | None = None) -> np.ndarray: @@ -240,11 +244,11 @@ def create_component_collections( Lorentzian has a width given by $D*Q^2$ and an area given by the scale parameter multiplied by the QISF (which is 1 for this model). """ - Q = self.Q - if Q is None: + if self.Q is None: self._component_collections = [] return self._component_collections + Q = self.Q.values component_collection_list = [None] * len(Q) # In more complex models, this is used to scale the area of the # Lorentzians and the delta function. @@ -257,31 +261,37 @@ def create_component_collections( component_collection_list[i] = ComponentCollection( name=f'{self.name}_Q{Q_value:.2f}', display_name=f'{self.display_name}_Q{Q_value:.2f}', - unit=self.unit, + x_unit=self.x_unit, + y_unit=self.y_unit, ) lorentzian_component = Lorentzian( name=self.lorentzian_name, display_name=self.lorentzian_display_name, - unit=self.unit, + x_unit=self.x_unit, + y_unit=self.y_unit, ) # Make the width dependent on Q dependency_expression = self._write_width_dependency_expression(Q[i]) dependency_map = self._write_width_dependency_map_expression() - lorentzian_component.width.make_dependent_on( - dependency_expression=dependency_expression, - dependency_map=dependency_map, - desired_unit=self.unit, - ) - - # Make the area dependent on Q - area_dependency_map = self._write_area_dependency_map_expression() - lorentzian_component.area.make_dependent_on( - dependency_expression=self._write_area_dependency_expression(QISF[i]), - dependency_map=area_dependency_map, - ) + # easyscience propagates inf bounds through arithmetic, producing inf/inf=nan + # as a transient intermediate. Python's min/max ignore nan so the final bounds + # are correct; suppress the spurious numpy RuntimeWarning. + with np.errstate(invalid='ignore'): + lorentzian_component.width.make_dependent_on( + dependency_expression=dependency_expression, + dependency_map=dependency_map, + desired_unit=self.x_unit, + ) + + # Make the area dependent on Q + area_dependency_map = self._write_area_dependency_map_expression() + lorentzian_component.area.make_dependent_on( + dependency_expression=self._write_area_dependency_expression(QISF[i]), + dependency_map=area_dependency_map, + ) component_collection_list[i].append_component(lorentzian_component) @@ -339,43 +349,6 @@ def _write_width_dependency_map_expression(self) -> dict[str, DescriptorNumber]: 'angstrom': self._angstrom, } - def _write_area_dependency_expression(self, QISF: float) -> str: - """ - Write the dependency expression for the area to make dependent Parameters. - - Parameters - ---------- - QISF : float - Quasielastic Incoherent Scattering Function. - - Raises - ------ - TypeError - If QISF is not a float. - - Returns - ------- - str - Dependency expression for the area. - """ - if not isinstance(QISF, (float)): - raise TypeError('QISF must be a float.') - - return f'{QISF} * scale' - - def _write_area_dependency_map_expression(self) -> dict[str, DescriptorNumber]: - """ - Write the dependency map expression to make dependent Parameters. - - Returns - ------- - dict[str, DescriptorNumber] - Dependency map for the area. - """ - return { - 'scale': self.scale, - } - # ------------------------------------------------------------------ # dunder methods # ------------------------------------------------------------------ @@ -392,6 +365,7 @@ def __repr__(self) -> str: return ( f'{self.__class__.__name__}(' f'name={self.name!r}, display_name={self.display_name!r},\n' + f'x_unit={self.x_unit}, y_unit={self.y_unit}, \n' f' diffusion_coefficient={self.diffusion_coefficient},\n' f' scale={self.scale})' ) diff --git a/src/easydynamics/sample_model/diffusion_model/delta_lorentz.py b/src/easydynamics/sample_model/diffusion_model/delta_lorentz.py index 1e709dbc4..12af54314 100644 --- a/src/easydynamics/sample_model/diffusion_model/delta_lorentz.py +++ b/src/easydynamics/sample_model/diffusion_model/delta_lorentz.py @@ -11,8 +11,11 @@ from easydynamics.sample_model.components import DeltaFunction from easydynamics.sample_model.components import Lorentzian from easydynamics.sample_model.diffusion_model.diffusion_model_base import DiffusionModelBase +from easydynamics.utils.fit_target import FitTarget from easydynamics.utils.utils import Numeric from easydynamics.utils.utils import Q_type +from easydynamics.utils.utils import angstrom +from easydynamics.utils.utils import verify_Q_index MINIMUM_WIDTH = 1e-10 # To avoid division by zero @@ -63,7 +66,8 @@ def __init__( lorentzian_width: Numeric = 1.0, allow_Q_variation: dict | None = None, Q: Q_type | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'DeltaLorentz', display_name: str | None = None, lorentzian_name: str = 'Lorentzian', @@ -92,8 +96,10 @@ def __init__( allowed. Q : Q_type | None, default=None Q values for the model. If None, Q is not set. - unit : str | sc.Unit, default='meV' - Unit of the diffusion model. Must be convertible to meV. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis (energy/frequency). Must be convertible to meV. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the model output (intensity). Determines scale.unit = x_unit * y_unit. name : str, default='DeltaLorentz' Name of the diffusion model. display_name : str | None, default=None @@ -120,7 +126,8 @@ def __init__( """ super().__init__( scale=scale, - unit=unit, + x_unit=x_unit, + y_unit=y_unit, Q=Q, lorentzian_name=lorentzian_name, lorentzian_display_name=lorentzian_display_name, @@ -134,6 +141,18 @@ def __init__( # -------------------------------------------------------------- self._mean_u_squared = self._create_mean_u_squared_parameter(mean_u_squared) + # Dimensionless /angstrom^2 for use inside exp() in the area dependency + # expressions and the EISF/QISF calculations. Deriving it through the dependency + # graph keeps those quantities correct if mean_u_squared is converted to another + # unit (e.g. nm**2), which a raw `.value` would silently get wrong. + with np.errstate(invalid='ignore'): + self._mean_u_squared_over_angstrom_squared = Parameter.from_dependency( + name='mean_u_squared_over_angstrom_squared', + dependency_expression='mean_u_squared / angstrom**2', + dependency_map={'mean_u_squared': self._mean_u_squared, 'angstrom': angstrom}, + ) + self._mean_u_squared_over_angstrom_squared.set_desired_unit('dimensionless') + self._A_0, self._A_1 = self._create_A0_A1_parameters(A_0) self._lorentzian_width = self._create_lorentzian_width_parameter(lorentzian_width) @@ -398,8 +417,18 @@ def calculate_width(self, Q: Q_type = None) -> np.ndarray: ------- np.ndarray HWHM values in the unit of the model (e.g., meV). + + Raises + ------ + ValueError + If Q-variation is enabled but Q has not been set on the model yet. """ if self._allow_Q_variation['lorentzian_width'] is True: + if not self._lorentzian_width_list: + raise ValueError( + 'Lorentzian width Q-variation list is empty. ' + 'Set Q before calling calculate_width.' + ) widths = [lorentzian_width.value for lorentzian_width in self._lorentzian_width_list] return np.array(widths) @@ -424,12 +453,13 @@ def calculate_EISF(self, Q: Q_type = None) -> np.ndarray: EISF values (dimensionless). """ Q = self._ensure_Q(Q) + mean_u_squared = self._mean_u_squared_over_angstrom_squared.value if self._allow_Q_variation['A_0'] is True: A_0_values = [A_0_.value for A_0_ in self._A_0_list] - return np.exp(-self.mean_u_squared.value * Q**2 / 3) * np.array(A_0_values) + return np.exp(-mean_u_squared * Q**2 / 3) * np.array(A_0_values) A_0_values = [self.A_0.value] * len(Q) - return np.exp(-self.mean_u_squared.value * Q**2 / 3) * np.array(A_0_values) + return np.exp(-mean_u_squared * Q**2 / 3) * np.array(A_0_values) def calculate_QISF(self, Q: Q_type = None) -> np.ndarray: """ @@ -446,12 +476,13 @@ def calculate_QISF(self, Q: Q_type = None) -> np.ndarray: QISF values (dimensionless). """ Q = self._ensure_Q(Q) + mean_u_squared = self._mean_u_squared_over_angstrom_squared.value if self._allow_Q_variation['A_0'] is True: A_1_values = [A_1_.value for A_1_ in self._A_1_list] - return np.exp(-self.mean_u_squared.value * Q**2 / 3) * np.array(A_1_values) + return np.exp(-mean_u_squared * Q**2 / 3) * np.array(A_1_values) A_1_values = [self.A_1.value] * len(Q) - return np.exp(-self.mean_u_squared.value * Q**2 / 3) * np.array(A_1_values) + return np.exp(-mean_u_squared * Q**2 / 3) * np.array(A_1_values) def create_component_collections( self, @@ -465,10 +496,11 @@ def create_component_collections( List of ComponentCollections with Lorentzian and delta function components for each Q value. """ - Q = self.Q - if Q is None: + if self.Q is None: return [] + Q = self.Q.values + if self._allow_Q_variation['A_0'] is True: A_0_list, A_1_list = self._create_A0_A1_parameter_lists(self.A_0) self._A_0_list = A_0_list @@ -484,19 +516,21 @@ def create_component_collections( for i, Q_value in enumerate(Q): component_collection_list[i] = ComponentCollection( display_name=f'{self.display_name}_Q{Q_value:.2f}', - unit=self.unit, + x_unit=self.x_unit, + y_unit=self.y_unit, ) # ------------------------------# # Create Lorentzian # ------------------------------# - lorentzian_component = Lorentzian( + lorz_component = Lorentzian( name=self.lorentzian_name, display_name=self.lorentzian_display_name, - unit=self.unit, + x_unit=self.x_unit, + y_unit=self.y_unit, ) if self._allow_Q_variation['lorentzian_width'] is True: - lorentzian_component._width = self._lorentzian_width_list[i] # noqa: SLF001 + lorz_component._width = self._lorentzian_width_list[i] # noqa: SLF001 # If the width is allowed to vary with Q it is independent. # If the width is not allowed to vary with Q it must be made @@ -504,10 +538,10 @@ def create_component_collections( if self._allow_Q_variation['lorentzian_width'] is False: dependency_map = self._write_width_dependency_map_expression() - lorentzian_component.width.make_dependent_on( + lorz_component.width.make_dependent_on( dependency_expression=self._write_lorz_width_dependency_expression(Q_value), dependency_map=dependency_map, - desired_unit=self.unit, + desired_unit=self.x_unit, ) # The area is always a dependent parameter in this model, as # it depends on the scale, mean_u_squared and A_1 parameters @@ -520,12 +554,12 @@ def create_component_collections( else: dependency_map = self._write_lorz_area_dependency_map_expression(None) - lorentzian_component.area.make_dependent_on( + lorz_component.area.make_dependent_on( dependency_expression=self._write_lorz_area_dependency_expression(Q_value), dependency_map=dependency_map, ) - component_collection_list[i].append_component(lorentzian_component) + component_collection_list[i].append_component(lorz_component) # ------------------------------# # Create delta function @@ -534,7 +568,8 @@ def create_component_collections( delta_component = DeltaFunction( name=self.delta_name, display_name=self.delta_display_name, - unit=self.unit, + x_unit=self.x_unit, + y_unit=self.y_unit, ) if self._allow_Q_variation['A_0'] is True: @@ -551,6 +586,32 @@ def create_component_collections( return component_collection_list + def get_fit_targets(self) -> list[FitTarget]: + """ + Get the fittable predictions of the DeltaLorentz model as FitTargets. + + Extends the base ``'area'`` and ``'width'`` predictions with ``'delta_area'`` (``scale * + EISF(Q)``, the delta function's weight), whose default dataset key is derived from the + delta component's name. + + Returns + ------- + list[FitTarget] + The fittable predictions of this model. + """ + targets = super().get_fit_targets() + targets.append( + FitTarget( + name='delta_area', + dataset_key=f'{self.delta_name} area', + function=lambda Q, model=self, **_: model.calculate_EISF(Q) * model.scale.value, + label=f'{self.display_name} delta_area', + x_unit='1/angstrom', + y_unit=str(self.scale.unit), + ) + ) + return targets + def get_global_variables(self) -> list[Parameter]: """ Get all global variables from the diffusion model. @@ -587,22 +648,9 @@ def get_independent_variables(self, Q_index: int | None = None) -> list[Paramete ------- list[Parameter] List of independent variables in the model. - - Raises - ------ - ValueError - If Q_index is not None and is not a valid index for the Q values in the model. """ - if Q_index is not None and ( - not isinstance(Q_index, int) - or Q_index < 0 - or Q_index >= len(self._component_collections) - ): - raise ValueError( - f'Q_index must be an integer between 0 and ' - f'{len(self._component_collections) - 1}, or None.' - ) + verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True) variables = [] if self._allow_Q_variation['A_0'] is True: @@ -629,22 +677,9 @@ def get_all_variables(self, Q_index: int | None = None) -> list[DescriptorNumber ------- list[DescriptorNumber] List of all variables in the model. - - Raises - ------ - ValueError - If Q_index is not None and is not a valid index for the Q values in the model. """ - if Q_index is not None and ( - not isinstance(Q_index, int) - or Q_index < 0 - or Q_index >= len(self._component_collections) - ): - raise ValueError( - f'Q_index must be an integer between 0 and ' - f'{len(self._component_collections) - 1}, or None.' - ) + verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True) variables = self.get_global_variables() variables.extend(self.get_independent_variables(Q_index=Q_index)) @@ -814,7 +849,7 @@ def _create_lorentzian_width_parameter(self, lorentzian_width: Numeric) -> Param value=float(lorentzian_width), fixed=False, min=MINIMUM_WIDTH, - unit=self.unit, + unit=self.x_unit, ) def _create_A0_A1_parameter_lists( @@ -837,7 +872,7 @@ def _create_A0_A1_parameter_lists( """ A_0_list = [] A_1_list = [] - for _ in self.Q: + for _ in range(len(self.Q)): a0 = Parameter( name='A_0', value=float(A_0.value), @@ -881,9 +916,9 @@ def _create_lorentzian_width_parameter_list( value=float(lorentzian_width.value), fixed=False, min=MINIMUM_WIDTH, - unit=self.unit, + unit=self.x_unit, ) - for _ in self.Q + for _ in range(len(self.Q)) ] # ------------------------------------------------------------------ @@ -914,6 +949,20 @@ def _on_Q_change(self) -> None: self._lorentzian_width_list = [] self._component_collections = self.create_component_collections() + def _convert_extra_x_unit_parameters(self, unit_str: str) -> None: + """ + Convert the Lorentzian width template to the new x-axis unit. + + The per-Q width list (when Q-variation is enabled) holds the very Parameter objects used by + the components, so those are converted in place with the collections. + + Parameters + ---------- + unit_str : str + The new x-axis unit as a string. + """ + self._lorentzian_width.convert_unit(unit_str) + def _write_lorz_width_dependency_expression(self, Q: float) -> str: """ Write the dependency expression for the width as a function of Q to make dependent @@ -976,7 +1025,9 @@ def _write_lorz_area_dependency_expression(self, Q: float) -> str: if not isinstance(Q, (float)): raise TypeError('Q must be a float.') - return f'scale * exp(-mean_u_squared.value * {Q}**2 / 3) * A_1' + # mean_u_squared_ratio is /angstrom^2, kept dimensionless through the dependency + # graph so the expression stays correct if mean_u_squared is converted to another unit. + return f'scale * exp(-mean_u_squared_ratio.value * {Q}**2 / 3) * A_1' def _write_lorz_area_dependency_map_expression( self, Q_index: int | None @@ -998,13 +1049,13 @@ def _write_lorz_area_dependency_map_expression( if Q_index is None: return { 'scale': self.scale, - 'mean_u_squared': self.mean_u_squared, + 'mean_u_squared_ratio': self._mean_u_squared_over_angstrom_squared, 'A_1': self.A_1, } return { 'scale': self.scale, - 'mean_u_squared': self.mean_u_squared, + 'mean_u_squared_ratio': self._mean_u_squared_over_angstrom_squared, 'A_1': self._A_1_list[Q_index], } @@ -1030,7 +1081,9 @@ def _write_delta_area_dependency_expression(self, Q: float) -> str: if not isinstance(Q, (float)): raise TypeError('Q must be a float.') - return f'scale * exp(-mean_u_squared.value * {Q}**2 / 3) * A_0' + # mean_u_squared_ratio is /angstrom^2, kept dimensionless through the dependency + # graph so the expression stays correct if mean_u_squared is converted to another unit. + return f'scale * exp(-mean_u_squared_ratio.value * {Q}**2 / 3) * A_0' def _write_delta_area_dependency_map_expression( self, @@ -1053,12 +1106,12 @@ def _write_delta_area_dependency_map_expression( if Q_index is None: return { 'scale': self.scale, - 'mean_u_squared': self.mean_u_squared, + 'mean_u_squared_ratio': self._mean_u_squared_over_angstrom_squared, 'A_0': self.A_0, } return { 'scale': self.scale, - 'mean_u_squared': self.mean_u_squared, + 'mean_u_squared_ratio': self._mean_u_squared_over_angstrom_squared, 'A_0': self._A_0_list[Q_index], } @@ -1076,10 +1129,10 @@ def __repr__(self) -> str: String representation of the DeltaLorentz model. """ return ( - f'{self.__class__.__name__}(' - f'display_name={self.display_name!r}, unit={self.unit},\n' - f' mean_u_squared={self.mean_u_squared},\n' - f' A_0={self.A_0}, A_1={self.A_1},\n' - f' lorentzian_width={self.lorentzian_width},\n' + f'DeltaLorentz(display_name={self.display_name}, ' + f'x_unit={self.x_unit}, y_unit={self.y_unit}, \n' + f' mean_u_squared={self.mean_u_squared}, \n' + f' A_0={self.A_0}, A_1={self.A_1}, \n' + f' lorentzian_width={self.lorentzian_width}, \n' f' scale={self.scale})' ) diff --git a/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py b/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py index 7d752b826..da6533f07 100644 --- a/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py +++ b/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +import contextlib + import numpy as np import scipp as sc from easyscience.variable import DescriptorNumber @@ -9,9 +11,11 @@ from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase from easydynamics.sample_model.component_collection import ComponentCollection +from easydynamics.utils.fit_target import FitTarget from easydynamics.utils.utils import Numeric from easydynamics.utils.utils import Q_type from easydynamics.utils.utils import _validate_and_convert_Q +from easydynamics.utils.utils import verify_Q_index class DiffusionModelBase(EasyDynamicsModelBase): @@ -21,7 +25,8 @@ def __init__( self, scale: Numeric = 1.0, Q: Q_type | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'DiffusionModel', display_name: str | None = 'DiffusionModel', lorentzian_name: str | None = None, @@ -31,14 +36,20 @@ def __init__( """ Initialize a new DiffusionModel. + Unit validation raises ``UnitError`` if x_unit is not a string or scipp Unit, or if it + cannot be converted to meV. + Parameters ---------- scale : Numeric, default=1.0 - Scale factor for the diffusion model. Must be a non-negative number. + Scale factor for the diffusion model. Must be a non-negative number. Its unit equals + area_unit = x_unit * y_unit because scale * QISF/EISF (dimensionless) = component area. Q : Q_type | None, default=None Q values for the model. If None, Q is not set. - unit : str | sc.Unit, default='meV' - Unit of the diffusion model. Must be convertible to meV. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis (energy/frequency). Must be convertible to meV. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the model output (intensity). Together with x_unit determines area_unit. name : str, default='DiffusionModel' Name of the diffusion model. display_name : str | None, default='DiffusionModel' @@ -57,21 +68,13 @@ def __init__( ------ TypeError If scale is not a number. - UnitError - If unit is not a string or scipp Unit, or if it cannot be converted to meV. ValueError If scale is negative. """ self._Q = _validate_and_convert_Q(Q) - try: - test = DescriptorNumber(name='test', value=1, unit=unit) - test.convert_unit('meV') - except Exception as e: - raise UnitError( - f'Invalid unit: {unit}. Unit must be a string or scipp Unit and convertible to meV.' # noqa: E501 - ) from e + self._assert_convertible_to_mev(x_unit) if not isinstance(scale, Numeric): raise TypeError('scale must be a number.') @@ -79,10 +82,18 @@ def __init__( if float(scale) < 0: raise ValueError('scale must be non-negative.') - scale = Parameter(name='scale', value=float(scale), fixed=False, min=0.0, unit=unit) - self._scale = scale + area_unit = str(sc.Unit(str(x_unit)) * sc.Unit(str(y_unit))) + self._scale = Parameter( + name='scale', value=float(scale), fixed=False, min=0.0, unit=area_unit + ) - super().__init__(unit=unit, name=name, display_name=display_name, unique_name=unique_name) + super().__init__( + x_unit=x_unit, + y_unit=y_unit, + name=name, + display_name=display_name, + unique_name=unique_name, + ) if lorentzian_name is None: lorentzian_name = name @@ -102,7 +113,10 @@ def __init__( if self.Q is None: self._component_collections = [] else: - self._component_collections = [ComponentCollection()] * len(self.Q) + self._component_collections = [ + ComponentCollection(x_unit=self.x_unit, y_unit=self.y_unit) + for _ in range(len(self.Q)) + ] # ------------------------------------------------------------------ # Properties @@ -145,14 +159,14 @@ def scale(self, scale: Numeric) -> None: self._scale.value = float(scale) @property - def Q(self) -> np.ndarray | None: + def Q(self) -> sc.Variable | None: """ Get the Q values of the SampleModel. Returns ------- - np.ndarray | None - The Q values of the SampleModel, or None if not set. + sc.Variable | None + The Q values of the SampleModel in 1/angstrom, or None if not set. """ return self._Q @@ -184,7 +198,7 @@ def Q(self, value: Q_type | None) -> None: self._on_Q_change() return - if len(old_Q) != len(new_Q) or not np.allclose(old_Q, new_Q): + if len(old_Q) != len(new_Q) or not sc.allclose(old_Q, new_Q): raise ValueError( 'New Q values are not similar to the old ones. ' 'To change Q values, first run clear_Q().' @@ -274,6 +288,177 @@ def clear_Q(self, confirm: bool = False) -> None: self._Q = None self._on_Q_change() + # ------------------------------------------------------------------ + # Unit conversion + # ------------------------------------------------------------------ + + @staticmethod + def _assert_convertible_to_mev(unit: str | sc.Unit) -> None: + """ + Assert that the given unit is an energy unit (convertible to meV). + + Frequency axes such as 1/ps are not supported for now. + + Parameters + ---------- + unit : str | sc.Unit + The unit to validate. + + Raises + ------ + UnitError + If unit is not a string or scipp Unit, or if it cannot be converted to meV. + """ + try: + test = DescriptorNumber(name='test', value=1, unit=str(unit)) + test.convert_unit('meV') + except Exception as e: + raise UnitError( + f'Invalid unit: {unit}. Unit must be a string or scipp Unit and convertible to meV.' # noqa: E501 + ) from e + + def convert_x_unit(self, unit: str | sc.Unit) -> None: + """ + Convert the x-axis unit of the diffusion model. + + Converts the scale parameter (unit ``x_unit * y_unit``), any subclass-specific x-unit + parameters, and the existing component collections in place — parameter values and object + identity are preserved, and nothing is scheduled for regeneration. Only energy units are + supported (the unit must be convertible to meV). + + Unit validation raises ``UnitError`` when the unit is not convertible to meV. + + Parameters + ---------- + unit : str | sc.Unit + The new x-axis unit. + + Raises + ------ + TypeError + If unit is not a string or sc.Unit. + Exception + If any conversion fails; the already-converted state is rolled back best-effort before + re-raising. + """ + if not isinstance(unit, (str, sc.Unit)): + raise TypeError(f'x_unit must be a string or sc.Unit, got {type(unit).__name__}') + unit_str = str(unit) + self._assert_convertible_to_mev(unit_str) + + old_x_unit = str(self.x_unit) + new_scale_unit = str(sc.Unit(unit_str) * sc.Unit(str(self.y_unit))) + self._scale.convert_unit(new_scale_unit) + try: + self._convert_extra_x_unit_parameters(unit_str) + for collection in self._component_collections: + collection.convert_x_unit(unit_str) + except Exception: + old_scale_unit = str(sc.Unit(old_x_unit) * sc.Unit(str(self.y_unit))) + with contextlib.suppress(Exception): + self._scale.convert_unit(old_scale_unit) + with contextlib.suppress(Exception): + self._convert_extra_x_unit_parameters(old_x_unit) + for collection in self._component_collections: + with contextlib.suppress(Exception): + collection.convert_x_unit(old_x_unit) + raise + + self._x_unit = unit_str + + def convert_y_unit(self, unit: str | sc.Unit) -> None: + """ + Convert the y-axis unit of the diffusion model. + + Converts the scale parameter from ``x_unit * old_y_unit`` to ``x_unit * new_y_unit`` and + the existing component collections in place — parameter values and object identity are + preserved, and nothing is scheduled for regeneration. The new y-unit must be dimensionally + compatible with the current one; the scale conversion raises ``UnitError`` otherwise. + + Parameters + ---------- + unit : str | sc.Unit + The new y-axis unit. + + Raises + ------ + TypeError + If unit is not a string or sc.Unit. + Exception + If any conversion fails; the already-converted state is rolled back best-effort before + re-raising. + """ + if not isinstance(unit, (str, sc.Unit)): + raise TypeError(f'y_unit must be a string or sc.Unit, got {type(unit).__name__}') + unit_str = str(unit) + + old_y_unit = str(self.y_unit) + new_scale_unit = str(sc.Unit(str(self.x_unit)) * sc.Unit(unit_str)) + self._scale.convert_unit(new_scale_unit) + try: + for collection in self._component_collections: + collection.convert_y_unit(unit_str) + except Exception: + old_scale_unit = str(sc.Unit(str(self.x_unit)) * sc.Unit(old_y_unit)) + with contextlib.suppress(Exception): + self._scale.convert_unit(old_scale_unit) + for collection in self._component_collections: + with contextlib.suppress(Exception): + collection.convert_y_unit(old_y_unit) + raise + + self._y_unit = unit_str + + def _convert_extra_x_unit_parameters(self, unit_str: str) -> None: + """ + Convert subclass-specific parameters that carry the x-axis unit. + + The base implementation does nothing; subclasses with x-unit parameters (e.g. a Lorentzian + width template) override this method. + + Parameters + ---------- + unit_str : str + The new x-axis unit as a string. + """ + + # ------------------------------------------------------------------ + # Fit targets + # ------------------------------------------------------------------ + + def get_fit_targets(self) -> list[FitTarget]: + """ + Get the fittable predictions of the diffusion model as FitTargets. + + The base implementation declares ``'area'`` (``scale * QISF(Q)``) and ``'width'`` (the HWHM + ``Gamma(Q)``), with default dataset keys derived from the Lorentzian component's name. + Subclasses with additional predictions (e.g. a delta-function weight) extend this list. The + targets are snapshots: units and default keys reflect the model state at call time. + + Returns + ------- + list[FitTarget] + The fittable predictions of this model. + """ + return [ + FitTarget( + name='area', + dataset_key=f'{self.lorentzian_name} area', + function=lambda Q, model=self, **_: model.calculate_QISF(Q) * model.scale.value, + label=f'{self.display_name} area', + x_unit='1/angstrom', + y_unit=str(self.scale.unit), + ), + FitTarget( + name='width', + dataset_key=f'{self.lorentzian_name} width', + function=lambda Q, model=self, **_: model.calculate_width(Q), + label=f'{self.display_name} width', + x_unit='1/angstrom', + y_unit=str(self.x_unit), + ), + ] + # ------------------------------------------------------------------ # Methods # ------------------------------------------------------------------ @@ -305,21 +490,8 @@ def get_independent_variables(self, Q_index: int | None = None) -> list[Paramete ------- list[Parameter] List of independent variables in the model. - - Raises - ------ - ValueError - If Q_index is not None and is not a valid index for the Q values in the model. """ - if Q_index is not None and ( - not isinstance(Q_index, int) - or Q_index < 0 - or Q_index >= len(self._component_collections) - ): - raise ValueError( - f'Q_index must be an integer between 0 and ' - f'{max(len(self._component_collections) - 1, 0)}, or None.' - ) + verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True) return [] @@ -337,22 +509,8 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: ------- list[Parameter] A list of all Parameters from the diffusion model. - - Raises - ------ - ValueError - If Q_index is out of bounds for the number of ComponentCollections. """ - - if Q_index is not None and ( - not isinstance(Q_index, int) - or Q_index < 0 - or Q_index >= len(self._component_collections) - ): - raise ValueError( - f'Q_index must be an integer between 0 and ' - f'{max(len(self._component_collections) - 1, 0)}, or None.' - ) + verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True) variables = self.get_global_variables() variables.extend(self.get_independent_variables(Q_index)) @@ -448,7 +606,9 @@ def create_component_collections(self) -> list[ComponentCollection]: self._component_collections = [] return self._component_collections - self._component_collections = [ComponentCollection()] * len(self.Q) + self._component_collections = [ + ComponentCollection(x_unit=self.x_unit, y_unit=self.y_unit) for _ in range(len(self.Q)) + ] return self._component_collections @@ -464,43 +624,64 @@ def get_component_collections( The index of the desired ComponentCollection. If None, all ComponentCollections are returned. - Raises - ------ - TypeError - If Q_index is not an int. - IndexError - If Q_index is out of bounds for the number of ComponentCollections. - Returns ------- ComponentCollection | list[ComponentCollection] The ComponentCollection at the specified Q index. If Q_index is None, a list of all ComponentCollections is returned. """ + verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True) if Q_index is None: return self._component_collections - if not isinstance(Q_index, int): - raise TypeError(f'Q_index must be an int, got {type(Q_index).__name__}') - if Q_index < 0 or Q_index >= len(self._component_collections): - raise IndexError( - f'Q_index {Q_index} is out of bounds for component collections ' - f'of length {len(self._component_collections)}' - ) return self._component_collections[Q_index] # ------------------------------------------------------------------ # private methods # ------------------------------------------------------------------ + def _write_area_dependency_expression(self, QISF: float) -> str: + """ + Write the dependency expression for the Lorentzian area. + + Parameters + ---------- + QISF : float + Q-dependent incoherent scattering function value. + + Raises + ------ + TypeError + If QISF is not a float. + + Returns + ------- + str + Dependency expression for the area. + """ + if not isinstance(QISF, float): + raise TypeError('QISF must be a float.') + return f'{QISF} * scale' + + def _write_area_dependency_map_expression(self) -> dict[str, DescriptorNumber]: + """ + Write the dependency map for the Lorentzian area. + + Returns + ------- + dict[str, DescriptorNumber] + Dependency map for the area. + """ + return {'scale': self.scale} + def _on_Q_change(self) -> None: """Handle changes to the Q values.""" self.create_component_collections() def _ensure_Q(self, Q: Q_type) -> np.ndarray: """ - Convert Q to a numpy array, ensuring it is not None. Uses the stored Q if no input is - given. + Convert Q to a numpy array of values in 1/angstrom, ensuring it is not None. Uses the + stored Q if no input is given. Parameters ---------- @@ -522,7 +703,7 @@ def _ensure_Q(self, Q: Q_type) -> np.ndarray: if Q is None: raise ValueError('Q must be provided either as an argument or set in the model.') - return _validate_and_convert_Q(Q) + return _validate_and_convert_Q(Q).values # ------------------------------------------------------------------ # dunder methods @@ -538,8 +719,7 @@ def __repr__(self) -> str: String representation of the DiffusionModel. """ return ( - f'{self.__class__.__name__}(' - f'name={self.name!r}, display_name={self.display_name!r}, ' - f'unit={self.unit},\n' + f'{self.__class__.__name__}(name={self.name}, display_name={self.display_name}, ' + f'x_unit={self.x_unit}, y_unit={self.y_unit}, \n' f' scale={self.scale})' ) diff --git a/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py b/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py index 0cd8a8284..67a693205 100644 --- a/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py +++ b/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py @@ -56,7 +56,8 @@ def __init__( diffusion_coefficient: Numeric = 1.0, relaxation_time: Numeric = 1.0, Q: Q_type | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', name: str = 'JumpTranslationalDiffusion', display_name: str | None = 'JumpTranslationalDiffusion', lorentzian_name: str | None = None, @@ -76,8 +77,10 @@ def __init__( Relaxation time t in ps. Q : Q_type | None, default=None Q values for the model. If None, Q is not set. - unit : str | sc.Unit, default='meV' - Unit of the diffusion model. Must be convertible to meV. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis (energy/frequency). Must be convertible to meV. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the model output (intensity). Determines scale.unit = x_unit * y_unit. name : str, default='JumpTranslationalDiffusion' Name of the diffusion model. display_name : str | None, default='JumpTranslationalDiffusion' @@ -101,7 +104,8 @@ def __init__( """ super().__init__( Q=Q, - unit=unit, + x_unit=x_unit, + y_unit=y_unit, scale=scale, name=name, display_name=display_name, @@ -246,7 +250,7 @@ def calculate_width(self, Q: Q_type | None = None) -> np.ndarray: unit_conversion_factor_numerator = ( self._hbar * self.diffusion_coefficient / (self._angstrom**2) ) - unit_conversion_factor_numerator.convert_unit(self.unit) + unit_conversion_factor_numerator.convert_unit(self.x_unit) numerator = unit_conversion_factor_numerator.value * Q**2 @@ -306,11 +310,11 @@ def create_component_collections( list[ComponentCollection] List of ComponentCollections with Jump Diffusion Lorentzian components. """ - Q = self.Q - if Q is None: + if self.Q is None: self._component_collections = [] return self._component_collections + Q = self.Q.values component_collection_list = [None] * len(Q) # In more complex models, this is used to scale the area of the # Lorentzians and the delta function. @@ -323,31 +327,37 @@ def create_component_collections( component_collection_list[i] = ComponentCollection( name=f'{self.name}_Q{Q_value:.2f}', display_name=f'{self.display_name}_Q{Q_value:.2f}', - unit=self.unit, + x_unit=self.x_unit, + y_unit=self.y_unit, ) lorentzian_component = Lorentzian( name=self.lorentzian_name, display_name=self.lorentzian_display_name, - unit=self.unit, + x_unit=self.x_unit, + y_unit=self.y_unit, ) # Make the width dependent on Q dependency_expression = self._write_width_dependency_expression(Q[i]) dependency_map = self._write_width_dependency_map_expression() - lorentzian_component.width.make_dependent_on( - dependency_expression=dependency_expression, - dependency_map=dependency_map, - desired_unit=self.unit, - ) - - # Make the area dependent on Q - area_dependency_map = self._write_area_dependency_map_expression() - lorentzian_component.area.make_dependent_on( - dependency_expression=self._write_area_dependency_expression(QISF[i]), - dependency_map=area_dependency_map, - ) + # easyscience propagates inf bounds through arithmetic, producing inf/inf=nan + # as a transient intermediate. Python's min/max ignore nan so the final bounds + # are correct; suppress the spurious numpy RuntimeWarning. + with np.errstate(invalid='ignore'): + lorentzian_component.width.make_dependent_on( + dependency_expression=dependency_expression, + dependency_map=dependency_map, + desired_unit=self.x_unit, + ) + + # Make the area dependent on Q + area_dependency_map = self._write_area_dependency_map_expression() + lorentzian_component.area.make_dependent_on( + dependency_expression=self._write_area_dependency_expression(QISF[i]), + dependency_map=area_dependency_map, + ) component_collection_list[i].append_component(lorentzian_component) @@ -405,44 +415,6 @@ def _write_width_dependency_map_expression(self) -> dict[str, DescriptorNumber]: 'angstrom': self._angstrom, } - def _write_area_dependency_expression(self, QISF: float) -> str: - """ - Write the dependency expression for the area to make dependent Parameters. - - Parameters - ---------- - QISF : float - Q-dependent intermediate scattering function. - - Raises - ------ - TypeError - If QISF is not a float. - - Returns - ------- - str - Dependency expression for the area. - """ - - if not isinstance(QISF, (float)): - raise TypeError('QISF must be a float.') - - return f'{QISF} * scale' - - def _write_area_dependency_map_expression(self) -> dict[str, DescriptorNumber]: - """ - Write the dependency map expression to make dependent Parameters. - - Returns - ------- - dict[str, DescriptorNumber] - Dependency map for the area. - """ - return { - 'scale': self.scale, - } - ################################ # dunder methods ################################ @@ -459,6 +431,7 @@ def __repr__(self) -> str: return ( f'{self.__class__.__name__}(' f'name={self.name!r}, display_name={self.display_name!r},\n' + f'x_unit={self.x_unit}, y_unit={self.y_unit}, \n' f' diffusion_coefficient={self.diffusion_coefficient},\n' f' scale={self.scale})' ) diff --git a/src/easydynamics/sample_model/instrument_model.py b/src/easydynamics/sample_model/instrument_model.py index 24c3f1719..c6d731129 100644 --- a/src/easydynamics/sample_model/instrument_model.py +++ b/src/easydynamics/sample_model/instrument_model.py @@ -3,7 +3,6 @@ from copy import copy -import numpy as np import scipp as sc from easyscience.base_classes.new_base import NewBase from easyscience.variable import Parameter @@ -15,6 +14,7 @@ from easydynamics.utils.utils import Q_type from easydynamics.utils.utils import _validate_and_convert_Q from easydynamics.utils.utils import _validate_unit +from easydynamics.utils.utils import verify_Q_index class InstrumentModel(NewBase): @@ -60,7 +60,7 @@ def __init__( resolution_model: ResolutionModel | SampleModel | None = None, background_model: BackgroundModel | None = None, energy_offset: Numeric | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', ) -> None: """ Initialize an InstrumentModel. @@ -83,7 +83,7 @@ def __init__( energy_offset : Numeric | None, default=None Template energy offset of the instrument. Will be copied to each Q value. If None, the energy offset will be 0. - unit : str | sc.Unit, default='meV' + x_unit : str | sc.Unit, default='meV' The unit of the energy axis. Raises @@ -97,7 +97,7 @@ def __init__( unique_name=unique_name, ) - self._unit = _validate_unit(unit) + self._x_unit = _validate_unit(x_unit) if resolution_model is None: self._resolution_model = ResolutionModel() @@ -130,7 +130,7 @@ def __init__( self._energy_offset = Parameter( name='energy_offset', value=float(energy_offset), - unit=self.unit, + unit=self.x_unit, fixed=False, ) self._energy_offsets: list = [] @@ -189,7 +189,6 @@ def background_model(self) -> BackgroundModel: BackgroundModel The background model of the instrument. """ - return self._background_model @background_model.setter @@ -207,7 +206,6 @@ def background_model(self, value: BackgroundModel) -> None: TypeError If value is not a BackgroundModel. """ - if not isinstance(value, BackgroundModel): raise TypeError( f'background_model must be a BackgroundModel, got {type(value).__name__}' @@ -216,14 +214,14 @@ def background_model(self, value: BackgroundModel) -> None: self._on_background_model_change() @property - def Q(self) -> np.ndarray | None: + def Q(self) -> sc.Variable | None: """ Get the Q values of the InstrumentModel. Returns ------- - np.ndarray | None - The Q values of the InstrumentModel, or None if not set. + sc.Variable | None + The Q values of the InstrumentModel in 1/angstrom, or None if not set. """ return self._Q @@ -256,68 +254,63 @@ def Q(self, value: Q_type | None) -> None: self._on_Q_change() return - if len(old_Q) != len(new_Q) or not np.allclose(old_Q, new_Q): + if len(old_Q) != len(new_Q) or not sc.allclose(old_Q, new_Q): raise ValueError( 'New Q values are not similar to the old ones. ' 'To change Q values, first run clear_Q().' ) @property - def unit(self) -> str | sc.Unit: + def x_unit(self) -> str | sc.Unit | None: """ - Get the unit of the InstrumentModel. + Get the x-axis unit of the InstrumentModel. Returns ------- - str | sc.Unit - The unit of the InstrumentModel. + str | sc.Unit | None + The x-axis unit of the InstrumentModel. """ - return self._unit + return self._x_unit - @unit.setter - def unit(self, _unit_str: str) -> None: + @x_unit.setter + def x_unit(self, _: str) -> None: """ - Set the unit of the InstrumentModel. - - The unit is read-only and cannot be set directly. Use convert_unit to change the unit - between allowed types or create a new InstrumentModel with the desired unit. + x_unit is read-only and cannot be set directly. - Parameters - ---------- - _unit_str : str - The new unit for the InstrumentModel (ignored). + Use convert_x_unit to change the unit between allowed types or create a new InstrumentModel + with the desired unit. Raises ------ AttributeError - Always, as the unit is read-only. + Always, as x_unit is read-only. """ raise AttributeError( - f'Unit is read-only. Use convert_unit to change the unit between allowed types ' + f'x_unit is read-only. Use convert_x_unit to change the unit between allowed types ' f'or create a new {self.__class__.__name__} with the desired unit.' ) @property def energy_offset(self) -> Parameter: """ - Get the energy offset template parameter of the instrument model. + Get the template energy offset of the instrument. Returns ------- Parameter - The energy offset template parameter of the instrument model. + The energy offset Parameter. Each Q value gets its own copy via get_energy_offset(). """ return self._energy_offset @energy_offset.setter def energy_offset(self, value: Numeric) -> None: """ - Set the offset parameter of the instrument model. + Set the template energy offset value, propagating to all Q-specific offsets. Parameters ---------- value : Numeric - The new value for the energy offset parameter. Will be copied to all Q values. + The new energy offset value in x_unit. Raises ------ @@ -327,7 +320,6 @@ def energy_offset(self, value: Numeric) -> None: if not isinstance(value, Numeric): raise TypeError(f'energy_offset must be a number, got {type(value).__name__}') self._energy_offset.value = value - self._on_energy_offset_change() # -------------------------------------------------------------- @@ -358,32 +350,32 @@ def clear_Q(self, confirm: bool = False) -> None: self.resolution_model.clear_Q(confirm=True) self._on_Q_change() - def convert_unit(self, unit_str: str | sc.Unit) -> None: + def convert_x_unit(self, x_unit: str | sc.Unit) -> None: """ Convert the unit of the InstrumentModel. Parameters ---------- - unit_str : str | sc.Unit + x_unit : str | sc.Unit The unit to convert to. Raises ------ ValueError - If unit_str is not a valid unit string or scipp Unit. + If x_unit is not a valid unit string or scipp Unit. """ - unit = _validate_unit(unit_str) + unit = _validate_unit(x_unit) if unit is None: - raise ValueError('unit_str must be a valid unit string or scipp Unit') + raise ValueError('x_unit must be a valid unit string or scipp Unit') - self._background_model.convert_unit(unit) - self._resolution_model.convert_unit(unit) - self._energy_offset.convert_unit(unit) + self._background_model.convert_x_unit(unit) + self._resolution_model.convert_x_unit(unit) self._ensure_energy_offsets_current() + self._energy_offset.convert_unit(unit) for offset in self._energy_offsets: offset.convert_unit(unit) - self._unit = unit + self._x_unit = unit def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: """ @@ -395,13 +387,6 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: The index of the Q value to get variables for. If None, get variables for all Q values. - Raises - ------ - TypeError - If Q_index is not an int or None. - IndexError - If Q_index is out of bounds for the Q values in the InstrumentModel. - Returns ------- list[Parameter] @@ -413,15 +398,10 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: return [] self._ensure_energy_offsets_current() + verify_Q_index(Q_index=Q_index, Q=self._Q, allow_none=True) if Q_index is None: variables = [self._energy_offsets[i] for i in range(len(self._Q))] else: - if not isinstance(Q_index, int): - raise TypeError(f'Q_index must be an int or None, got {type(Q_index).__name__}') - if Q_index < 0 or Q_index >= len(self._Q): - raise IndexError( - f'Q_index {Q_index} is out of bounds for Q of length {len(self._Q)}' - ) variables = [self._energy_offsets[Q_index]] variables.extend(self._background_model.get_all_variables(Q_index=Q_index)) @@ -458,10 +438,6 @@ def get_energy_offset( ------ ValueError If no Q values are set in the InstrumentModel. - IndexError - If Q_index is out of bounds. - TypeError - If Q_index is not an int or None. Returns ------- @@ -473,15 +449,10 @@ def get_energy_offset( raise ValueError('No Q values are set in the InstrumentModel.') self._ensure_energy_offsets_current() + verify_Q_index(Q_index=Q_index, Q=self._Q, allow_none=True) if Q_index is None: return self._energy_offsets - if not isinstance(Q_index, int): - raise TypeError(f'Q_index must be an int or None, got {type(Q_index).__name__}') - - if Q_index < 0 or Q_index >= len(self._Q): - raise IndexError(f'Q_index {Q_index} is out of bounds for Q of length {len(self._Q)}') - return self._energy_offsets[Q_index] def fix_energy_offset(self, Q_index: int | None = None) -> None: @@ -531,27 +502,14 @@ def _fix_or_free_energy_offset(self, Q_index: int | None = None, fixed: bool = T energy offsets for all Q values. fixed : bool, default=True Whether to fix (True) or free (False) the energy offset. - - Raises - ------ - TypeError - If Q_index is not an int or None. - IndexError - If Q_index is out of bounds for the Q values in the InstrumentModel. """ - self._ensure_energy_offsets_current() + verify_Q_index(Q_index=Q_index, Q=self._Q, allow_none=True) if Q_index is None: + self._energy_offset.fixed = fixed for offset in self._energy_offsets: offset.fixed = fixed else: - if not isinstance(Q_index, int): - raise TypeError(f'Q_index must be an int or None, got {type(Q_index).__name__}') - - if Q_index < 0 or Q_index >= len(self._Q): - raise IndexError( - f'Q_index {Q_index} is out of bounds for Q of length {len(self._Q)}' - ) self._energy_offsets[Q_index].fixed = fixed def _ensure_energy_offsets_current(self) -> None: @@ -566,7 +524,7 @@ def _generate_energy_offsets(self) -> None: self._energy_offsets = [] return - self._energy_offsets = [copy(self._energy_offset) for _ in self._Q] + self._energy_offsets = [copy(self._energy_offset) for _ in range(len(self._Q))] def _on_Q_change(self) -> None: """Handle changes to the Q values.""" @@ -601,11 +559,10 @@ def __repr__(self) -> str: str A string representation of the InstrumentModel. """ - return ( f'{self.__class__.__name__}(' f'unique_name={self.unique_name!r}, ' - f'unit={self.unit}, ' + f'x_unit={self.x_unit}, ' f'Q_len={None if self._Q is None else len(self._Q)}, ' f'resolution_model={self._resolution_model!r}, ' f'background_model={self._background_model!r})' diff --git a/src/easydynamics/sample_model/model_base.py b/src/easydynamics/sample_model/model_base.py index a2f1d7765..f2e13ae9d 100644 --- a/src/easydynamics/sample_model/model_base.py +++ b/src/easydynamics/sample_model/model_base.py @@ -13,6 +13,7 @@ from easydynamics.utils.utils import Numeric from easydynamics.utils.utils import Q_type from easydynamics.utils.utils import _validate_and_convert_Q +from easydynamics.utils.utils import verify_Q_index class ModelBase(EasyDynamicsModelBase): @@ -26,7 +27,8 @@ def __init__( self, display_name: str = 'MyModelBase', unique_name: str | None = None, - unit: str | sc.Unit | None = 'meV', + x_unit: str | sc.Unit | None = 'meV', + y_unit: str | sc.Unit = 'dimensionless', components: ModelComponent | ComponentCollection | None = None, Q: Q_type | None = None, ) -> None: @@ -39,8 +41,10 @@ def __init__( Display name of the model. unique_name : str | None, default=None Unique name of the model. If None, a unique name will be generated. - unit : str | sc.Unit | None, default='meV' - Unit of the model. + x_unit : str | sc.Unit | None, default='meV' + Unit of the x-axis (energy, Q, etc.). + y_unit : str | sc.Unit, default='dimensionless' + Unit of the model output (intensity). components : ModelComponent | ComponentCollection | None, default=None Template components of the model. If None, no components are added. These components are copied into ComponentCollections for each Q value. @@ -53,7 +57,8 @@ def __init__( If components is not a ModelComponent or ComponentCollection. """ super().__init__( - unit=unit, + x_unit=x_unit, + y_unit=y_unit, display_name=display_name, unique_name=unique_name, ) @@ -74,17 +79,19 @@ def __init__( self.append_component(components) def evaluate( - self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray - ) -> list[np.ndarray]: + self, + x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray, + output: str = 'numpy', + ) -> list[np.ndarray] | list[sc.Variable]: """ Evaluate the sample model at all Q for the given x values. Parameters ---------- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - Energy axis values to evaluate the model at. If a scipp Variable or DataArray is - provided, the unit of the model will be converted to match the unit of x for - evaluation, and the result will be returned in the same unit as x. + Energy axis values to evaluate the model at. + output : str, default='numpy' + 'numpy' returns np.ndarray per Q; 'scipp' returns sc.Variable per Q. Raises ------ @@ -93,15 +100,16 @@ def evaluate( Returns ------- - list[np.ndarray] - A list of numpy arrays containing the evaluated model values for each Q. The length of - the list will match the number of Q values in the model. + list[np.ndarray] | list[sc.Variable] + A list of arrays containing the evaluated model values for each Q. The length of the + list will match the number of Q values in the model. """ - self._ensure_component_collections_current() if not self._component_collections: raise ValueError('No components in the model to evaluate.') - return [collection.evaluate(x) for collection in self._component_collections] + return [ + collection.evaluate(x, output=output) for collection in self._component_collections + ] # ------------------------------------------------------------------ # Component management @@ -186,14 +194,14 @@ def component_collections_is_dirty(self) -> bool: return self._component_collections_is_dirty @property - def Q(self) -> np.ndarray | None: + def Q(self) -> sc.Variable | None: """ Get the Q values of the SampleModel. Returns ------- - np.ndarray | None - The Q values of the SampleModel, or None if not set. + sc.Variable | None + The Q values of the SampleModel in 1/angstrom, or None if not set. """ return self._Q @@ -225,7 +233,7 @@ def Q(self, value: Q_type | None) -> None: self._on_Q_change() return - if len(old_Q) != len(new_Q) or not np.allclose(old_Q, new_Q): + if len(old_Q) != len(new_Q) or not sc.allclose(old_Q, new_Q): raise ValueError( 'New Q values are not similar to the old ones. ' 'To change Q values, first run clear_Q().' @@ -257,14 +265,42 @@ def clear_Q(self, confirm: bool = False) -> None: # Other methods # ------------------------------------------------------------------ - def convert_unit(self, unit: str | sc.Unit) -> None: + def convert_x_unit(self, unit: str | sc.Unit) -> None: + """ + Convert the x-axis unit of all components in the model. + + Parameters + ---------- + unit : str | sc.Unit + The new x-axis unit to convert to. + """ + self._convert_axis_unit(unit, axis='x') + + def convert_y_unit(self, unit: str | sc.Unit) -> None: + """ + Convert the y-axis unit of all components in the model. + + Parameters + ---------- + unit : str | sc.Unit + The new y-axis unit to convert to. + """ + self._convert_axis_unit(unit, axis='y') + + def _convert_axis_unit(self, unit: str | sc.Unit, axis: str) -> None: """ - Convert the unit of the ComponentCollection and all its components. + Convert one axis unit on all template components and per-Q collections. + + Converts every child via its ``convert__unit`` method and updates the model's own + unit attribute. On failure, attempts a best-effort rollback of all children to the old unit + before re-raising. Parameters ---------- unit : str | sc.Unit The new unit to convert to. + axis : str + Which axis to convert: ``'x'`` or ``'y'``. Raises ------ @@ -273,24 +309,31 @@ def convert_unit(self, unit: str | sc.Unit) -> None: Exception If the provided unit is not compatible with the current unit. """ - - old_unit = self._unit - if not isinstance(unit, (str, sc.Unit)): raise TypeError(f'Unit must be a string or sc.Unit, got {type(unit).__name__}') + + method = f'convert_{axis}_unit' + old_unit = self.x_unit if axis == 'x' else self.y_unit try: for component in self.components: - component.convert_unit(unit) - self._unit = unit + getattr(component, method)(unit) + for collection in self._component_collections: + getattr(collection, method)(unit) + unit_str = str(unit) if isinstance(unit, sc.Unit) else unit + if axis == 'x': + self._x_unit = unit_str + else: + self._y_unit = unit_str except Exception as e: - # Attempt to rollback on failure - try: - for component in self.components: - component.convert_unit(old_unit) - except Exception: # noqa: S110 - pass # Best effort rollback + if old_unit is not None: + try: + for component in self.components: + getattr(component, method)(old_unit) + for collection in self._component_collections: + getattr(collection, method)(old_unit) + except Exception: # noqa: S110 + pass raise e - self._on_components_change() def fix_all_parameters(self) -> None: """Fix all Parameters in all ComponentCollections.""" @@ -314,21 +357,14 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: If None, get variables for all ComponentCollections. If int, get variables for the ComponentCollection at this index. - Raises - ------ - TypeError - If Q_index is not an int or None. - IndexError - If Q_index is out of bounds for the number of ComponentCollections. - Returns ------- list[Parameter] A list of all Parameters and Descriptors from the ComponentCollections in the ModelBase. """ - self._ensure_component_collections_current() + verify_Q_index(Q_index=Q_index, Q=self.Q, allow_none=True) if Q_index is None: all_vars = [ var @@ -336,13 +372,6 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: for var in collection.get_all_variables() ] else: - if not isinstance(Q_index, int): - raise TypeError(f'Q_index must be an int or None, got {type(Q_index).__name__}') - if Q_index < 0 or Q_index >= len(self._component_collections): - raise IndexError( - f'Q_index {Q_index} is out of bounds for component collections ' - f'of length {len(self._component_collections)}' - ) all_vars = self._component_collections[Q_index].get_all_variables() return all_vars @@ -355,26 +384,13 @@ def get_component_collection(self, Q_index: int) -> ComponentCollection: Q_index : int The index of the desired ComponentCollection. - Raises - ------ - TypeError - If Q_index is not an int. - IndexError - If Q_index is out of bounds for the number of ComponentCollections. - Returns ------- ComponentCollection The ComponentCollection at the given Q index. """ self._ensure_component_collections_current() - if not isinstance(Q_index, int): - raise TypeError(f'Q_index must be an int, got {type(Q_index).__name__}') - if Q_index < 0 or Q_index >= len(self._component_collections): - raise IndexError( - f'Q_index {Q_index} is out of bounds for component collections ' - f'of length {len(self._component_collections)}' - ) + verify_Q_index(Q_index=Q_index, Q=self.Q) return self._component_collections[Q_index] def normalize_area(self) -> None: @@ -397,13 +413,12 @@ def _ensure_component_collections_current(self) -> None: def _generate_component_collections(self) -> None: """Generate ComponentCollections for each Q value.""" - if self.Q is None: self._component_collections = [] return self._component_collections = [] - for _ in self.Q: + for _ in range(len(self.Q)): self._component_collections.append(copy(self._components)) def _on_Q_change(self) -> None: @@ -430,7 +445,8 @@ def __repr__(self) -> str: return ( f'{self.__class__.__name__}(' f'unique_name={self.unique_name!r}, ' - f'unit={self.unit}, ' - f'Q={self.Q}, ' + f'x_unit={self.x_unit}, ' + f'y_unit={self.y_unit}, ' + f'Q={None if self.Q is None else self.Q.values}, ' f'components={self.components})' ) diff --git a/src/easydynamics/sample_model/resolution_model.py b/src/easydynamics/sample_model/resolution_model.py index ce5117ff7..a9fc9e1ee 100644 --- a/src/easydynamics/sample_model/resolution_model.py +++ b/src/easydynamics/sample_model/resolution_model.py @@ -51,12 +51,13 @@ def __init__( self, display_name: str = 'MyResolutionModel', unique_name: str | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', components: ModelComponent | ComponentCollection | None = None, Q: Q_type | None = None, ) -> None: """ - Initialize a ResolutionModel. + Initialize the ResolutionModel. Parameters ---------- @@ -64,19 +65,20 @@ def __init__( Display name of the model. unique_name : str | None, default=None Unique name of the model. If None, a unique name will be generated. - unit : str | sc.Unit, default='meV' - Unit of the model. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the y-axis (output). components : ModelComponent | ComponentCollection | None, default=None - Template components of the model. If None, no components are added. These components - are copied into ComponentCollections for each Q value. + Template components. DeltaFunction, Polynomial, and Exponential are not allowed. Q : Q_type | None, default=None Q values for the model. If None, Q is not set. """ - super().__init__( display_name=display_name, unique_name=unique_name, - unit=unit, + x_unit=x_unit, + y_unit=y_unit, components=components, Q=Q, ) @@ -85,8 +87,8 @@ def append_component(self, component: ModelComponent | ComponentCollection) -> N """ Append a component to the ResolutionModel. - Does not allow DeltaFunction or Polynomial components, as these are not physical resolution - components. + Does not allow DeltaFunction, Polynomial, or Exponential components, as these are not + physical resolution components. Parameters ---------- @@ -96,7 +98,7 @@ def append_component(self, component: ModelComponent | ComponentCollection) -> N Raises ------ TypeError - If the component is a DeltaFunction or Polynomial. + If the component is a DeltaFunction, Polynomial, or Exponential. """ components = component if isinstance(component, ComponentCollection) else (component,) @@ -151,22 +153,27 @@ def from_sample_model( resolution_model = cls( display_name=sample_model.display_name, - unit=sample_model.unit, + x_unit=sample_model.x_unit, + y_unit=sample_model.y_unit, components=sample_model.components, Q=sample_model.Q, ) if sample_model.Q is not None: - resolution_model._ensure_component_collections_current() - for index in range(len(sample_model.Q)): - resolution_model._component_collections[index] = copy( - sample_model.get_component_collection(Q_index=index) - ) - if normalize_area: - resolution_model.normalize_area() - - if fix_parameters: - resolution_model.fix_all_parameters() + # Prepare the per-Q collections detached from the model so no EasyScience + # callback can schedule a rebuild halfway through, then install them and + # clear the dirty flag in one final step. + collections = [ + copy(sample_model.get_component_collection(Q_index=index)) + for index in range(len(sample_model.Q)) + ] + for collection in collections: + if normalize_area: + collection.normalize_area() + if fix_parameters: + collection.fix_all_parameters() + resolution_model._component_collections = collections + resolution_model._component_collections_is_dirty = False return resolution_model @@ -174,7 +181,8 @@ def __repr__(self) -> str: return ( f'{self.__class__.__name__}(' f'unique_name={self.unique_name!r}, ' - f'unit={self.unit}, ' + f'x_unit={self.x_unit}, ' + f'y_unit={self.y_unit}, ' f'Q_len={None if self._Q is None else len(self._Q)}, ' f'components={self.components})' ) diff --git a/src/easydynamics/sample_model/sample_model.py b/src/easydynamics/sample_model/sample_model.py index 3991a10f9..f0039161f 100644 --- a/src/easydynamics/sample_model/sample_model.py +++ b/src/easydynamics/sample_model/sample_model.py @@ -66,7 +66,8 @@ def __init__( self, display_name: str = 'MySampleModel', unique_name: str | None = None, - unit: str | sc.Unit = 'meV', + x_unit: str | sc.Unit = 'meV', + y_unit: str | sc.Unit = 'dimensionless', components: ModelComponent | ComponentCollection | None = None, Q: Q_type | None = None, diffusion_models: DiffusionModelBase | list[DiffusionModelBase] | None = None, @@ -83,33 +84,31 @@ def __init__( Display name of the model. unique_name : str | None, default=None Unique name of the model. If None, a unique name will be generated. - unit : str | sc.Unit, default='meV' - Unit of the model. If None,. + x_unit : str | sc.Unit, default='meV' + Unit of the x-axis. + y_unit : str | sc.Unit, default='dimensionless' + Unit of the y-axis (output). components : ModelComponent | ComponentCollection | None, default=None - Template components of the model. If None, no components are added. These components - are copied into ComponentCollections for each Q value. + Template components copied into each Q's ComponentCollection. Q : Q_type | None, default=None - Q values for the model. If None, Q is not set. + Q values. If None, Q is not set. diffusion_models : DiffusionModelBase | list[DiffusionModelBase] | None, default=None - Diffusion models to include in the SampleModel. If None, no diffusion models are added. + Diffusion models to include. Each must be a DiffusionModelBase. temperature : float | None, default=None - Temperature for detailed balancing. If None, no detailed balancing is applied. By - default, None. + Sample temperature in temperature_unit. If provided, detailed balance is applied. temperature_unit : str | sc.Unit, default='K' - Unit of the temperature. + Unit for the temperature parameter. detailed_balance_settings : DetailedBalanceSettings | None, default=None - Settings for detailed balancing. + Detailed balance settings. If None, default settings are used. Raises ------ TypeError - If diffusion_models is not a DiffusionModelBase, a list of DiffusionModelBase, or None, - or if temperature is not a number or None, or if detailed_balance_settings is not a - DetailedBalanceSettings instance. + If diffusion_models contains non-DiffusionModelBase items, temperature is not numeric, + or detailed_balance_settings is not a DetailedBalanceSettings instance. ValueError If temperature is negative. """ - if diffusion_models is None: self._diffusion_models = [] elif isinstance(diffusion_models, DiffusionModelBase): @@ -131,7 +130,8 @@ def __init__( super().__init__( display_name=display_name, unique_name=unique_name, - unit=unit, + x_unit=x_unit, + y_unit=y_unit, components=components, Q=Q, ) @@ -178,7 +178,6 @@ def append_diffusion_model(self, diffusion_model: DiffusionModelBase) -> None: TypeError If the diffusion_model is not a DiffusionModelBase. """ - if not isinstance(diffusion_model, DiffusionModelBase): raise TypeError( f'diffusion_model must be a DiffusionModelBase, got {type(diffusion_model).__name__}' # noqa: E501 @@ -201,15 +200,19 @@ def remove_diffusion_model(self, name: str) -> None: ValueError If no DiffusionModel with the given name is found. """ - for i, dm in enumerate(self.diffusion_models): - if dm.name == name: - del self.diffusion_models[i] - self._component_collections_is_dirty = True - return - raise ValueError( - f'No DiffusionModel with name {name} found. \n' - f'The available names are: {[dm.name for dm in self.diffusion_models]}' - ) + matches = [i for i, dm in enumerate(self.diffusion_models) if dm.name == name] + if len(matches) == 0: + raise ValueError( + f'No DiffusionModel with name {name!r} found. ' + f'Available names are: {[dm.name for dm in self.diffusion_models]}' + ) + if len(matches) > 1: + raise ValueError( + f'Multiple DiffusionModels share the name {name!r}. ' + f'Rename them to have unique names before removing.' + ) + del self.diffusion_models[matches[0]] + self._component_collections_is_dirty = True def clear_diffusion_models(self) -> None: """Clear all DiffusionModels from the SampleModel.""" @@ -250,7 +253,6 @@ def diffusion_models( TypeError If value is not a DiffusionModelBase, a list of DiffusionModelBase, or None. """ - if value is None: self._diffusion_models = [] self._on_diffusion_models_change() @@ -292,7 +294,7 @@ def temperature(self, value: Numeric | None) -> None: Parameters ---------- value : Numeric | None - The temperature value to set. Can be a number or None to unset the temperature. + The temperature value. If None, temperature is cleared (no detailed balance). Raises ------ @@ -325,12 +327,12 @@ def temperature(self, value: Numeric | None) -> None: @property def temperature_unit(self) -> str | sc.Unit: """ - Get the temperature unit of the SampleModel. + Get the temperature unit. Returns ------- str | sc.Unit - The unit of the temperature Parameter. + The unit of the temperature parameter. """ return self._temperature_unit @@ -351,8 +353,8 @@ def temperature_unit(self, _value: str | sc.Unit) -> None: """ raise AttributeError( - f'Temperature_unit is read-only. Use convert_temperature_unit to change the unit between allowed types ' # noqa: E501 - f'or create a new {self.__class__.__name__} with the desired unit.' + f'Temperature_unit is read-only. Use convert_temperature_unit to change the unit ' + f'between allowed types or create a new {self.__class__.__name__} with the desired unit.' # noqa: E501 ) def convert_temperature_unit(self, unit: str | sc.Unit) -> None: @@ -382,10 +384,53 @@ def convert_temperature_unit(self, unit: str | sc.Unit) -> None: self._temperature_unit = unit except Exception: # Attempt to rollback on failure + with suppress(Exception): self.temperature.convert_unit(old_unit) raise + def _convert_axis_unit(self, unit: str | sc.Unit, axis: str) -> None: + """ + Convert one axis unit on the SampleModel, including any attached diffusion models. + + Extends the ModelBase conversion (template components and per-Q collections) by also + converting each diffusion model, whose regenerated collections would otherwise come back in + the old unit on the next rebuild. + + Parameters + ---------- + unit : str | sc.Unit + The new unit to convert to. + axis : str + Which axis to convert: ``'x'`` or ``'y'``. + + Raises + ------ + Exception + If any conversion fails; the already-converted diffusion models and the ModelBase state + are rolled back before re-raising. + """ + old_unit = self.x_unit if axis == 'x' else self.y_unit + super()._convert_axis_unit(unit, axis) + + method = f'convert_{axis}_unit' + converted_models = [] + try: + for diffusion_model in self.diffusion_models: + getattr(diffusion_model, method)(unit) + converted_models.append(diffusion_model) + except Exception: + for diffusion_model in converted_models: + with suppress(Exception): + getattr(diffusion_model, method)(old_unit) + if old_unit is not None: + with suppress(Exception): + super()._convert_axis_unit(old_unit, axis) + raise + # Everything is converted in place (the merged collections share the diffusion models' + # component objects), so nothing is marked dirty: conversion must not discard the + # per-Q state by triggering a rebuild. + @property def normalize_detailed_balance(self) -> bool: """ @@ -420,12 +465,12 @@ def normalize_detailed_balance(self, value: bool) -> None: @property def use_detailed_balance(self) -> bool: """ - Get whether to apply detailed balance to the model. + Get whether detailed balance correction is applied. Returns ------- bool - True if detailed balance is applied, False otherwise. + True if detailed balance is applied during evaluation, False otherwise """ return self.detailed_balance_settings.use_detailed_balance @@ -451,12 +496,12 @@ def use_detailed_balance(self, value: bool) -> None: @property def detailed_balance_settings(self) -> DetailedBalanceSettings: """ - Get the DetailedBalanceSettings of the SampleModel. + Get the detailed balance settings. Returns ------- DetailedBalanceSettings - The DetailedBalanceSettings of the SampleModel. + The detailed balance settings object. """ return self._detailed_balance_settings @@ -484,31 +529,33 @@ def detailed_balance_settings(self, value: DetailedBalanceSettings) -> None: # ------------------------------------------------------------------ def evaluate( - self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray - ) -> list[np.ndarray]: + self, + x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray, + output: str = 'numpy', + ) -> list[np.ndarray] | list[sc.Variable]: """ Evaluate the sample model at all Q for the given x values. Parameters ---------- x : Numeric | list | np.ndarray | sc.Variable | sc.DataArray - The x values to evaluate the model at. Can be a number, list, numpy array, scipp - Variable, or scipp DataArray. + The x values to evaluate the model at. + output : str, default='numpy' + 'numpy' returns list of np.ndarray; 'scipp' returns list of sc.Variable. Returns ------- - list[np.ndarray] + list[np.ndarray] | list[sc.Variable] List of evaluated model values for each Q. """ - - y = super().evaluate(x) + y = super().evaluate(x, output=output) if self.temperature is not None and self.detailed_balance_settings.use_detailed_balance: DBF = detailed_balance_factor( energy=x, temperature=self.temperature, divide_by_temperature=self.detailed_balance_settings.normalize_detailed_balance, - energy_unit=self.unit, + energy_unit=self.x_unit, ) y = [yi * DBF for yi in y] @@ -530,10 +577,8 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: Returns ------- list[Parameter] - List of all Parameters and Descriptors, including temperature if set and all variables - from diffusion models. + All Parameters and Descriptors in the SampleModel. """ - all_vars = super().get_all_variables(Q_index=Q_index) if self.temperature is not None: all_vars.append(self.temperature) @@ -557,8 +602,7 @@ def _generate_component_collections(self) -> None: if self.Q is None: return - # Generate components from diffusion models - # and add to component collections + # Generate components from diffusion models and add to component collections for diffusion_model in self.diffusion_models: diffusion_collections = diffusion_model.get_component_collections() for target, source in zip( @@ -595,12 +639,11 @@ def __repr__(self) -> str: str A string representation of the SampleModel. """ - return ( - f'{self.__class__.__name__}(' - f'unique_name={self.unique_name!r}, unit={self.unit},\n' - f' Q={self.Q},\n' - f' components={self.components}, diffusion_models={self.diffusion_models},\n' - f' temperature={self.temperature},\n' - f' detailed_balance_settings={self.detailed_balance_settings})' + f'{self.__class__.__name__}(unique_name={self.unique_name}, ' + f'x_unit={self.x_unit}, y_unit={self.y_unit}, ' + f'Q = {None if self.Q is None else self.Q.values}, \n ' + f'components = {self.components}, diffusion_models = {self.diffusion_models}, ' + f'temperature = {self.temperature}, ' + f'detailed_balance_settings = {self.detailed_balance_settings})' ) diff --git a/src/easydynamics/settings/convolution_settings.py b/src/easydynamics/settings/convolution_settings.py index fadc13a1e..309efa9c6 100644 --- a/src/easydynamics/settings/convolution_settings.py +++ b/src/easydynamics/settings/convolution_settings.py @@ -90,6 +90,13 @@ def __init__( self._suppress_warnings = suppress_warnings self._convolution_plan_is_valid = False + # Plan-invalidation bookkeeping for convolvers sharing this settings object. + # _plan_version is bumped whenever the plan is invalidated (accuracy knob change or + # an explicit convolution_plan_is_valid = False). _blessed_plan_version records the + # version at the last explicit convolution_plan_is_valid = True, which suppresses + # rebuilds for all convolvers until the next invalidation. + self._plan_version = 0 + self._blessed_plan_version = -1 @property def upsample_factor(self) -> Numeric | None: @@ -173,6 +180,11 @@ def extension_factor(self, factor: Numeric) -> None: If factor is negative. """ + if factor is None: + self._extension_factor = factor + self.convolution_plan_is_valid = False + return + if not isinstance(factor, Numeric): raise TypeError('Extension factor must be a number.') if factor < 0.0: @@ -211,6 +223,48 @@ def convolution_plan_is_valid(self, is_valid: bool) -> None: if not isinstance(is_valid, bool): raise TypeError('convolution_plan_is_valid must be True or False.') self._convolution_plan_is_valid = is_valid + if is_valid: + # An explicit True blesses the current version: all convolvers sharing these + # settings may skip rebuilding until the next invalidation. + self._blessed_plan_version = self._plan_version + else: + # An explicit False (or a knob setter delegating here) invalidates the plan for + # every convolver sharing these settings. + self._plan_version += 1 + + def _plan_valid_for(self, seen_version: int) -> bool: + """ + Check whether a convolver that last rebuilt at seen_version can skip rebuilding. + + Parameters + ---------- + seen_version : int + The plan version the convolver recorded when it last rebuilt its plan. + + Returns + ------- + bool + True if the plan flag is set and no invalidation happened since the convolver's rebuild + (or the current version was explicitly blessed). + """ + return self._convolution_plan_is_valid and ( + seen_version == self._plan_version or self._blessed_plan_version == self._plan_version + ) + + def _mark_plan_built(self) -> int: + """ + Record that one convolver rebuilt its plan. + + Sets the plan flag without blessing the current version, so other convolvers sharing this + settings object still detect invalidations they have not consumed yet. + + Returns + ------- + int + The current plan version, to be stored by the convolver. + """ + self._convolution_plan_is_valid = True + return self._plan_version @property def suppress_warnings(self) -> bool: @@ -243,6 +297,22 @@ def suppress_warnings(self, suppress: bool) -> None: raise TypeError('suppress_warnings must be True or False.') self._suppress_warnings = suppress + def __copy__(self) -> 'ConvolutionSettings': + """ + Return a shallow copy of the ConvolutionSettings. + + Returns + ------- + 'ConvolutionSettings' + A new ConvolutionSettings instance with the same parameter values. + """ + return ConvolutionSettings( + upsample_factor=self.upsample_factor, + extension_factor=self.extension_factor, + suppress_warnings=self.suppress_warnings, + display_name=self.display_name, + ) + def __repr__(self) -> str: """ Return a string representation of the ConvolutionSettings. diff --git a/src/easydynamics/utils/fit_target.py b/src/easydynamics/utils/fit_target.py new file mode 100644 index 000000000..55c7ea5c5 --- /dev/null +++ b/src/easydynamics/utils/fit_target.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable + + +@dataclass(frozen=True) +class FitTarget: + """ + One fittable prediction of a model, bound to a key in a parameters Dataset. + + Models declare their predictions by returning FitTargets (see + ``DiffusionModelBase.get_fit_targets``), and ``FitBinding`` maps them onto the dataset keys + they should be fitted against. Instances are immutable snapshots created on demand, so the + units always reflect the model state at the time the targets are built. + + Attributes + ---------- + name : str + The prediction's name (e.g. ``'width'``, ``'area'``, ``'delta_area'``, ``'value'``). + dataset_key : str + The key in the parameters Dataset holding the data this prediction is fitted against. + function : Callable + The fit function; called as ``function(x)`` with raw x values expressed in *x_unit* and + returning raw values expressed in *y_unit*. + label : str + Display label used for plots and results (e.g. ``'DeltaLorentz width'``). + x_unit : str | None + The unit *function* expects its input in, or None if no unit conversion applies. + y_unit : str | None + The unit of *function*'s output, or None if no unit conversion applies. + """ + + name: str + dataset_key: str + function: Callable + label: str + x_unit: str | None + y_unit: str | None diff --git a/src/easydynamics/utils/utils.py b/src/easydynamics/utils/utils.py index d9fc27c0f..02ec54e5f 100644 --- a/src/easydynamics/utils/utils.py +++ b/src/easydynamics/utils/utils.py @@ -4,8 +4,10 @@ import numpy as np import scipp as sc from easyscience.variable import DescriptorNumber +from easyscience.variable import Parameter from numpy.typing import ArrayLike from scipp.constants import hbar as scipp_hbar +from scipp.constants import k as scipp_k Numeric = float | int @@ -13,19 +15,101 @@ energy_type = np.ndarray | Numeric | list | ArrayLike | sc.Variable hbar = DescriptorNumber.from_scipp('hbar', scipp_hbar) +kb = DescriptorNumber.from_scipp('kb', scipp_k) angstrom = DescriptorNumber('angstrom', 1e-10, unit='m') +def verify_Q_index(Q_index: int, Q: sc.Variable | None, allow_none: bool = False) -> None: + """ + Verify that Q_index is a valid integer index into Q. + + Parameters + ---------- + Q_index : int + Index to validate. + Q : sc.Variable | None + The Q values (may be None if no data is loaded). + allow_none : bool, default=False + Whether or not to allow Q_index to be None + + Raises + ------ + TypeError + If Q_index is not an int (or not an int or None when allow_none=True). + IndexError + If Q_index is out of range. + ValueError + If Q is None and Q_index is not None. + """ + if allow_none and Q_index is None: + return + + if Q_index is None or not isinstance(Q_index, int): + if allow_none: + raise TypeError(f'Q_index must be an int or None, got {type(Q_index).__name__}') + raise TypeError(f'Q_index must be an int, got {type(Q_index).__name__}') + + if Q is None: + raise ValueError('Q is None, cannot validate Q_index.') + + if not (0 <= Q_index < len(Q)): + raise IndexError(f'Q_index {Q_index} is out of bounds for Q of length {len(Q)}') + + +def convert_parameter_unit(parameter: Parameter, unit: str | sc.Unit) -> None: + """ + Convert a parameter to a new unit, keeping dependent parameters consistent. + + Independent parameters are converted with ``convert_unit``. Dependent parameters are converted + with ``set_desired_unit``, so the new unit survives later dependency-graph recomputations (a + plain ``convert_unit`` would be reverted to the old desired unit the next time the dependency + expression is re-evaluated). + + Parameters + ---------- + parameter : Parameter + The parameter to convert. + unit : str | sc.Unit + The unit to convert to. + """ + if parameter.independent: + parameter.convert_unit(str(unit)) + else: + parameter.set_desired_unit(str(unit)) + + +def energy_to_scipp(energy: np.ndarray, unit: str | sc.Unit) -> sc.Variable: + """ + Convert a numpy energy array to a scipp Variable with dimension 'energy'. + + Parameters + ---------- + energy : np.ndarray + The energy array to be converted + unit : str | sc.Unit + The unit of the energy + + Returns + ------- + sc.Variable + Energy as sc.Variable. + """ + return sc.array(dims=['energy'], values=energy, unit=unit) + + def _validate_and_convert_Q( Q: np.ndarray | Numeric | list | ArrayLike | sc.Variable | None, -) -> np.ndarray | None: +) -> sc.Variable | None: """ - Validate and convert Q to a numpy array. + Validate and convert Q to a scipp Variable in 1/angstrom. + + Numbers, lists, and numpy arrays are assumed to be in 1/angstrom. Scipp Variables may be in any + unit convertible to 1/angstrom and are converted. Parameters ---------- Q : np.ndarray | Numeric | list | ArrayLike | sc.Variable | None - Scattering vector values in 1/angstrom. + Scattering vector values. Raises ------ @@ -37,8 +121,8 @@ def _validate_and_convert_Q( Returns ------- - np.ndarray | None - Q as a np.ndarray or None if Q is None. + sc.Variable | None + Q as a sc.Variable with dimension 'Q' and unit 1/angstrom, or None if Q is None. """ if Q is None: return None @@ -59,7 +143,7 @@ def _validate_and_convert_Q( if Q.dims != ('Q',): raise ValueError("Q must have a single dimension named 'Q'.") Q = Q.to(unit='1/angstrom') - return Q.values + return Q def _validate_unit(unit: str | sc.Unit | None) -> sc.Unit | None: @@ -92,6 +176,30 @@ def _validate_unit(unit: str | sc.Unit | None) -> sc.Unit | None: return unit +def _assert_valid_unit(unit: str | sc.Unit) -> None: + """ + Assert that the given unit is recognised by scipp. + + Parameters + ---------- + unit : str | sc.Unit + The unit to validate. + + Raises + ------ + TypeError + If unit is not a string or scipp Unit. + ValueError + If the string is not a valid scipp unit. + """ + if not isinstance(unit, (str, sc.Unit)): + raise TypeError(f'unit must be a string or sc.Unit, got {type(unit).__name__}') + try: + sc.Unit(str(unit)) + except sc.UnitError as e: + raise ValueError(f"'{unit}' is not a valid scipp unit.") from e + + def _in_notebook() -> bool: """ Check if the code is running in a Jupyter notebook. diff --git a/tests/performance_tests/convolution/convolution_width_thresholds.ipynb b/tests/performance_tests/convolution/convolution_width_thresholds.ipynb index fbb5df296..543002d54 100644 --- a/tests/performance_tests/convolution/convolution_width_thresholds.ipynb +++ b/tests/performance_tests/convolution/convolution_width_thresholds.ipynb @@ -48,7 +48,7 @@ ")\n", "numerical_convolver.upsample_factor = None\n", "for gwidth in gaussian_widths:\n", - " sample_components.components[0].width = gwidth\n", + " sample_components[0].width = gwidth\n", " y_analytical = analytical_convolver.convolution()\n", "\n", " y_numerical = numerical_convolver.convolution()\n", @@ -106,8 +106,8 @@ ")\n", "numerical_convolver.upsample_factor = None\n", "for gwidth, gcenter in zip(gaussian_widths, gaussian_centers, strict=True):\n", - " sample_components.components[0].width = gwidth\n", - " sample_components.components[0].center = gcenter\n", + " sample_components[0].width = gwidth\n", + " sample_components[0].center = gcenter\n", " y_analytical = analytical_convolver.convolution()\n", "\n", " y_numerical = numerical_convolver.convolution()\n", @@ -135,7 +135,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "default", "language": "python", "name": "python3" }, @@ -149,7 +149,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.14.4" + "version": "3.14.5" } }, "nbformat": 4, diff --git a/tests/performance_tests/utils/detailed_balance_approximations.ipynb b/tests/performance_tests/utils/detailed_balance_approximations.ipynb index f50de3cf5..7b5a75fa0 100644 --- a/tests/performance_tests/utils/detailed_balance_approximations.ipynb +++ b/tests/performance_tests/utils/detailed_balance_approximations.ipynb @@ -42,8 +42,6 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "\n", "x = np.linspace(1e-10, 1e-5, 1000)\n", "\n", "y = x / (1 - np.exp(-x))\n", @@ -62,7 +60,7 @@ ], "metadata": { "kernelspec": { - "display_name": "newdynamics", + "display_name": "default", "language": "python", "name": "python3" }, @@ -76,7 +74,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.13" + "version": "3.14.5" } }, "nbformat": 4, diff --git a/tests/unit/easydynamics/analysis/test_analysis.py b/tests/unit/easydynamics/analysis/test_analysis.py index 516b718ef..2a67357ce 100644 --- a/tests/unit/easydynamics/analysis/test_analysis.py +++ b/tests/unit/easydynamics/analysis/test_analysis.py @@ -98,6 +98,24 @@ def test_analysis_list_contains_all_Q_indices(self, analysis): for i in range(3): assert analysis.analysis_list[i].Q_index == i + def test_analysis_list_shares_convolution_settings(self, analysis): + # WHEN THEN EXPECT: every Analysis1d holds the same ConvolutionSettings object, so + # user changes to analysis.convolution_settings reach all Q indices (regression: + # per-Q copies silently detached the settings) + for analysis1d in analysis.analysis_list: + assert analysis1d.convolution_settings is analysis.convolution_settings + + def test_convolution_settings_changes_reach_all_Q(self, analysis): + # WHEN: the analysis list has been built + assert len(analysis.analysis_list) == 3 + + # THEN: mutate the settings on the Analysis after the fact + analysis.convolution_settings.upsample_factor = 20 + + # EXPECT: the per-Q analyses see the new value + for analysis1d in analysis.analysis_list: + assert analysis1d.convolution_settings.upsample_factor == 20 + def test_analysis_list_setter_raises(self, analysis): # WHEN / THEN / EXPECT with pytest.raises( @@ -152,7 +170,7 @@ def test_calculate_with_invalid_Q_index(self, analysis): # WHEN / THEN / EXPECT with pytest.raises( IndexError, - match='must be a valid index', + match='Q_index 3 is out of bounds', ): analysis.calculate(Q_index=3) @@ -512,7 +530,7 @@ def test_parameters_to_dataset_different_units(self, analysis): ) # Convert the unit of a component to eV. - analysis.sample_model.get_component_collection(Q_index=1)[0].convert_unit('eV') + analysis.sample_model.get_component_collection(Q_index=1)[0].convert_x_unit('eV') # THEN parameters_dataset = analysis.parameters_to_dataset() @@ -533,12 +551,18 @@ def test_parameters_to_dataset_different_units(self, analysis): assert 'Q' in parameters_dataset[parameter_name].dims def test_parameters_to_dataset_raises_on_duplicate_names(self, analysis): - # Add a second Gaussian with the same parameter names as the first - analysis.sample_model.append_component( - Gaussian(name='GaussianName', display_name='Gaussian2', area=0.5) - ) + # Add a second Gaussian with the same parameter names as the first. Appending and the + # later per-Q copy both warn about the duplicate names; capture them so the test suite + # stays warning-clean. + with pytest.warns(UserWarning, match='Duplicate component names'): + analysis.sample_model.append_component( + Gaussian(name='GaussianName', display_name='Gaussian2', area=0.5) + ) - with pytest.raises(ValueError, match='Duplicate parameter names'): + with ( + pytest.warns(UserWarning, match='Duplicate component names'), + pytest.raises(ValueError, match='Duplicate parameter names'), + ): analysis.parameters_to_dataset() @pytest.mark.parametrize( @@ -732,15 +756,18 @@ def test_on_instrument_model_changed(self, analysis): def test_on_convolution_settings_changed(self, analysis): # WHEN - new_convolution_settings = ConvolutionSettings() + new_convolution_settings = ConvolutionSettings(upsample_factor=7, extension_factor=0.3) # THEN (this calls _on_convolution_settings_changed internally) analysis.convolution_settings = new_convolution_settings - # EXPECT + # EXPECT: the parent holds the new settings object assert analysis.convolution_settings is new_convolution_settings + # Each Analysis1d gets its own copy of the settings (so plan_is_valid is + # tracked independently per Q), but all values are propagated from the parent. for analysis1d in analysis.analysis_list: - assert analysis1d.convolution_settings is new_convolution_settings + assert analysis1d.convolution_settings.upsample_factor == 7 + assert analysis1d.convolution_settings.extension_factor == pytest.approx(0.3) def test_fit_single_Q_valid(self, analysis): # WHEN @@ -757,7 +784,7 @@ def test_fit_single_Q_invalid_Q_index(self, analysis): # WHEN / THEN / EXPECT with pytest.raises( IndexError, - match='must be a valid index', + match='Q_index 3 is out of bounds', ): analysis.fit(Q_index=3) @@ -824,6 +851,51 @@ def test_fit_all_Q_simultaneously(self, analysis): # And that the result from the fit method is returned assert result == fake_fit_result + def test_fit_all_Q_simultaneously_with_nan_data(self): + # Regression test: energy must be sliced via sc.array(dims=['energy'], values=mask), + # NOT via energy[numpy_bool_array]. The bug: numpy booleans are treated as integer + # indices by scipp (True→1, False→0), so energy[[True,False,True]] returns 3 elements + # with wrong values instead of filtering to the 2 finite points. + + # WHEN: data with a NaN at index 1; finite energies are -1.0 and 1.0 + Q = sc.array(dims=['Q'], values=[1.0], unit='1/Angstrom') + energy = sc.array(dims=['energy'], values=[-1.0, 0.0, 1.0], unit='meV') + values = np.array([[1.0, np.nan, 2.0]]) + variances = np.array([[0.1, 0.1, 0.1]]) + data = sc.array(dims=['Q', 'energy'], values=values, variances=variances) + data_array = sc.DataArray(data=data, coords={'Q': Q, 'energy': energy}) + experiment = Experiment(data=data_array) + + # Set up a sample model and analysis + sample_model = SampleModel(components=Gaussian(name='G')) + analysis = Analysis(experiment=experiment, sample_model=sample_model) + + captured_energy = [] + original_refresh = analysis.analysis_list[0].refresh_convolver + + # Patch the refresh_convolver method to capture the energy passed to it + def capture_refresh(energy, **kwargs): + captured_energy.append(energy) + return original_refresh(energy=energy, **kwargs) + + analysis.analysis_list[0].refresh_convolver = capture_refresh + analysis.get_fit_functions = MagicMock(return_value=['fit_fn']) + + fake_fitter_instance = MagicMock() + fake_fitter_instance.fit.return_value = object() + + # THEN + with patch( + 'easydynamics.analysis.analysis.MultiFitter', + return_value=fake_fitter_instance, + ): + analysis._fit_all_Q_simultaneously() + + # EXPECT: only the 2 finite energy points (-1.0, 1.0) were passed, not all 3 + assert len(captured_energy) == 1 + assert len(captured_energy[0]) == 2 + np.testing.assert_array_equal(captured_energy[0].values, [-1.0, 1.0]) + def test_get_fit_functions(self, analysis): # WHEN @@ -923,9 +995,9 @@ def test_create_components_dataset_single_Q(self, analysis_single_Q): assert components_dataset.sizes['Q'] == 1 assert components_dataset.coords['Q'].ndim == 1 - def test_ensure_analysis_list_current_clears_dirty_when_Q_is_none(self): - # An Analysis with no experiment has Q=None; _ensure_analysis_list_current should still - # clear the dirty flag without attempting to build the list. + def test_ensure_analysis_list_current_stays_dirty_when_Q_is_none(self): + # When Q is None, _ensure_analysis_list_current should NOT clear the dirty flag — + # it must stay dirty so the list is rebuilt as soon as Q becomes available. # WHEN analysis = Analysis(display_name='NoQ') @@ -935,8 +1007,8 @@ def test_ensure_analysis_list_current_clears_dirty_when_Q_is_none(self): # THEN result = analysis.analysis_list - # EXPECT - dirty flag cleared, list stays empty - assert analysis._analysis_list_is_dirty is False + # EXPECT - dirty flag preserved, list stays empty + assert analysis._analysis_list_is_dirty is True assert result == [] def test_rebin_marks_analysis_list_dirty(self, analysis): @@ -1039,3 +1111,9 @@ def test_direct_experiment_rebin_does_not_update_analysis_list(self, analysis): # EXPECT - analysis_list is NOT marked dirty (callers must use Analysis.rebin()) assert analysis._analysis_list_is_dirty is False + + def test_repr(self, analysis): + repr_str = repr(analysis) + assert 'Analysis' in repr_str + assert 'display_name=' in repr_str + assert 'n_analyses=' in repr_str diff --git a/tests/unit/easydynamics/analysis/test_analysis1d.py b/tests/unit/easydynamics/analysis/test_analysis1d.py index 5bd571634..a44c3ed4b 100644 --- a/tests/unit/easydynamics/analysis/test_analysis1d.py +++ b/tests/unit/easydynamics/analysis/test_analysis1d.py @@ -75,8 +75,8 @@ def test_Q_index_setter(self, analysis1d): @pytest.mark.parametrize( 'invalid_Q_index, expected_exception, expected_message', [ - (-1, IndexError, 'Q_index must be'), - (10, IndexError, 'Q_index must be'), + (-1, IndexError, 'Q_index -1 is out of bounds'), + (10, IndexError, 'Q_index 10 is out of bounds'), ('invalid', TypeError, 'Q_index must be '), (np.nan, TypeError, 'Q_index must be '), ([1, 2], TypeError, 'Q_index must be '), @@ -111,10 +111,10 @@ def test_calculate_updates_convolver_and_calls_calculate(self, analysis1d): result = analysis1d.calculate() # EXPECT - analysis1d._create_convolver.assert_called_once() - assert analysis1d._convolver is fake_convolver - analysis1d._calculate.assert_called_once() + # calculate() passes the convolver directly to _calculate without storing it on self + _, call_kwargs = analysis1d._calculate.call_args + assert call_kwargs['convolver'] is fake_convolver np.testing.assert_array_equal(result, expected_result) def test__calculate_adds_sample_and_background(self, analysis1d): @@ -150,7 +150,7 @@ def test_fit_calls_fitter_with_correct_arguments(self, analysis1d): fake_weights = np.array([0.1, 0.2, 0.3]) fake_mask = np.array([True, False, True]) - analysis1d.experiment._extract_x_y_weights_only_finite = MagicMock( + analysis1d.experiment.extract_x_y_weights_only_finite = MagicMock( return_value=(fake_x, fake_y, fake_weights, fake_mask) ) @@ -180,7 +180,7 @@ def test_fit_calls_fitter_with_correct_arguments(self, analysis1d): fit_function='fit_func', ) - analysis1d.experiment._extract_x_y_weights_only_finite.assert_called_once() + analysis1d.experiment.extract_x_y_weights_only_finite.assert_called_once() fake_fitter_instance.fit.assert_called_once_with( x=fake_x, @@ -508,15 +508,19 @@ def test_calculate_energy_with_offset_different_units(self, analysis1d): # WHEN energy = analysis1d.experiment.energy energy_offset = analysis1d.instrument_model.get_energy_offset(Q_index=analysis1d.Q_index) - energy_offset.value = 1.0 # override with a simple value for testing - energy_offset.convert_unit('eV') + energy_offset.value = 1.0 # set to 1.0 in original unit (meV) + energy_offset.convert_unit('eV') # now 0.001 eV, still represents 1 meV # THEN result = analysis1d._calculate_energy_with_offset(energy, energy_offset) - # EXPECT - expected = energy.values - energy_offset.value - np.testing.assert_array_equal(result.values, expected) + # EXPECT: offset must be converted to energy's unit before subtraction + offset_in_energy_unit = sc.to_unit( + sc.scalar(energy_offset.value, unit=str(energy_offset.unit)), + str(energy.unit), + ).value + expected = energy.values - offset_in_energy_unit + np.testing.assert_array_almost_equal(result.values, expected) def test_calculate_energy_with_offset_raises_if_incompatible_units(self, analysis1d): # WHEN @@ -524,9 +528,7 @@ def test_calculate_energy_with_offset_raises_if_incompatible_units(self, analysi energy_offset = Parameter(name='energy_offset', value=1.0, unit='m') # incompatible unit # THEN / EXPECT - with pytest.raises( - sc.UnitError, match='Energy and energy offset must have compatible units' - ): + with pytest.raises(sc.UnitError): analysis1d._calculate_energy_with_offset(energy, energy_offset) ############# @@ -920,7 +922,7 @@ def test_fit_marks_convolver_dirty_when_sample_model_components_change(self, ana 'easydynamics.analysis.analysis1d.EasyScienceFitter', return_value=MagicMock(fit=MagicMock(return_value=MagicMock())), ): - analysis1d.experiment._extract_x_y_weights_only_finite = MagicMock( + analysis1d.experiment.extract_x_y_weights_only_finite = MagicMock( return_value=( np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 3.0]), @@ -946,7 +948,7 @@ def test_fit_does_not_rebuild_convolver_when_nothing_changed(self, analysis1d): 'easydynamics.analysis.analysis1d.EasyScienceFitter', return_value=MagicMock(fit=MagicMock(return_value=MagicMock())), ): - analysis1d.experiment._extract_x_y_weights_only_finite = MagicMock( + analysis1d.experiment.extract_x_y_weights_only_finite = MagicMock( return_value=( np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 3.0]), @@ -1030,7 +1032,7 @@ def test_fit_marks_convolver_dirty_when_resolution_model_components_change(self, 'easydynamics.analysis.analysis1d.EasyScienceFitter', return_value=MagicMock(fit=MagicMock(return_value=MagicMock())), ): - analysis1d.experiment._extract_x_y_weights_only_finite = MagicMock( + analysis1d.experiment.extract_x_y_weights_only_finite = MagicMock( return_value=( np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 3.0]), @@ -1042,3 +1044,67 @@ def test_fit_marks_convolver_dirty_when_resolution_model_components_change(self, # EXPECT analysis1d._create_convolver.assert_called_once() + + # ───── Regression tests ───── + + @pytest.fixture + def analysis1d_with_nan(self): + """analysis1d fixture whose data contains a NaN at the second energy point.""" + Q = sc.array(dims=['Q'], values=[1.0], unit='1/Angstrom') + energy = sc.array(dims=['energy'], values=[10.0, 20.0, 30.0], unit='meV') + data = sc.array( + dims=['Q', 'energy'], + values=[[1.0, float('nan'), 3.0]], + variances=[[0.1, 0.2, 0.3]], + ) + data_array = sc.DataArray(data=data, coords={'Q': Q, 'energy': energy}) + experiment = Experiment(data=data_array) + sample_model = SampleModel(components=Gaussian()) + instrument_model = InstrumentModel() + return Analysis1d( + display_name='TestNaN', + experiment=experiment, + sample_model=sample_model, + instrument_model=instrument_model, + Q_index=0, + ) + + def test_create_residuals_array_with_nan_data_does_not_crash(self, analysis1d_with_nan): + # Before the fix, residuals subtracted a 2-point model from 3-point data + # (including NaN) which caused a dimension mismatch crash. + # WHEN + + # THEN + result = analysis1d_with_nan._create_residuals_array() + + # EXPECT + assert isinstance(result, sc.DataArray) + # Only the 2 finite energy points survive the mask. + assert result.sizes['energy'] == 2 + + def test_data_and_model_to_datagroup_with_nan_excludes_nan_from_data( + self, analysis1d_with_nan + ): + # Before the fix, 'Data' contained the full 3-point grid (including NaN) + # and computing Residuals crashed on the dimension mismatch. + # WHEN + energy = sc.array(dims=['energy'], values=[20.0, 30.0, 40.0], unit='meV') + + # THEN + datagroup = analysis1d_with_nan.data_and_model_to_datagroup( + energy=energy, include_residuals=True + ) + + # EXPECT + assert isinstance(datagroup, sc.DataGroup) + # 'Data' must contain only the 2 finite points. + assert datagroup['Data'].sizes['energy'] == 2 + # Residuals must be present and have matching size. + assert 'Residuals' in datagroup + assert datagroup['Residuals'].sizes['energy'] == 2 + + def test_repr(self, analysis1d): + repr_str = repr(analysis1d) + assert 'Analysis1d' in repr_str + assert 'display_name=' in repr_str + assert 'Q_index=' in repr_str diff --git a/tests/unit/easydynamics/analysis/test_analysis_base.py b/tests/unit/easydynamics/analysis/test_analysis_base.py index 7507b0b35..b68210c61 100644 --- a/tests/unit/easydynamics/analysis/test_analysis_base.py +++ b/tests/unit/easydynamics/analysis/test_analysis_base.py @@ -92,19 +92,28 @@ def test_init_detailed_balance_settings(self): assert analysis.detailed_balance_settings is detailed_balance_settings def test_init_extra_parameter(self): + # WHEN extra_parameter = Parameter(name='param1', value=1.0) + + # THEN analysis = AnalysisBase(extra_parameters=extra_parameter) + # EXPECT assert analysis._extra_parameters == [extra_parameter] def test_init_extra_parameters(self): + # WHEN extra_parameters = [ Parameter(name='param1', value=1.0), Parameter(name='param2', value=2.0), ] + + # THEN analysis = AnalysisBase(extra_parameters=extra_parameters) + # EXPECT assert analysis._extra_parameters == extra_parameters def test_init_calls_on_experiment_changed(self): + # WHEN THEN EXPECT with patch.object(AnalysisBase, '_on_experiment_changed') as mock_on_experiment_changed: AnalysisBase() mock_on_experiment_changed.assert_called_once() @@ -159,74 +168,99 @@ def test_init_calls_on_experiment_changed(self): ], ) def test_init_invalid_inputs(self, kwargs, expected_exception, expected_message): + # WHEN THEN EXPECT with pytest.raises(expected_exception, match=expected_message): AnalysisBase(**kwargs) def test_experiment_setter_calls_on_experiment_changed(self, analysis_base): + # WHEN THEN EXPECT with patch.object(analysis_base, '_on_experiment_changed') as mock_on_experiment_changed: new_experiment = Experiment() analysis_base.experiment = new_experiment mock_on_experiment_changed.assert_called_once() def test_experiment_setter_invalid_type(self, analysis_base): + # WHEN THEN EXPECT with pytest.raises(TypeError, match='experiment must be an instance of Experiment'): analysis_base.experiment = 'not an experiment' def test_experiment_setter_valid(self, analysis_base): + # WHEN new_experiment = Experiment() analysis_base.experiment = new_experiment + # THEN EXPECT assert analysis_base.experiment == new_experiment def test_sample_model_setter_invalid_type(self, analysis_base): + # WHEN THEN EXPECT with pytest.raises(TypeError, match='sample_model must be an instance of SampleModel'): analysis_base.sample_model = 'not a sample model' def test_sample_model_setter_valid(self, analysis_base): + # WHEN new_sample_model = SampleModel() + + # THEN analysis_base.sample_model = new_sample_model + # EXPECT assert analysis_base.sample_model == new_sample_model def test_sample_model_setter_calls_on_sample_model_changed(self, analysis_base): + # WHEN + new_sample_model = SampleModel() with patch.object( analysis_base, '_on_sample_model_changed' ) as mock_on_sample_model_changed: - new_sample_model = SampleModel() + # THEN analysis_base.sample_model = new_sample_model + + # EXPECT mock_on_sample_model_changed.assert_called_once() def test_instrument_model_setter_invalid_type(self, analysis_base): + # WHEN THEN EXPECT with pytest.raises( TypeError, match='instrument_model must be an instance of InstrumentModel' ): analysis_base.instrument_model = 'not an instrument model' def test_instrument_model_setter_valid(self, analysis_base): + # WHEN new_instrument_model = InstrumentModel() + + # THEN analysis_base.instrument_model = new_instrument_model + + # EXPECT assert analysis_base.instrument_model == new_instrument_model def test_instrument_model_setter_calls_on_instrument_model_changed(self, analysis_base): + # WHEN + new_instrument_model = InstrumentModel() with patch.object( analysis_base, '_on_instrument_model_changed' ) as mock_on_instrument_model_changed: - new_instrument_model = InstrumentModel() + # THEN analysis_base.instrument_model = new_instrument_model + + # EXPECT mock_on_instrument_model_changed.assert_called_once() def test_Q_property(self, analysis_base): - # Create a mock Q value + # WHEN fake_Q = [1, 2, 3] - - # Patch the 'experiment' attribute's Q property with patch.object( type(analysis_base.experiment), 'Q', new_callable=PropertyMock ) as mock_Q: mock_Q.return_value = fake_Q - result = analysis_base.Q # Access the property + # THEN + result = analysis_base.Q + # EXPECT assert result == fake_Q mock_Q.assert_called_once() def test_Q_setter_raises(self, analysis_base): + # WHEN THEN EXPECT with pytest.raises( AttributeError, match=r'Q is a read-only property derived from the Experiment.', @@ -234,19 +268,20 @@ def test_Q_setter_raises(self, analysis_base): analysis_base.Q = [1, 2, 3] def test_energy_property(self, analysis_base): - # Create a mock energy value + # WHEN fake_energy = [10, 20, 30] - - # Patch the 'experiment' attribute's energy property with patch.object( type(analysis_base.experiment), 'energy', new_callable=PropertyMock ) as mock_energy: mock_energy.return_value = fake_energy - result = analysis_base.energy # Access the property + # THEN + result = analysis_base.energy + # EXPECT assert result == fake_energy mock_energy.assert_called_once() def test_energy_setter_raises(self, analysis_base): + # WHEN THEN EXPECT with pytest.raises( AttributeError, match=r'energy is a read-only property derived from the Experiment.', @@ -254,30 +289,32 @@ def test_energy_setter_raises(self, analysis_base): analysis_base.energy = [10, 20, 30] def test_temperature_property_no_temperature(self, analysis_base): - # Patch the 'experiment' attribute's temperature property to - # return None + # WHEN with patch.object( type(analysis_base.sample_model), 'temperature', new_callable=PropertyMock ) as mock_temperature: mock_temperature.return_value = None - result = analysis_base.temperature # Access the property + # THEN + result = analysis_base.temperature + # EXPECT assert result is None mock_temperature.assert_called_once() def test_temperature_property(self, analysis_base): - # Create a mock temperature value + # WHEN fake_temperature = 300 - - # Patch the 'sample_model' attribute's temperature property with patch.object( type(analysis_base.sample_model), 'temperature', new_callable=PropertyMock ) as mock_temperature: mock_temperature.return_value = fake_temperature - result = analysis_base.temperature # Access the property + # THEN + result = analysis_base.temperature + # EXPECT assert result == fake_temperature mock_temperature.assert_called_once() def test_temperature_setter_raises(self, analysis_base): + # WHEN THEN EXPECT with pytest.raises( AttributeError, match='temperature is a read-only property', @@ -381,6 +418,7 @@ def test_extra_parameters_property(self, analysis_base, extra_parameters): ], ) def test_extra_parameters_setter_invalid_type(self, analysis_base, invalid_extra_parameters): + # WHEN THEN EXPECT with pytest.raises( TypeError, match='extra_parameters must be', @@ -392,6 +430,7 @@ def test_extra_parameters_setter_invalid_type(self, analysis_base, invalid_extra ############# def test_normalize_resolution_calls_instrument_model(self, analysis_base): + # WHEN THEN EXPECT with patch.object( analysis_base.instrument_model, 'normalize_resolution' ) as mock_normalize_resolution: @@ -458,12 +497,13 @@ def test_get_parameters_near_bounds_with_tolerances(self, analysis_base_with_com def test_get_parameters_near_bounds_errors( self, analysis_base_with_components, rtol, atol, expected_param, expected_error ): + # WHEN THEN with pytest.raises(expected_error) as exc: analysis_base_with_components.get_parameters_near_bounds( rtol=rtol, atol=atol, ) - + # EXPECT assert expected_param in str(exc.value) def test_not_finite_parameters(self, analysis_base_with_components): @@ -499,10 +539,9 @@ def test_on_experiment_changed_updates_Q(self, analysis_base): analysis_base._on_experiment_changed() # EXPECT - # assert that the Q attribute was set np.testing.assert_array_equal(analysis_base.Q, fake_Q) - np.testing.assert_array_equal(analysis_base.sample_model.Q, fake_Q) - np.testing.assert_array_equal(analysis_base.instrument_model.Q, fake_Q) + np.testing.assert_array_equal(analysis_base.sample_model.Q.values, fake_Q) + np.testing.assert_array_equal(analysis_base.instrument_model.Q.values, fake_Q) def test_on_sample_model_changed_updates_Q(self, analysis_base): # WHEN @@ -518,37 +557,45 @@ def test_on_sample_model_changed_updates_Q(self, analysis_base): analysis_base._on_sample_model_changed() # EXPECT - np.testing.assert_array_equal(analysis_base.sample_model.Q, fake_Q) + np.testing.assert_array_equal(analysis_base.sample_model.Q.values, fake_Q) def test_on_instrument_model_changed_updates_Q(self, analysis_base): + # WHEN fake_Q = [1, 2, 3] - - # Patch the Q property of analysis_base with patch.object( type(analysis_base.experiment), 'Q', new_callable=PropertyMock ) as mock_Q: mock_Q.return_value = fake_Q - + # THEN analysis_base._on_instrument_model_changed() - np.testing.assert_array_equal(analysis_base.instrument_model.Q, fake_Q) + # EXPECT + np.testing.assert_array_equal(analysis_base.instrument_model.Q.values, fake_Q) - def test_verify_Q_index_valid(self, analysis_base): + def test_verify_Q_index_valid(self, analysis_base_with_components): # WHEN valid_Q_index = 0 # THEN - result = analysis_base._verify_Q_index(valid_Q_index) + result = analysis_base_with_components._verify_Q_index(valid_Q_index) # EXPECT assert result == valid_Q_index - def test_verify_Q_index_invalid(self, analysis_base): + def test_verify_Q_index_invalid(self, analysis_base_with_components): # WHEN invalid_Q_index = -1 # THEN / EXPECT - with pytest.raises(IndexError, match='Q_index must be a valid index'): - analysis_base._verify_Q_index(invalid_Q_index) + with pytest.raises(IndexError, match='Q_index -1 is out of bounds for Q of length 2'): + analysis_base_with_components._verify_Q_index(invalid_Q_index) + + def test_verify_Q_index_invalid_when_Q_is_none(self, analysis_base): + # WHEN + positive_Q_index = 0 + + # THEN / EXPECT + with pytest.raises(ValueError, match='Q is None, cannot validate Q_index'): + analysis_base._verify_Q_index(positive_Q_index) def test_repr(self, analysis_base): # WHEN diff --git a/tests/unit/easydynamics/analysis/test_fit_binding.py b/tests/unit/easydynamics/analysis/test_fit_binding.py index 191d1ba1a..901ab9e50 100644 --- a/tests/unit/easydynamics/analysis/test_fit_binding.py +++ b/tests/unit/easydynamics/analysis/test_fit_binding.py @@ -2,288 +2,272 @@ # SPDX-License-Identifier: BSD-3-Clause -from unittest.mock import Mock - +import numpy as np import pytest from easydynamics.analysis.fit_binding import FitBinding from easydynamics.sample_model.components.gaussian import Gaussian +from easydynamics.sample_model.components.polynomial import Polynomial from easydynamics.sample_model.diffusion_model.brownian_translational_diffusion import ( BrownianTranslationalDiffusion, ) +from easydynamics.sample_model.diffusion_model.delta_lorentz import DeltaLorentz +from easydynamics.utils.fit_target import FitTarget class TestFitBinding: @pytest.fixture - def fit_binding(self): - model = Gaussian() - return FitBinding(parameter_name='parameter1', model=model) + def component_binding(self): + model = Gaussian(display_name='GaussianModel') + return FitBinding(model=model, targets='parameter1') @pytest.fixture - def diffusion_fit_binding(self): - model = BrownianTranslationalDiffusion() - return FitBinding(parameter_name='parameter3', model=model) - - def test_initialization(self, fit_binding): - # WHEN THEN EXPECT - assert isinstance(fit_binding, FitBinding) - assert fit_binding.parameter_name == 'parameter1' - assert isinstance(fit_binding.model, Gaussian) - assert fit_binding.modes is None - - @pytest.mark.parametrize( - 'parameter_name, model, modes, error_msg', - [ - # parameter_name errors - (123, Gaussian(), None, 'parameter_name must be a string'), - (None, Gaussian(), None, 'parameter_name must be a string'), - # model errors - ( - 'param', - 123, - None, - 'model must be a ModelComponent, ComponentCollection, or DiffusionModelBase', - ), - ( - 'param', - 'not_a_model', - None, - 'model must be a ModelComponent, ComponentCollection, or DiffusionModelBase', - ), - # modes type errors - ( - 'param', - Gaussian(), - 123, - 'modes must be a string, list of strings, or None', - ), - ( - 'param', - Gaussian(), - {'mode': 'area'}, - 'modes must be a string, list of strings, or None', - ), - # modes list content errors - ( - 'param', - Gaussian(), - ['area', 123], - 'All modes in the list must be strings', - ), - ('param', Gaussian(), [None], 'All modes in the list must be strings'), - ], - ) - def test_fitbinding_init_errors(self, parameter_name, model, modes, error_msg): - with pytest.raises(TypeError, match=error_msg): - FitBinding( - parameter_name=parameter_name, - model=model, - modes=modes, - ) + def diffusion_binding(self): + model = BrownianTranslationalDiffusion(lorentzian_name='Lorentzian') + return FitBinding(model=model) # ------------------------------------------------------------------ - # Properties + # Initialization and validation # ------------------------------------------------------------------ - def test_parameter_name_setter(self, fit_binding): - # WHEN - fit_binding.parameter_name = 'new_parameter' + def test_initialization(self, component_binding): + # WHEN THEN EXPECT + assert isinstance(component_binding, FitBinding) + assert component_binding.targets == 'parameter1' + assert isinstance(component_binding.model, Gaussian) - # THEN EXPECT - assert fit_binding.parameter_name == 'new_parameter' + def test_initialization_invalid_model_raises(self): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='model must be a ModelComponent'): + FitBinding(model='not_a_model', targets='parameter1') - def test_parameter_name_setter_errors(self, fit_binding): - with pytest.raises(TypeError, match='parameter_name must be a string'): - fit_binding.parameter_name = 123 + @pytest.mark.parametrize( + 'targets', + [None, ['parameter1'], {'value': 'parameter1'}, 123], + ids=['none', 'list', 'dict', 'int'], + ) + def test_component_model_requires_string_targets(self, targets): + # WHEN THEN EXPECT: component models have a single prediction, so targets must be + # the dataset key + with pytest.raises(TypeError, match='targets must be the dataset key'): + FitBinding(model=Gaussian(), targets=targets) - def test_model_setter(self, fit_binding): - # WHEN + def test_diffusion_unknown_prediction_raises(self): + # WHEN THEN EXPECT model = BrownianTranslationalDiffusion() + with pytest.raises(ValueError, match=r'Unknown prediction.*Available predictions'): + FitBinding(model=model, targets=['nonsense']) - # THEN - fit_binding.model = model - - # EXPECT - assert fit_binding.model is model + def test_diffusion_unknown_prediction_in_dict_raises(self): + # WHEN THEN EXPECT + model = BrownianTranslationalDiffusion() + with pytest.raises(ValueError, match='Unknown prediction'): + FitBinding(model=model, targets={'delta_area': 'Elastic area'}) - def test_model_setter_errors(self, fit_binding): - with pytest.raises( - TypeError, - match='model must be a ModelComponent, ComponentCollection, or DiffusionModelBase', - ): - fit_binding.model = 'not_a_model' + def test_diffusion_invalid_targets_type_raises(self): + # WHEN THEN EXPECT + model = BrownianTranslationalDiffusion() + with pytest.raises(TypeError, match='targets must be None'): + FitBinding(model=model, targets=123) - def test_modes_setter(self, fit_binding): - # WHEN - fit_binding.modes = 'area' + def test_diffusion_non_string_dataset_key_raises(self): + # WHEN THEN EXPECT + model = BrownianTranslationalDiffusion() + with pytest.raises(TypeError, match='dataset keys'): + FitBinding(model=model, targets={'width': 123}) - # THEN EXPECT - assert fit_binding.modes == ['area'] + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ - # WHEN - fit_binding.modes = ['area', 'width'] + def test_model_setter_revalidates_targets(self): + # WHEN: a binding using DeltaLorentz's delta_area prediction + binding = FitBinding(model=DeltaLorentz(), targets=['delta_area']) - # THEN EXPECT - assert fit_binding.modes == ['area', 'width'] + # THEN EXPECT: switching to a model without that prediction fails + with pytest.raises(ValueError, match='Unknown prediction'): + binding.model = BrownianTranslationalDiffusion() - def test_modes_setter_errors(self, fit_binding): - with pytest.raises(TypeError, match='modes must be a string, list of strings, or None'): - fit_binding.modes = 123 + def test_model_setter_invalid_type_raises(self, component_binding): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='model must be a ModelComponent'): + component_binding.model = 'not_a_model' - with pytest.raises(TypeError, match='modes must be a string, list of strings, or None'): - fit_binding.modes = {'mode': 'area'} + def test_targets_setter(self, diffusion_binding): + # WHEN THEN + diffusion_binding.targets = ['width'] - with pytest.raises(TypeError, match='All modes in the list must be strings'): - fit_binding.modes = ['area', 123] + # EXPECT + assert [t.name for t in diffusion_binding.get_targets()] == ['width'] - with pytest.raises(TypeError, match='All modes in the list must be strings'): - fit_binding.modes = [None] + def test_targets_setter_invalid_raises(self, diffusion_binding): + # WHEN THEN EXPECT + with pytest.raises(ValueError, match='Unknown prediction'): + diffusion_binding.targets = ['nonsense'] # ------------------------------------------------------------------ - # Other methods + # get_targets # ------------------------------------------------------------------ - def test_build_callables_component(self, fit_binding): + def test_component_target(self, component_binding): # WHEN - mock_model = Mock() - mock_model.evaluate.return_value = 1.0 - fit_binding._model = mock_model - - # THEN - callables = fit_binding.build_callables() + targets = component_binding.get_targets() + + # THEN EXPECT: a single target evaluating the component, with its units + assert len(targets) == 1 + target = targets[0] + assert isinstance(target, FitTarget) + assert target.name == 'value' + assert target.dataset_key == 'parameter1' + assert target.label == 'GaussianModel' + assert target.x_unit == component_binding.model.x_unit + assert target.y_unit == component_binding.model.y_unit + + def test_component_target_function_evaluates_model(self, component_binding): + # WHEN + target = component_binding.get_targets()[0] + x = np.linspace(-1, 1, 5) - # EXPECT - assert len(callables) == 1 - assert callable(callables[0]) - assert callables[0](0) == pytest.approx(1.0) - mock_model.evaluate.assert_called_once_with(0) + # THEN EXPECT + np.testing.assert_allclose(target.function(x), component_binding.model.evaluate(x)) - def test_build_callables_diffusion(self, diffusion_fit_binding): + def test_diffusion_default_targets(self, diffusion_binding): # WHEN - mock_model = Mock(spec=BrownianTranslationalDiffusion) - mock_model.calculate_QISF.return_value = 2.0 - mock_model.scale.value = 3.0 - mock_model.calculate_width.return_value = 0.5 - diffusion_fit_binding._model = mock_model + targets = diffusion_binding.get_targets() - # THEN - callables = diffusion_fit_binding.build_callables() + # THEN EXPECT: all predictions with default dataset keys from the Lorentzian name + assert [t.name for t in targets] == ['area', 'width'] + assert [t.dataset_key for t in targets] == ['Lorentzian area', 'Lorentzian width'] - # EXPECT - assert len(callables) == 2 - assert callable(callables[0]) - assert callable(callables[1]) - assert callables[0](0) == pytest.approx(6.0) # 2.0 * 3.0 - assert callables[1](0) == pytest.approx(0.5) - mock_model.calculate_QISF.assert_called_once_with(0) - mock_model.calculate_width.assert_called_once_with(0) - - def test_build_callables_diffusion_with_modes(self, diffusion_fit_binding): + def test_diffusion_string_target(self): # WHEN - diffusion_fit_binding.modes = 'area' - mock_model = Mock(spec=BrownianTranslationalDiffusion) - mock_model.calculate_QISF.return_value = 2.0 - mock_model.scale.value = 3.0 - diffusion_fit_binding._model = mock_model + model = BrownianTranslationalDiffusion(lorentzian_name='QENS') + binding = FitBinding(model=model, targets='width') # THEN - callables = diffusion_fit_binding.build_callables() + targets = binding.get_targets() # EXPECT - assert len(callables) == 1 - assert callable(callables[0]) - assert callables[0](0) == pytest.approx(6.0) # 2.0 * 3.0 - mock_model.calculate_QISF.assert_called_once_with(0) + assert len(targets) == 1 + assert targets[0].name == 'width' + assert targets[0].dataset_key == 'QENS width' - def test_get_model_names(self, fit_binding): - # WHEN THEN - model_names = fit_binding.get_model_names() - - # EXPECT - assert model_names == ['Gaussian'] + def test_diffusion_list_targets(self): + # WHEN + model = BrownianTranslationalDiffusion(lorentzian_name='QENS') + binding = FitBinding(model=model, targets=['width', 'area']) + + # THEN EXPECT: order follows the user's list + assert [t.name for t in binding.get_targets()] == ['width', 'area'] + + def test_diffusion_dict_targets_map_dataset_keys(self): + # WHEN: mapping DeltaLorentz predictions to custom dataset keys + model = DeltaLorentz() + binding = FitBinding( + model=model, + targets={ + 'width': 'L width', + 'area': 'L area', + 'delta_area': 'Elastic area', + }, + ) - def test_get_model_names_diffusion(self, diffusion_fit_binding): - # WHEN THEN - model_names = diffusion_fit_binding.get_model_names() + # THEN + targets = {t.name: t for t in binding.get_targets()} # EXPECT - assert model_names == [ - 'BrownianTranslationalDiffusion area', - 'BrownianTranslationalDiffusion width', - ] + assert targets['width'].dataset_key == 'L width' + assert targets['area'].dataset_key == 'L area' + assert targets['delta_area'].dataset_key == 'Elastic area' - def test_get_parameter_names(self, fit_binding): - # WHEN THEN - parameter_names = fit_binding.get_parameter_names() + def test_delta_lorentz_default_targets(self): + # WHEN: default keys derive from the component names + model = DeltaLorentz(lorentzian_name='Lorentzian', delta_name='Delta function') + binding = FitBinding(model=model) - # EXPECT - assert parameter_names == ['parameter1'] - - def test_get_parameter_names_diffusion(self, diffusion_fit_binding): - # WHEN THEN - parameter_names = diffusion_fit_binding.get_parameter_names() + # THEN + targets = {t.name: t for t in binding.get_targets()} # EXPECT - assert parameter_names == ['parameter3 area', 'parameter3 width'] + assert set(targets) == {'area', 'width', 'delta_area'} + assert targets['area'].dataset_key == 'Lorentzian area' + assert targets['width'].dataset_key == 'Lorentzian width' + assert targets['delta_area'].dataset_key == 'Delta function area' - # ------------------------------------------------------------------ - # Private methods - # ------------------------------------------------------------------ + def test_diffusion_target_units(self, diffusion_binding): + # WHEN + targets = {t.name: t for t in diffusion_binding.get_targets()} - def test_build_diffusion_callable(self, diffusion_fit_binding): + # THEN EXPECT: predictions take Q in 1/angstrom; widths are in the model's x_unit and + # areas in the scale unit + assert targets['width'].x_unit == '1/angstrom' + assert targets['width'].y_unit == diffusion_binding.model.x_unit + assert targets['area'].y_unit == str(diffusion_binding.model.scale.unit) - # WHEN - mock_model = Mock() - mock_model.calculate_QISF.return_value = 2.0 - mock_model.scale.value = 3.0 - mock_model.calculate_width.return_value = 0.5 - diffusion_fit_binding._model = mock_model + def test_diffusion_target_units_track_conversions(self, diffusion_binding): + # WHEN: targets are snapshots of the live model, so unit conversions are reflected + diffusion_binding.model.convert_x_unit('ueV') # THEN - area_callable = diffusion_fit_binding._build_diffusion_callable(mode='area') - width_callable = diffusion_fit_binding._build_diffusion_callable(mode='width') + targets = {t.name: t for t in diffusion_binding.get_targets()} # EXPECT - assert area_callable(0) == pytest.approx(6.0) # 2.0 * 3.0 - mock_model.calculate_QISF.assert_called_once_with(0) + assert targets['width'].y_unit == 'ueV' - assert width_callable(0) == pytest.approx(0.5) - mock_model.calculate_width.assert_called_once_with(0) + def test_diffusion_target_functions(self): + # WHEN + model = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, scale=0.5) + binding = FitBinding(model=model) + Q = np.array([0.5, 1.0]) # THEN - result_area = area_callable(0, unused_arg=123) - result_width = width_callable(0, unused_arg=123) + targets = {t.name: t for t in binding.get_targets()} # EXPECT - assert result_area == pytest.approx(6.0) # Should ignore unused_arg - assert result_width == pytest.approx(0.5) # Should ignore unused_arg - - def test_build_diffusion_callable_errors(self, diffusion_fit_binding): - with pytest.raises(ValueError, match='Unknown diffusion mode: invalid_mode'): - diffusion_fit_binding._build_diffusion_callable(mode='invalid_mode') + np.testing.assert_allclose(targets['width'].function(Q), model.calculate_width(Q)) + np.testing.assert_allclose( + targets['area'].function(Q), model.calculate_QISF(Q) * model.scale.value + ) - def test_get_modes(self, diffusion_fit_binding): + def test_delta_lorentz_delta_area_function(self): # WHEN - modes = diffusion_fit_binding._get_modes() - - # EXPECT - assert modes == ['area', 'width'] + model = DeltaLorentz(A_0=0.5, mean_u_squared=0.3, scale=2.0, lorentzian_width=0.1) + binding = FitBinding(model=model, targets=['delta_area']) + Q = np.array([0.5, 1.0]) # THEN - diffusion_fit_binding.modes = 'area' - modes = diffusion_fit_binding._get_modes() + target = binding.get_targets()[0] + # EXPECT - assert modes == ['area'] + np.testing.assert_allclose(target.function(Q), model.calculate_EISF(Q) * model.scale.value) # ------------------------------------------------------------------ # dunder methods # ------------------------------------------------------------------ - def test_repr(self, fit_binding): - # WHEN - repr_str = repr(fit_binding) - # THEN EXPECT + def test_repr(self, diffusion_binding): + # WHEN THEN + repr_str = repr(diffusion_binding) + + # EXPECT assert 'FitBinding' in repr_str - assert "parameter_name='parameter1'" in repr_str - assert "model='Gaussian'" in repr_str - assert 'modes=None' in repr_str + assert 'model=' in repr_str + assert 'targets=' in repr_str + + +class TestFitBindingWorkflows: + """End-to-end regression tests for the standard ParameterAnalysis workflows.""" + + def test_polynomial_targets_gaussian_area(self): + # WHEN: fitting a Polynomial to a 'Gaussian area' dataset key + polynomial = Polynomial( + coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV', display_name='Line' + ) + binding = FitBinding(model=polynomial, targets='Gaussian area') + + # THEN + target = binding.get_targets()[0] + + # EXPECT + assert target.dataset_key == 'Gaussian area' + assert target.label == 'Line' diff --git a/tests/unit/easydynamics/analysis/test_parameter_analysis.py b/tests/unit/easydynamics/analysis/test_parameter_analysis.py index 14d9c58c6..031f813cf 100644 --- a/tests/unit/easydynamics/analysis/test_parameter_analysis.py +++ b/tests/unit/easydynamics/analysis/test_parameter_analysis.py @@ -17,12 +17,25 @@ from easydynamics.sample_model.diffusion_model.brownian_translational_diffusion import ( BrownianTranslationalDiffusion, ) +from easydynamics.utils.fit_target import FitTarget + + +def make_target(dataset_key, function, label, x_unit=None, y_unit=None, name='value'): + """Build a FitTarget for mocking FitBinding.get_targets in tests.""" + return FitTarget( + name=name, + dataset_key=dataset_key, + function=function, + label=label, + x_unit=x_unit, + y_unit=y_unit, + ) class TestParameterAnalysis: @pytest.fixture def dataset(self): - Q = sc.array(dims=['Q'], values=[0.1, 0.2]) + Q = sc.array(dims=['Q'], values=[0.1, 0.2], unit='1/angstrom') return sc.Dataset( data={ 'parameter1': sc.DataArray( @@ -63,11 +76,14 @@ def mock_model_dataset(self): @pytest.fixture def parameter_analysis(self, dataset): - model = Polynomial(coefficients=[1.0, 0.5]) + model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV') diffusion_model = BrownianTranslationalDiffusion() - fit_binding1 = FitBinding(parameter_name='parameter1', model=model) - fit_binding2 = FitBinding(parameter_name='parameter3', model=diffusion_model) + fit_binding1 = FitBinding(model=model, targets='parameter1') + fit_binding2 = FitBinding( + model=diffusion_model, + targets={'area': 'parameter3 area', 'width': 'parameter3 width'}, + ) return ParameterAnalysis(parameters=dataset, bindings=[fit_binding1, fit_binding2]) @@ -75,8 +91,11 @@ def test_initialization(self, parameter_analysis): # WHEN THEN EXPECT assert isinstance(parameter_analysis, ParameterAnalysis) assert len(parameter_analysis.bindings) == 2 - assert parameter_analysis.bindings[0].parameter_name == 'parameter1' - assert parameter_analysis.bindings[1].parameter_name == 'parameter3' + assert parameter_analysis.bindings[0].targets == 'parameter1' + assert parameter_analysis.bindings[1].targets == { + 'area': 'parameter3 area', + 'width': 'parameter3 width', + } def test_parameter_property(self, parameter_analysis): # WHEN @@ -133,7 +152,7 @@ def test_bindings_property(self, parameter_analysis): # WHEN model = Polynomial(coefficients=[2.0, 1.0]) - new_binding = FitBinding(parameter_name='parameter2', model=model) + new_binding = FitBinding(model=model, targets='parameter2') parameter_analysis.bindings = new_binding # THEN EXPECT @@ -162,7 +181,7 @@ def test_fit_no_parameters_raises(self, parameter_analysis): def test_fit_wrong_parameter_name_raises(self, parameter_analysis): # WHEN model = Polynomial(coefficients=[2.0, 1.0]) - incorrect_binding = FitBinding(parameter_name='nonexistent_parameter', model=model) + incorrect_binding = FitBinding(model=model, targets='nonexistent_parameter') parameter_analysis.bindings = incorrect_binding # THEN EXPECT @@ -172,6 +191,87 @@ def test_fit_wrong_parameter_name_raises(self, parameter_analysis): ): parameter_analysis.fit() + def test_fit_incompatible_x_unit_raises(self, dataset): + # WHEN: Polynomial has x_unit='meV' but Q coordinate has unit '1/angstrom' + model = Polynomial(coefficients=[1.0, 0.5], x_unit='meV', y_unit='meV') + binding = FitBinding(model=model, targets='parameter1') + pa = ParameterAnalysis(parameters=dataset, bindings=binding) + + # THEN EXPECT + with pytest.raises(Exception, match=r'Q coordinate unit .* is incompatible'): + pa.fit() + + def test_fit_incompatible_y_unit_raises(self, dataset): + # WHEN: Polynomial has y_unit='1/angstrom' but parameter1 has unit 'meV' + model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='1/angstrom') + binding = FitBinding(model=model, targets='parameter1') + pa = ParameterAnalysis(parameters=dataset, bindings=binding) + + # THEN EXPECT + with pytest.raises(Exception, match="Parameter 'parameter1' unit 'meV' is incompatible"): + pa.fit() + + def test_fit_converts_x_unit(self): + # WHEN: Q is in '1/m' but the model declares x_unit='1/angstrom'. + # 1 1/m = 1e-10 1/angstrom, so values [1e10, 2e10] 1/m → [1.0, 2.0] 1/angstrom. + Q = sc.array(dims=['Q'], values=[1e10, 2e10], unit='1/m') + dataset = sc.Dataset( + data={ + 'param': sc.DataArray( + data=sc.array(dims=['Q'], values=[1.0, 2.0], variances=[0.1, 0.2], unit='meV'), + coords={'Q': Q}, + ) + } + ) + model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV') + pa = ParameterAnalysis( + parameters=dataset, bindings=FitBinding(model=model, targets='param') + ) + + with patch('easydynamics.analysis.parameter_analysis.MultiFitter') as MockMultiFitter: + MockMultiFitter.return_value.fit.return_value = MagicMock() + # THEN + pa.fit() + x_passed = MockMultiFitter.return_value.fit.call_args.kwargs['x'][0] + + # EXPECT: x values converted from 1/m to 1/angstrom + np.testing.assert_allclose(x_passed, [1.0, 2.0]) + + def test_fit_converts_y_unit(self): + # WHEN: parameter is in 'eV' but model declares y_unit='meV'. + # 1 eV = 1000 meV, so values [0.001, 0.002] eV → [1.0, 2.0] meV. + # Weights (= 1/std) scale by 1/y_factor: sqrt(1e-6)=1e-3 eV^-1 → 1.0 meV^-1. + Q = sc.array(dims=['Q'], values=[0.1, 0.2], unit='1/angstrom') + dataset = sc.Dataset( + data={ + 'param': sc.DataArray( + data=sc.array( + dims=['Q'], + values=[0.001, 0.002], + variances=[1e-6, 4e-6], + unit='eV', + ), + coords={'Q': Q}, + ) + } + ) + model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV') + pa = ParameterAnalysis( + parameters=dataset, bindings=FitBinding(model=model, targets='param') + ) + + with patch('easydynamics.analysis.parameter_analysis.MultiFitter') as MockMultiFitter: + MockMultiFitter.return_value.fit.return_value = MagicMock() + # THEN + pa.fit() + kwargs = MockMultiFitter.return_value.fit.call_args.kwargs + y_passed = kwargs['y'][0] + w_passed = kwargs['weights'][0] + + # EXPECT: y values converted from eV to meV; weights scaled inversely + np.testing.assert_allclose(y_passed, [1.0, 2.0]) + np.testing.assert_allclose(w_passed, [1.0, 0.5]) + def test_fit_success(self, parameter_analysis): # WHEN mock_result = MagicMock() @@ -469,9 +569,11 @@ def test_calculate_model_dataset_missing_parameter_raises(self, parameter_analys binding = parameter_analysis.bindings[0] with ( - patch.object(binding, 'get_parameter_names', return_value=['missing_name']), - patch.object(binding, 'get_model_names', return_value=['model']), - patch.object(binding, 'build_callables', return_value=[lambda x: x]), + patch.object( + binding, + 'get_targets', + return_value=[make_target('missing_name', lambda x: x, 'model')], + ), pytest.raises(ValueError, match='not found in parameters Dataset'), ): parameter_analysis.calculate_model_dataset([binding]) @@ -483,10 +585,10 @@ def test_calculate_model_dataset_single_binding(self, parameter_analysis): # Mock callable to return predictable output mock_callable = MagicMock(return_value=np.array([10.0, 20.0])) - with ( - patch.object(binding, 'build_callables', return_value=[mock_callable]), - patch.object(binding, 'get_model_names', return_value=['model1']), - patch.object(binding, 'get_parameter_names', return_value=['parameter1']), + with patch.object( + binding, + 'get_targets', + return_value=[make_target('parameter1', mock_callable, 'model1')], ): # THEN result = parameter_analysis.calculate_model_dataset([binding]) @@ -516,12 +618,16 @@ def test_calculate_model_dataset_multiple_bindings(self, parameter_analysis): mock_callable2 = MagicMock(return_value=np.array([30.0, 40.0])) with ( - patch.object(binding1, 'build_callables', return_value=[mock_callable1]), - patch.object(binding1, 'get_model_names', return_value=['model1']), - patch.object(binding1, 'get_parameter_names', return_value=['parameter1']), - patch.object(binding2, 'build_callables', return_value=[mock_callable2]), - patch.object(binding2, 'get_model_names', return_value=['model2']), - patch.object(binding2, 'get_parameter_names', return_value=['parameter2']), + patch.object( + binding1, + 'get_targets', + return_value=[make_target('parameter1', mock_callable1, 'model1')], + ), + patch.object( + binding2, + 'get_targets', + return_value=[make_target('parameter2', mock_callable2, 'model2')], + ), ): # THEN result = parameter_analysis.calculate_model_dataset([binding1, binding2]) @@ -562,19 +668,18 @@ def test_calculate_model_dataset_multiple_bindings_diffusion(self, parameter_ana mock_callable2w = MagicMock(return_value=np.array([50.0, 60.0])) with ( - patch.object(binding1, 'build_callables', return_value=[mock_callable1]), - patch.object(binding1, 'get_model_names', return_value=['model1']), - patch.object(binding1, 'get_parameter_names', return_value=['parameter1']), patch.object( - binding2, - 'build_callables', - return_value=[mock_callable2a, mock_callable2w], + binding1, + 'get_targets', + return_value=[make_target('parameter1', mock_callable1, 'model1')], ), - patch.object(binding2, 'get_model_names', return_value=['model2a', 'model2w']), patch.object( binding2, - 'get_parameter_names', - return_value=['parameter3 area', 'parameter3 width'], + 'get_targets', + return_value=[ + make_target('parameter3 area', mock_callable2a, 'model2a', name='area'), + make_target('parameter3 width', mock_callable2w, 'model2w', name='width'), + ], ), ): # THEN @@ -614,10 +719,76 @@ def test_calculate_model_dataset_multiple_bindings_diffusion(self, parameter_ana np.testing.assert_allclose(args[0], [0.1, 0.2]) assert kwargs == {} + def test_calculate_model_dataset_converts_x_unit(self): + # WHEN: Q is in '1/m' but model declares x_unit='1/angstrom'. + # Values [1e10, 2e10] 1/m → [1.0, 2.0] 1/angstrom passed to callable. + Q = sc.array(dims=['Q'], values=[1e10, 2e10], unit='1/m') + dataset = sc.Dataset( + data={ + 'param': sc.DataArray( + data=sc.array(dims=['Q'], values=[1.0, 2.0], unit='meV'), + coords={'Q': Q}, + ) + } + ) + model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV') + binding = FitBinding(model=model, targets='param') + pa = ParameterAnalysis(parameters=dataset, bindings=binding) + + mock_callable = MagicMock(return_value=np.array([10.0, 20.0])) + # THEN + with patch.object( + binding, + 'get_targets', + return_value=[ + make_target( + 'param', mock_callable, 'Polynomial', x_unit='1/angstrom', y_unit='meV' + ) + ], + ): + pa.calculate_model_dataset([binding]) + + # EXPECT: callable received x in 1/angstrom, not raw 1/m values + args, _ = mock_callable.call_args + np.testing.assert_allclose(args[0], [1.0, 2.0]) + + def test_calculate_model_dataset_converts_y_unit(self): + # WHEN: parameter is in 'eV' but model declares y_unit='meV'. + # Callable returns [1.0, 2.0] meV → stored as [0.001, 0.002] eV in the DataArray. + Q = sc.array(dims=['Q'], values=[0.1, 0.2], unit='1/angstrom') + dataset = sc.Dataset( + data={ + 'param': sc.DataArray( + data=sc.array(dims=['Q'], values=[0.001, 0.002], unit='eV'), + coords={'Q': Q}, + ) + } + ) + model = Polynomial(coefficients=[1.0, 0.5], x_unit='1/angstrom', y_unit='meV') + binding = FitBinding(model=model, targets='param') + pa = ParameterAnalysis(parameters=dataset, bindings=binding) + + mock_callable = MagicMock(return_value=np.array([1.0, 2.0])) # meV + # THEN + with patch.object( + binding, + 'get_targets', + return_value=[ + make_target( + 'param', mock_callable, 'Polynomial', x_unit='1/angstrom', y_unit='meV' + ) + ], + ): + result = pa.calculate_model_dataset([binding]) + + # EXPECT: model output converted from meV back to eV + np.testing.assert_allclose(result['Polynomial'].values, [0.001, 0.002]) + assert result['Polynomial'].unit == sc.Unit('eV') + def test_append_binding(self, parameter_analysis): # WHEN model = Polynomial(coefficients=[2.0, 1.0]) - new_binding = FitBinding(parameter_name='parameter2', model=model) + new_binding = FitBinding(model=model, targets='parameter2') # THEN parameter_analysis.append_binding(new_binding) @@ -666,8 +837,8 @@ def test_get_all_variables_overlapping_variables(self, parameter_analysis): # Create two bindings with overlapping variables model1 = Gaussian(display_name='model1') - binding1 = FitBinding(parameter_name='parameter1', model=model1) - binding2 = FitBinding(parameter_name='parameter2', model=model1) + binding1 = FitBinding(model=model1, targets='parameter1') + binding2 = FitBinding(model=model1, targets='parameter2') parameter_analysis.bindings = [binding1, binding2] @@ -832,14 +1003,14 @@ def test_get_xyweight_from_dataset_wrong_parameter_name_raises(self, parameter_a parameter_analysis._get_xyweight_from_dataset('nonexistent_parameter') @pytest.mark.parametrize( - 'non_finite_variance', - [np.inf, -np.inf, np.nan, -1.0, 0.0], - ids=['inf', '-inf', 'nan', 'negative', 'zero'], + 'bad_variance', + [np.inf, -np.inf, -1.0, 0.0], + ids=['inf', '-inf', 'negative', 'zero'], ) def test_get_xyweight_from_dataset_non_finite_weights_raises( - self, parameter_analysis, non_finite_variance + self, parameter_analysis, bad_variance ): - # WHEN + # Non-NaN invalid variances (inf, negative, zero) should still raise. Q = sc.array(dims=['Q'], values=[0.1, 0.2]) parameter_analysis.parameters = sc.Dataset( data={ @@ -847,7 +1018,7 @@ def test_get_xyweight_from_dataset_non_finite_weights_raises( data=sc.array( dims=['Q'], values=[1.0, 2.0], - variances=[1.0, non_finite_variance], + variances=[1.0, bad_variance], unit='meV', ), coords={'Q': Q}, @@ -861,6 +1032,33 @@ def test_get_xyweight_from_dataset_non_finite_weights_raises( ): parameter_analysis._get_xyweight_from_dataset('parameter1') + def test_get_xyweight_from_dataset_nan_variance_filters_row(self, parameter_analysis): + # NaN variances arise when a parameter is absent for a given Q; those rows are filtered. + + # WHEN + Q = sc.array(dims=['Q'], values=[0.1, 0.2]) + parameter_analysis.parameters = sc.Dataset( + data={ + 'parameter1': sc.DataArray( + data=sc.array( + dims=['Q'], + values=[1.0, np.nan], + variances=[0.25, np.nan], + unit='meV', + ), + coords={'Q': Q}, + ), + } + ) + + # THEN + x, y, w = parameter_analysis._get_xyweight_from_dataset('parameter1') + + # EXPECT + np.testing.assert_allclose(x, [0.1]) + np.testing.assert_allclose(y, [1.0]) + np.testing.assert_allclose(w, [1 / np.sqrt(0.25)]) + def test_get_xyweight_from_dataset_valid(self, parameter_analysis): # WHEN THEN x, y, w = parameter_analysis._get_xyweight_from_dataset('parameter1') @@ -871,6 +1069,74 @@ def test_get_xyweight_from_dataset_valid(self, parameter_analysis): expected_w = 1 / np.sqrt([0.1, 0.2]) np.testing.assert_allclose(w, expected_w) + def test_get_xyweight_from_dataset_no_variances(self, parameter_analysis): + # WHEN + Q = sc.array(dims=['Q'], values=[0.1, 0.2], unit='1/angstrom') + parameter_analysis.parameters = sc.Dataset( + data={ + 'parameter1': sc.DataArray( + data=sc.array(dims=['Q'], values=[1.0, 2.0], unit='meV'), + coords={'Q': Q}, + ), + } + ) + + # THEN + x, y, w = parameter_analysis._get_xyweight_from_dataset('parameter1') + + # EXPECT + np.testing.assert_allclose(x, [0.1, 0.2]) + np.testing.assert_allclose(y, [1.0, 2.0]) + np.testing.assert_allclose(w, [1.0, 1.0]) + + def test_get_xyweight_from_dataset_all_nan_variances_raises(self, parameter_analysis): + # WHEN + Q = sc.array(dims=['Q'], values=[0.1, 0.2], unit='1/angstrom') + parameter_analysis.parameters = sc.Dataset( + data={ + 'parameter1': sc.DataArray( + data=sc.array( + dims=['Q'], + values=[np.nan, np.nan], + variances=[np.nan, np.nan], + unit='meV', + ), + coords={'Q': Q}, + ), + } + ) + + # THEN EXPECT + with pytest.raises( + ValueError, + match="No finite positive variances found for parameter 'parameter1'", + ): + parameter_analysis._get_xyweight_from_dataset('parameter1') + + def test_get_unit_conversions_none_x_unit(self, parameter_analysis): + # WHEN + model = Polynomial(coefficients=[1.0], x_unit=None, y_unit='meV') + binding = FitBinding(model=model, targets='parameter1') + + # THEN + x_factor, y_factor = parameter_analysis._get_unit_conversions(binding.get_targets()[0]) + + # EXPECT + assert x_factor == pytest.approx(1.0) + assert y_factor == pytest.approx(1.0) + + def test_get_unit_conversions_none_y_unit(self, parameter_analysis): + # WHEN + model = Polynomial(coefficients=[1.0], x_unit='1/angstrom', y_unit=None) + binding = FitBinding(model=model, targets='parameter1') + + # THEN + x_factor, y_factor = parameter_analysis._get_unit_conversions(binding.get_targets()[0]) + + # EXPECT + assert x_factor == pytest.approx(1.0) + assert y_factor == pytest.approx(1.0) + def test_repr(self, parameter_analysis): # WHEN repr_str = repr(parameter_analysis) @@ -883,3 +1149,102 @@ def test_repr(self, parameter_analysis): assert f'n_parameters={len(parameter_analysis.parameters)}' in repr_str assert 'parameter_names=' in repr_str assert 'bindings=' in repr_str + + +class TestParameterAnalysisWorkflows: + """End-to-end fits for the standard workflows on synthetic data.""" + + @staticmethod + def _dataset_from_targets(model, Q, unit_overrides=None): + """Build a parameters Dataset from a model's own predictions (no noise).""" + unit_overrides = unit_overrides or {} + Q_coord = sc.array(dims=['Q'], values=Q, unit='1/angstrom') + data = {} + for target in model.get_fit_targets(): + values = target.function(Q) + unit = unit_overrides.get(target.name, target.y_unit) + factor = sc.to_unit(sc.scalar(1.0, unit=target.y_unit), unit).value + data[target.dataset_key] = sc.DataArray( + data=sc.array( + dims=['Q'], + values=values * factor, + variances=(0.01 * values * factor) ** 2, + unit=unit, + ), + coords={'Q': Q_coord}, + ) + return sc.Dataset(data) + + def test_delta_lorentz_three_target_simultaneous_fit(self): + # WHEN: synthetic width, area, and delta area curves from a known DeltaLorentz + from easydynamics.sample_model.diffusion_model.delta_lorentz import DeltaLorentz + + Q = np.linspace(0.4, 2.0, 9) + truth = DeltaLorentz(scale=2.0, mean_u_squared=0.3, A_0=0.6, lorentzian_width=0.12) + dataset = self._dataset_from_targets(truth, Q) + + # THEN: fit a fresh DeltaLorentz to all three dataset keys simultaneously + fit_model = DeltaLorentz(scale=1.5, mean_u_squared=0.2, A_0=0.5, lorentzian_width=0.1) + pa = ParameterAnalysis(parameters=dataset, bindings=FitBinding(model=fit_model)) + pa.fit() + + # EXPECT: the shared parameters are recovered across all three curves + assert fit_model.scale.value == pytest.approx(2.0, rel=1e-3) + assert fit_model.mean_u_squared.value == pytest.approx(0.3, rel=1e-3) + assert fit_model.A_0.value == pytest.approx(0.6, rel=1e-3) + assert fit_model.lorentzian_width.value == pytest.approx(0.12, rel=1e-3) + + def test_jump_diffusion_width_only_fit(self): + # WHEN: synthetic widths from a known jump diffusion model + from easydynamics.sample_model.diffusion_model.jump_translational_diffusion import ( + JumpTranslationalDiffusion, + ) + + Q = np.linspace(0.4, 2.0, 9) + truth = JumpTranslationalDiffusion(diffusion_coefficient=2.4e-9, relaxation_time=2.0) + dataset = self._dataset_from_targets(truth, Q) + + # THEN: fit only the width prediction + fit_model = JumpTranslationalDiffusion(diffusion_coefficient=1.0e-9, relaxation_time=1.0) + fit_model.scale.fixed = True + pa = ParameterAnalysis( + parameters=dataset, bindings=FitBinding(model=fit_model, targets=['width']) + ) + pa.fit() + + # EXPECT + assert fit_model.diffusion_coefficient.value == pytest.approx(2.4e-9, rel=1e-3) + assert fit_model.relaxation_time.value == pytest.approx(2.0, rel=1e-3) + + def test_diffusion_fit_converts_dataset_units(self): + # WHEN: the dataset stores widths in ueV and Q in 1/nm, while the model works in meV + # and 1/angstrom (regression: diffusion fits used to ignore units entirely, silently + # misfitting scaled data) + Q = np.linspace(0.4, 2.0, 9) + truth = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9) + width_target = {t.name: t for t in truth.get_fit_targets()}['width'] + widths_mev = width_target.function(Q) + + Q_coord = sc.array(dims=['Q'], values=Q * 10, unit='1/nm') # 1 1/angstrom = 10 1/nm + dataset = sc.Dataset({ + width_target.dataset_key: sc.DataArray( + data=sc.array( + dims=['Q'], + values=widths_mev * 1000, # meV -> ueV + variances=(0.01 * widths_mev * 1000) ** 2, + unit='ueV', + ), + coords={'Q': Q_coord}, + ) + }) + + # THEN + fit_model = BrownianTranslationalDiffusion(diffusion_coefficient=1.0e-9) + fit_model.scale.fixed = True + pa = ParameterAnalysis( + parameters=dataset, bindings=FitBinding(model=fit_model, targets=['width']) + ) + pa.fit() + + # EXPECT: the diffusion coefficient is recovered despite the unit differences + assert fit_model.diffusion_coefficient.value == pytest.approx(2.4e-9, rel=1e-3) diff --git a/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py b/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py index 6767640a4..c205c9959 100644 --- a/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py +++ b/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py @@ -13,7 +13,7 @@ class TestEasyDynamicsModelBase: def easy_dynamics_modelbase(self): """Fixture for creating an instance of EasyDynamicsModelBase.""" - return EasyDynamicsModelBase(name='TestModel', unit='meV') + return EasyDynamicsModelBase(name='TestModel', x_unit='meV') def test_initialization(self, easy_dynamics_modelbase): """Test that the EasyDynamicsModelBase is initialized correctly.""" @@ -22,6 +22,8 @@ def test_initialization(self, easy_dynamics_modelbase): assert easy_dynamics_modelbase.name == 'TestModel' assert easy_dynamics_modelbase.display_name == 'TestModel' assert easy_dynamics_modelbase.unique_name is not None + assert easy_dynamics_modelbase.x_unit == 'meV' + assert easy_dynamics_modelbase.y_unit == 'dimensionless' def test_init_raises_type_error_for_invalid_name(self): """Test that initializing with an invalid name raises a TypeError.""" @@ -68,9 +70,15 @@ def test_name_setter_invalid_type(self, easy_dynamics_modelbase, invalid_name): def test_unit_property(self, easy_dynamics_modelbase): # WHEN THEN EXPECT - assert easy_dynamics_modelbase.unit == 'meV' + assert easy_dynamics_modelbase.x_unit == 'meV' + assert easy_dynamics_modelbase.y_unit == 'dimensionless' - def test_unit_setter_raises(self, easy_dynamics_modelbase): + def test_x_unit_setter_raises(self, easy_dynamics_modelbase): # WHEN / THEN / EXPECT - with pytest.raises(AttributeError, match='Use convert_unit to change '): - easy_dynamics_modelbase.unit = 'K' + with pytest.raises(AttributeError, match='read-only'): + easy_dynamics_modelbase.x_unit = 'K' + + def test_y_unit_setter_raises(self, easy_dynamics_modelbase): + # WHEN THEN EXPECT + with pytest.raises(AttributeError, match='read-only'): + easy_dynamics_modelbase.y_unit = '1/meV' diff --git a/tests/unit/easydynamics/convolution/test_analytical_convolution.py b/tests/unit/easydynamics/convolution/test_analytical_convolution.py index c9d3f46e2..55c6cc297 100644 --- a/tests/unit/easydynamics/convolution/test_analytical_convolution.py +++ b/tests/unit/easydynamics/convolution/test_analytical_convolution.py @@ -289,7 +289,6 @@ def test_convolute_analytic_pair_delta(self, default_analytical_convolution, fun """ # WHEN THEN delta_function = DeltaFunction(area=2.0, center=0.5) - function1 = Gaussian(area=3.0, center=-1.0, width=1.0) convoluted = default_analytical_convolution._convolute_analytic_pair( delta_function, function1 diff --git a/tests/unit/easydynamics/convolution/test_convolution.py b/tests/unit/easydynamics/convolution/test_convolution.py index dbb4290d4..2c4c5da3c 100644 --- a/tests/unit/easydynamics/convolution/test_convolution.py +++ b/tests/unit/easydynamics/convolution/test_convolution.py @@ -73,7 +73,8 @@ def test_init(self, default_convolution): assert default_convolution.upsample_factor == 5 assert default_convolution.extension_factor == pytest.approx(0.2) assert default_convolution.temperature is None - assert default_convolution.unit == 'meV' + assert default_convolution.x_unit == 'meV' + assert default_convolution.y_unit == 'dimensionless' assert default_convolution.detailed_balance_settings.normalize_detailed_balance is True assert isinstance(default_convolution._energy_grid, EnergyGrid) @@ -107,7 +108,8 @@ def test_init_components(self, convolution_with_components): assert convolution_with_components.upsample_factor == 5 assert convolution_with_components.extension_factor == pytest.approx(0.2) assert convolution_with_components.temperature is None - assert convolution_with_components.unit == 'meV' + assert convolution_with_components.x_unit == 'meV' + assert convolution_with_components.y_unit == 'dimensionless' assert ( convolution_with_components.detailed_balance_settings.normalize_detailed_balance is True @@ -403,7 +405,7 @@ def test_check_if_pair_is_analytic_raises_with_delta_in_resolution(self, default ) @pytest.mark.parametrize('delta_component', [True, False], ids=['with_delta', 'without_delta']) @pytest.mark.parametrize( - 'temperature', [None, 100], ids=['with_temperature', 'without_temperature'] + 'temperature', [None, 100], ids=['without_temperature', 'with_temperature'] ) def test_build_convolution_plan( self, @@ -577,3 +579,54 @@ def test_setattr_invalidates_plan_for_tracked_attribute( # EXPECT assert conv.convolution_settings.convolution_plan_is_valid is False + + def test_convert_y_unit_propagates_to_sub_convolvers(self): + # WHEN: Convolution with Gaussian sample and resolution components, + # both with y_unit='1/meV' + energy = np.linspace(-10, 10, 5001) + sample_components = ComponentCollection() + sample_components.append_component( + Gaussian(name='G', area=1.0, center=0.0, width=0.5, x_unit='meV', y_unit='1/meV') + ) + resolution_components = ComponentCollection() + resolution_components.append_component( + Gaussian(name='R', area=1.0, center=0.0, width=0.3, x_unit='meV', y_unit='1/meV') + ) + conv = Convolution( + energy=energy, + sample_components=sample_components, + resolution_components=resolution_components, + ) + + # THEN: convert y_unit to '1/eV' + conv.convert_y_unit('1/eV') + + # EXPECT: y_unit updated and propagated to sub-convolvers via Convolution.convert_y_unit + assert conv.y_unit == '1/eV' + assert conv._analytical_convolver._y_unit == '1/eV' + + def test_convert_y_unit_propagates_to_numerical_convolver(self): + # WHEN: a DHO sample component forces a numerical convolver + energy = np.linspace(-10, 10, 5001) + sample_components = ComponentCollection() + sample_components.append_component( + DampedHarmonicOscillator( + name='DHO', area=1.0, center=1.0, width=0.5, x_unit='meV', y_unit='1/meV' + ) + ) + resolution_components = ComponentCollection() + resolution_components.append_component( + Gaussian(name='R', area=1.0, center=0.0, width=0.3, x_unit='meV', y_unit='1/meV') + ) + conv = Convolution( + energy=energy, + sample_components=sample_components, + resolution_components=resolution_components, + ) + + # THEN + conv.convert_y_unit('1/eV') + + # EXPECT + assert conv.y_unit == '1/eV' + assert conv._numerical_convolver._y_unit == '1/eV' diff --git a/tests/unit/easydynamics/convolution/test_convolution_base.py b/tests/unit/easydynamics/convolution/test_convolution_base.py index fe8ceef62..7bc26fff7 100644 --- a/tests/unit/easydynamics/convolution/test_convolution_base.py +++ b/tests/unit/easydynamics/convolution/test_convolution_base.py @@ -32,6 +32,8 @@ def test_init(self, convolution_base): assert np.allclose(convolution_base.energy.values, np.linspace(-10, 10, 100)) assert isinstance(convolution_base._sample_components, ComponentCollection) assert isinstance(convolution_base._resolution_components, ComponentCollection) + assert convolution_base.x_unit == 'meV' + assert convolution_base.y_unit == 'dimensionless' def test_init_with_model_component(self): # WHEN @@ -78,7 +80,7 @@ def test_init_energy_numerical_none_offset(self): 'energy': 'invalid', 'sample_components': ComponentCollection(), 'resolution_components': ComponentCollection(), - 'unit': 'meV', + 'x_unit': 'meV', 'energy_offset': 0, }, 'Energy must be', @@ -88,7 +90,7 @@ def test_init_energy_numerical_none_offset(self): 'energy': np.linspace(-10, 10, 100), 'sample_components': 'invalid', 'resolution_components': ComponentCollection(), - 'unit': 'meV', + 'x_unit': 'meV', 'energy_offset': 0, }, ( @@ -101,7 +103,7 @@ def test_init_energy_numerical_none_offset(self): 'energy': np.linspace(-10, 10, 100), 'sample_components': ComponentCollection(), 'resolution_components': 'invalid', - 'unit': 'meV', + 'x_unit': 'meV', 'energy_offset': 0, }, ( @@ -114,7 +116,7 @@ def test_init_energy_numerical_none_offset(self): 'energy': np.linspace(-10, 10, 100), 'sample_components': ComponentCollection(), 'resolution_components': ComponentCollection(), - 'unit': 123, + 'x_unit': 123, 'energy_offset': 0, }, 'unit must be ', @@ -124,7 +126,17 @@ def test_init_energy_numerical_none_offset(self): 'energy': np.linspace(-10, 10, 100), 'sample_components': ComponentCollection(), 'resolution_components': ComponentCollection(), - 'unit': 'meV', + 'y_unit': 123, + 'energy_offset': 0, + }, + 'unit must be ', + ), + ( + { + 'energy': np.linspace(-10, 10, 100), + 'sample_components': ComponentCollection(), + 'resolution_components': ComponentCollection(), + 'x_unit': 'meV', 'energy_offset': 'invalid', }, 'Energy_offset must be ', @@ -181,17 +193,17 @@ def test_unit_setter_raises(self, convolution_base): # WHEN THEN EXPECT with pytest.raises( AttributeError, - match=r'Use convert_unit to change the unit between allowed types ', + match=r'read-only', ): - convolution_base.unit = 'K' + convolution_base.x_unit = 'K' def test_convert_unit(self, convolution_base): # WHEN THEN - convolution_base.convert_unit('eV') + convolution_base.convert_x_unit('eV') # EXPECT assert convolution_base.energy.unit == 'eV' - assert convolution_base.unit == 'eV' + assert convolution_base.x_unit == 'eV' assert np.allclose(convolution_base.energy.values, np.linspace(-0.01, 0.01, 100)) def test_convert_unit_invalid_type_raises(self, convolution_base): @@ -200,7 +212,7 @@ def test_convert_unit_invalid_type_raises(self, convolution_base): TypeError, match=r'Energy unit must be a string or scipp unit.', ): - convolution_base.convert_unit(123) + convolution_base.convert_x_unit(123) def test_convert_unit_invalid_unit_rollback(self, convolution_base): # WHEN THEN @@ -208,10 +220,10 @@ def test_convert_unit_invalid_unit_rollback(self, convolution_base): UnitError, match=r'Conversion from `meV` to `s` is not valid.', ): - convolution_base.convert_unit('s') + convolution_base.convert_x_unit('s') # EXPECT - assert convolution_base.unit == 'meV' + assert convolution_base.x_unit == 'meV' assert np.allclose(convolution_base.energy.values, np.linspace(-10, 10, 100)) def test_convert_unit_invalid_offset_unit_rollback(self, convolution_base): @@ -223,10 +235,10 @@ def test_convert_unit_invalid_offset_unit_rollback(self, convolution_base): UnitError, match=r'Conversion from `s` to `meV` is not valid.', ): - convolution_base.convert_unit('meV') + convolution_base.convert_x_unit('meV') # EXPECT - assert convolution_base.unit == 'meV' + assert convolution_base.x_unit == 'meV' assert convolution_base.energy_offset.unit == 's' def test_energy_offset_property(self, convolution_base): @@ -255,6 +267,21 @@ def test_energy_with_offset_setter_raises(self, convolution_base): with pytest.raises(AttributeError): convolution_base.energy_with_offset = 5 + def test_energy_with_offset_unit_conversion(self, convolution_base): + # WHEN: energy is in meV and energy_offset given in eV (0.001 eV = 1 meV) + convolution_base.energy_offset = Parameter( + name='energy_offset', + value=0.001, + unit='eV', # 0.001 eV = 1 meV + ) + + # THEN + result = convolution_base.energy_with_offset + + # EXPECT: offset is converted to meV before subtracting → shifted by -1 meV + expected = convolution_base.energy.values - 1.0 + np.testing.assert_allclose(result.values, expected, rtol=1e-5) + def test_sample_components_property(self, convolution_base): # WHEN THEN EXPECT assert isinstance(convolution_base.sample_components, ComponentCollection) @@ -303,3 +330,76 @@ def test_resolution_components_setter_invalid_type_raises(self, convolution_base ), ): convolution_base.resolution_components = 'invalid' + + def test_y_unit_setter_raises(self, convolution_base): + # WHEN THEN EXPECT + with pytest.raises(AttributeError, match='read-only'): + convolution_base.y_unit = '1/meV' + + def test_convert_y_unit(self, convolution_base): + # WHEN: sample component with y_unit='1/meV' + convolution_base.sample_components.append_component( + Gaussian(area=1.0, x_unit='meV', y_unit='1/meV') + ) + + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + convolution_base.convert_y_unit('1/eV') + + # EXPECT: y_unit updated and propagated to sample components + assert convolution_base.y_unit == '1/eV' + for component in convolution_base.sample_components: + assert component.y_unit == '1/eV' + + def test_convert_y_unit_invalid_type_raises(self, convolution_base): + # WHEN THEN EXPECT + with pytest.raises(TypeError): + convolution_base.convert_y_unit(123) + + def test_convert_y_unit_rollback_on_failure(self, convolution_base): + # WHEN: sample component with default y_unit='dimensionless' + convolution_base.sample_components.append_component(Gaussian(area=1.0, x_unit='meV')) + + # THEN EXPECT: 'K' is dimensionally incompatible → triggers rollback + with pytest.raises(UnitError): + convolution_base.convert_y_unit('K') + + assert convolution_base.y_unit == 'dimensionless' + + def test_convert_x_unit_without_sample_components(self): + # WHEN + cb = ConvolutionBase(energy=np.linspace(-10, 10, 100), sample_components=None) + + # THEN + cb.convert_x_unit('eV') + + # EXPECT + assert cb.energy.unit == 'eV' + assert np.allclose(cb.energy.values, np.linspace(-0.01, 0.01, 100)) + assert cb.x_unit == 'eV' + + def test_convert_x_unit_without_resolution_components(self): + # WHEN + sample = ComponentCollection(display_name='Sample') + cb = ConvolutionBase( + energy=np.linspace(-10, 10, 100), + sample_components=sample, + resolution_components=None, + ) + + # THEN + cb.convert_x_unit('eV') + + # EXPECT + assert cb.energy.unit == 'eV' + assert np.allclose(cb.energy.values, np.linspace(-0.01, 0.01, 100)) + assert cb.x_unit == 'eV' + + def test_convert_y_unit_without_sample_components(self): + # WHEN + cb = ConvolutionBase(energy=np.linspace(-10, 10, 100), sample_components=None) + + # THEN + cb.convert_y_unit('1/meV') + + # EXPECT + assert cb.y_unit == '1/meV' diff --git a/tests/unit/easydynamics/convolution/test_energy_grid.py b/tests/unit/easydynamics/convolution/test_energy_grid.py index 01cf19d40..e0390b3ee 100644 --- a/tests/unit/easydynamics/convolution/test_energy_grid.py +++ b/tests/unit/easydynamics/convolution/test_energy_grid.py @@ -8,12 +8,14 @@ class TestEnergyGrid: def test_energy_grid_attributes(self): + # WHEN energy_dense = np.array([-1.0, 0.0, 1.0]) energy_dense_centered = np.array([-1.0, 0.0, 1.0]) energy_dense_step = 1.0 energy_span_dense = 2.0 energy_even_length_offset = 0.0 + # THEN energy_grid = EnergyGrid( energy_dense=energy_dense, energy_span_dense=energy_span_dense, @@ -22,6 +24,7 @@ def test_energy_grid_attributes(self): energy_dense_step=energy_dense_step, ) + # EXPECT assert np.array_equal(energy_grid.energy_dense, energy_dense) assert np.array_equal(energy_grid.energy_dense_centered, energy_dense_centered) assert energy_grid.energy_dense_step == energy_dense_step diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution.py b/tests/unit/easydynamics/convolution/test_numerical_convolution.py index 755381c24..6e43ce017 100644 --- a/tests/unit/easydynamics/convolution/test_numerical_convolution.py +++ b/tests/unit/easydynamics/convolution/test_numerical_convolution.py @@ -4,33 +4,86 @@ import numpy as np import pytest import scipp as sc +from easyscience.variable import Parameter from scipy.signal import fftconvolve from easydynamics.convolution.energy_grid import EnergyGrid from easydynamics.convolution.numerical_convolution import NumericalConvolution from easydynamics.sample_model import Gaussian from easydynamics.sample_model.component_collection import ComponentCollection +from easydynamics.settings.convolution_settings import ConvolutionSettings from easydynamics.utils.detailed_balance import detailed_balance_factor +def _make_numerical_convolution(convolution_settings=None): + energy = np.linspace(-10, 10, 5001) + sample_components = ComponentCollection(display_name='ComponentCollection') + sample_components.append_component(Gaussian(name='Gaussian1', area=2.0, center=0.1, width=0.4)) + resolution_components = ComponentCollection(display_name='ResolutionModel') + resolution_components.append_component( + Gaussian(name='GaussianRes', area=3.0, center=0.2, width=0.5) + ) + + return NumericalConvolution( + energy=energy, + sample_components=sample_components, + resolution_components=resolution_components, + convolution_settings=convolution_settings, + ) + + class TestNumericalConvolution: @pytest.fixture def default_numerical_convolution(self): - energy = np.linspace(-10, 10, 5001) - sample_components = ComponentCollection(display_name='ComponentCollection') - sample_components.append_component( - Gaussian(name='Gaussian1', area=2.0, center=0.1, width=0.4) - ) - resolution_components = ComponentCollection(display_name='ResolutionModel') - resolution_components.append_component( - Gaussian(name='GaussianRes', area=3.0, center=0.2, width=0.5) - ) + return _make_numerical_convolution() - return NumericalConvolution( - energy=energy, - sample_components=sample_components, - resolution_components=resolution_components, - ) + def test_settings_change_invalidates_all_sharing_convolvers(self): + # WHEN: two convolvers sharing one ConvolutionSettings, and an accuracy knob changes + settings = ConvolutionSettings() + convolver_a = _make_numerical_convolution(settings) + convolver_b = _make_numerical_convolution(settings) + settings.upsample_factor = 10 + + # THEN: A consumes the change by rebuilding its plan + convolver_a.convolution() + + # EXPECT: A is current, but B still sees the stale plan (regression: A's rebuild + # setting the shared flag used to mask the settings change from B) + assert convolver_a._convolution_plan_is_current() is True + assert convolver_b._convolution_plan_is_current() is False + + # and B becomes current after its own rebuild + convolver_b.convolution() + assert convolver_b._convolution_plan_is_current() is True + + def test_manual_plan_valid_true_blesses_all_sharing_convolvers(self): + # WHEN: a settings change followed by an explicit convolution_plan_is_valid = True + settings = ConvolutionSettings() + convolver_a = _make_numerical_convolution(settings) + convolver_b = _make_numerical_convolution(settings) + settings.upsample_factor = 10 + + # THEN: the escape hatch suppresses rebuilds for every convolver sharing the settings + settings.convolution_plan_is_valid = True + + # EXPECT + assert convolver_a._convolution_plan_is_current() is True + assert convolver_b._convolution_plan_is_current() is True + + def test_manual_plan_valid_false_invalidates_all_sharing_convolvers(self): + # WHEN: two current convolvers sharing one ConvolutionSettings + settings = ConvolutionSettings() + convolver_a = _make_numerical_convolution(settings) + convolver_b = _make_numerical_convolution(settings) + + # THEN: an explicit False invalidates the plan for every convolver, and one + # convolver rebuilding does not mask the invalidation from the other + settings.convolution_plan_is_valid = False + convolver_a.convolution() + + # EXPECT + assert convolver_a._convolution_plan_is_current() is True + assert convolver_b._convolution_plan_is_current() is False def test_init(self, default_numerical_convolution): """ @@ -48,7 +101,8 @@ def test_init(self, default_numerical_convolution): assert default_numerical_convolution.upsample_factor == 5 assert default_numerical_convolution.extension_factor == pytest.approx(0.2) assert default_numerical_convolution.temperature is None - assert default_numerical_convolution.unit == 'meV' + assert default_numerical_convolution.x_unit == 'meV' + assert default_numerical_convolution.y_unit == 'dimensionless' assert ( default_numerical_convolution.detailed_balance_settings.normalize_detailed_balance is True @@ -233,3 +287,53 @@ def fake_interp(*args, **kwargs): # noqa: ARG001 assert result.shape == conv.energy.values.shape else: assert result.shape == dense.shape + + def test_convolution_with_energy_offset_in_different_unit(self, default_numerical_convolution): + # WHEN: energy grid is in meV, offset given as a Parameter in eV (0.001 eV = 1.0 meV) + conv = default_numerical_convolution + conv.energy_offset = Parameter(name='energy_offset', value=0.001, unit='eV') + result_ev = conv.convolution() + + # THEN: run again with the numerically equivalent offset as an explicit meV Parameter + conv.energy_offset = Parameter(name='energy_offset', value=1.0, unit='meV') + result_mev = conv.convolution() + + # EXPECT: unit conversion is transparent — results are numerically identical + np.testing.assert_allclose(result_ev, result_mev, rtol=1e-6) + + def test_repr(self, default_numerical_convolution): + r = repr(default_numerical_convolution) + assert 'NumericalConvolution' in r + assert 'energy_len=' in r + + # ───── Regression tests ───── + + def test_detailed_balance_energy_includes_even_length_offset( + self, default_numerical_convolution, monkeypatch + ): + # WHEN: even-length energy grid with temperature + nc = default_numerical_convolution + nc.energy = np.linspace(-10, 10, 100) # even-length → non-zero energy_even_length_offset + nc.temperature = 300.0 + + captured = {} + + # spy_dbf wraps detailed_balance_factor: captures the energy argument passed to it, + # then delegates to the real function so the convolution still produces a valid result. + def spy_dbf(**kwargs): + captured['energy'] = kwargs['energy'] + return detailed_balance_factor(**kwargs) + + monkeypatch.setattr( + 'easydynamics.convolution.numerical_convolution.detailed_balance_factor', spy_dbf + ) + + # THEN: run convolution + nc.convolution() + + # EXPECT: DBF receives energy_dense - energy_even_length_offset + # Before the fix, energy_even_length_offset was omitted, causing a half-bin error. + grid = nc._energy_grid + np.testing.assert_allclose( + captured['energy'], grid.energy_dense - grid.energy_even_length_offset, atol=1e-12 + ) diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py index 73bcee0bd..2d64e0da1 100644 --- a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py +++ b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py @@ -48,7 +48,8 @@ def test_init(self, default_numerical_convolution_base): assert default_numerical_convolution_base.upsample_factor == 5 assert default_numerical_convolution_base.extension_factor == pytest.approx(0.2) assert default_numerical_convolution_base.temperature is None - assert default_numerical_convolution_base.unit == 'meV' + assert default_numerical_convolution_base.x_unit == 'meV' + assert default_numerical_convolution_base.y_unit == 'dimensionless' assert ( default_numerical_convolution_base.detailed_balance_settings.normalize_detailed_balance is True @@ -79,7 +80,7 @@ def test_init_with_custom_parameters(self): detailed_balance_settings=detailed_balance_settings, temperature=temperature, temperature_unit=temperature_unit, - unit=unit, + x_unit=unit, ) # EXPECT @@ -87,7 +88,8 @@ def test_init_with_custom_parameters(self): assert numerical_convolution_base.extension_factor == pytest.approx(0.5) assert numerical_convolution_base.temperature.value == temperature assert numerical_convolution_base.temperature.unit == temperature_unit - assert numerical_convolution_base.unit == unit + assert numerical_convolution_base.x_unit == unit + assert numerical_convolution_base.y_unit == 'dimensionless' assert ( numerical_convolution_base.detailed_balance_settings.normalize_detailed_balance is False @@ -489,9 +491,11 @@ def test_create_energy_grid_upsample_none_non_uniform_raises( and non-uniform energy raises ValueError (the energy grid must always be uniform). """ - # WHEN + # WHEN: non-uniform energy with upsample_factor=None default_numerical_convolution_base.energy = np.array([0, 1, 3, 6, 10]) default_numerical_convolution_base.upsample_factor = None + + # THEN EXPECT with pytest.raises( ValueError, match='Input array `energy` must be uniformly spaced if upsample_factor is not given', @@ -638,12 +642,12 @@ def test_repr(self, default_numerical_convolution_base): # Sample and resolution models assert 'ComponentCollection' in repr_str - assert 'components=[No components]' in repr_str + assert 'Components: No components' in repr_str assert 'sample_components=' in repr_str assert 'resolution_components=' in repr_str # Important parameters - assert 'unit=meV' in repr_str + assert 'x_unit=meV' in repr_str assert 'upsample_factor=5' in repr_str assert 'extension_factor=0.2' in repr_str assert 'temperature=None' in repr_str diff --git a/tests/unit/easydynamics/experiment/test_experiment.py b/tests/unit/easydynamics/experiment/test_experiment.py index 7ffc83c87..ed360d94e 100644 --- a/tests/unit/easydynamics/experiment/test_experiment.py +++ b/tests/unit/easydynamics/experiment/test_experiment.py @@ -142,6 +142,18 @@ def test_load_hdf5_invalid_file_raises(self, experiment): with pytest.raises(OSError): experiment.load_hdf5('non_existent_file.h5') + def test_load_hdf5_invalid_data_type_raises(self, experiment): + "Test that loading a file that returns a non-DataArray raises TypeError" + # WHEN THEN EXPECT + with ( + patch( + 'easydynamics.experiment.experiment.sc_load_hdf5', + return_value=sc.scalar(1.0), + ), + pytest.raises(TypeError, match=r'sc\.DataArray'), + ): + experiment.load_hdf5('fake_file.h5') + def test_save_hdf5(self, tmp_path, experiment): "Test saving data to an HDF5 file. Load the saved file" 'using scipp and compare to the original data.' @@ -319,15 +331,55 @@ def test_get_masked_energy_no_data_returns_None(self): @pytest.mark.parametrize( 'Q_index', - [-1, 100, 'not an index'], - ids=['negative_index', 'out_of_bounds_index', 'invalid_type'], + [-1, 100], + ids=['negative_index', 'out_of_bounds_index'], ) def test_get_masked_energy_invalid_Q_index_raises(self, experiment_with_data, Q_index): - "Test getting masked energy raises IndexError when Q index is invalid" + "Test getting masked energy raises IndexError when Q index is out of range" # WHEN THEN EXPECT with pytest.raises(IndexError): experiment_with_data.get_masked_energy(Q_index=Q_index) + def test_get_masked_energy_invalid_type_raises(self, experiment_with_data): + "Test getting masked energy raises TypeError when Q index is not an integer" + # WHEN THEN EXPECT + with pytest.raises(TypeError): + experiment_with_data.get_masked_energy(Q_index='not an index') + + def test_get_finite_energy_mask_returns_none_without_data(self): + "Test get_finite_energy_mask returns None when no data is loaded" + # WHEN THEN EXPECT + assert Experiment().get_finite_energy_mask(Q_index=0) is None + + def test_get_masked_binned_data_returns_none_without_data(self): + "Test get_masked_binned_data returns None when no data is loaded" + # WHEN THEN EXPECT + assert Experiment().get_masked_binned_data(Q_index=0) is None + + def test_get_finite_energy_mask_with_data(self, experiment_with_data): + "Test get_finite_energy_mask returns a boolean scipp variable when data is loaded" + # WHEN + + # THEN + mask = experiment_with_data.get_finite_energy_mask(Q_index=0) + + # EXPECT + assert mask is not None + assert mask.dims == ('energy',) + assert len(mask) == 3 + + def test_get_masked_binned_data_with_data(self, experiment_with_data): + "Test get_masked_binned_data returns a DataArray with all-finite data loaded" + # WHEN + + # THEN + result = experiment_with_data.get_masked_binned_data(Q_index=0) + + # EXPECT + assert result is not None + assert isinstance(result, sc.DataArray) + assert result.dims == ('energy',) + ############## # test plotting ############## @@ -496,8 +548,8 @@ def test_extract_x_y_var(self, experiment_with_data): experiment_with_data.data.variances[Q_index], ) - def test_extract_x_y_weights_only_finite_zero_variances(self, experiment_with_data): - "Test that _extract_x_y_weights_only_finite raises ValueError when variances contain zeros" + def testextract_x_y_weights_only_finite_zero_variances(self, experiment_with_data): + "Test that extract_x_y_weights_only_finite raises ValueError when variances contain zeros" # WHEN Q_index = 0 invalid_data = experiment_with_data._data.copy() @@ -507,10 +559,10 @@ def test_extract_x_y_weights_only_finite_zero_variances(self, experiment_with_da # THEN EXPECT with pytest.raises(ValueError, match='Cannot compute weights: some variances are zero'): - Experiment(data=invalid_data)._extract_x_y_weights_only_finite(Q_index=Q_index) + Experiment(data=invalid_data).extract_x_y_weights_only_finite(Q_index=Q_index) - def test_extract_x_y_weights_only_finite(self, experiment_with_data): - "Test that _extract_x_y_weights_only_finite only returns finite values" + def testextract_x_y_weights_only_finite(self, experiment_with_data): + "Test that extract_x_y_weights_only_finite only returns finite values" # WHEN Q_index = 0 invalid_data = experiment_with_data._data.copy() @@ -518,7 +570,7 @@ def test_extract_x_y_weights_only_finite(self, experiment_with_data): invalid_data.data.variances[Q_index][1] = np.nan # THEN - x, y, weights, mask = Experiment(data=invalid_data)._extract_x_y_weights_only_finite( + x, y, weights, mask = Experiment(data=invalid_data).extract_x_y_weights_only_finite( Q_index=Q_index ) @@ -533,7 +585,7 @@ def test_extract_x_y_weights_only_finite(self, experiment_with_data): # Mask should indicate which values were removed assert np.array_equal(mask, [False, False, True]) - def test_extract_x_y_weights_only_finite_zero_variance(self, experiment_with_data): + def testextract_x_y_weights_only_finite_zero_variance(self, experiment_with_data): "Test getting x y and weights when variances are None" # WHEN Q_index = 0 @@ -543,9 +595,7 @@ def test_extract_x_y_weights_only_finite_zero_variance(self, experiment_with_dat experiment_with_data.data = data # THEN - x, y, weights, mask = experiment_with_data._extract_x_y_weights_only_finite( - Q_index=Q_index - ) + x, y, weights, mask = experiment_with_data.extract_x_y_weights_only_finite(Q_index=Q_index) # EXPECT assert np.array_equal(x, experiment_with_data.energy.values) diff --git a/tests/unit/easydynamics/sample_model/components/test_damped_harmonic_oscillator.py b/tests/unit/easydynamics/sample_model/components/test_damped_harmonic_oscillator.py index 02fc31dfa..12f73235f 100644 --- a/tests/unit/easydynamics/sample_model/components/test_damped_harmonic_oscillator.py +++ b/tests/unit/easydynamics/sample_model/components/test_damped_harmonic_oscillator.py @@ -5,7 +5,9 @@ import numpy as np import pytest +import scipp as sc from easyscience.variable import Parameter +from scipp import UnitError from scipy.integrate import simpson from easydynamics.sample_model import DampedHarmonicOscillator @@ -20,7 +22,7 @@ def dho(self): area=2.0, center=1.5, width=0.3, - unit='meV', + x_unit='meV', ) def test_init_no_inputs(self): @@ -32,7 +34,8 @@ def test_init_no_inputs(self): assert dho.area.value == pytest.approx(1.0) assert dho.center.value == pytest.approx(1.0) assert dho.width.value == pytest.approx(1.0) - assert dho.unit == 'meV' + assert dho.x_unit == 'meV' + assert dho.y_unit == 'dimensionless' def test_initialization(self, dho: DampedHarmonicOscillator): # WHEN THEN EXPECT @@ -40,26 +43,36 @@ def test_initialization(self, dho: DampedHarmonicOscillator): assert dho.area.value == pytest.approx(2.0) assert dho.center.value == pytest.approx(1.5) assert dho.width.value == pytest.approx(0.3) - assert dho.unit == 'meV' + assert dho.x_unit == 'meV' @pytest.mark.parametrize( 'kwargs, expected_message', [ ( - {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'unit': 'meV'}, + {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'x_unit': 'meV'}, 'area must be a number', ), ( - {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'unit': 'meV'}, + {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'x_unit': 'meV'}, 'center must be ', ), ( - {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'unit': 'meV'}, + {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'x_unit': 'meV'}, 'width must be a number', ), ( - {'area': 2.0, 'center': 0.5, 'width': 0.6, 'unit': 123}, - 'unit must be None', + {'area': 2.0, 'center': 0.5, 'width': 0.6, 'x_unit': 123}, + 'unit must be None, a string', + ), + ( + { + 'area': 2.0, + 'center': 0.5, + 'width': 0.6, + 'x_unit': 'meV', + 'y_unit': 123, + }, + 'unit must be None, a string', ), ], ) @@ -78,7 +91,7 @@ def test_negative_width_raises(self): area=2.0, center=0.5, width=-0.6, - unit='meV', + x_unit='meV', ) def test_negative_area_warns(self): @@ -89,7 +102,7 @@ def test_negative_area_warns(self): area=-2.0, center=0.5, width=0.6, - unit='meV', + x_unit='meV', ) @pytest.mark.parametrize( @@ -108,11 +121,15 @@ def test_property_setters( invalid_value, invalid_message, ): - # set valid + # WHEN + + # THEN : set a valid value setattr(dho, prop, valid_value) + + # EXPECT assert getattr(dho, prop).value == valid_value - # invalid + # WHEN: set an invalid value — THEN EXPECT with pytest.raises(TypeError, match=invalid_message): setattr(dho, prop, invalid_value) @@ -167,12 +184,12 @@ def test_area_matches_parameter(self, dho: DampedHarmonicOscillator): # EXPECT assert numerical_area == pytest.approx(dho.area.value, rel=2e-3) - def test_convert_unit(self, dho: DampedHarmonicOscillator): + def test_convert_x_unit(self, dho: DampedHarmonicOscillator): # WHEN THEN - dho.convert_unit('microeV') + dho.convert_x_unit('microeV') # EXPECT - assert dho.unit == 'microeV' + assert dho.x_unit == 'microeV' assert dho.area.value == pytest.approx(2 * 1e3) assert dho.center.value == pytest.approx(1.5 * 1e3) assert dho.width.value == pytest.approx(0.3 * 1e3) @@ -194,7 +211,7 @@ def test_copy(self, dho: DampedHarmonicOscillator): assert dho_copy.width.value == dho.width.value assert dho_copy.width.fixed == dho.width.fixed - assert dho_copy.unit == dho.unit + assert dho_copy.x_unit == dho.x_unit def test_repr(self, dho: DampedHarmonicOscillator): # WHEN THEN @@ -202,8 +219,91 @@ def test_repr(self, dho: DampedHarmonicOscillator): # EXPECT assert 'DampedHarmonicOscillator' in repr_str - assert "name='TestDHOName'" in repr_str - assert 'unit=meV' in repr_str - assert 'area=' in repr_str - assert 'center=' in repr_str - assert 'width=' in repr_str + assert 'name = TestDHOName' in repr_str + assert 'x_unit = meV' in repr_str + assert 'area =' in repr_str + assert 'center =' in repr_str + assert 'width =' in repr_str + + def test_y_unit_custom(self): + # WHEN THEN + dho = DampedHarmonicOscillator( + area=1.0, center=1.0, width=0.3, x_unit='meV', y_unit='1/meV' + ) + # EXPECT + assert dho.y_unit == '1/meV' + + def test_y_unit_setter_raises(self, dho: DampedHarmonicOscillator): + # WHEN THEN EXPECT + with pytest.raises(AttributeError): + dho.y_unit = '1/meV' + + def test_convert_y_unit(self): + # WHEN: x_unit='meV', y_unit='1/meV' → area_unit='dimensionless' + dho = DampedHarmonicOscillator( + area=1.0, center=1.0, width=0.3, x_unit='meV', y_unit='1/meV' + ) + + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + dho.convert_y_unit('1/eV') + + # EXPECT: y_unit updated and area value rescaled (1e3 factor) + assert dho.y_unit == '1/eV' + assert dho.area.value == pytest.approx(1e3) + + def test_convert_y_unit_invalid_type_raises(self, dho: DampedHarmonicOscillator): + # WHEN THEN EXPECT + with pytest.raises(TypeError): + dho.convert_y_unit(123) + + def test_evaluate_scipp_output(self, dho: DampedHarmonicOscillator): + # WHEN + x = np.linspace(0.5, 5.0, 50) + + # THEN + result = dho.evaluate(x, output='scipp') + + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('dimensionless') + assert len(result.values) == 50 + np.testing.assert_allclose(result.values, dho.evaluate(x, output='numpy')) + + def test_evaluate_scipp_output_with_y_unit(self): + # WHEN + dho = DampedHarmonicOscillator( + area=1.0, center=1.0, width=0.3, x_unit='meV', y_unit='1/meV' + ) + x = np.linspace(0.5, 5.0, 50) + + # THEN + result = dho.evaluate(x, output='scipp') + + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('1/meV') + + def test_convert_x_unit_invalid_type_raises(self, dho: DampedHarmonicOscillator): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match=r'x_unit must be a string or sc\.Unit'): + dho.convert_x_unit(123) + + def test_convert_x_unit_rollback_on_failure(self, dho: DampedHarmonicOscillator): + # WHEN THEN + with pytest.raises(UnitError): + dho.convert_x_unit('m') + # EXPECT: state rolled back + assert dho.x_unit == 'meV' + assert dho.area.value == pytest.approx(2.0) + assert dho.center.value == pytest.approx(1.5) + assert dho.width.value == pytest.approx(0.3) + + def test_convert_y_unit_rollback_on_failure(self): + # WHEN + dho = DampedHarmonicOscillator(area=1.0, center=1.0, width=0.3, x_unit='meV') + # THEN + with pytest.raises(UnitError): + dho.convert_y_unit('K') + # EXPECT: state rolled back + assert dho.y_unit == 'dimensionless' + assert dho.area.value == pytest.approx(1.0) diff --git a/tests/unit/easydynamics/sample_model/components/test_delta_function.py b/tests/unit/easydynamics/sample_model/components/test_delta_function.py index de3148b7d..50baf2317 100644 --- a/tests/unit/easydynamics/sample_model/components/test_delta_function.py +++ b/tests/unit/easydynamics/sample_model/components/test_delta_function.py @@ -20,7 +20,7 @@ def delta_function(self): display_name='TestDeltaFunction', area=2.0, center=0.5, - unit='meV', + x_unit='meV', ) def test_init_no_inputs(self): @@ -31,7 +31,8 @@ def test_init_no_inputs(self): assert delta_function.display_name == 'DeltaFunction' assert delta_function.area.value == pytest.approx(1.0) assert delta_function.center.value == pytest.approx(0.0) - assert delta_function.unit == 'meV' + assert delta_function.x_unit == 'meV' + assert delta_function.y_unit == 'dimensionless' assert delta_function.center.fixed is True def test_initialization(self, delta_function: DeltaFunction): @@ -39,21 +40,25 @@ def test_initialization(self, delta_function: DeltaFunction): assert delta_function.display_name == 'TestDeltaFunction' assert delta_function.area.value == pytest.approx(2.0) assert delta_function.center.value == pytest.approx(0.5) - assert delta_function.unit == 'meV' + assert delta_function.x_unit == 'meV' @pytest.mark.parametrize( 'kwargs, expected_message', [ ( - {'area': 'invalid', 'center': 0.5, 'unit': 'meV'}, + {'area': 'invalid', 'center': 0.5, 'x_unit': 'meV'}, 'area must be a number', ), ( - {'area': 2.0, 'center': 'invalid', 'unit': 'meV'}, + {'area': 2.0, 'center': 'invalid', 'x_unit': 'meV'}, 'center must be ', ), ( - {'area': 2.0, 'center': 0.5, 'unit': 123}, + {'area': 2.0, 'center': 0.5, 'x_unit': 123}, + 'unit must be ', + ), + ( + {'area': 2.0, 'center': 0.5, 'x_unit': 'meV', 'y_unit': 123}, 'unit must be ', ), ], @@ -65,7 +70,7 @@ def test_input_type_validation_raises(self, kwargs, expected_message): def test_negative_area_warns(self): # WHEN THEN EXPECT with pytest.warns(UserWarning, match='may not be physically meaningful'): - DeltaFunction(display_name='TestDeltaFunction', area=-2.0, center=0.5, unit='meV') + DeltaFunction(display_name='TestDeltaFunction', area=-2.0, center=0.5, x_unit='meV') @pytest.mark.parametrize( 'prop, valid_value, invalid_value, invalid_message', @@ -82,11 +87,13 @@ def test_property_setters( invalid_value, invalid_message, ): - # set valid + # WHEN: set a valid value setattr(delta_function, prop, valid_value) + + # THEN EXPECT assert getattr(delta_function, prop).value == valid_value - # invalid + # WHEN: set an invalid value — THEN EXPECT with pytest.raises(TypeError, match=invalid_message): setattr(delta_function, prop, invalid_value) @@ -147,7 +154,7 @@ def test_evaluate_with_incompatible_unit_raises(self, delta_function: DeltaFunct # THEN EXPECT with pytest.raises( UnitError, - match='Input x has unit nm, but DeltaFunction component ', + match='Input x has unit nm', ): delta_function.evaluate(x) @@ -190,12 +197,12 @@ def test_get_all_parameters(self, delta_function: DeltaFunction): actual_names = {param.name for param in params} assert actual_names == expected_names - def test_convert_unit(self, delta_function: DeltaFunction): + def test_convert_x_unit(self, delta_function: DeltaFunction): # WHEN THEN - delta_function.convert_unit('microeV') + delta_function.convert_x_unit('microeV') # EXPECT - assert delta_function.unit == 'microeV' + assert delta_function.x_unit == 'microeV' assert delta_function.area.value == pytest.approx(2 * 1e3) assert delta_function.center.value == pytest.approx(0.5 * 1e3) @@ -213,7 +220,7 @@ def test_copy(self, delta_function: DeltaFunction): assert delta_copy.center.value == delta_function.center.value assert delta_copy.center.fixed == delta_function.center.fixed - assert delta_copy.unit == delta_function.unit + assert delta_copy.x_unit == delta_function.x_unit def test_repr(self, delta_function: DeltaFunction): # WHEN THEN @@ -221,7 +228,103 @@ def test_repr(self, delta_function: DeltaFunction): # EXPECT assert 'DeltaFunction' in repr_str - assert "name='DeltaFunctionName'" in repr_str - assert 'unit=meV' in repr_str - assert 'area=' in repr_str - assert 'center=' in repr_str + assert 'name = DeltaFunctionName' in repr_str + assert 'x_unit = meV' in repr_str + assert 'area =' in repr_str + assert 'center =' in repr_str + + def test_y_unit_custom(self): + # WHEN THEN + delta = DeltaFunction(area=1.0, x_unit='meV', y_unit='1/meV') + + # EXPECT + assert delta.y_unit == '1/meV' + + def test_y_unit_setter_raises(self, delta_function: DeltaFunction): + # WHEN THEN EXPECT + with pytest.raises(AttributeError): + delta_function.y_unit = '1/meV' + + def test_convert_y_unit(self): + # WHEN: x_unit='meV', y_unit='1/meV' → area_unit='dimensionless' + delta = DeltaFunction(area=1.0, x_unit='meV', y_unit='1/meV') + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + delta.convert_y_unit('1/eV') + # EXPECT: y_unit updated and area value rescaled (1e3 factor) + assert delta.y_unit == '1/eV' + assert delta.area.value == pytest.approx(1e3) + + def test_convert_y_unit_invalid_type_raises(self, delta_function: DeltaFunction): + # WHEN THEN EXPECT + with pytest.raises(TypeError): + delta_function.convert_y_unit(123) + + def test_evaluate_scipp_output(self, delta_function: DeltaFunction): + # WHEN + x = np.linspace(-5, 5, 50) + # THEN + result = delta_function.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('dimensionless') + assert len(result.values) == 50 + np.testing.assert_allclose(result.values, delta_function.evaluate(x, output='numpy')) + + def test_evaluate_scipp_output_with_y_unit(self): + # WHEN + delta = DeltaFunction(area=1.0, x_unit='meV', y_unit='1/meV') + x = np.linspace(-5, 5, 50) + # THEN + result = delta.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('1/meV') + + @pytest.mark.parametrize( + 'x, center, expected_idx', + [ + (np.array([0.5, 1.0, 2.0]), 0.5, 0), # center at first element (line 202) + (np.array([0.0, 1.0, 1.5]), 1.5, 2), # center at last element (line 207) + ], + ids=['center_at_first', 'center_at_last'], + ) + def test_evaluate_center_at_boundary(self, x, center, expected_idx): + # WHEN + area = 1.0 + delta = DeltaFunction(area=area, center=center, x_unit='meV') + + # THEN + result = delta.evaluate(x) + + # EXPECT + # All elements except the boundary one should be zero + assert result[expected_idx] > 0.0 + other_indices = [i for i in range(len(x)) if i != expected_idx] + assert all(result[i] == pytest.approx(0.0) for i in other_indices) + # Boundary bin width: both left and right are set to the single adjacent spacing + bin_width = x[1] - x[0] if expected_idx == 0 else x[-1] - x[-2] + assert np.isclose(result[expected_idx], area / bin_width, rtol=1e-10) + + def test_convert_x_unit_invalid_type_raises(self, delta_function: DeltaFunction): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match=r'x_unit must be a string or sc\.Unit'): + delta_function.convert_x_unit(123) + + def test_convert_x_unit_rollback_on_failure(self, delta_function: DeltaFunction): + # WHEN THEN + with pytest.raises(UnitError): + delta_function.convert_x_unit('m') + # EXPECT: state rolled back + assert delta_function.x_unit == 'meV' + assert delta_function.area.value == pytest.approx(2.0) + assert delta_function.center.value == pytest.approx(0.5) + + def test_convert_y_unit_rollback_on_failure(self): + # WHEN + delta = DeltaFunction(area=1.0, center=0.0, x_unit='meV', y_unit='dimensionless') + # THEN + with pytest.raises(UnitError): + delta.convert_y_unit('K') + # EXPECT: state rolled back + assert delta.y_unit == 'dimensionless' + assert delta.area.value == pytest.approx(1.0) diff --git a/tests/unit/easydynamics/sample_model/components/test_exponential.py b/tests/unit/easydynamics/sample_model/components/test_exponential.py index b7fa55031..701bf73a9 100644 --- a/tests/unit/easydynamics/sample_model/components/test_exponential.py +++ b/tests/unit/easydynamics/sample_model/components/test_exponential.py @@ -5,6 +5,7 @@ import numpy as np import pytest +import scipp as sc from easyscience.variable import Parameter from scipp import UnitError @@ -20,7 +21,7 @@ def exponential(self): amplitude=2.0, center=0.5, rate=1.2, - unit='meV', + x_unit='meV', ) def test_init_no_inputs(self): @@ -32,7 +33,8 @@ def test_init_no_inputs(self): assert exponential.amplitude.value == pytest.approx(1.0) assert exponential.center.value == pytest.approx(0.0) assert exponential.rate.value == pytest.approx(1.0) - assert exponential.unit == 'meV' + assert exponential.x_unit == 'meV' + assert exponential.y_unit == 'dimensionless' def test_initialization(self, exponential: Exponential): # WHEN THEN EXPECT @@ -40,23 +42,33 @@ def test_initialization(self, exponential: Exponential): assert exponential.amplitude.value == pytest.approx(2.0) assert exponential.center.value == pytest.approx(0.5) assert exponential.rate.value == pytest.approx(1.2) - assert exponential.unit == 'meV' + assert exponential.x_unit == 'meV' @pytest.mark.parametrize( 'kwargs, expected_message', [ ( - {'amplitude': 'invalid', 'center': 0.5, 'rate': 1.0, 'unit': 'meV'}, + {'amplitude': 'invalid', 'center': 0.5, 'rate': 1.0, 'x_unit': 'meV'}, 'amplitude must be a number', ), ( - {'amplitude': 2.0, 'center': 'invalid', 'rate': 1.0, 'unit': 'meV'}, + {'amplitude': 2.0, 'center': 'invalid', 'rate': 1.0, 'x_unit': 'meV'}, 'center must be None or a number', ), ( - {'amplitude': 2.0, 'center': 0.5, 'rate': 'invalid', 'unit': 'meV'}, + {'amplitude': 2.0, 'center': 0.5, 'rate': 'invalid', 'x_unit': 'meV'}, 'rate must be a number', ), + ( + { + 'amplitude': 2.0, + 'center': 0.5, + 'rate': 1.0, + 'x_unit': 'meV', + 'y_unit': 123, + }, + 'unit must be None, a string', + ), ], ) def test_input_type_validation_raises(self, kwargs, expected_message): @@ -67,12 +79,12 @@ def test_input_type_validation_raises(self, kwargs, expected_message): 'kwargs, expected_message', [ ( - {'amplitude': np.nan, 'center': 0.5, 'rate': 1.0, 'unit': 'meV'}, - 'amplitude must be a finite number or a Parameter', + {'amplitude': np.nan, 'center': 0.5, 'rate': 1.0, 'x_unit': 'meV'}, + 'amplitude must be finite', ), ( - {'amplitude': 2.0, 'center': 0.5, 'rate': np.nan, 'unit': 'meV'}, - 'rate must be a finite number or a Parameter', + {'amplitude': 2.0, 'center': 0.5, 'rate': np.nan, 'x_unit': 'meV'}, + 'rate must be finite', ), ], ) @@ -96,11 +108,12 @@ def test_property_setters( invalid_value, invalid_message, ): - # set valid + # WHEN: set a valid value setattr(exponential, prop, valid_value) + # THEN EXPECT assert getattr(exponential, prop).value == valid_value - # invalid + # WHEN: set an invalid value — THEN EXPECT with pytest.raises(TypeError, match=invalid_message): setattr(exponential, prop, invalid_value) @@ -144,12 +157,14 @@ def test_get_all_parameters(self, exponential: Exponential): assert actual_names == expected_names - def test_convert_unit(self, exponential: Exponential): + def test_convert_x_unit(self, exponential: Exponential): # WHEN - exponential.convert_unit('microeV') - # THEN EXPECT - assert exponential.unit == 'microeV' + # THEN + exponential.convert_x_unit('microeV') + + # EXPECT + assert exponential.x_unit == 'microeV' assert exponential.amplitude.value == pytest.approx(2.0 * 1e3) assert exponential.center.value == pytest.approx(0.5 * 1e3) @@ -158,21 +173,21 @@ def test_convert_unit(self, exponential: Exponential): assert exponential.rate.value == pytest.approx(1.2 / 1e3) assert str(exponential.rate.unit) == '1/ueV' - def test_convert_unit_incorrect_unit_raises(self, exponential: Exponential): + def test_convert_x_unit_incorrect_unit_raises(self, exponential: Exponential): # WHEN THEN EXPECT with pytest.raises(TypeError, match=r'unit must be a string or sc.Unit'): - exponential.convert_unit(123) + exponential.convert_x_unit(123) - def test_convert_unit_rollback(self, exponential: Exponential): - # WHEN + def test_convert_x_unit_rollback(self, exponential: Exponential): + # WHEN THEN with pytest.raises( UnitError, match=r'Failed to convert unit: Conversion from `meV` to `m` is not valid.', ): - exponential.convert_unit('m') + exponential.convert_x_unit('m') - # THEN EXPECT - values should be unchanged - assert exponential.unit == 'meV' + # EXPECT - values should be unchanged + assert exponential.x_unit == 'meV' assert exponential.amplitude.value == pytest.approx(2.0) assert exponential.amplitude.unit == 'meV' assert exponential.center.value == pytest.approx(0.5) @@ -182,9 +197,11 @@ def test_convert_unit_rollback(self, exponential: Exponential): def test_copy(self, exponential: Exponential): # WHEN + + # THEN exponential_copy = copy(exponential) - # THEN EXPECT + # EXPECT assert exponential_copy is not exponential assert exponential_copy.display_name == exponential.display_name @@ -197,7 +214,8 @@ def test_copy(self, exponential: Exponential): assert exponential_copy.rate.value == exponential.rate.value assert exponential_copy.rate.fixed == exponential.rate.fixed - assert exponential_copy.unit == exponential.unit + assert exponential_copy.x_unit == exponential.x_unit + assert exponential_copy.y_unit == exponential.y_unit def test_repr(self, exponential: Exponential): # WHEN @@ -205,8 +223,76 @@ def test_repr(self, exponential: Exponential): # THEN EXPECT assert 'Exponential' in repr_str - assert "name='ExponentialName'" in repr_str - assert 'unit=meV' in repr_str - assert 'amplitude=' in repr_str - assert 'center=' in repr_str - assert 'rate=' in repr_str + assert 'name = ExponentialName' in repr_str + assert 'x_unit = meV' in repr_str + assert 'amplitude =' in repr_str + assert 'center =' in repr_str + assert 'rate =' in repr_str + + def test_y_unit_custom(self): + # WHEN THEN + exp = Exponential(amplitude=1.0, center=0.0, rate=1.0, x_unit='meV', y_unit='1/meV') + # EXPECT + assert exp.y_unit == '1/meV' + + def test_y_unit_setter_raises(self, exponential: Exponential): + # WHEN THEN EXPECT + with pytest.raises(AttributeError): + exponential.y_unit = '1/meV' + + def test_convert_y_unit(self): + # WHEN: x_unit='meV', y_unit='1/meV' → amplitude_unit='dimensionless' + exp = Exponential(amplitude=1.0, center=0.0, rate=1.0, x_unit='meV', y_unit='1/meV') + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + exp.convert_y_unit('1/eV') + # EXPECT: y_unit updated and amplitude value rescaled (1e3 factor) + assert exp.y_unit == '1/eV' + assert exp.amplitude.value == pytest.approx(1e3) + + def test_convert_y_unit_invalid_type_raises(self, exponential: Exponential): + # WHEN THEN EXPECT + with pytest.raises(TypeError): + exponential.convert_y_unit(123) + + def test_evaluate_scipp_output(self, exponential: Exponential): + # WHEN + x = np.linspace(-5, 5, 50) + # THEN + result = exponential.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('dimensionless') + assert len(result.values) == 50 + np.testing.assert_allclose(result.values, exponential.evaluate(x, output='numpy')) + + def test_evaluate_scipp_output_with_y_unit(self): + # WHEN + exp = Exponential(amplitude=1.0, center=0.0, rate=1.0, x_unit='meV', y_unit='1/meV') + x = np.linspace(-5, 5, 50) + # THEN + result = exp.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('1/meV') + + def test_init_rejects_parameter_amplitude(self): + # WHEN THEN EXPECT + amplitude_param = Parameter(name='amp', value=3.0, unit='meV') + with pytest.raises(TypeError, match='amplitude must be a number'): + Exponential(amplitude=amplitude_param, x_unit='meV') + + def test_init_rejects_parameter_rate(self): + # WHEN THEN EXPECT + rate_param = Parameter(name='rate', value=0.5, unit='1/meV') + with pytest.raises(TypeError, match='rate must be a number'): + Exponential(rate=rate_param, x_unit='meV') + + def test_convert_y_unit_rollback_on_failure(self): + # WHEN + exp = Exponential(amplitude=1.0, center=0.0, rate=1.0, x_unit='meV') + # THEN + with pytest.raises(UnitError): + exp.convert_y_unit('K') + # EXPECT: state rolled back + assert exp.y_unit == 'dimensionless' + assert exp.amplitude.value == pytest.approx(1.0) diff --git a/tests/unit/easydynamics/sample_model/components/test_expression_component.py b/tests/unit/easydynamics/sample_model/components/test_expression_component.py index 48fbafe26..b509f61b4 100644 --- a/tests/unit/easydynamics/sample_model/components/test_expression_component.py +++ b/tests/unit/easydynamics/sample_model/components/test_expression_component.py @@ -1,13 +1,18 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +import warnings from copy import copy import numpy as np import pytest +import scipp as sc +from easyscience.variable import DescriptorNumber from easyscience.variable import Parameter from easydynamics.sample_model import ExpressionComponent +from easydynamics.sample_model import Gaussian +from easydynamics.sample_model import Lorentzian class TestExpressionComponent: @@ -16,14 +21,14 @@ def expr(self): return ExpressionComponent( 'A * exp(-(x - x0)**2 / (2*sigma**2))', parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6}, - unit='meV', + x_unit='meV', display_name='TestExpression', ) def test_init_valid(self, expr: ExpressionComponent): # WHEN THEN EXPECT assert expr.display_name == 'TestExpression' - assert expr.unit == 'meV' + assert expr.x_unit == 'meV' assert expr.A.value == pytest.approx(2.0) assert expr.x0.value == pytest.approx(0.5) @@ -44,6 +49,50 @@ def test_init_with_parameter(self): # EXPECT assert expr.A.value == pytest.approx(3.0) + def test_parameters_are_dimensionless_by_default(self, expr: ExpressionComponent): + # WHEN THEN EXPECT + assert str(expr.A.unit) == 'dimensionless' + assert str(expr.x0.unit) == 'dimensionless' + assert str(expr.sigma.unit) == 'dimensionless' + + def test_init_with_parameter_units(self): + # WHEN THEN + expr = ExpressionComponent( + 'A * exp(-(x - x0)**2 / (2*sigma**2))', + parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6}, + parameter_units={'x0': 'meV', 'sigma': 'meV'}, + ) + + # EXPECT: the named parameters get their units, values unchanged; others stay + # dimensionless + assert str(expr.x0.unit) == 'meV' + assert str(expr.sigma.unit) == 'meV' + assert str(expr.A.unit) == 'dimensionless' + assert expr.x0.value == pytest.approx(0.5) + assert expr.sigma.value == pytest.approx(0.6) + + def test_init_parameter_units_overrides_parameter_instance_unit(self): + # WHEN: a Parameter instance in eV, but parameter_units requests meV + A = Parameter('A', 3.0, unit='eV') + + # THEN: A*x evaluates to meV**2, which mismatches the default dimensionless y_unit + with pytest.warns(UserWarning, match='does not match'): + expr = ExpressionComponent('A * x', parameters={'A': A}, parameter_units={'A': 'meV'}) + + # EXPECT: the unit is relabelled without rescaling the value + assert str(expr.A.unit) == 'meV' + assert expr.A.value == pytest.approx(3.0) + + def test_init_parameter_units_unknown_name_raises(self): + # WHEN THEN EXPECT + with pytest.raises(ValueError, match='unknown parameter'): + ExpressionComponent('A * x', parameter_units={'B': 'meV'}) + + def test_init_parameter_units_not_a_dict_raises(self): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='parameter_units must be None or a dictionary'): + ExpressionComponent('A * x', parameter_units='meV') + def test_invalid_expression_raises(self): # WHEN THEN EXPECT with pytest.raises(ValueError, match='Invalid expression'): @@ -131,10 +180,110 @@ def test_expression_is_read_only(self, expr: ExpressionComponent): with pytest.raises(AttributeError, match='cannot be changed'): expr.expression = 'x' - def test_convert_unit_not_implemented(self, expr: ExpressionComponent): + def test_set_unit_relabels_without_rescaling(self, expr: ExpressionComponent): + # WHEN + expr.sigma.min = 0.1 + expr.sigma.max = 10.0 + + # THEN: relabelling only sigma leaves the expression unit-inconsistent, which warns + with pytest.warns(UserWarning, match='not unit-consistent'): + expr.set_unit('sigma', 'meV') + + # EXPECT: unit relabelled; value and bounds keep their numeric values + assert str(expr.sigma.unit) == 'meV' + assert expr.sigma.value == pytest.approx(0.6) + assert expr.sigma.min == pytest.approx(0.1) + assert expr.sigma.max == pytest.approx(10.0) + + def test_set_unit_accepts_scipp_unit(self, expr: ExpressionComponent): + # WHEN THEN: relabelling only x0 leaves the expression unit-inconsistent, which warns + with pytest.warns(UserWarning, match='not unit-consistent'): + expr.set_unit('x0', sc.Unit('meV')) + + # EXPECT + assert str(expr.x0.unit) == 'meV' + + def test_set_unit_leaves_other_parameters_untouched(self, expr: ExpressionComponent): + # WHEN THEN + with pytest.warns(UserWarning, match='not unit-consistent'): + expr.set_unit('x0', 'meV') + + # EXPECT + assert str(expr.A.unit) == 'dimensionless' + assert str(expr.sigma.unit) == 'dimensionless' + + def test_set_unit_unknown_parameter_raises(self, expr: ExpressionComponent): + # WHEN THEN EXPECT + with pytest.raises(KeyError, match="No parameter named 'B'"): + expr.set_unit('B', 'meV') + + def test_set_unit_invalid_unit_raises(self, expr: ExpressionComponent): + # WHEN THEN EXPECT + with pytest.raises(ValueError, match='not a valid scipp unit'): + expr.set_unit('A', 'not_a_unit') + + @pytest.mark.parametrize( + 'name, unit, expected_message', + [ + (123, 'meV', 'name must be a string'), + ('A', 123, 'unit must be a string or sc.Unit'), + ], + ids=['non_string_name', 'non_unit_unit'], + ) + def test_set_unit_type_validation( + self, expr: ExpressionComponent, name, unit, expected_message + ): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match=expected_message): + expr.set_unit(name, unit) + + def test_set_unit_dependent_parameter_raises(self, expr: ExpressionComponent): + # WHEN: make A dependent on sigma + expr.A.make_dependent_on( + dependency_expression='2 * sigma', + dependency_map={'sigma': expr.sigma}, + ) + + # THEN EXPECT + with pytest.raises(AttributeError, match='dependent parameter'): + expr.set_unit('A', 'meV') + + def test_parameter_units_survive_copy(self): + # WHEN + expr = ExpressionComponent( + 'A * x + B', + parameters={'A': 2.0, 'B': 3.0}, + parameter_units={'A': '1/meV'}, + ) + + # THEN + expr_copy = copy(expr) + + # EXPECT: per-parameter units round-trip through to_dict/from_dict + assert str(expr_copy.A.unit) == '1/meV' + assert str(expr_copy.B.unit) == 'dimensionless' + assert expr_copy.A.value == pytest.approx(2.0) + + def test_convert_x_unit_not_implemented(self, expr: ExpressionComponent): + # WHEN THEN EXPECT + with pytest.raises(NotImplementedError, match='not implemented'): + expr.convert_x_unit('microeV') + + def test_convert_y_unit_not_implemented(self, expr: ExpressionComponent): # WHEN THEN EXPECT with pytest.raises(NotImplementedError, match='not implemented'): - expr.convert_unit('microeV') + expr.convert_y_unit('1/meV') + + def test_evaluate_scipp_output(self, expr: ExpressionComponent): + # WHEN + x = np.linspace(-2, 2, 30) + # THEN + result = expr.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('dimensionless') + assert len(result.values) == 30 + np.testing.assert_allclose(result.values, expr.evaluate(x, output='numpy')) def test_missing_parameter_defaults(self): # WHEN THEN @@ -166,20 +315,19 @@ def test_repr(self, expr: ExpressionComponent): def test_evaluate_scalar_input(self, expr: ExpressionComponent): # WHEN x = 0.5 - result = expr.evaluate(x) - # THEN + result = expr.evaluate(x) + # EXPECT expected = 2.0 * np.exp(-((x - 0.5) ** 2) / (2 * 0.6**2)) assert np.isclose(result, expected) def test_reserved_name_not_parameter(self): # WHEN expr = ExpressionComponent('x + A', parameters={'A': 2.0}) - # THEN params = expr.get_all_variables() names = {p.name for p in params} - + # EXPECT assert 'A' in names assert 'x' not in names # x is reserved @@ -191,7 +339,8 @@ def test_copy(self, expr: ExpressionComponent): assert expr_copy is not expr assert isinstance(expr_copy, ExpressionComponent) assert expr_copy.expression == expr.expression - assert expr_copy.unit == expr.unit + assert expr_copy.x_unit == expr.x_unit + assert expr_copy.y_unit == expr.y_unit assert expr_copy.display_name == expr.display_name assert expr_copy.A.value == pytest.approx(expr.A.value) @@ -209,3 +358,416 @@ def test_erf(self): # EXPECT expected = np.array([-0.84270079, 0.0, 0.84270079]) # erf(-1), erf(0), erf(1) np.testing.assert_allclose(result, expected, rtol=1e-5) + + +GAUSSIAN_EXPRESSION = 'A / (sigma*sqrt(2*pi)) * exp(-(x - x0)**2 / (2*sigma**2))' +GAUSSIAN_UNITS = {'A': 'meV', 'x0': 'meV', 'sigma': 'meV'} + + +class TestExpressionComponentUnitCorrectness: + """Compare unit-aware expressions against the built-in components.""" + + @pytest.fixture + def gaussian_expr(self): + return ExpressionComponent( + GAUSSIAN_EXPRESSION, + parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6}, + parameter_units=GAUSSIAN_UNITS, + x_unit='meV', + ) + + def test_matches_gaussian_component(self, gaussian_expr): + # WHEN + gaussian = Gaussian(area=2.0, center=0.5, width=0.6, x_unit='meV') + x = np.linspace(-3, 3, 101) + + # THEN EXPECT: same values on a plain numpy grid + np.testing.assert_allclose(gaussian_expr.evaluate(x), gaussian.evaluate(x)) + + def test_matches_gaussian_component_scipp_input_and_output(self, gaussian_expr): + # WHEN + gaussian = Gaussian(area=2.0, center=0.5, width=0.6, x_unit='meV') + x = sc.linspace('energy', -3.0, 3.0, 101, unit='meV') + + # THEN + expr_result = gaussian_expr.evaluate(x, output='scipp') + gaussian_result = gaussian.evaluate(x, output='scipp') + + # EXPECT: same values and the same output unit + np.testing.assert_allclose(expr_result.values, gaussian_result.values) + assert expr_result.unit == gaussian_result.unit + + def test_matches_gaussian_component_in_other_unit(self): + # WHEN: the built-in Gaussian auto-converts its meV parameters to a ueV grid; the + # expression is given the physically identical parameters expressed in ueV directly. + gaussian = Gaussian(area=2.0, center=0.5, width=0.6, x_unit='meV') + expr = ExpressionComponent( + GAUSSIAN_EXPRESSION, + parameters={'A': 2000.0, 'x0': 500.0, 'sigma': 600.0}, + parameter_units={'A': 'ueV', 'x0': 'ueV', 'sigma': 'ueV'}, + x_unit='ueV', + ) + x = sc.linspace('energy', -3000.0, 3000.0, 101, unit='ueV') + + # THEN EXPECT: identical physical results (the Gaussian output is dimensionless, so no + # y-scale factor is involved) + np.testing.assert_allclose(expr.evaluate(x), gaussian.evaluate(x)) + + def test_matches_lorentzian_component(self): + # WHEN + lorentzian = Lorentzian(area=2.0, center=0.5, width=0.3, x_unit='meV') + # Note: 'gamma' would collide with sympy's gamma function, so use 'hwhm' + expr = ExpressionComponent( + 'A / pi * hwhm / ((x - x0)**2 + hwhm**2)', + parameters={'A': 2.0, 'x0': 0.5, 'hwhm': 0.3}, + parameter_units={'A': 'meV', 'x0': 'meV', 'hwhm': 'meV'}, + x_unit='meV', + ) + x = np.linspace(-3, 3, 101) + + # THEN EXPECT + np.testing.assert_allclose(expr.evaluate(x), lorentzian.evaluate(x)) + + def test_consistent_units_do_not_warn(self): + # WHEN THEN EXPECT: a fully unit-consistent expression constructs without warnings + with warnings.catch_warnings(): + warnings.simplefilter('error') + ExpressionComponent( + GAUSSIAN_EXPRESSION, + parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6}, + parameter_units=GAUSSIAN_UNITS, + x_unit='meV', + ) + + def test_unit_agnostic_expression_does_not_warn(self): + # WHEN THEN EXPECT: without any parameter units the consistency check stays silent, + # even though x_unit is meV and the parameters are dimensionless + with warnings.catch_warnings(): + warnings.simplefilter('error') + ExpressionComponent( + GAUSSIAN_EXPRESSION, + parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6}, + x_unit='meV', + ) + + +class TestExpressionComponentOutputUnit: + def test_output_unit_gaussian_is_dimensionless(self): + # WHEN: area in meV divided by sigma in meV + expr = ExpressionComponent( + GAUSSIAN_EXPRESSION, + parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6}, + parameter_units=GAUSSIAN_UNITS, + x_unit='meV', + ) + + # THEN EXPECT + assert expr.output_unit == 'dimensionless' + + def test_output_unit_with_dimensionless_area_is_one_over_x(self): + # WHEN: dimensionless amplitude over a meV width gives 1/meV + with pytest.warns(UserWarning, match='does not match'): + expr = ExpressionComponent( + GAUSSIAN_EXPRESSION, + parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6}, + parameter_units={'x0': 'meV', 'sigma': 'meV'}, + x_unit='meV', + ) + + # THEN EXPECT + assert sc.Unit(expr.output_unit) == sc.Unit('1/meV') + + def test_output_unit_matching_y_unit_does_not_warn(self): + # WHEN THEN EXPECT: declaring the matching y_unit silences the mismatch warning + with warnings.catch_warnings(): + warnings.simplefilter('error') + expr = ExpressionComponent( + GAUSSIAN_EXPRESSION, + parameters={'A': 2.0, 'x0': 0.5, 'sigma': 0.6}, + parameter_units={'x0': 'meV', 'sigma': 'meV'}, + x_unit='meV', + y_unit='1/meV', + ) + assert sc.Unit(expr.output_unit) == sc.Unit(expr.y_unit) + + def test_output_unit_preserved_by_abs(self): + # WHEN + with pytest.warns(UserWarning, match='does not match'): + expr = ExpressionComponent( + 'abs(x - x0)', parameters={'x0': 0.5}, parameter_units={'x0': 'meV'} + ) + + # THEN EXPECT + assert sc.Unit(expr.output_unit) == sc.Unit('meV') + + def test_output_unit_sign_is_dimensionless(self): + # WHEN + expr = ExpressionComponent( + 'A * sign(x - x0)', parameters={'A': 1.0, 'x0': 0.5}, parameter_units={'x0': 'meV'} + ) + + # THEN EXPECT + assert expr.output_unit == 'dimensionless' + + def test_output_unit_sqrt_of_squared_quantity(self): + # WHEN: sqrt(sigma**2) is a half-integer power of a quantity with units + with warnings.catch_warnings(): + warnings.simplefilter('error') + expr = ExpressionComponent( + 'x / sqrt(2*pi*sigma**2) + offset', + parameters={'sigma': 0.5, 'offset': 1.0}, + parameter_units={'sigma': 'meV'}, + x_unit='meV', + ) + + # THEN EXPECT: meV / sqrt(meV**2) + dimensionless = dimensionless + assert expr.output_unit == 'dimensionless' + + def test_output_unit_incompatible_addition_raises(self): + # WHEN + with pytest.warns(UserWarning, match='not unit-consistent'): + expr = ExpressionComponent( + 'x + A', parameters={'A': 1.0}, parameter_units={'A': 'K'}, x_unit='meV' + ) + + # THEN EXPECT + with pytest.raises(sc.UnitError, match='incompatible units'): + _ = expr.output_unit + + def test_output_unit_exp_of_dimensional_quantity_raises(self): + # WHEN + with pytest.warns(UserWarning, match='not unit-consistent'): + expr = ExpressionComponent( + 'exp(x * A)', parameters={'A': 1.0}, parameter_units={'A': 'meV'}, x_unit='meV' + ) + + # THEN EXPECT + with pytest.raises(sc.UnitError, match='dimensionless argument'): + _ = expr.output_unit + + def test_output_unit_symbolic_power_of_dimensional_base_raises(self): + # WHEN + with pytest.warns(UserWarning, match='not unit-consistent'): + expr = ExpressionComponent( + 'sigma**n * x', + parameters={'sigma': 0.5, 'n': 2.0}, + parameter_units={'sigma': 'meV'}, + ) + + # THEN EXPECT + with pytest.raises(sc.UnitError, match='symbolic power'): + _ = expr.output_unit + + def test_output_unit_exponent_with_unit_raises(self): + # WHEN: the exponent itself carries a unit + with pytest.warns(UserWarning, match='not unit-consistent'): + expr = ExpressionComponent( + 'x**A', parameters={'A': 1.0}, parameter_units={'A': 'meV'}, x_unit='meV' + ) + + # THEN EXPECT + with pytest.raises(sc.UnitError, match='exponent must be dimensionless'): + _ = expr.output_unit + + def test_output_unit_quarter_power_of_dimensional_base_raises(self): + # WHEN: only integer and half-integer powers of quantities with units are supported + with pytest.warns(UserWarning, match='not unit-consistent'): + expr = ExpressionComponent( + 'sigma**0.25 * x', parameters={'sigma': 0.5}, parameter_units={'sigma': 'meV'} + ) + + # THEN EXPECT + with pytest.raises(sc.UnitError, match='raised to the power'): + _ = expr.output_unit + + def test_output_unit_with_x_unit_none(self): + # WHEN: without an x_unit, x is treated as dimensionless in the propagation + with pytest.warns(UserWarning, match='does not match'): + expr = ExpressionComponent( + 'A * x', parameters={'A': 1.0}, parameter_units={'A': 'meV'}, x_unit=None + ) + + # THEN EXPECT + assert sc.Unit(expr.output_unit) == sc.Unit('meV') + + def test_evaluate_warns_for_genuinely_different_x_unit(self): + # WHEN: x carries a different (but compatible) unit than x_unit + expr = ExpressionComponent( + 'A * x', parameters={'A': 1.0}, parameter_units={'A': '1/meV'}, x_unit='meV' + ) + x = sc.linspace('energy', -1.0, 1.0, 11, unit='ueV') + + # THEN EXPECT: ExpressionComponent cannot auto-convert, so it warns and uses raw values + with pytest.warns(UserWarning, match='cannot auto-convert'): + result = expr.evaluate(x) + np.testing.assert_allclose(result, x.values) + + def test_evaluate_does_not_warn_for_equivalent_unit_spelling(self): + # WHEN: 'ueV' and scipp's canonical micro-sign spelling are the same unit + expr = ExpressionComponent( + 'A * x', parameters={'A': 1.0}, parameter_units={'A': '1/ueV'}, x_unit='ueV' + ) + x = sc.linspace('energy', -1.0, 1.0, 11, unit='\u00b5eV') + + # THEN EXPECT: no spurious warning from the differing spellings + with warnings.catch_warnings(): + warnings.simplefilter('error') + expr.evaluate(x) + + def test_set_unit_warns_when_breaking_consistency(self): + # WHEN: a consistent expression + expr = ExpressionComponent( + 'A * (x - x0)', + parameters={'A': 1.0, 'x0': 0.5}, + parameter_units={'A': '1/meV', 'x0': 'meV'}, + x_unit='meV', + ) + + # THEN EXPECT: relabelling A breaks the output unit + with pytest.warns(UserWarning, match='does not match'): + expr.set_unit('A', '1/ueV') + + +class TestExpressionComponentPhysicalConstants: + def test_kb_constant_value_and_unit(self): + # WHEN + expr = ExpressionComponent( + 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'} + ) + + # THEN EXPECT: Boltzmann constant in meV/K, a read-only DescriptorNumber + assert expr.kb.value == pytest.approx(0.08617333, rel=1e-6) + assert sc.Unit(str(expr.kb.unit)) == sc.Unit('meV/K') + assert isinstance(expr.kb, DescriptorNumber) + assert not isinstance(expr.kb, Parameter) + + def test_hbar_constant_value_and_unit(self): + # WHEN + expr = ExpressionComponent( + 'exp(-(x * tau / hbar)**2)', parameters={'tau': 1.0}, parameter_units={'tau': 'ps'} + ) + + # THEN EXPECT: hbar in meV*ps, a read-only DescriptorNumber + assert expr.hbar.value == pytest.approx(0.6582120, rel=1e-6) + assert sc.Unit(str(expr.hbar.unit)) == sc.Unit('meV*ps') + assert isinstance(expr.hbar, DescriptorNumber) + assert not isinstance(expr.hbar, Parameter) + + def test_boltzmann_factor_value(self): + # WHEN + expr = ExpressionComponent( + 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'} + ) + x = np.array([0.0, 10.0, 25.0]) + + # THEN + result = expr.evaluate(x) + + # EXPECT + expected = np.exp(-x / (0.08617333262145177 * 300.0)) + np.testing.assert_allclose(result, expected, rtol=1e-9) + + def test_constants_are_unit_consistent(self): + # WHEN THEN EXPECT: x/(kb*T) is dimensionless, so no warning is raised + with warnings.catch_warnings(): + warnings.simplefilter('error') + expr = ExpressionComponent( + 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'} + ) + assert expr.output_unit == 'dimensionless' + + def test_constant_with_dimensionless_temperature_warns(self): + # WHEN THEN EXPECT: forgetting the unit on T makes x/(kb*T) carry kelvins + with pytest.warns(UserWarning, match='not unit-consistent'): + ExpressionComponent('exp(-x / (kb * T))', parameters={'T': 300.0}) + + def test_constants_property(self): + # WHEN + expr = ExpressionComponent( + 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'} + ) + + # THEN EXPECT + assert set(expr.constants) == {'kb'} + + def test_constant_not_in_get_all_variables(self): + # WHEN + expr = ExpressionComponent( + 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'} + ) + + # THEN EXPECT: constants are not fittable variables + names = {p.name for p in expr.get_all_variables()} + assert names == {'T'} + + def test_constant_cannot_be_set(self): + # WHEN + expr = ExpressionComponent( + 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'} + ) + + # THEN EXPECT + with pytest.raises(AttributeError, match='physical constant'): + expr.kb = 5.0 + + def test_set_unit_on_constant_raises(self): + # WHEN + expr = ExpressionComponent( + 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'} + ) + + # THEN EXPECT + with pytest.raises(AttributeError, match='physical constant'): + expr.set_unit('kb', 'meV/K') + + def test_parameter_units_for_constant_raises(self): + # WHEN THEN EXPECT + with pytest.raises(ValueError, match='physical constant'): + ExpressionComponent( + 'exp(-x / (kb * T))', + parameters={'T': 300.0}, + parameter_units={'T': 'K', 'kb': 'meV/K'}, + ) + + def test_constant_convert_unit_rescales(self): + # WHEN + expr = ExpressionComponent( + 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'} + ) + + # THEN + expr.kb.convert_unit('eV/K') + + # EXPECT: the value is rescaled with the unit + assert expr.kb.value == pytest.approx(8.617333e-05, rel=1e-6) + + def test_user_parameter_overrides_constant(self): + # WHEN: the user explicitly provides 'kb' as a parameter + with warnings.catch_warnings(): + warnings.simplefilter('error') + expr = ExpressionComponent('exp(-x / (kb * T))', parameters={'kb': 2.0, 'T': 300.0}) + + # THEN EXPECT: it is an ordinary (fittable, dimensionless) parameter + assert not expr.constants + assert expr.kb.value == pytest.approx(2.0) + assert expr.kb.fixed is False + names = {p.name for p in expr.get_all_variables()} + assert names == {'kb', 'T'} + + def test_constant_in_dir(self): + # WHEN + expr = ExpressionComponent( + 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'} + ) + + # THEN EXPECT + assert 'kb' in dir(expr) + + def test_constant_in_repr(self): + # WHEN + expr = ExpressionComponent( + 'exp(-x / (kb * T))', parameters={'T': 300.0}, parameter_units={'T': 'K'} + ) + + # THEN EXPECT + assert 'kb' in repr(expr) diff --git a/tests/unit/easydynamics/sample_model/components/test_gaussian.py b/tests/unit/easydynamics/sample_model/components/test_gaussian.py index a370e74fb..3ef9b01b7 100644 --- a/tests/unit/easydynamics/sample_model/components/test_gaussian.py +++ b/tests/unit/easydynamics/sample_model/components/test_gaussian.py @@ -5,7 +5,9 @@ import numpy as np import pytest +import scipp as sc from easyscience.variable import Parameter +from scipp import UnitError from scipy.integrate import simpson from easydynamics.sample_model import Gaussian @@ -20,7 +22,7 @@ def gaussian(self): area=2.0, center=0.5, width=0.6, - unit='meV', + x_unit='meV', ) def test_init_no_inputs(self): @@ -32,7 +34,8 @@ def test_init_no_inputs(self): assert gaussian.area.value == pytest.approx(1.0) assert gaussian.center.value == pytest.approx(0.0) assert gaussian.width.value == pytest.approx(1.0) - assert gaussian.unit == 'meV' + assert gaussian.x_unit == 'meV' + assert gaussian.y_unit == 'dimensionless' assert gaussian.center.fixed is True def test_initialization(self, gaussian: Gaussian): @@ -41,33 +44,38 @@ def test_initialization(self, gaussian: Gaussian): assert gaussian.area.value == pytest.approx(2.0) assert gaussian.center.value == pytest.approx(0.5) assert gaussian.width.value == pytest.approx(0.6) - assert gaussian.unit == 'meV' + assert gaussian.x_unit == 'meV' @pytest.mark.parametrize( 'kwargs, expected_message', [ ( - {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'unit': 'meV'}, + {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'x_unit': 'meV'}, 'area must be a number', ), ( - {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'unit': 'meV'}, + {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'x_unit': 'meV'}, 'center must be None or a number', ), ( - {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'unit': 'meV'}, + {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'x_unit': 'meV'}, 'width must be a number', ), ( - {'area': 2.0, 'center': 0.5, 'width': 0.6, 'unit': 123}, - 'unit must be None', + {'area': 2.0, 'center': 0.5, 'width': 0.6, 'x_unit': 123}, + 'unit must be None, a string', + ), + ( + {'area': 2.0, 'center': 0.5, 'width': 0.6, 'x_unit': 'meV', 'y_unit': 123}, + 'unit must be None, a string', ), ], ids=[ 'invalid area', 'invalid center', 'invalid width', - 'invalid unit', + 'invalid x_unit', + 'invalid y_unit', ], ) def test_input_type_validation_raises(self, kwargs, expected_message): @@ -84,7 +92,7 @@ def test_negative_width_raises(self): area=2.0, center=0.5, width=-0.6, - unit='meV', + x_unit='meV', ) def test_negative_area_warns(self): @@ -95,7 +103,7 @@ def test_negative_area_warns(self): area=-2.0, center=0.5, width=0.6, - unit='meV', + x_unit='meV', ) @pytest.mark.parametrize( @@ -109,11 +117,12 @@ def test_negative_area_warns(self): def test_property_setters( self, gaussian: Gaussian, prop, valid_value, invalid_value, invalid_message ): - # set valid + # WHEN: set a valid value setattr(gaussian, prop, valid_value) + # THEN EXPECT assert getattr(gaussian, prop).value == valid_value - # invalid + # WHEN: set an invalid value — THEN EXPECT with pytest.raises(TypeError, match=invalid_message): setattr(gaussian, prop, invalid_value) @@ -177,12 +186,12 @@ def test_area_matches_parameter(self, gaussian: Gaussian): numerical_area = simpson(y, x) assert np.isclose(numerical_area, gaussian.area.value, rtol=1e-3) - def test_convert_unit(self, gaussian: Gaussian): + def test_convert_x_unit(self, gaussian: Gaussian): # WHEN THEN - gaussian.convert_unit('microeV') + gaussian.convert_x_unit('microeV') # EXPECT - assert gaussian.unit == 'microeV' + assert gaussian.x_unit == 'microeV' assert gaussian.area.value == pytest.approx(2 * 1e3) assert gaussian.center.value == pytest.approx(0.5 * 1e3) assert gaussian.width.value == pytest.approx(0.6 * 1e3) @@ -216,15 +225,91 @@ def test_copy(self, gaussian: Gaussian): assert gaussian_copy.width.min == gaussian.width.min assert gaussian_copy.width.max == gaussian.width.max - assert gaussian_copy.unit == gaussian.unit + assert gaussian_copy.x_unit == gaussian.x_unit def test_repr(self, gaussian: Gaussian): # WHEN THEN repr_str = repr(gaussian) # EXPECT assert 'Gaussian' in repr_str - assert "name='GaussianName'" in repr_str - assert 'unit=meV' in repr_str - assert 'area=' in repr_str - assert 'center=' in repr_str - assert 'width=' in repr_str + assert 'name = GaussianName' in repr_str + assert 'x_unit = meV' in repr_str + assert 'area =' in repr_str + assert 'center =' in repr_str + assert 'width =' in repr_str + + def test_y_unit_custom(self): + # WHEN THEN + gaussian = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV') + # EXPECT + assert gaussian.y_unit == '1/meV' + + def test_y_unit_setter_raises(self, gaussian: Gaussian): + # WHEN THEN EXPECT + with pytest.raises(AttributeError): + gaussian.y_unit = '1/meV' + + def test_convert_y_unit(self): + # WHEN: x_unit='meV', y_unit='1/meV' → area_unit ≈ dimensionless + gaussian = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV') + + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + gaussian.convert_y_unit('1/eV') + + # EXPECT: unit updated and area value rescaled (1/eV = 1e-3/meV, so value x 1e3) + assert gaussian.y_unit == '1/eV' + assert gaussian.area.value == pytest.approx(1e3) + + def test_convert_y_unit_invalid_type_raises(self, gaussian: Gaussian): + # WHEN THEN EXPECT + with pytest.raises(TypeError): + gaussian.convert_y_unit(123) + + def test_evaluate_scipp_output(self, gaussian: Gaussian): + # WHEN + x = np.linspace(-5, 5, 100) + + # THEN + result = gaussian.evaluate(x, output='scipp') + + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('dimensionless') + assert len(result.values) == 100 + np.testing.assert_allclose(result.values, gaussian.evaluate(x, output='numpy')) + + def test_evaluate_scipp_output_with_y_unit(self): + gaussian = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV') + x = np.linspace(-5, 5, 100) + + # WHEN + result = gaussian.evaluate(x, output='scipp') + + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('1/meV') + + def test_convert_x_unit_invalid_type_raises(self, gaussian: Gaussian): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match=r'x_unit must be a string or sc\.Unit'): + gaussian.convert_x_unit(123) + + def test_convert_x_unit_rollback_on_failure(self, gaussian: Gaussian): + # WHEN THEN + with pytest.raises(UnitError): + gaussian.convert_x_unit('m') + # EXPECT: state rolled back + assert gaussian.x_unit == 'meV' + assert gaussian.area.value == pytest.approx(2.0) + assert gaussian.center.value == pytest.approx(0.5) + assert gaussian.width.value == pytest.approx(0.6) + + def test_convert_y_unit_rollback_on_failure(self): + # WHEN + gaussian = Gaussian(area=1.0, center=0.0, width=0.5, x_unit='meV') + # THEN + with pytest.raises(UnitError): + gaussian.convert_y_unit('K') + # EXPECT: state rolled back + assert gaussian.y_unit == 'dimensionless' + assert gaussian.area.value == pytest.approx(1.0) diff --git a/tests/unit/easydynamics/sample_model/components/test_lorentzian.py b/tests/unit/easydynamics/sample_model/components/test_lorentzian.py index 1aefee6a9..97e02aad9 100644 --- a/tests/unit/easydynamics/sample_model/components/test_lorentzian.py +++ b/tests/unit/easydynamics/sample_model/components/test_lorentzian.py @@ -5,7 +5,9 @@ import numpy as np import pytest +import scipp as sc from easyscience.variable import Parameter +from scipp import UnitError from scipy.integrate import simpson from easydynamics.sample_model import Lorentzian @@ -20,7 +22,7 @@ def lorentzian(self): area=2.0, center=0.5, width=0.6, - unit='meV', + x_unit='meV', ) def test_init_no_inputs(self): @@ -32,7 +34,8 @@ def test_init_no_inputs(self): assert lorentzian.area.value == pytest.approx(1.0) assert lorentzian.center.value == pytest.approx(0.0) assert lorentzian.width.value == pytest.approx(1.0) - assert lorentzian.unit == 'meV' + assert lorentzian.x_unit == 'meV' + assert lorentzian.y_unit == 'dimensionless' assert lorentzian.center.fixed is True def test_initialization(self, lorentzian: Lorentzian): @@ -41,26 +44,30 @@ def test_initialization(self, lorentzian: Lorentzian): assert lorentzian.area.value == pytest.approx(2.0) assert lorentzian.center.value == pytest.approx(0.5) assert lorentzian.width.value == pytest.approx(0.6) - assert lorentzian.unit == 'meV' + assert lorentzian.x_unit == 'meV' @pytest.mark.parametrize( 'kwargs, expected_message', [ ( - {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'unit': 'meV'}, + {'area': 'invalid', 'center': 0.5, 'width': 0.6, 'x_unit': 'meV'}, 'area must be a number', ), ( - {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'unit': 'meV'}, - 'center must be None', + {'area': 2.0, 'center': 'invalid', 'width': 0.6, 'x_unit': 'meV'}, + 'center must be None or a number', ), ( - {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'unit': 'meV'}, + {'area': 2.0, 'center': 0.5, 'width': 'invalid', 'x_unit': 'meV'}, 'width must be a number', ), ( - {'area': 2.0, 'center': 0.5, 'width': 0.6, 'unit': 123}, - 'unit must be None', + {'area': 2.0, 'center': 0.5, 'width': 0.6, 'x_unit': 123}, + 'unit must be None, a string', + ), + ( + {'area': 2.0, 'center': 0.5, 'width': 0.6, 'x_unit': 'meV', 'y_unit': 123}, + 'unit must be None, a string', ), ], ) @@ -78,7 +85,7 @@ def test_negative_width_raises(self): area=2.0, center=0.5, width=-0.6, - unit='meV', + x_unit='meV', ) def test_negative_area_warns(self): @@ -89,7 +96,7 @@ def test_negative_area_warns(self): area=-2.0, center=0.5, width=0.6, - unit='meV', + x_unit='meV', ) @pytest.mark.parametrize( @@ -103,11 +110,12 @@ def test_negative_area_warns(self): def test_property_setters( self, lorentzian: Lorentzian, prop, valid_value, invalid_value, invalid_message ): - # set valid + # WHEN: set a valid value setattr(lorentzian, prop, valid_value) + # THEN EXPECT assert getattr(lorentzian, prop).value == valid_value - # invalid + # WHEN: set an invalid value — THEN EXPECT with pytest.raises(TypeError, match=invalid_message): setattr(lorentzian, prop, invalid_value) @@ -166,12 +174,12 @@ def test_area_matches_parameter(self, lorentzian: Lorentzian): # EXPECT assert numerical_area == pytest.approx(lorentzian.area.value, rel=2e-3) - def test_convert_unit(self, lorentzian: Lorentzian): + def test_convert_x_unit(self, lorentzian: Lorentzian): # WHEN THEN - lorentzian.convert_unit('microeV') + lorentzian.convert_x_unit('microeV') # EXPECT - assert lorentzian.unit == 'microeV' + assert lorentzian.x_unit == 'microeV' assert lorentzian.area.value == pytest.approx(2 * 1e3) assert lorentzian.center.value == pytest.approx(0.5 * 1e3) assert lorentzian.width.value == pytest.approx(0.6 * 1e3) @@ -193,7 +201,7 @@ def test_copy(self, lorentzian: Lorentzian): assert lorentzian_copy.width.value == lorentzian.width.value assert lorentzian_copy.width.fixed == lorentzian.width.fixed - assert lorentzian_copy.unit == lorentzian.unit + assert lorentzian_copy.x_unit == lorentzian.x_unit def test_repr(self, lorentzian: Lorentzian): # WHEN THEN @@ -201,8 +209,79 @@ def test_repr(self, lorentzian: Lorentzian): # EXPECT assert 'Lorentzian' in repr_str - assert "name='LorentzianName'" in repr_str - assert 'unit=meV' in repr_str - assert 'area=' in repr_str - assert 'center=' in repr_str - assert 'width=' in repr_str + assert 'name = LorentzianName' in repr_str + assert 'x_unit = meV' in repr_str + assert 'area =' in repr_str + assert 'center =' in repr_str + assert 'width =' in repr_str + + def test_y_unit_custom(self): + # WHEN THEN + lor = Lorentzian(area=1.0, x_unit='meV', y_unit='1/meV') + # EXPECT + assert lor.y_unit == '1/meV' + + def test_y_unit_setter_raises(self, lorentzian: Lorentzian): + # WHEN THEN EXPECT + with pytest.raises(AttributeError): + lorentzian.y_unit = '1/meV' + + def test_convert_y_unit(self): + # WHEN: x_unit='meV', y_unit='1/meV' → area_unit='dimensionless' + lor = Lorentzian(area=1.0, x_unit='meV', y_unit='1/meV') + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + lor.convert_y_unit('1/eV') + # EXPECT: y_unit updated and area value rescaled (1e3 factor) + assert lor.y_unit == '1/eV' + assert lor.area.value == pytest.approx(1e3) + + def test_convert_y_unit_invalid_type_raises(self, lorentzian: Lorentzian): + # WHEN THEN EXPECT + with pytest.raises(TypeError): + lorentzian.convert_y_unit(123) + + def test_evaluate_scipp_output(self, lorentzian: Lorentzian): + # WHEN + x = np.linspace(-5, 5, 50) + # THEN + result = lorentzian.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('dimensionless') + assert len(result.values) == 50 + np.testing.assert_allclose(result.values, lorentzian.evaluate(x, output='numpy')) + + def test_evaluate_scipp_output_with_y_unit(self): + # WHEN + lor = Lorentzian(area=1.0, x_unit='meV', y_unit='1/meV') + x = np.linspace(-5, 5, 50) + # THEN + result = lor.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('1/meV') + + def test_convert_x_unit_invalid_type_raises(self, lorentzian: Lorentzian): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match=r'x_unit must be a string or sc\.Unit'): + lorentzian.convert_x_unit(123) + + def test_convert_x_unit_rollback_on_failure(self, lorentzian: Lorentzian): + # WHEN THEN + with pytest.raises(UnitError): + lorentzian.convert_x_unit('m') + # EXPECT: state rolled back + assert lorentzian.x_unit == 'meV' + assert lorentzian.area.value == pytest.approx(2.0) + assert lorentzian.center.value == pytest.approx(0.5) + assert lorentzian.width.value == pytest.approx(0.6) + + def test_convert_y_unit_rollback_on_failure(self): + # WHEN + lor = Lorentzian(area=1.0, center=0.0, width=0.5, x_unit='meV') + # THEN + with pytest.raises(UnitError): + lor.convert_y_unit('K') + # EXPECT: state rolled back + assert lor.y_unit == 'dimensionless' + assert lor.area.value == pytest.approx(1.0) diff --git a/tests/unit/easydynamics/sample_model/components/test_mixins.py b/tests/unit/easydynamics/sample_model/components/test_mixins.py index e30f4bdd1..bc7a312d1 100644 --- a/tests/unit/easydynamics/sample_model/components/test_mixins.py +++ b/tests/unit/easydynamics/sample_model/components/test_mixins.py @@ -18,7 +18,7 @@ def dummy_model(self): @pytest.mark.parametrize('area_input', [2, 2.0]) def test_create_area_parameter_from_numeric(self, dummy_model, area_input, unit): # WHEN THEN - area_param = dummy_model._create_area_parameter(area_input, 'TestModel', unit=unit) + area_param = dummy_model._create_area_parameter(area_input, 'TestModel', x_unit=unit) # EXPECT assert isinstance(area_param, Parameter) @@ -59,7 +59,7 @@ def test_negative_area_warns(self, dummy_model): def test_create_center_parameter_from_numeric(self, dummy_model, center_input, unit): # WHEN THEN center_param = dummy_model._create_center_parameter( - center_input, 'TestModel', fix_if_none=False, unit=unit + center_input, 'TestModel', fix_if_none=False, x_unit=unit ) # EXPECT assert isinstance(center_param, Parameter) @@ -104,7 +104,7 @@ def test_create_center_parameter_invalid_numeric_raises(self, dummy_model, non_f def test_create_width_parameter_from_numeric(self, dummy_model, width_input, unit): # WHEN THEN width_param = dummy_model._create_width_parameter( - width_input, 'TestModel', param_name='width', unit=unit + width_input, 'TestModel', param_name='width', x_unit=unit ) # EXPECT assert isinstance(width_param, Parameter) diff --git a/tests/unit/easydynamics/sample_model/components/test_model_component.py b/tests/unit/easydynamics/sample_model/components/test_model_component.py index 514ced73b..8a1a08678 100644 --- a/tests/unit/easydynamics/sample_model/components/test_model_component.py +++ b/tests/unit/easydynamics/sample_model/components/test_model_component.py @@ -5,7 +5,9 @@ import pytest import scipp as sc from easyscience.variable import Parameter +from scipp import UnitError +from easydynamics.sample_model.components.gaussian import Gaussian from easydynamics.sample_model.components.model_component import ModelComponent @@ -15,13 +17,13 @@ def __init__(self): self.area = Parameter(name='area', value=1.0, unit='meV', fixed=False) self.center = Parameter(name='center', value=2.0, unit='meV', fixed=True) self.width = Parameter(name='width', value=3.0, unit='meV', fixed=True) - self._unit = 'meV' + self._x_unit = 'meV' def get_all_parameters(self): return [self.area, self.center, self.width] - def evaluate(self, x): - return np.zeros_like(x) + def _evaluate_values(self, x_vals, _eval_unit): + return np.zeros_like(x_vals) class TestModelComponent: @@ -31,23 +33,23 @@ def dummy(self): def test_unit_cannot_be_set_directly(self, dummy: ModelComponent): # WHEN THEN EXPECT - with pytest.raises(AttributeError, match='Unit is read-only'): - dummy.unit = 'K' + with pytest.raises(AttributeError, match='read-only'): + dummy.x_unit = 'K' def test_convert_unit(self, dummy: DummyComponent): # WHEN THEN - dummy.convert_unit('microeV') + dummy.convert_x_unit('microeV') # EXPECT - assert dummy.unit == 'microeV' + assert dummy.x_unit == 'microeV' assert dummy.area.value == pytest.approx(1 * 1e3) assert dummy.center.value == pytest.approx(2 * 1e3) assert dummy.width.value == pytest.approx(3 * 1e3) def test_convert_unit_incorrect_unit_raises(self, dummy: DummyComponent): # WHEN THEN EXPECT - with pytest.raises(TypeError, match=r'Unit must be a string or sc.Unit'): - dummy.convert_unit(123) + with pytest.raises(TypeError, match=r'unit must be a string or sc.Unit'): + dummy.convert_x_unit(123) def test_free_and_fix_all_parameters(self, dummy): # WHEN THEN EXPECT @@ -92,8 +94,11 @@ def test_repr(self, dummy): ], ) def test_prepare_x_for_evaluate_various_inputs(self, dummy, x_input, expected_array): - x_prepared = dummy._prepare_x_for_evaluate(x_input) + # WHEN THEN + result = dummy._prepare_x_for_evaluate(x_input) + x_prepared, _detected_unit, _dim = result + # EXPECT assert isinstance(x_prepared, np.ndarray) assert x_prepared.shape == expected_array.shape np.testing.assert_array_equal(x_prepared, expected_array) @@ -137,26 +142,126 @@ def test_prepare_x_for_evaluate_with_incompatible_unit_raises(self, dummy): # THEN EXPECT with pytest.raises( Exception, - match='Input x has unit nm, but DummyComponent component ', + match='Input x has unit nm', ): dummy._prepare_x_for_evaluate(x) - def test_prepare_x_for_evaluate_with_different_unit_warns(self, dummy): + def test_prepare_x_for_evaluate_with_different_unit_no_warn(self, dummy): # WHEN x = sc.array(dims=['x'], values=[1.0, 2.0, 3.0], unit='microeV') - # THEN EXPECT - with pytest.warns( - UserWarning, - match='Input x has unit [µμ]eV, but DummyComponent component ', - ): - x_prepared = dummy._prepare_x_for_evaluate(x) + # THEN: compatible units are accepted without warning; + # the component's x_unit is NOT mutated and x values are returned as-is. + x_prepared, _detected_unit, _dim = dummy._prepare_x_for_evaluate(x) # EXPECT assert isinstance(x_prepared, np.ndarray) assert x_prepared.shape == (3,) np.testing.assert_array_equal(x_prepared, [1.0, 2.0, 3.0]) - assert dummy.unit == 'µeV' # noqa: RUF001 - assert dummy.area.value == pytest.approx(1.0 * 1e3) - assert dummy.center.value == pytest.approx(2.0 * 1e3) - assert dummy.width.value == pytest.approx(3.0 * 1e3) + assert dummy.x_unit == 'meV' # component unit unchanged + assert dummy.area.value == pytest.approx(1.0) # parameter values unchanged + assert dummy.center.value == pytest.approx(2.0) + assert dummy.width.value == pytest.approx(3.0) + + def test_resolve_param_value_same_unit_returns_raw_value(self, dummy): + # WHEN: target unit matches parameter unit + + # THEN + result = dummy._resolve_param_value(dummy.area, 'meV') + # EXPECT: raw value returned without conversion + assert result == pytest.approx(dummy.area.value) + + def test_resolve_param_value_none_target_returns_raw_value(self, dummy): + # WHEN: target unit is None + + # THEN + result = dummy._resolve_param_value(dummy.area, None) + + # EXPECT: raw value returned without conversion + assert result == pytest.approx(dummy.area.value) + + def test_resolve_param_value_converts_without_mutating(self, dummy): + # WHEN: target unit differs from parameter unit + + # THEN + result = dummy._resolve_param_value(dummy.area, 'eV') + # EXPECT: converted value (1.0 meV → 0.001 eV) + assert result == pytest.approx(0.001) + # parameter itself is not mutated + assert dummy.area.value == pytest.approx(1.0) + assert str(dummy.area.unit) == 'meV' + + def test_evaluate_with_compatible_unit_gives_correct_result(self): + # WHEN: Gaussian in meV and a physically equivalent Gaussian in eV + g_mev = Gaussian(area=1.0, center=0.0, width=0.5, x_unit='meV') + g_ev = Gaussian(area=0.001, center=0.0, width=0.0005, x_unit='eV') + + x_ev = sc.array( + dims=['energy'], + values=np.array([-0.002, -0.001, 0.0, 0.001, 0.002]), + unit='eV', + ) + x_ev_np = np.array([-0.002, -0.001, 0.0, 0.001, 0.002]) + + # THEN: evaluate meV-Gaussian with x in eV + result_mev = g_mev.evaluate(x_ev) + result_ev = g_ev.evaluate(x_ev_np) + + # EXPECT: physically identical outputs + np.testing.assert_allclose(result_mev, result_ev, rtol=1e-10) + # EXPECT: model state is unchanged + assert g_mev.x_unit == 'meV' + assert g_mev.width.value == pytest.approx(0.5) + assert g_mev.area.value == pytest.approx(1.0) + + # ───── Regression tests ───── + + def test_convert_x_unit_rollback_on_failure(self, dummy: DummyComponent): + # Conversion to 'm' (length) is incompatible with 'meV' (energy) → triggers rollback + with pytest.raises(UnitError): + dummy.convert_x_unit('m') + # Parameters should be restored to original values after rollback + assert dummy.x_unit == 'meV' + assert dummy.area.value == pytest.approx(1.0) + assert dummy.center.value == pytest.approx(2.0) + assert dummy.width.value == pytest.approx(3.0) + + def test_convert_y_unit_not_implemented(self, dummy: DummyComponent): + with pytest.raises(NotImplementedError, match='does not support convert_y_unit'): + dummy.convert_y_unit('1/meV') + + def test_evaluate_invalid_output_raises(self, dummy: DummyComponent): + # WHEN THEN EXPECT + with pytest.raises(ValueError, match="output must be 'numpy' or 'scipp'"): + dummy.evaluate(np.linspace(-1, 1, 5), output='Scipp') + + def test_eval_area_unit(self, dummy: DummyComponent): + # WHEN THEN EXPECT: y_unit is dimensionless, so the area unit equals the eval unit + assert sc.Unit(dummy._eval_area_unit('meV')) == sc.Unit('meV') + + def test_eval_area_unit_none_eval_unit(self, dummy: DummyComponent): + # WHEN THEN EXPECT: without an eval unit there is no area unit + assert dummy._eval_area_unit(None) is None + + def test_eval_area_unit_none_y_unit(self, dummy: DummyComponent): + # WHEN + dummy._y_unit = None + + # THEN EXPECT + assert dummy._eval_area_unit('meV') is None + + def test_evaluate_preserves_dataarray_coord_key_as_dim(self): + # WHEN: a Gaussian and a DataArray where the coord key ('energy') differs + # from the coord Variable's internal dim name ('x'). This is a valid scipp + # non-dimension coordinate: the data's dimension is 'x' and the coord is + # labelled 'energy' but lives on the same 'x' axis. + g = Gaussian(name='G', area=1.0, center=0.0, width=1.0, x_unit='meV') + coord = sc.Variable(dims=['x'], values=np.linspace(-5.0, 5.0, 10), unit='meV') + data = sc.Variable(dims=['x'], values=np.ones(10)) + da = sc.DataArray(data=data, coords={'energy': coord}) + # THEN: evaluate with scipp output + # Before the fix, dim was overwritten with coord.dims[0] = 'x', so the + # output Variable had dim 'x' instead of the coord key 'energy'. + result = g.evaluate(da, output='scipp') + # EXPECT: output dim must be the coord key 'energy', not the Variable dim 'x'. + assert result.dims == ('energy',) diff --git a/tests/unit/easydynamics/sample_model/components/test_polynomial.py b/tests/unit/easydynamics/sample_model/components/test_polynomial.py index 2ba3f5525..8f6550bee 100644 --- a/tests/unit/easydynamics/sample_model/components/test_polynomial.py +++ b/tests/unit/easydynamics/sample_model/components/test_polynomial.py @@ -5,6 +5,7 @@ import numpy as np import pytest +import scipp as sc from easyscience.variable import Parameter from scipp import UnitError @@ -27,7 +28,8 @@ def test_init_no_inputs(self): # EXPECT assert polynomial.display_name == 'Polynomial' assert polynomial.coefficients[0].value == pytest.approx(0.0) - assert polynomial.unit == 'meV' + assert polynomial.x_unit == 'meV' + assert polynomial.y_unit == 'dimensionless' def test_initialization(self, polynomial: Polynomial): # WHEN THEN EXPECT @@ -48,7 +50,11 @@ def test_initialization(self, polynomial: Polynomial): 'Each coefficient must be ', ), ( - {'coefficients': [1.0, -2.0, 3.0], 'unit': 123}, + {'coefficients': [1.0, -2.0, 3.0], 'x_unit': 123}, + 'unit must be ', + ), + ( + {'coefficients': [1.0, -2.0, 3.0], 'x_unit': 'meV', 'y_unit': 123}, 'unit must be ', ), ( @@ -59,6 +65,7 @@ def test_initialization(self, polynomial: Polynomial): ], ) def test_input_type_validation_raises(self, kwargs, expected_message): + # WHEN THEN EXPECT with pytest.raises(TypeError, match=expected_message): Polynomial(display_name='TestPolynomial', **kwargs) @@ -118,13 +125,12 @@ def test_set_coefficients(self, polynomial: Polynomial, values): assert np.isclose(polynomial.coefficients[i].value, expected) def test_set_coefficients_wrong_length_raises(self, polynomial: Polynomial): - """Ensure that setting coefficients with mismatched length - raises an error.""" + # WHEN THEN EXPECT with pytest.raises(ValueError, match='Number of coefficients'): polynomial.coefficients = [1.0, 2.0] # shorter list def test_set_coefficients_invalid_type_raises(self, polynomial: Polynomial): - """Ensure that invalid coefficient types raise a TypeError.""" + # WHEN THEN EXPECT with pytest.raises(TypeError): polynomial.coefficients = [1.0, 'invalid', 3.0] @@ -137,8 +143,11 @@ def test_set_coefficients_invalid_type_raises(self, polynomial: Polynomial): ], ) def test_set_coefficients_raises(self, invalid_coeffs, expected_message): + # WHEN + # THEN + polynomial = Polynomial(display_name='TestPolynomial', coefficients=[1.0, -2.0, 3.0]) with pytest.raises(TypeError, match=expected_message): - polynomial = Polynomial(display_name='TestPolynomial', coefficients=[1.0, -2.0, 3.0]) + # EXPECT polynomial.coefficients = invalid_coeffs def test_coefficient_values(self, polynomial: Polynomial): @@ -161,20 +170,20 @@ def test_get_all_parameters(self, polynomial: Polynomial): actual_names = {param.name for param in params} assert actual_names == expected_names - def test_convert_unit(self, polynomial: Polynomial): + def test_convert_x_unit(self, polynomial: Polynomial): # WHEN - polynomial.convert_unit('microeV') + polynomial.convert_x_unit('microeV') # THEN EXPECT - assert polynomial._unit == 'microeV' + assert polynomial._x_unit == 'microeV' assert np.isclose(polynomial.coefficients[0].value, 1.0) assert np.isclose(polynomial.coefficients[1].value, -2.0 * 1e-3) assert np.isclose(polynomial.coefficients[2].value, 3.0 * 1e-6) - def test_convert_unit_raises_invalid_unit(self, polynomial: Polynomial): + def test_convert_x_unit_raises_invalid_unit(self, polynomial: Polynomial): # WHEN THEN EXPECT - with pytest.raises(UnitError, match='unit must be '): - polynomial.convert_unit(123) + with pytest.raises(Exception, match='unit must be '): + polynomial.convert_x_unit(123) def test_copy(self, polynomial: Polynomial): # WHEN THEN @@ -192,10 +201,101 @@ def test_copy(self, polynomial: Polynomial): assert copied_coeff.value == original_coeff.value assert copied_coeff.fixed == original_coeff.fixed + def test_y_unit_custom(self): + # WHEN THEN + p = Polynomial(coefficients=[1.0, 2.0], x_unit='meV', y_unit='1/meV') + # EXPECT + assert p.y_unit == '1/meV' + + def test_y_unit_setter_raises(self, polynomial: Polynomial): + # WHEN THEN EXPECT + with pytest.raises(AttributeError): + polynomial.y_unit = '1/meV' + + def test_convert_y_unit_scales_all_coefficients(self): + # WHEN: polynomial with two non-zero coefficients and a physical y_unit + p = Polynomial(coefficients=[3.0, 1.0], x_unit='meV', y_unit='meV^-1') + x = np.array([2.0]) + val_before = p.evaluate(x)[0] # 3.0 + 1.0*2.0 = 5.0 [meV^-1] + + # THEN + p.convert_y_unit('eV^-1') + + # EXPECT: both coefficients rescaled by 1000 (1 meV^-1 = 1000 eV^-1) + assert p.y_unit == 'eV^-1' + assert np.isclose(p.coefficients[0].value, 3000.0) + assert np.isclose(p.coefficients[1].value, 1000.0) + assert np.isclose(p.evaluate(x)[0], val_before * 1000.0) + + def test_evaluate_scipp_output(self): + # WHEN + p = Polynomial(coefficients=[1.0, 2.0], x_unit='meV', suppress_warnings=True) + x = np.linspace(-3, 3, 40) + # THEN + result = p.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('dimensionless') + assert len(result.values) == 40 + np.testing.assert_allclose(result.values, p.evaluate(x, output='numpy')) + + def test_evaluate_scipp_output_with_y_unit(self): + # WHEN + p = Polynomial( + coefficients=[1.0, 2.0], + x_unit='meV', + y_unit='1/meV', + suppress_warnings=True, + ) + x = np.linspace(-3, 3, 40) + # THEN + result = p.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('1/meV') + def test_repr(self, polynomial: Polynomial): # WHEN THEN repr_str = repr(polynomial) # EXPECT - assert "name='PolynomialName'" in repr_str - assert 'coefficients=' in repr_str + assert 'name = PolynomialName' in repr_str + assert 'coefficients =' in repr_str + + def test_evaluate_with_scipp_x_different_compatible_unit(self): + # WHEN: polynomial with x_unit='meV', coefficients [1.0, 1.0] → f(x) = 1 + x + p = Polynomial(coefficients=[1.0, 1.0], x_unit='meV') + # THEN: evaluate with x in eV (compatible unit) — triggers unit-rescaling branch + x_eV = sc.array(dims=['x'], values=np.array([0.001, 0.002]), unit='eV') + result = p.evaluate(x_eV) + # EXPECT: 0.001 eV = 1 meV → f(1)=2, 0.002 eV = 2 meV → f(2)=3; state not mutated + np.testing.assert_allclose(result, [2.0, 3.0], rtol=1e-5) + assert p.x_unit == 'meV' + + def test_evaluate_with_equivalent_unit_spelling_does_not_rescale(self): + # WHEN: 'ueV' and scipp's canonical micro-sign spelling are the same unit + # (regression: a string comparison used to treat them as different and rescale) + p = Polynomial(coefficients=[1.0, 1.0], x_unit='ueV', suppress_warnings=True) + x = sc.array(dims=['x'], values=np.array([1.0, 2.0]), unit='\u00b5eV') + + # THEN + result = p.evaluate(x) + + # EXPECT: f(x) = 1 + x with the raw coefficient values, no rescaling applied + np.testing.assert_allclose(result, [2.0, 3.0]) + + def test_convert_y_unit_invalid_type_raises(self, polynomial: Polynomial): + # WHEN THEN EXPECT + with pytest.raises(UnitError, match='new_y_unit must be a string or a scipp unit'): + polynomial.convert_y_unit(123) + + def test_convert_y_unit_rollback_on_failure(self): + # WHEN + p = Polynomial(coefficients=[1.0, 2.0], x_unit='meV') + # THEN + with pytest.raises(UnitError): + p.convert_y_unit('K') + # EXPECT: state rolled back + assert p.y_unit == 'dimensionless' + assert np.isclose(p.coefficients[0].value, 1.0) + assert np.isclose(p.coefficients[1].value, 2.0) diff --git a/tests/unit/easydynamics/sample_model/components/test_voigt.py b/tests/unit/easydynamics/sample_model/components/test_voigt.py index 8b8d44cf3..42eb5099c 100644 --- a/tests/unit/easydynamics/sample_model/components/test_voigt.py +++ b/tests/unit/easydynamics/sample_model/components/test_voigt.py @@ -5,7 +5,9 @@ import numpy as np import pytest +import scipp as sc from easyscience.variable import Parameter +from scipp import UnitError from scipy.integrate import simpson from scipy.special import voigt_profile @@ -22,7 +24,7 @@ def voigt(self): center=0.5, gaussian_width=0.6, lorentzian_width=0.7, - unit='meV', + x_unit='meV', ) def test_init_no_inputs(self): @@ -35,7 +37,8 @@ def test_init_no_inputs(self): assert voigt.center.value == pytest.approx(0.0) assert voigt.gaussian_width.value == pytest.approx(1.0) assert voigt.lorentzian_width.value == pytest.approx(1.0) - assert voigt.unit == 'meV' + assert voigt.x_unit == 'meV' + assert voigt.y_unit == 'dimensionless' assert voigt.center.fixed is True def test_initialization(self, voigt: Voigt): @@ -45,7 +48,7 @@ def test_initialization(self, voigt: Voigt): assert voigt.center.value == pytest.approx(0.5) assert voigt.gaussian_width.value == pytest.approx(0.6) assert voigt.lorentzian_width.value == pytest.approx(0.7) - assert voigt.unit == 'meV' + assert voigt.x_unit == 'meV' @pytest.mark.parametrize( 'kwargs, expected_message', @@ -56,7 +59,7 @@ def test_initialization(self, voigt: Voigt): 'center': 0.5, 'gaussian_width': 0.6, 'lorentzian_width': 0.7, - 'unit': 'meV', + 'x_unit': 'meV', }, 'area must be a number', ), @@ -66,7 +69,7 @@ def test_initialization(self, voigt: Voigt): 'center': 'invalid', 'gaussian_width': 0.6, 'lorentzian_width': 0.7, - 'unit': 'meV', + 'x_unit': 'meV', }, 'center must be None', ), @@ -76,7 +79,7 @@ def test_initialization(self, voigt: Voigt): 'center': 0.5, 'gaussian_width': 'invalid', 'lorentzian_width': 0.7, - 'unit': 'meV', + 'x_unit': 'meV', }, 'gaussian_width must be a number', ), @@ -86,7 +89,7 @@ def test_initialization(self, voigt: Voigt): 'center': 0.5, 'gaussian_width': 0.6, 'lorentzian_width': 'invalid', - 'unit': 'meV', + 'x_unit': 'meV', }, 'lorentzian_width must be a number', ), @@ -96,13 +99,25 @@ def test_initialization(self, voigt: Voigt): 'center': 0.5, 'gaussian_width': 0.6, 'lorentzian_width': 0.7, - 'unit': 123, + 'x_unit': 123, }, - 'unit must be None,', + 'unit must be None, a string', + ), + ( + { + 'area': 2.0, + 'center': 0.5, + 'gaussian_width': 0.6, + 'lorentzian_width': 0.7, + 'x_unit': 'meV', + 'y_unit': 123, + }, + 'unit must be None, a string', ), ], ) def test_input_type_validation_raises(self, kwargs, expected_message): + # WHEN THEN EXPECT with pytest.raises(TypeError, match=expected_message): Voigt(display_name='TestVoigt', **kwargs) @@ -117,7 +132,7 @@ def test_negative_gaussian_width_raises(self): center=0.5, gaussian_width=-0.6, lorentzian_width=0.7, - unit='meV', + x_unit='meV', ) def test_negative_lorentzian_width_raises(self): @@ -132,7 +147,7 @@ def test_negative_lorentzian_width_raises(self): center=0.5, gaussian_width=0.6, lorentzian_width=-0.7, - unit='meV', + x_unit='meV', ) def test_negative_area_warns(self): @@ -144,7 +159,7 @@ def test_negative_area_warns(self): center=0.5, gaussian_width=0.6, lorentzian_width=0.7, - unit='meV', + x_unit='meV', ) @pytest.mark.parametrize( @@ -164,11 +179,12 @@ def test_negative_area_warns(self): def test_property_setters( self, voigt: Voigt, prop, valid_value, invalid_value, invalid_message ): - # set valid + # WHEN: set a valid value setattr(voigt, prop, valid_value) + # THEN EXPECT assert getattr(voigt, prop).value == valid_value - # invalid + # WHEN: set an invalid value — THEN EXPECT with pytest.raises(TypeError, match=invalid_message): setattr(voigt, prop, invalid_value) @@ -214,19 +230,19 @@ def test_center_is_fixed_if_init_to_None(self): center=None, gaussian_width=0.6, lorentzian_width=0.7, - unit='meV', + x_unit='meV', ) # EXPECT assert test_voigt.center.value == pytest.approx(0.0) assert test_voigt.center.fixed is True - def test_convert_unit(self, voigt: Voigt): + def test_convert_x_unit(self, voigt: Voigt): # WHEN THEN - voigt.convert_unit('microeV') + voigt.convert_x_unit('microeV') # EXPECT - assert voigt.unit == 'microeV' + assert voigt.x_unit == 'microeV' assert voigt.area.value == pytest.approx(2 * 1e3) assert voigt.center.value == pytest.approx(0.5 * 1e3) assert voigt.gaussian_width.value == pytest.approx(0.6 * 1e3) @@ -285,7 +301,7 @@ def test_copy(self, voigt: Voigt): assert voigt_copy.lorentzian_width.value == voigt.lorentzian_width.value assert voigt_copy.lorentzian_width.fixed == voigt.lorentzian_width.fixed - assert voigt_copy.unit == voigt.unit + assert voigt_copy.x_unit == voigt.x_unit def test_repr(self, voigt: Voigt): # WHEN THEN @@ -293,9 +309,102 @@ def test_repr(self, voigt: Voigt): # EXPECT assert 'Voigt' in repr_str - assert "name='VoigtName'" in repr_str - assert 'unit=meV' in repr_str - assert 'area=' in repr_str - assert 'center=' in repr_str - assert 'gaussian_width=' in repr_str - assert 'lorentzian_width=' in repr_str + assert 'name = VoigtName' in repr_str + assert 'x_unit = meV' in repr_str + assert 'area =' in repr_str + assert 'center =' in repr_str + assert 'gaussian_width =' in repr_str + assert 'lorentzian_width =' in repr_str + + def test_y_unit_custom(self): + # WHEN THEN + voigt = Voigt( + area=1.0, + center=0.0, + gaussian_width=0.5, + lorentzian_width=0.3, + x_unit='meV', + y_unit='1/meV', + ) + # EXPECT + assert voigt.y_unit == '1/meV' + + def test_y_unit_setter_raises(self, voigt: Voigt): + # WHEN THEN EXPECT + with pytest.raises(AttributeError): + voigt.y_unit = '1/meV' + + def test_convert_y_unit(self): + # WHEN: x_unit='meV', y_unit='1/meV' → area_unit='dimensionless' + voigt = Voigt( + area=1.0, + center=0.0, + gaussian_width=0.5, + lorentzian_width=0.3, + x_unit='meV', + y_unit='1/meV', + ) + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + voigt.convert_y_unit('1/eV') + # EXPECT: y_unit updated and area value rescaled (1e3 factor) + assert voigt.y_unit == '1/eV' + assert voigt.area.value == pytest.approx(1e3) + + def test_convert_y_unit_invalid_type_raises(self, voigt: Voigt): + # WHEN THEN EXPECT + with pytest.raises(TypeError): + voigt.convert_y_unit(123) + + def test_evaluate_scipp_output(self, voigt: Voigt): + # WHEN + x = np.linspace(-5, 5, 50) + # THEN + result = voigt.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('dimensionless') + assert len(result.values) == 50 + np.testing.assert_allclose(result.values, voigt.evaluate(x, output='numpy')) + + def test_evaluate_scipp_output_with_y_unit(self): + # WHEN + voigt = Voigt( + area=1.0, + center=0.0, + gaussian_width=0.5, + lorentzian_width=0.3, + x_unit='meV', + y_unit='1/meV', + ) + x = np.linspace(-5, 5, 50) + # THEN + result = voigt.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('1/meV') + + def test_convert_x_unit_invalid_type_raises(self, voigt: Voigt): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match=r'x_unit must be a string or sc\.Unit'): + voigt.convert_x_unit(123) + + def test_convert_x_unit_rollback_on_failure(self, voigt: Voigt): + # WHEN THEN + with pytest.raises(UnitError): + voigt.convert_x_unit('m') + # EXPECT: state rolled back + assert voigt.x_unit == 'meV' + assert voigt.area.value == pytest.approx(2.0) + assert voigt.center.value == pytest.approx(0.5) + assert voigt.gaussian_width.value == pytest.approx(0.6) + assert voigt.lorentzian_width.value == pytest.approx(0.7) + + def test_convert_y_unit_rollback_on_failure(self): + # WHEN + voigt = Voigt(area=1.0, center=0.0, gaussian_width=0.5, lorentzian_width=0.3, x_unit='meV') + # THEN + with pytest.raises(UnitError): + voigt.convert_y_unit('K') + # EXPECT: state rolled back + assert voigt.y_unit == 'dimensionless' + assert voigt.area.value == pytest.approx(1.0) diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py index 7941d4068..4b2137540 100644 --- a/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py +++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_brownian_translational_diffusion.py @@ -25,16 +25,60 @@ def brownian_diffusion_model(self): def test_init_default(self, brownian_diffusion_model): # WHEN THEN EXPECT assert brownian_diffusion_model.display_name == 'BrownianTranslationalDiffusion' - assert brownian_diffusion_model.unit == 'meV' + assert brownian_diffusion_model.x_unit == 'meV' + assert brownian_diffusion_model.y_unit == 'dimensionless' assert brownian_diffusion_model.scale.value == pytest.approx(1.0) + assert brownian_diffusion_model.scale.unit == 'meV' assert brownian_diffusion_model.diffusion_coefficient.value == pytest.approx(1.0) + def test_convert_x_unit_rescales_widths(self): + # WHEN + model = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, Q=np.array([1.0])) + width_mev = model.calculate_width()[0] + + # THEN + model.convert_x_unit('ueV') + + # EXPECT: calculated width and the component width follow the new unit + assert model.calculate_width()[0] == pytest.approx(width_mev * 1000) + collection = model.get_component_collections(0) + assert sc.Unit(str(collection[0].width.unit)) == sc.Unit('ueV') + assert collection[0].width.value == pytest.approx(width_mev * 1000) + assert sc.Unit(str(model.scale.unit)) == sc.Unit('ueV') + + def test_convert_x_unit_converts_collections_in_place(self): + # WHEN + model = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, Q=np.array([1.0])) + collection_before = model.get_component_collections(0) + + # THEN + model.convert_x_unit('ueV') + + # EXPECT: conversion does not regenerate the collections (regression: rebuilding + # replaced the component objects, breaking external references) + assert model.get_component_collections(0) is collection_before + + def test_convert_x_unit_persists_after_dependency_update(self): + # WHEN + model = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, Q=np.array([1.0])) + width = model.get_component_collections(0)[0].width + model.convert_x_unit('ueV') + width_uev = width.value + + # THEN: trigger a dependency-graph recompute + model.diffusion_coefficient = 4.8e-9 + + # EXPECT: the dependent width stays in the new unit (regression: a plain convert_unit + # was reverted to the old desired unit by the next graph update) + assert sc.Unit(str(width.unit)) == sc.Unit('ueV') + assert width.value == pytest.approx(2 * width_uev) + @pytest.mark.parametrize( 'kwargs,expected_exception, expected_message', [ ( { - 'unit': 123, + 'x_unit': 123, 'scale': 1.0, 'diffusion_coefficient': 1.0, }, @@ -43,7 +87,16 @@ def test_init_default(self, brownian_diffusion_model): ), ( { - 'unit': 'meV', + 'y_unit': 123, + 'scale': 1.0, + 'diffusion_coefficient': 1.0, + }, + TypeError, + None, + ), + ( + { + 'x_unit': 'meV', 'scale': 'invalid', 'diffusion_coefficient': 1.0, }, @@ -52,7 +105,7 @@ def test_init_default(self, brownian_diffusion_model): ), ( { - 'unit': 'meV', + 'x_unit': 'meV', 'scale': -123.4, 'diffusion_coefficient': 1.0, }, @@ -61,7 +114,7 @@ def test_init_default(self, brownian_diffusion_model): ), ( { - 'unit': 'meV', + 'x_unit': 'meV', 'scale': 1.0, 'diffusion_coefficient': 'invalid', }, @@ -70,7 +123,7 @@ def test_init_default(self, brownian_diffusion_model): ), ( { - 'unit': 'meV', + 'x_unit': 'meV', 'scale': 1.0, 'diffusion_coefficient': -123.4, }, @@ -78,6 +131,14 @@ def test_init_default(self, brownian_diffusion_model): 'diffusion_coefficient must be non-negative', ), ], + ids=[ + 'invalid_x_unit', + 'invalid_y_unit', + 'invalid_scale_type', + 'invalid_scale_negative', + 'invalid_diffusion_coefficient_type', + 'invalid_diffusion_coefficient_negative', + ], ) def test_input_type_validation_raises(self, kwargs, expected_exception, expected_message): with pytest.raises(expected_exception, match=expected_message): @@ -123,6 +184,17 @@ def test_calculate_width(self, brownian_diffusion_model): expected_widths = 1.0 * unit_conversion_factor.value * (Q_values**2) np.testing.assert_allclose(widths, expected_widths, rtol=1e-5) + def test_calculate_width_scipp_Q_converts_unit(self, brownian_diffusion_model): + # WHEN: the same Q expressed in 1/angstrom (numpy, assumed) and in 1/nm (scipp) + Q_values = np.array([0.1, 0.2, 0.3]) + Q_scipp = sc.Variable(dims=['Q'], values=Q_values * 10, unit='1/nm') + + # THEN EXPECT: scipp input is converted to 1/angstrom before the width calculation + np.testing.assert_allclose( + brownian_diffusion_model.calculate_width(Q_scipp), + brownian_diffusion_model.calculate_width(Q_values), + ) + def test_calculate_EISF(self, brownian_diffusion_model): # WHEN Q_values = np.array([0.1, 0.2, 0.3]) # Example Q values in Å^-1 @@ -180,9 +252,11 @@ def test_create_component_collections(self, brownian_diffusion_model, Q): model = component_collections[model_index] assert len(model) == 1 component = model[0] - assert component.width.unit == brownian_diffusion_model.unit + assert component.width.unit == brownian_diffusion_model.x_unit assert np.isclose(component.width.value, expected_widths[model_index]) assert component.width.independent is False + # area.unit = area_unit = x_unit * y_unit + assert component.area.unit == 'meV' def test_write_width_dependency_expression(self, brownian_diffusion_model): # WHEN THEN @@ -213,6 +287,11 @@ def test_write_area_dependency_expression_raises(self, brownian_diffusion_model) with pytest.raises(TypeError, match='QISF must be a float'): brownian_diffusion_model._write_area_dependency_expression('invalid') + def test_y_unit_setter_raises(self, brownian_diffusion_model): + # WHEN THEN EXPECT + with pytest.raises(AttributeError, match=r'read-only'): + brownian_diffusion_model.y_unit = '1/meV' + def test_repr(self, brownian_diffusion_model): # WHEN THEN repr_str = repr(brownian_diffusion_model) @@ -221,3 +300,5 @@ def test_repr(self, brownian_diffusion_model): assert 'BrownianTranslationalDiffusion' in repr_str assert 'diffusion_coefficient' in repr_str assert 'scale=' in repr_str + # Regression: a stray ')' used to mangle this into 'x_unit=meV), y_unit=...' + assert 'x_unit=meV, y_unit=dimensionless' in repr_str diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_delta_lorentz.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_delta_lorentz.py index b3740db66..1e69bcec7 100644 --- a/tests/unit/easydynamics/sample_model/diffusion_model/test_delta_lorentz.py +++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_delta_lorentz.py @@ -3,6 +3,7 @@ import numpy as np import pytest +import scipp as sc from easyscience.variable import Parameter from easydynamics.sample_model.components.delta_function import DeltaFunction @@ -15,6 +16,57 @@ class TestDeltaLorentz: def delta_lorentz_model(self): return DeltaLorentz() + def test_mean_u_squared_unit_conversion_leaves_physics_unchanged(self): + # WHEN: a model with a non-zero Debye-Waller factor + model = DeltaLorentz(A_0=0.5, mean_u_squared=1.0, lorentzian_width=0.1, Q=np.array([1.0])) + eisf_before = model.calculate_EISF() + qisf_before = model.calculate_QISF() + delta_area_before = model.get_component_collections(0)[1].area.value + + # THEN: convert mean_u_squared to nm**2 (same physical value, different number) + model.mean_u_squared.convert_unit('nm**2') + + # EXPECT: EISF, QISF, and the dependent delta area are unchanged + np.testing.assert_allclose(model.calculate_EISF(), eisf_before) + np.testing.assert_allclose(model.calculate_QISF(), qisf_before) + assert model.get_component_collections(0)[1].area.value == pytest.approx(delta_area_before) + + def test_convert_x_unit_converts_lorentzian_width(self): + # WHEN + model = DeltaLorentz(lorentzian_width=0.1, Q=np.array([1.0])) + + # THEN + model.convert_x_unit('ueV') + + # EXPECT: the width template is rescaled and the regenerated component follows + assert sc.Unit(str(model.lorentzian_width.unit)) == sc.Unit('ueV') + assert model.lorentzian_width.value == pytest.approx(100.0) + collection = model.get_component_collections(0) + assert sc.Unit(str(collection[0].width.unit)) == sc.Unit('ueV') + assert collection[0].width.value == pytest.approx(100.0) + + def test_convert_x_unit_with_q_varying_width_preserves_state(self): + # WHEN: a model with per-Q widths, one of which has been changed from the template + model = DeltaLorentz( + lorentzian_width=0.1, + Q=np.array([1.0, 2.0]), + allow_Q_variation={'lorentzian_width': True}, + ) + model._lorentzian_width_list[0].value = 0.2 + collection_before = model.get_component_collections(0) + + # THEN + model.convert_x_unit('ueV') + + # EXPECT: conversion happens in place — per-Q values are converted, not reset to the + # template, and the collections are not regenerated (regression: conversion used to + # rebuild the collections, discarding per-Q state) + assert model.get_component_collections(0) is collection_before + assert model._lorentzian_width_list[0].value == pytest.approx(200.0) + assert model._lorentzian_width_list[1].value == pytest.approx(100.0) + for width in model._lorentzian_width_list: + assert sc.Unit(str(width.unit)) == sc.Unit('ueV') + @pytest.fixture def delta_lorentz_model_with_Q(self): Q = np.linspace(0.5, 2, 7) @@ -38,16 +90,22 @@ def delta_lorentz_model_with_Q_no_variation(self): def test_init_default(self, delta_lorentz_model): # WHEN THEN EXPECT assert delta_lorentz_model.display_name == 'DeltaLorentz' - assert delta_lorentz_model.unit == 'meV' + assert delta_lorentz_model.x_unit == 'meV' + assert delta_lorentz_model.y_unit == 'dimensionless' assert delta_lorentz_model.scale.value == pytest.approx(1.0) assert delta_lorentz_model.mean_u_squared.value == pytest.approx(0.0) assert delta_lorentz_model.A_0.value == pytest.approx(1.0) assert delta_lorentz_model.lorentzian_width.value == pytest.approx(1.0) + def test_y_unit_setter_raises(self, delta_lorentz_model): + # WHEN THEN EXPECT + with pytest.raises(AttributeError, match=r'read-only'): + delta_lorentz_model.y_unit = '1/meV' + def test_init_with_Q(self, delta_lorentz_model_with_Q): # WHEN THEN EXPECT assert delta_lorentz_model_with_Q.display_name == 'DeltaLorentz' - assert delta_lorentz_model_with_Q.unit == 'meV' + assert delta_lorentz_model_with_Q.x_unit == 'meV' assert delta_lorentz_model_with_Q.scale.value == pytest.approx(1.0) assert delta_lorentz_model_with_Q.mean_u_squared.value == pytest.approx(0.0) assert delta_lorentz_model_with_Q.A_0.value == pytest.approx(0.5) @@ -404,7 +462,7 @@ def test_calculate_EISF_with_Q(self, delta_lorentz_model_with_Q): for i in range(len(eisf)): expected = delta_lorentz_model_with_Q._A_0_list[i].value * np.exp( -delta_lorentz_model_with_Q.mean_u_squared.value - * delta_lorentz_model_with_Q.Q[i] ** 2 + * delta_lorentz_model_with_Q.Q.values[i] ** 2 ) assert eisf[i] == pytest.approx(expected) @@ -428,7 +486,7 @@ def test_calculate_QISF_with_Q(self, delta_lorentz_model_with_Q): for i in range(len(qisf)): expected = delta_lorentz_model_with_Q._A_1_list[i].value * np.exp( -delta_lorentz_model_with_Q.mean_u_squared.value - * delta_lorentz_model_with_Q.Q[i] ** 2 + * delta_lorentz_model_with_Q.Q.values[i] ** 2 ) assert qisf[i] == pytest.approx(expected) @@ -459,7 +517,8 @@ def test_create_component_collections_with_Q_variation(self, delta_lorentz_model assert collection[0].width.independent is True assert collection[0].area.independent is False assert ( - 'scale * exp(-mean_u_squared.value *' in collection[0].area.dependency_expression + 'scale * exp(-mean_u_squared_ratio.value *' + in collection[0].area.dependency_expression ) assert 'A_1' in collection[0].area.dependency_expression @@ -467,7 +526,8 @@ def test_create_component_collections_with_Q_variation(self, delta_lorentz_model assert isinstance(collection[1], DeltaFunction) assert collection[1].area.independent is False assert ( - 'scale * exp(-mean_u_squared.value *' in collection[1].area.dependency_expression + 'scale * exp(-mean_u_squared_ratio.value *' + in collection[1].area.dependency_expression ) assert 'A_0' in collection[1].area.dependency_expression @@ -491,7 +551,8 @@ def test_create_component_collections_with_no_Q_variation( assert collection[0].width.dependency_expression == 'lorentzian_width' assert collection[0].area.independent is False assert ( - 'scale * exp(-mean_u_squared.value *' in collection[0].area.dependency_expression + 'scale * exp(-mean_u_squared_ratio.value *' + in collection[0].area.dependency_expression ) assert 'A_1' in collection[0].area.dependency_expression @@ -499,16 +560,17 @@ def test_create_component_collections_with_no_Q_variation( assert isinstance(collection[1], DeltaFunction) assert collection[1].area.independent is False assert ( - 'scale * exp(-mean_u_squared.value *' in collection[1].area.dependency_expression + 'scale * exp(-mean_u_squared_ratio.value *' + in collection[1].area.dependency_expression ) assert 'A_0' in collection[1].area.dependency_expression @pytest.mark.parametrize( - 'Q_index', + ('Q_index', 'expected_exception', 'expected_message'), [ - -1, - 100, - 'string', + (-1, IndexError, r'Q_index -1 is out of bounds'), + (100, IndexError, r'Q_index 100 is out of bounds'), + ('string', TypeError, r'Q_index must be an int or None, got str'), ], ids=[ 'negative index', @@ -516,9 +578,11 @@ def test_create_component_collections_with_no_Q_variation( 'non-integer index', ], ) - def test_get_independent_variables_raises(self, delta_lorentz_model_with_Q, Q_index): + def test_get_independent_variables_raises( + self, delta_lorentz_model_with_Q, Q_index, expected_exception, expected_message + ): # WHEN THEN EXPECT - with pytest.raises(ValueError, match=r'Q_index must be an integer between 0 and'): + with pytest.raises(expected_exception, match=expected_message): delta_lorentz_model_with_Q.get_independent_variables(Q_index=Q_index) def test_get_all_variables_no_Q_index(self, delta_lorentz_model_with_Q): @@ -578,10 +642,10 @@ def test_get_all_variables_with_Q_index_no_Q_variation( def test_get_all_variables_invalid_Q_index(self, delta_lorentz_model_with_Q): # WHEN THEN EXPECT - with pytest.raises(ValueError, match='Q_index must be an integer between 0 and'): + with pytest.raises(IndexError, match='Q_index -1 is out of bounds'): delta_lorentz_model_with_Q.get_all_variables(Q_index=-1) - with pytest.raises(ValueError, match='Q_index must be an integer between 0 and'): + with pytest.raises(IndexError, match=r'Q_index \d+ is out of bounds'): delta_lorentz_model_with_Q.get_all_variables(Q_index=len(delta_lorentz_model_with_Q.Q)) @pytest.mark.parametrize( @@ -856,3 +920,22 @@ def test_repr(self, delta_lorentz_model): assert 'mean_u_squared' in repr_str assert 'A_0' in repr_str assert 'lorentzian_width' in repr_str + # Regression: a stray ')' used to mangle this into 'x_unit=meV), y_unit=...' + assert 'x_unit=meV, y_unit=dimensionless' in repr_str + + # ───── Regression tests ───── + + def test_calculate_width_raises_after_clear_Q_when_allow_Q_variation( + self, delta_lorentz_model_with_Q + ): + # WHEN: model with Q-variation enabled for lorentzian_width + assert delta_lorentz_model_with_Q._allow_Q_variation['lorentzian_width'] is True + assert len(delta_lorentz_model_with_Q._lorentzian_width_list) > 0 + + # THEN: clear Q (empties _lorentzian_width_list) + delta_lorentz_model_with_Q.clear_Q(confirm=True) + assert len(delta_lorentz_model_with_Q._lorentzian_width_list) == 0 + + # THEN: before the fix, calculate_width() silently returned [] instead of raising. + with pytest.raises(ValueError, match='Lorentzian width Q-variation list is empty'): + delta_lorentz_model_with_Q.calculate_width() diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model_base.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model_base.py index 3364e5104..d3d56c3b6 100644 --- a/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model_base.py +++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_diffusion_model_base.py @@ -3,7 +3,9 @@ import numpy as np import pytest +import scipp as sc from easyscience.variable.parameter import Parameter +from scipp import UnitError from easydynamics.sample_model.diffusion_model.diffusion_model_base import DiffusionModelBase @@ -19,7 +21,10 @@ def test_init_default(self, diffusion_model): assert diffusion_model.name == 'DiffusionModel' assert diffusion_model.lorentzian_name == 'DiffusionModel' assert diffusion_model.lorentzian_display_name == 'DiffusionModel' - assert diffusion_model.unit == 'meV' + assert diffusion_model.x_unit == 'meV' + assert diffusion_model.y_unit == 'dimensionless' + # scale.unit = area_unit = x_unit * y_unit = meV * dimensionless = meV + assert diffusion_model.scale.unit == 'meV' def test_init_raises(self): # WHEN THEN EXPECT @@ -29,13 +34,66 @@ def test_init_raises(self): with pytest.raises(TypeError, match=r'lorentzian_display_name must be a string or None'): DiffusionModelBase(lorentzian_display_name=123) - def test_unit_setter_raises(self, diffusion_model): + def test_x_unit_setter_raises(self, diffusion_model): # WHEN THEN EXPECT with pytest.raises( AttributeError, - match=r'Unit is read-only. Use convert_unit to change the unit between allowed types', + match=r'read-only', ): - diffusion_model.unit = 'eV' + diffusion_model.x_unit = 'eV' + + def test_y_unit_setter_raises(self, diffusion_model): + # WHEN THEN EXPECT + with pytest.raises(AttributeError, match=r'read-only'): + diffusion_model.y_unit = '1/meV' + + def test_init_accepts_scipp_units(self): + # WHEN THEN + model = DiffusionModelBase(x_unit=sc.Unit('meV'), y_unit=sc.Unit('1/meV')) + + # EXPECT: scale.unit = x_unit * y_unit = dimensionless + assert str(model.x_unit) == 'meV' + assert sc.Unit(str(model.scale.unit)) == sc.Unit('dimensionless') + + def test_convert_x_unit(self, diffusion_model): + # WHEN THEN + diffusion_model.convert_x_unit('ueV') + + # EXPECT: x_unit and the scale unit (x_unit * y_unit) follow + assert sc.Unit(diffusion_model.x_unit) == sc.Unit('ueV') + assert sc.Unit(str(diffusion_model.scale.unit)) == sc.Unit('ueV') + + def test_convert_x_unit_invalid_type_raises(self, diffusion_model): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match=r'x_unit must be a string or sc.Unit'): + diffusion_model.convert_x_unit(123) + + @pytest.mark.parametrize('unit', ['m', '1/ps'], ids=['length', 'frequency']) + def test_convert_x_unit_non_energy_raises(self, diffusion_model, unit): + # WHEN THEN EXPECT: only energy axes are supported for now + with pytest.raises(UnitError, match=r'convertible to meV'): + diffusion_model.convert_x_unit(unit) + + def test_convert_y_unit(self): + # WHEN + model = DiffusionModelBase(y_unit='1/meV') + + # THEN + model.convert_y_unit('1/eV') + + # EXPECT: y_unit and the scale unit follow (meV * 1/eV) + assert sc.Unit(model.y_unit) == sc.Unit('1/eV') + assert sc.Unit(str(model.scale.unit)) == sc.Unit('meV/eV') + + def test_convert_y_unit_invalid_type_raises(self, diffusion_model): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match=r'y_unit must be a string or sc.Unit'): + diffusion_model.convert_y_unit(123) + + def test_convert_y_unit_incompatible_raises(self, diffusion_model): + # WHEN THEN EXPECT: dimensionless -> K changes the dimension of the scale unit + with pytest.raises(UnitError): + diffusion_model.convert_y_unit('K') @pytest.mark.parametrize( ('attribute', 'value', 'expected'), @@ -131,7 +189,9 @@ def test_Q_property(self, diffusion_model): diffusion_model.Q = [1.0, 2.0, 3.0] # EXPECT - np.testing.assert_allclose(diffusion_model.Q, [1.0, 2.0, 3.0]) + assert isinstance(diffusion_model.Q, sc.Variable) + assert diffusion_model.Q.unit == sc.Unit('1/angstrom') + np.testing.assert_allclose(diffusion_model.Q.values, [1.0, 2.0, 3.0]) # THEN EXPECT with pytest.raises(ValueError, match=r'New Q values are not similar to the old ones'): @@ -147,13 +207,24 @@ def test_Q_property(self, diffusion_model): # EXPECT assert diffusion_model.Q is None + def test_Q_setter_accepts_equivalent_scipp_Q_in_other_unit(self, diffusion_model): + # WHEN + diffusion_model.Q = [1.0, 2.0, 3.0] + + # THEN: the same Q in 1/nm is equivalent after conversion, so no error is raised + diffusion_model.Q = sc.Variable(dims=['Q'], values=[10.0, 20.0, 30.0], unit='1/nm') + + # EXPECT + np.testing.assert_allclose(diffusion_model.Q.values, [1.0, 2.0, 3.0]) + def test_repr(self, diffusion_model): # WHEN THEN repr_str = repr(diffusion_model) # EXPECT assert 'DiffusionModelBase' in repr_str - assert 'unit=meV' in repr_str + # Regression: a stray ')' used to mangle this into 'x_unit=meV), y_unit=...' + assert 'x_unit=meV, y_unit=dimensionless' in repr_str def test_get_independent_variables(self, diffusion_model): # WHEN THEN EXPECT @@ -163,11 +234,11 @@ def test_get_independent_variables(self, diffusion_model): assert independent_vars == [] @pytest.mark.parametrize( - 'Q_index', + ('Q_index', 'expected_exception', 'expected_message'), [ - -1, - 100, - 'string', + (-1, ValueError, 'Q is None'), + (100, ValueError, 'Q is None'), + ('string', TypeError, 'Q_index must be an int or None, got str'), ], ids=[ 'negative index', @@ -175,17 +246,19 @@ def test_get_independent_variables(self, diffusion_model): 'non-integer index', ], ) - def test_get_independent_variables_raises(self, diffusion_model, Q_index): + def test_get_independent_variables_raises( + self, diffusion_model, Q_index, expected_exception, expected_message + ): # WHEN THEN EXPECT - with pytest.raises(ValueError, match=r'Q_index must be an integer between 0 and'): + with pytest.raises(expected_exception, match=expected_message): diffusion_model.get_independent_variables(Q_index=Q_index) @pytest.mark.parametrize( - 'Q_index', + ('Q_index', 'expected_exception', 'expected_message'), [ - -1, - 100, - 'string', + (-1, ValueError, 'Q is None'), + (100, ValueError, 'Q is None'), + ('string', TypeError, 'Q_index must be an int or None, got str'), ], ids=[ 'negative index', @@ -193,9 +266,11 @@ def test_get_independent_variables_raises(self, diffusion_model, Q_index): 'non-integer index', ], ) - def test_get_all_variables_raises(self, diffusion_model, Q_index): + def test_get_all_variables_raises( + self, diffusion_model, Q_index, expected_exception, expected_message + ): # WHEN THEN EXPECT - with pytest.raises(ValueError, match=r'Q_index must be an integer between 0 and'): + with pytest.raises(expected_exception, match=expected_message): diffusion_model.get_all_variables(Q_index=Q_index) def test_create_component_collections_no_Q(self, diffusion_model): diff --git a/tests/unit/easydynamics/sample_model/diffusion_model/test_jump_translational_diffusion.py b/tests/unit/easydynamics/sample_model/diffusion_model/test_jump_translational_diffusion.py index 014a221f7..7fcf4dacd 100644 --- a/tests/unit/easydynamics/sample_model/diffusion_model/test_jump_translational_diffusion.py +++ b/tests/unit/easydynamics/sample_model/diffusion_model/test_jump_translational_diffusion.py @@ -25,8 +25,10 @@ def jump_diffusion_model(self): def test_init_default(self, jump_diffusion_model): # WHEN THEN EXPECT assert jump_diffusion_model.display_name == 'JumpTranslationalDiffusion' - assert jump_diffusion_model.unit == 'meV' + assert jump_diffusion_model.x_unit == 'meV' + assert jump_diffusion_model.y_unit == 'dimensionless' assert jump_diffusion_model.scale.value == pytest.approx(1.0) + assert jump_diffusion_model.scale.unit == 'meV' assert jump_diffusion_model.diffusion_coefficient.value == pytest.approx(1.0) assert jump_diffusion_model.relaxation_time.value == pytest.approx(1.0) @@ -35,7 +37,7 @@ def test_init_default(self, jump_diffusion_model): [ ( { - 'unit': 123, + 'x_unit': 123, 'scale': 1.0, 'diffusion_coefficient': 1.0, 'relaxation_time': 1.0, @@ -45,7 +47,17 @@ def test_init_default(self, jump_diffusion_model): ), ( { - 'unit': 'meV', + 'y_unit': 123, + 'scale': 1.0, + 'diffusion_coefficient': 1.0, + 'relaxation_time': 1.0, + }, + TypeError, + None, + ), + ( + { + 'x_unit': 'meV', 'scale': 'invalid', 'diffusion_coefficient': 1.0, 'relaxation_time': 1.0, @@ -55,7 +67,7 @@ def test_init_default(self, jump_diffusion_model): ), ( { - 'unit': 'meV', + 'x_unit': 'meV', 'scale': 1.0, 'diffusion_coefficient': 'invalid', 'relaxation_time': 1.0, @@ -65,7 +77,7 @@ def test_init_default(self, jump_diffusion_model): ), ( { - 'unit': 'meV', + 'x_unit': 'meV', 'scale': 1.0, 'diffusion_coefficient': -1.0, 'relaxation_time': 1.0, @@ -75,7 +87,7 @@ def test_init_default(self, jump_diffusion_model): ), ( { - 'unit': 'meV', + 'x_unit': 'meV', 'scale': 1.0, 'diffusion_coefficient': 1.0, 'relaxation_time': 'invalid', @@ -85,7 +97,7 @@ def test_init_default(self, jump_diffusion_model): ), ( { - 'unit': 'meV', + 'x_unit': 'meV', 'scale': 1.0, 'diffusion_coefficient': 1.0, 'relaxation_time': -1.0, @@ -94,6 +106,15 @@ def test_init_default(self, jump_diffusion_model): 'relaxation_time must be non-negative', ), ], + ids=[ + 'invalid_x_unit', + 'invalid_y_unit', + 'invalid_scale_type', + 'invalid_diffusion_coefficient_type', + 'invalid_diffusion_coefficient_negative', + 'invalid_relaxation_time_type', + 'invalid_relaxation_time_negative', + ], ) def test_input_type_validation_raises(self, kwargs, expected_exception, expected_message): with pytest.raises(expected_exception, match=expected_message): @@ -159,7 +180,7 @@ def test_calculate_width(self, jump_diffusion_model): # EXPECT expected_widths = scipp_hbar * diffusion_coefficient_sc * (Q_values**2) / (1 + denominator) - expected_widths = expected_widths.to(unit=jump_diffusion_model.unit) + expected_widths = expected_widths.to(unit=jump_diffusion_model.x_unit) np.testing.assert_allclose(widths, expected_widths.values, rtol=1e-5) @@ -221,9 +242,11 @@ def test_create_component_collections(self, jump_diffusion_model, Q): model = component_collections[model_index] assert len(model) == 1 component = model[0] - assert component.width.unit == jump_diffusion_model.unit + assert component.width.unit == jump_diffusion_model.x_unit assert np.isclose(component.width.value, expected_widths[model_index]) assert component.width.independent is False + # area.unit = area_unit = x_unit * y_unit + assert component.area.unit == 'meV' def test_write_width_dependency_expression(self, jump_diffusion_model): # WHEN THEN @@ -257,6 +280,11 @@ def test_write_area_dependency_expression_raises(self, jump_diffusion_model): with pytest.raises(TypeError, match='QISF must be a float'): jump_diffusion_model._write_area_dependency_expression('invalid') + def test_y_unit_setter_raises(self, jump_diffusion_model): + # WHEN THEN EXPECT + with pytest.raises(AttributeError, match=r'read-only'): + jump_diffusion_model.y_unit = '1/meV' + def test_repr(self, jump_diffusion_model): # WHEN THEN repr_str = repr(jump_diffusion_model) @@ -265,3 +293,5 @@ def test_repr(self, jump_diffusion_model): assert 'JumpTranslationalDiffusion' in repr_str assert 'diffusion_coefficient' in repr_str assert 'scale=' in repr_str + # Regression: a stray ')' used to mangle this into 'x_unit=meV), y_unit=...' + assert 'x_unit=meV, y_unit=dimensionless' in repr_str diff --git a/tests/unit/easydynamics/sample_model/test_background_model.py b/tests/unit/easydynamics/sample_model/test_background_model.py index a698a1bf0..ff1233d84 100644 --- a/tests/unit/easydynamics/sample_model/test_background_model.py +++ b/tests/unit/easydynamics/sample_model/test_background_model.py @@ -18,14 +18,14 @@ def background_model(self): area=1.0, center=0.0, width=1.0, - unit='meV', + x_unit='meV', ) component2 = Lorentzian( display_name='TestLorentzian1', area=2.0, center=1.0, width=0.5, - unit='meV', + x_unit='meV', ) component_collection = ComponentCollection() component_collection.append_component(component1) @@ -33,19 +33,19 @@ def background_model(self): return BackgroundModel( display_name='InitModel', components=component_collection, - unit='meV', + x_unit='meV', Q=np.array([1.0, 2.0, 3.0]), ) def test_init(self, background_model): # WHEN THEN - model = background_model # EXPECT - assert model.display_name == 'InitModel' - assert model.unit == 'meV' - assert len(model.components) == 2 - np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0])) + assert background_model.display_name == 'InitModel' + assert background_model.x_unit == 'meV' + assert background_model.y_unit == 'dimensionless' + assert len(background_model.components) == 2 + np.testing.assert_array_equal(background_model.Q.values, np.array([1.0, 2.0, 3.0])) @pytest.mark.parametrize( 'invalid_component, expected_error_msg', @@ -80,3 +80,19 @@ def test_init_raises_with_invalid_components(self, invalid_component, expected_e collection = ComponentCollection() collection.append_component(invalid_component) BackgroundModel(components=collection) + + def test_y_unit_setter_raises(self, background_model): + # WHEN / THEN / EXPECT + with pytest.raises(AttributeError): + background_model.y_unit = '1/meV' + + def test_convert_y_unit(self): + # WHEN + g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV') + model = BackgroundModel(components=g, x_unit='meV') + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + model.convert_y_unit('1/eV') + # EXPECT + assert model.y_unit == '1/eV' + assert model.components[0].y_unit == '1/eV' + assert g.area.value == pytest.approx(1e3) diff --git a/tests/unit/easydynamics/sample_model/test_component_collection.py b/tests/unit/easydynamics/sample_model/test_component_collection.py index 7ae15ed65..d72b84ce8 100644 --- a/tests/unit/easydynamics/sample_model/test_component_collection.py +++ b/tests/unit/easydynamics/sample_model/test_component_collection.py @@ -5,10 +5,12 @@ import numpy as np import pytest +import scipp as sc from easyscience.variable import Parameter from scipy.integrate import simpson from easydynamics.sample_model import ComponentCollection +from easydynamics.sample_model import ExpressionComponent from easydynamics.sample_model import Gaussian from easydynamics.sample_model import Lorentzian from easydynamics.sample_model import Polynomial @@ -24,7 +26,7 @@ def component_collection(self): area=1.0, center=0.0, width=1.0, - unit='meV', + x_unit='meV', ) component2 = Lorentzian( name='TestLorentzian1Name', @@ -32,7 +34,7 @@ def component_collection(self): area=2.0, center=1.0, width=0.5, - unit='meV', + x_unit='meV', ) model.append_component(component1) model.append_component(component2) @@ -45,10 +47,12 @@ def test_init(self): # EXPECT assert component_collection.display_name == 'InitModel' assert not component_collection + assert component_collection.x_unit == 'meV' + assert component_collection.y_unit == 'dimensionless' def test_init_with_component(self): # WHEN THEN - component1 = Gaussian(name='TestGaussian1', area=1.0, center=0.0, width=1.0, unit='meV') + component1 = Gaussian(name='TestGaussian1', area=1.0, center=0.0, width=1.0, x_unit='meV') component_collection = ComponentCollection(display_name='InitModel', components=component1) # EXPECT @@ -58,9 +62,9 @@ def test_init_with_component(self): def test_init_with_components(self): # WHEN THEN - component1 = Gaussian(name='TestGaussian1', area=1.0, center=0.0, width=1.0, unit='meV') + component1 = Gaussian(name='TestGaussian1', area=1.0, center=0.0, width=1.0, x_unit='meV') component2 = Lorentzian( - name='TestLorentzian1', area=2.0, center=1.0, width=0.5, unit='meV' + name='TestLorentzian1', area=2.0, center=1.0, width=0.5, x_unit='meV' ) component_collection = ComponentCollection( display_name='InitModel', components=[component1, component2] @@ -91,13 +95,13 @@ def test_init_with_invalid_list_of_components_raises(self): def test_init_with_invalid_unit_raises(self): # WHEN THEN EXPECT with pytest.raises(TypeError, match='unit must be'): - ComponentCollection(unit=123) + ComponentCollection(x_unit=123) # ───── Component Management ───── def test_append_component(self, component_collection): # WHEN - component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, unit='meV') + component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, x_unit='meV') # THEN component_collection.append_component(component) # EXPECT @@ -105,7 +109,7 @@ def test_append_component(self, component_collection): def test_append_component_collection(self, component_collection): # WHEN - component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, unit='meV') + component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, x_unit='meV') component_collection2 = ComponentCollection() component_collection2.append_component(component) # THEN @@ -127,7 +131,7 @@ def test_append_invalid_component_raises(self, component_collection): def test_getitem(self, component_collection): # WHEN - component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, unit='meV') + component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, x_unit='meV') # THEN component_collection.append_component(component) # EXPECT @@ -140,7 +144,7 @@ def test_is_empty(self): assert component_collection.is_empty is True # WHEN THEN - component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, unit='meV') + component = Gaussian(name='TestComponent', area=1.0, center=0.0, width=1.0, x_unit='meV') component_collection.append_component(component) # EXPECT assert component_collection.is_empty is False @@ -158,68 +162,100 @@ def test_list_component_names(self, component_collection): assert components[0] == 'TestGaussian1Name' assert components[1] == 'TestLorentzian1Name' - def test_convert_unit(self, component_collection): + def test_convert_x_unit(self, component_collection): # WHEN THEN - component_collection.convert_unit('eV') + component_collection.convert_x_unit('eV') # EXPECT for component in component_collection: - assert component.unit == 'eV' + assert component.x_unit == 'eV' - def test_convert_unit_incorrect_unit_raises(self, component_collection): + def test_convert_x_unit_incorrect_unit_raises(self, component_collection): # WHEN THEN EXPECT - with pytest.raises(TypeError, match=r'Unit must be a string or sc.Unit'): - component_collection.convert_unit(123) + with pytest.raises(TypeError, match=r'unit must be a string or sc.Unit'): + component_collection.convert_x_unit(123) - def test_convert_unit_failure_rolls_back(self, component_collection): + def test_convert_x_unit_failure_rolls_back(self, component_collection): # WHEN THEN # Introduce a faulty component that will fail conversion class FaultyComponent(Gaussian): - def convert_unit(self, _unit: str) -> None: + def convert_x_unit(self, _unit: str) -> None: raise RuntimeError('Conversion failed.') faulty_component = FaultyComponent( - name='FaultyComponent', area=1.0, center=0.0, width=1.0, unit='meV' + name='FaultyComponent', area=1.0, center=0.0, width=1.0, x_unit='meV' ) component_collection.append_component(faulty_component) - original_units = {component.name: component.unit for component in component_collection} + original_units = {component.name: component.x_unit for component in component_collection} # EXPECT with pytest.raises(RuntimeError, match=r'Conversion failed.'): - component_collection.convert_unit('eV') + component_collection.convert_x_unit('eV') # Check that all components have their original units for component in component_collection: - assert component.unit == original_units[component.name] + assert component.x_unit == original_units[component.name] - def test_set_unit(self, component_collection): + def test_set_x_unit(self, component_collection): # WHEN THEN EXPECT with pytest.raises( AttributeError, - match=r'Unit is read-only. Use convert_unit to change the unit', + match=r'read-only', ): - component_collection.unit = 'eV' + component_collection.x_unit = 'eV' def test_evaluate(self, component_collection): # WHEN x = np.linspace(-5, 5, 100) + + # THEN result = component_collection.evaluate(x) # EXPECT expected_result = component_collection[0].evaluate(x) + component_collection[1].evaluate(x) np.testing.assert_allclose(result, expected_result, rtol=1e-5) def test_evaluate_no_components_returns_zero(self): - # WHEN THEN + # WHEN component_collection = ComponentCollection(display_name='EmptyModel') x = np.linspace(-5, 5, 100) - # EXPECT + # THEN result = component_collection.evaluate(x) + + # EXPECT assert np.all(result == pytest.approx(0.0)) assert result.shape == x.shape + def test_evaluate_no_components_scipp_output(self): + # WHEN + component_collection = ComponentCollection(display_name='EmptyModel', y_unit='1/meV') + x = np.linspace(-5, 5, 100) + + # THEN + result = component_collection.evaluate(x, output='scipp') + + # EXPECT: an sc.Variable of zeros carrying the collection's y_unit + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('1/meV') + assert np.all(result.values == pytest.approx(0.0)) + + def test_evaluate_no_components_scipp_input(self): + # WHEN + component_collection = ComponentCollection(display_name='EmptyModel') + x = sc.linspace('energy', -5.0, 5.0, 100, unit='meV') + + # THEN + result = component_collection.evaluate(x, output='scipp') + + # EXPECT: zeros on the input grid, keeping the input's dimension name + assert isinstance(result, sc.Variable) + assert result.dims == ('energy',) + assert np.all(result.values == pytest.approx(0.0)) + def test_evaluate_component(self, component_collection): - # WHEN THEN + # WHEN x = np.linspace(-5, 5, 100) + + # THEN result1 = component_collection.evaluate_component(x, 'TestGaussian1Name') result2 = component_collection.evaluate_component(x, 'TestLorentzian1Name') @@ -290,13 +326,57 @@ def test_normalize_area_not_finite_area_raises(self, component_collection, area_ def test_normalize_area_non_area_component_warns(self, component_collection): # WHEN - component1 = Polynomial(display_name='TestPolynomial', coefficients=[1, 2, 3], unit='meV') + component1 = Polynomial( + display_name='TestPolynomial', coefficients=[1, 2, 3], x_unit='meV' + ) component_collection.append_component(component1) # THEN EXPECT with pytest.warns(UserWarning, match="does not have an 'area' "): component_collection.normalize_area() + def test_convert_x_unit_rollback_skipped_when_old_unit_none(self): + # WHEN: a collection without an x_unit of its own + collection = ComponentCollection( + components=Gaussian(name='G', area=1.0, width=0.5, x_unit='meV'), x_unit=None + ) + + # THEN: an incompatible unit fails; the outer rollback is skipped (no old unit to + # restore) but the component's own atomic rollback keeps it consistent + with pytest.raises(sc.UnitError): + collection.convert_x_unit('m') + + # EXPECT + assert collection[0].x_unit == 'meV' + + def test_normalize_area_only_non_area_components_raises(self): + # WHEN: no component in the collection has an area attribute + collection = ComponentCollection( + components=Polynomial(display_name='OnlyPolynomial', coefficients=[1, 2]) + ) + + # THEN EXPECT + with ( + pytest.warns(UserWarning, match="does not have an 'area' "), + pytest.raises(ValueError, match='No components with an area attribute'), + ): + collection.normalize_area() + + def test_normalize_area_mixed_units(self): + # WHEN: two Gaussians with compatible but different area units (meV and ueV) + gaussian_mev = Gaussian(name='G1', area=1.0, width=1.0, x_unit='meV') + gaussian_uev = Gaussian(name='G2', area=1000.0, width=500.0, x_unit='ueV') + collection = ComponentCollection(components=[gaussian_mev, gaussian_uev]) + + # THEN: both areas are physically 1 meV, so each should end up at half its value + collection.normalize_area() + + # EXPECT: areas sum to 1 in the first component's unit (meV) + total_mev = gaussian_mev.area.value + gaussian_uev.area.value / 1000.0 + assert total_mev == pytest.approx(1.0) + assert gaussian_mev.area.value == pytest.approx(0.5) + assert gaussian_uev.area.value == pytest.approx(500.0) + def test_get_all_parameters(self, component_collection): # WHEN THEN parameters = component_collection.get_all_parameters() @@ -364,6 +444,7 @@ def test_fix_and_free_all_parameters(self, component_collection): assert param.fixed is False def test_contains(self, component_collection): + # WHEN THEN EXPECT: membership by name and by object assert 'TestGaussian1Name' in component_collection assert 'TestLorentzian1Name' in component_collection assert 'NonExistentComponent' not in component_collection @@ -373,9 +454,10 @@ def test_contains(self, component_collection): assert gaussian_component in component_collection assert lorentzian_component in component_collection - # WHEN THEN - fake_component = Gaussian(name='FakeGaussian', area=1.0, center=0.0, width=1.0, unit='meV') - # EXPECT + # WHEN: a component not in the collection — THEN EXPECT + fake_component = Gaussian( + name='FakeGaussian', area=1.0, center=0.0, width=1.0, x_unit='meV' + ) assert fake_component not in component_collection assert 123 not in component_collection # Invalid type @@ -392,13 +474,15 @@ def test_to_dict(self, component_collection): # EXPECT assert model_dict['display_name'] == component_collection.display_name - assert model_dict['unit'] == component_collection.unit + assert model_dict['x_unit'] == component_collection.x_unit + assert model_dict['y_unit'] == component_collection.y_unit assert len(model_dict['components']) == len(component_collection) for comp, comp_dict in zip(component_collection, model_dict['components'], strict=True): assert comp_dict['@class'] == type(comp).__name__ assert comp_dict['display_name'] == comp.display_name - assert comp_dict['unit'] == comp.unit + assert comp_dict['x_unit'] == comp.x_unit + assert comp_dict['y_unit'] == comp.y_unit def test_from_dict(self, component_collection): # WHEN @@ -409,12 +493,15 @@ def test_from_dict(self, component_collection): # EXPECT assert new_model.display_name == component_collection.display_name + assert new_model.x_unit == component_collection.x_unit + assert new_model.y_unit == component_collection.y_unit assert len(new_model) == len(component_collection) for orig_comp, new_comp in zip(component_collection, new_model, strict=True): assert type(new_comp) is type(orig_comp) assert new_comp.display_name == orig_comp.display_name - assert new_comp.unit == orig_comp.unit + assert new_comp.x_unit == orig_comp.x_unit + assert new_comp.y_unit == orig_comp.y_unit orig_params = orig_comp.get_all_parameters() new_params = new_comp.get_all_parameters() @@ -426,6 +513,13 @@ def test_from_dict(self, component_collection): assert param_new.value == param_orig.value assert param_new.fixed == param_orig.fixed + @pytest.mark.parametrize('missing_key', ['x_unit', 'y_unit', 'components', 'name']) + def test_from_dict_requires_all_keys(self, component_collection, missing_key): + model_dict = component_collection.to_dict() + del model_dict[missing_key] + with pytest.raises(KeyError): + ComponentCollection.from_dict(model_dict) + def test_copy(self, component_collection): # WHEN component_collection[0].area.min = 0.5 @@ -451,7 +545,8 @@ def test_copy(self, component_collection): # Same type and display name assert type(copied_comp) is type(orig_comp) assert copied_comp.display_name == orig_comp.display_name - assert copied_comp.unit == orig_comp.unit + assert copied_comp.x_unit == orig_comp.x_unit + assert copied_comp.y_unit == orig_comp.y_unit # Parameters are deep-copied and equivalent orig_params = orig_comp.get_all_parameters() @@ -487,3 +582,95 @@ def test_no_warning_with_unique_names(self, recwarn): ComponentCollection(components=[g1, g2]) user_warnings = [w for w in recwarn.list if issubclass(w.category, UserWarning)] assert not user_warnings + + def test_y_unit_custom(self): + # WHEN THEN + cc = ComponentCollection(y_unit='1/meV') + # EXPECT + assert cc.y_unit == '1/meV' + + def test_y_unit_setter_raises(self, component_collection): + # WHEN THEN EXPECT + with pytest.raises(AttributeError, match=r'read-only'): + component_collection.y_unit = '1/meV' + + def test_convert_y_unit(self): + # WHEN: components with y_unit='1/meV' so area_unit ≈ dimensionless + g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV') + lor = Lorentzian(area=1.0, x_unit='meV', y_unit='1/meV') + cc = ComponentCollection(components=[g, lor]) + + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + cc.convert_y_unit('1/eV') + + # EXPECT + assert cc.y_unit == '1/eV' + for component in cc: + assert component.y_unit == '1/eV' + assert g.area.value == pytest.approx(1e3) + assert lor.area.value == pytest.approx(1e3) + + def test_convert_y_unit_invalid_type_raises(self, component_collection): + # WHEN THEN EXPECT + with pytest.raises(TypeError): + component_collection.convert_y_unit(123) + + def test_convert_x_unit_rollback_on_failure(self): + # WHEN: collection whose first Gaussian converts fine, but second has an + # ExpressionComponent that raises NotImplementedError for convert_x_unit. + g = Gaussian(area=1.0, x_unit='meV') + expr = ExpressionComponent('A * x', parameters={'A': 1.0}, x_unit='meV') + cc = ComponentCollection(components=[g, expr]) + original_area = g.area.value + + # THEN: attempt a unit conversion that will fail on the ExpressionComponent + with pytest.raises(NotImplementedError): + cc.convert_x_unit('microeV') + + # EXPECT: Gaussian is rolled back to its original state + assert cc.x_unit == 'meV' + assert g.x_unit == 'meV' + assert g.area.value == pytest.approx(original_area) + + def test_convert_y_unit_rollback_on_failure(self): + # WHEN: collection where first Gaussian converts successfully but second + # ExpressionComponent always raises NotImplementedError for convert_y_unit. + g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV') + expr = ExpressionComponent('A * x', parameters={'A': 1.0}, x_unit='meV') + cc = ComponentCollection(components=[g, expr], y_unit='1/meV') + original_area = g.area.value + + # THEN: attempt y_unit conversion that will fail on the ExpressionComponent + with pytest.raises(NotImplementedError): + cc.convert_y_unit('1/eV') + + # EXPECT: collection y_unit and Gaussian are both rolled back + assert cc.y_unit == '1/meV' + assert g.y_unit == '1/meV' + assert g.area.value == pytest.approx(original_area) + + def test_evaluate_scipp_output_with_y_unit(self): + # WHEN + g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV') + cc = ComponentCollection(components=[g], y_unit='1/meV') + x = np.linspace(-5, 5, 50) + # THEN + result = cc.evaluate(x, output='scipp') + # EXPECT + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('1/meV') + + # ───── Regression tests ───── + + def test_evaluate_scipp_output_multi_component_does_not_raise(self, component_collection): + # WHEN: collection with two components (Gaussian + Lorentzian) + x = sc.Variable(dims=['energy'], values=np.linspace(-5.0, 5.0, 100), unit='meV') + # THEN: evaluate with scipp output + # Before the fix, sum() started from int 0 → '0 + sc.Variable' raised TypeError. + result = component_collection.evaluate(x, output='scipp') + # EXPECT: returns a Variable whose values are the sum of both components + assert isinstance(result, sc.Variable) + expected = component_collection[0].evaluate(x, output='scipp') + component_collection[ + 1 + ].evaluate(x, output='scipp') + assert sc.allclose(result, expected) diff --git a/tests/unit/easydynamics/sample_model/test_instrument_model.py b/tests/unit/easydynamics/sample_model/test_instrument_model.py index cb716fc42..a802b1961 100644 --- a/tests/unit/easydynamics/sample_model/test_instrument_model.py +++ b/tests/unit/easydynamics/sample_model/test_instrument_model.py @@ -68,12 +68,13 @@ def test_init(self, instrument_model): # EXPECT assert model.display_name == 'TestInstrumentModel' - np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0])) + assert isinstance(model.Q, sc.Variable) + assert model.Q.unit == sc.Unit('1/angstrom') + np.testing.assert_array_equal(model.Q.values, np.array([1.0, 2.0, 3.0])) assert isinstance(model.background_model, BackgroundModel) assert isinstance(model.resolution_model, ResolutionModel) - np.testing.assert_array_equal(model.background_model.Q, np.array([1.0, 2.0, 3.0])) - np.testing.assert_array_equal(model.resolution_model.Q, np.array([1.0, 2.0, 3.0])) - np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0])) + np.testing.assert_array_equal(model.background_model.Q.values, np.array([1.0, 2.0, 3.0])) + np.testing.assert_array_equal(model.resolution_model.Q.values, np.array([1.0, 2.0, 3.0])) def test_init_defaults(self): # WHEN THEN @@ -113,7 +114,7 @@ def test_init_sample_model_as_resolution_model(self, sample_model): 'energy_offset must be a number', ), ( - {'unit': 123}, + {'x_unit': 123}, TypeError, 'unit must be', ), @@ -122,7 +123,7 @@ def test_init_sample_model_as_resolution_model(self, sample_model): 'invalid resolution_model', 'invalid background_model', 'invalid energy_offset', - 'invalid unit', + 'invalid x unit', ], ) def test_instrument_model_init_invalid_inputs( @@ -209,13 +210,10 @@ def test_clear_Q_raises_without_confirm(self, instrument_model): with pytest.raises(ValueError, match='Clearing Q values requires confirmation'): instrument_model.clear_Q() - def test_unit_setter_raises(self, instrument_model): + def test_x_unit_setter_raises(self, instrument_model): # WHEN / THEN / EXPECT - with pytest.raises( - AttributeError, - match=r'Unit is read-only. Use convert_unit to change the unit between allowed types ', - ): - instrument_model.unit = 'meV' + with pytest.raises(AttributeError): + instrument_model.x_unit = 'meV' def test_energy_offset_setter(self, instrument_model): # WHEN @@ -272,16 +270,16 @@ def test_get_energy_offset_no_Q_raises(self, instrument_model): ): instrument_model.get_energy_offset(0) - def test_convert_unit_calls_all_children(self, instrument_model): + def test_convert_x_unit_calls_all_children(self, instrument_model): # WHEN new_unit = 'eV' # THEN # Ensure energy offsets are built before mocking instrument_model._ensure_energy_offsets_current() - # Mock downstream convert_unit calls - instrument_model._background_model.convert_unit = MagicMock() - instrument_model._resolution_model.convert_unit = MagicMock() + # Mock downstream convert_x_unit calls + instrument_model._background_model.convert_x_unit = MagicMock() + instrument_model._resolution_model.convert_x_unit = MagicMock() instrument_model._energy_offset.convert_unit = MagicMock() for offset in instrument_model._energy_offsets: offset.convert_unit = MagicMock() @@ -290,28 +288,28 @@ def test_convert_unit_calls_all_children(self, instrument_model): 'easydynamics.sample_model.instrument_model._validate_unit', return_value=new_unit, ) as mock_validate: - instrument_model.convert_unit(new_unit) + instrument_model.convert_x_unit(new_unit) # EXPECT mock_validate.assert_called_once_with(new_unit) - instrument_model._background_model.convert_unit.assert_called_once_with(new_unit) - instrument_model._resolution_model.convert_unit.assert_called_once_with(new_unit) + instrument_model._background_model.convert_x_unit.assert_called_once_with(new_unit) + instrument_model._resolution_model.convert_x_unit.assert_called_once_with(new_unit) instrument_model._energy_offset.convert_unit.assert_called_once_with(new_unit) for offset in instrument_model._energy_offsets: offset.convert_unit.assert_called_once_with(new_unit) # final state - assert instrument_model.unit == new_unit + assert instrument_model.x_unit == new_unit - def test_convert_unit_None_raises(self, instrument_model): + def test_convert_x_unit_None_raises(self, instrument_model): # WHEN / THEN / EXPECT with pytest.raises( ValueError, match=' must be a valid unit', ): - instrument_model.convert_unit(None) + instrument_model.convert_x_unit(None) def test_fix_resolution_parameters(self, instrument_model): # WHEN @@ -420,7 +418,9 @@ def test_generate_energy_offsets_Q_none(self, instrument_model): def test_generate_energy_offsets(self, instrument_model): # WHEN - instrument_model._Q = np.array([1.0, 2.0, 3.0, 4.0]) + instrument_model._Q = sc.Variable( + dims=['Q'], values=[1.0, 2.0, 3.0, 4.0], unit='1/angstrom' + ) # THEN instrument_model._generate_energy_offsets() @@ -429,7 +429,7 @@ def test_generate_energy_offsets(self, instrument_model): assert len(instrument_model._energy_offsets) == 4 for offset in instrument_model._energy_offsets: assert offset.name == 'energy_offset' - assert offset.unit == instrument_model.unit + assert offset.unit == instrument_model.x_unit assert offset.value == instrument_model.energy_offset.value def test_Q_setter(self, instrument_model_without_Q): @@ -441,8 +441,12 @@ def test_Q_setter(self, instrument_model_without_Q): # EXPECT assert instrument_model_without_Q._energy_offsets_is_dirty is True - np.testing.assert_array_equal(instrument_model_without_Q.background_model.Q, first_new_Q) - np.testing.assert_array_equal(instrument_model_without_Q.resolution_model.Q, first_new_Q) + np.testing.assert_array_equal( + instrument_model_without_Q.background_model.Q.values, first_new_Q + ) + np.testing.assert_array_equal( + instrument_model_without_Q.resolution_model.Q.values, first_new_Q + ) # THEN new_Q = np.array([4.0, 5.0, 6.0]) @@ -457,8 +461,12 @@ def test_Q_setter(self, instrument_model_without_Q): # EXPECT # Q values remain unchanged - np.testing.assert_array_equal(instrument_model_without_Q.background_model.Q, first_new_Q) - np.testing.assert_array_equal(instrument_model_without_Q.resolution_model.Q, first_new_Q) + np.testing.assert_array_equal( + instrument_model_without_Q.background_model.Q.values, first_new_Q + ) + np.testing.assert_array_equal( + instrument_model_without_Q.resolution_model.Q.values, first_new_Q + ) # THEN - set Q to an equivalent scipp Variable; values match so should be accepted new_Q = sc.Variable(dims=['Q'], values=[1.0, 2.0, 3.0], unit='1/angstrom') @@ -466,8 +474,12 @@ def test_Q_setter(self, instrument_model_without_Q): # EXPECT - Q propagated to child models, offsets marked dirty again assert instrument_model_without_Q._energy_offsets_is_dirty is True - np.testing.assert_array_equal(instrument_model_without_Q.background_model.Q, first_new_Q) - np.testing.assert_array_equal(instrument_model_without_Q.resolution_model.Q, first_new_Q) + np.testing.assert_array_equal( + instrument_model_without_Q.background_model.Q.values, first_new_Q + ) + np.testing.assert_array_equal( + instrument_model_without_Q.resolution_model.Q.values, first_new_Q + ) def test_fix_and_free_offset(self, instrument_model): # WHEN @@ -516,13 +528,13 @@ def test_fix_or_free_energy_offset_nonint_Q_index_raises(self, instrument_model) # WHEN / THEN / EXPECT with pytest.raises( TypeError, - match='Q_index must be an int or None, got str', + match='Q_index must be an int', ): instrument_model.fix_energy_offset(Q_index='invalid_index') with pytest.raises( TypeError, - match='Q_index must be an int or None, got str', + match='Q_index must be an int', ): instrument_model.free_energy_offset(Q_index='invalid_index') @@ -567,7 +579,7 @@ def test_repr_contains_expected_fields(self, instrument_model): # EXPECT assert repr_str.startswith('InstrumentModel(') assert f'unique_name={instrument_model.unique_name!r}' in repr_str - assert f'unit={instrument_model.unit}' in repr_str + assert f'x_unit={instrument_model.x_unit}' in repr_str assert 'Q_len=3' in repr_str assert f'resolution_model={instrument_model._resolution_model!r}' in repr_str assert f'background_model={instrument_model._background_model!r}' in repr_str diff --git a/tests/unit/easydynamics/sample_model/test_model_base.py b/tests/unit/easydynamics/sample_model/test_model_base.py index 5b58370d6..0b19950b8 100644 --- a/tests/unit/easydynamics/sample_model/test_model_base.py +++ b/tests/unit/easydynamics/sample_model/test_model_base.py @@ -23,7 +23,7 @@ def model_base(self): area=1.0, center=0.0, width=1.0, - unit='meV', + x_unit='meV', ) component2 = Lorentzian( name='TestLorentzian1Name', @@ -31,7 +31,7 @@ def model_base(self): area=2.0, center=1.0, width=0.5, - unit='meV', + x_unit='meV', ) component_collection = ComponentCollection() component_collection.append_component(component1) @@ -39,19 +39,22 @@ def model_base(self): return ModelBase( display_name='InitModel', components=component_collection, - unit='meV', + x_unit='meV', Q=np.array([1.0, 2.0, 3.0]), ) def test_init(self, model_base): # WHEN THEN - model = model_base # EXPECT - assert model.display_name == 'InitModel' - assert model.unit == 'meV' - assert len(model.components) == 2 - np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0])) + assert model_base.display_name == 'InitModel' + assert model_base.x_unit == 'meV' + assert model_base.y_unit == 'dimensionless' + assert len(model_base.components) == 2 + assert isinstance(model_base.Q, sc.Variable) + assert model_base.Q.dims == ('Q',) + assert model_base.Q.unit == sc.Unit('1/angstrom') + np.testing.assert_array_equal(model_base.Q.values, np.array([1.0, 2.0, 3.0])) def test_init_raises_with_invalid_components(self): # WHEN / THEN / EXPECT @@ -78,8 +81,8 @@ def test_evaluate_calls_all_component_collections(self, model_base): result = model_base.evaluate(x) # EXPECT - collection1.evaluate.assert_called_once_with(x) - collection2.evaluate.assert_called_once_with(x) + collection1.evaluate.assert_called_once_with(x, output='numpy') + collection2.evaluate.assert_called_once_with(x, output='numpy') np.testing.assert_allclose(result[0], np.array([1.0, 2.0, 3.0])) np.testing.assert_allclose(result[1], np.array([4.0, 5.0, 6.0])) @@ -128,7 +131,7 @@ def test_get_all_variables(self, model_base): # WHEN all_vars = model_base.get_all_variables() - # THEN + # EXPECT expected_var_display_names = { 'TestGaussian1Name area', 'TestGaussian1Name center', @@ -168,7 +171,7 @@ def test_get_all_variables_with_invalid_Q_index_raises(self, model_base): # WHEN / THEN / EXPECT with pytest.raises( IndexError, - match='Q_index 5 is out of bounds for component collections of length 3', + match='Q_index 5 is out of bounds for Q of length 3', ): model_base.get_all_variables(Q_index=5) @@ -198,7 +201,7 @@ def test_get_component_collection_invalid_index_raises(self, model_base): # WHEN THEN EXPECT with pytest.raises( IndexError, - match='Q_index 5 is out of bounds for ', + match='Q_index 5 is out of bounds for Q of length 3', ): model_base.get_component_collection(Q_index=5) @@ -246,36 +249,105 @@ def test_append_component_invalid_type_raises(self, model_base): with pytest.raises(TypeError, match=' must be '): model_base.append_component('invalid_component') - def test_unit_property(self, model_base): + def test_x_unit_property(self, model_base): # WHEN - unit = model_base.unit + unit = model_base.x_unit # THEN / EXPECT assert unit == 'meV' - def test_unit_setter_raises(self, model_base): + def test_x_unit_setter_raises(self, model_base): # WHEN / THEN / EXPECT - with pytest.raises(AttributeError, match='Use convert_unit to change '): - model_base.unit = 'K' + with pytest.raises(AttributeError): + model_base.x_unit = 'K' + + def test_convert_x_unit(self, model_base): + # Build collections before conversion so we can verify in-place update + _ = model_base.get_component_collection(0) + assert model_base._component_collections_is_dirty is False + collection_before = model_base._component_collections[0] - def test_convert_unit(self, model_base): # WHEN - model_base.convert_unit('eV') + model_base.convert_x_unit('eV') - # THEN / EXPECT - assert model_base.unit == 'eV' + # THEN / EXPECT: dirty flag NOT set and same collections reused (not rebuilt) + assert model_base._component_collections_is_dirty is False + assert model_base._component_collections[0] is collection_before + + assert model_base.x_unit == 'eV' for component in model_base.components: - assert component.unit == 'eV' + assert component.x_unit == 'eV' + for collection in model_base._component_collections: + for component in collection: + assert component.x_unit == 'eV' - def test_convert_unit_invalid_raises(self, model_base): + def test_convert_x_unit_invalid_raises(self, model_base): # WHEN / THEN / EXPECT with pytest.raises(UnitError): - model_base.convert_unit('invalid_unit') + model_base.convert_x_unit('invalid_unit') - def test_convert_unit_incorrect_unit_raises(self, model_base): + def test_convert_x_unit_incorrect_unit_raises(self, model_base): # WHEN THEN EXPECT with pytest.raises(TypeError, match=r'Unit must be a string or sc.Unit'): - model_base.convert_unit(123) + model_base.convert_x_unit(123) + + def test_components_setter_none(self, model_base): + # WHEN THEN + model_base.components = None + # EXPECT + assert len(model_base.components) == 0 + + def test_convert_x_unit_rollback_when_old_unit_none(self): + # WHEN: model with _x_unit=None (rollback branch is skipped when old_unit is None) + component = Gaussian(name='G', area=1.0, center=0.0, width=0.5, x_unit='meV') + model = ModelBase(display_name='M', x_unit=None, components=component) + model._x_unit = None + # THEN + with pytest.raises(UnitError): + model.convert_x_unit('m') # incompatible unit triggers failure + # EXPECT: Gaussian's own atomic rollback keeps it at 'meV' even though + # ModelBase's outer rollback loop is skipped when old_unit is None + assert component.x_unit == 'meV' + + def test_convert_x_unit_rollback_on_failure(self, model_base): + # WHEN THEN + with pytest.raises(UnitError): + model_base.convert_x_unit('m') + # EXPECT: state rolled back + assert model_base.x_unit == 'meV' + for component in model_base.components: + assert component.x_unit == 'meV' + + def test_convert_y_unit_rollback_on_failure(self, model_base): + # WHEN THEN + with pytest.raises(UnitError): + model_base.convert_y_unit('K') + # EXPECT: state rolled back + assert model_base.y_unit == 'dimensionless' + + def test_convert_x_unit_rollback_restores_collections(self): + # WHEN: a model with built per-Q collections + component = Gaussian(name='G', area=1.0, center=0.0, width=0.5, x_unit='meV') + model = ModelBase(display_name='M', components=component, Q=[1.0, 2.0]) + collection = model.get_component_collection(0) + + # THEN: an incompatible unit fails and triggers the rollback of components and + # collections + with pytest.raises(UnitError): + model.convert_x_unit('m') + + # EXPECT + assert model.x_unit == 'meV' + assert component.x_unit == 'meV' + assert collection[0].x_unit == 'meV' + + def test_component_collections_empty_without_Q(self): + # WHEN: a model without Q regenerates its collections + model = ModelBase(display_name='M', components=Gaussian(name='G')) + + # THEN EXPECT: no per-Q collections and therefore no variables + assert model.get_all_variables() == [] + assert model._component_collections == [] def test_components_setter(self, model_base): # WHEN @@ -320,8 +392,9 @@ def test_Q_setter_raises_if_Q_is_not_similar(self, model_base): [1.0, 2.0, 3.0], np.array([1.0, 2.0, 3.0]), sc.Variable(dims=['Q'], values=[1.0, 2.0, 3.0], unit='1/angstrom'), + sc.Variable(dims=['Q'], values=[10.0, 20.0, 30.0], unit='1/nm'), ], - ids=['list', 'numpy_array', 'scipp_variable'], + ids=['list', 'numpy_array', 'scipp_variable', 'scipp_variable_other_unit'], ) def test_Q_setter_with_similar_Q(self, model_base, new_Q): # WHEN @@ -331,7 +404,7 @@ def test_Q_setter_with_similar_Q(self, model_base, new_Q): model_base.Q = new_Q # EXPECT - np.testing.assert_array_equal(model_base.Q, old_Q) + np.testing.assert_array_equal(model_base.Q.values, old_Q.values) def test_Q_setter_with_none(self, model_base): # WHEN @@ -352,7 +425,20 @@ def test_Q_setter_when_current_Q_is_none(self, model_base): model_base.Q = new_Q # EXPECT - np.testing.assert_array_equal(model_base.Q, np.array(new_Q)) + np.testing.assert_array_equal(model_base.Q.values, np.array(new_Q)) + + def test_Q_stored_as_scipp_in_inverse_angstrom(self, model_base): + # WHEN: a scipp Q in 1/nm + model_base._Q = None + new_Q = sc.Variable(dims=['Q'], values=[5.0, 10.0], unit='1/nm') + + # THEN + model_base.Q = new_Q + + # EXPECT: stored canonically in 1/angstrom + assert isinstance(model_base.Q, sc.Variable) + assert model_base.Q.unit == sc.Unit('1/angstrom') + np.testing.assert_allclose(model_base.Q.values, [0.5, 1.0]) def test_clear_Q(self, model_base): # WHEN @@ -388,3 +474,42 @@ def test_repr(self, model_base): assert 'unit' in repr_str assert 'Q=' in repr_str assert 'components=' in repr_str + + def test_y_unit_setter_raises(self, model_base): + # WHEN / THEN / EXPECT + with pytest.raises(AttributeError): + model_base.y_unit = '1/meV' + + def test_convert_y_unit(self): + # WHEN: model with components where y_unit='1/meV' so area_unit ≈ dimensionless + g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV') + lor = Lorentzian(area=1.0, x_unit='meV', y_unit='1/meV') + cc = ComponentCollection(components=[g, lor]) + model = ModelBase(components=cc, x_unit='meV', Q=np.array([1.0])) + + # Build collections before conversion so we can verify in-place update + _ = model.get_component_collection(0) + assert model._component_collections_is_dirty is False + collection_before = model._component_collections[0] + + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + model.convert_y_unit('1/eV') + + # EXPECT: dirty flag NOT set and same collections reused (not rebuilt) + assert model._component_collections_is_dirty is False + assert model._component_collections[0] is collection_before + + # EXPECT: model y_unit and template components updated + assert model.y_unit == '1/eV' + for component in model.components: + assert component.y_unit == '1/eV' + assert g.area.value == pytest.approx(1e3) + assert lor.area.value == pytest.approx(1e3) + # EXPECT: component collections updated in-place (not rebuilt from templates) + for component in collection_before: + assert component.y_unit == '1/eV' + + def test_convert_y_unit_invalid_raises(self, model_base): + # WHEN THEN EXPECT + with pytest.raises(TypeError): + model_base.convert_y_unit(123) diff --git a/tests/unit/easydynamics/sample_model/test_resolution_model.py b/tests/unit/easydynamics/sample_model/test_resolution_model.py index f894daf20..102363f60 100644 --- a/tests/unit/easydynamics/sample_model/test_resolution_model.py +++ b/tests/unit/easydynamics/sample_model/test_resolution_model.py @@ -21,14 +21,14 @@ def resolution_model(self): area=1.0, center=0.0, width=1.0, - unit='meV', + x_unit='meV', ) component2 = Lorentzian( display_name='TestLorentzian1', area=2.0, center=1.0, width=0.5, - unit='meV', + x_unit='meV', ) component_collection = ComponentCollection() component_collection.append_component(component1) @@ -36,7 +36,7 @@ def resolution_model(self): return ResolutionModel( display_name='InitModel', components=component_collection, - unit='meV', + x_unit='meV', Q=np.array([1.0, 2.0, 3.0]), ) @@ -48,7 +48,7 @@ def sample_model(self): area=1.0, center=0.0, width=1.0, - unit='meV', + x_unit='meV', ) component2 = Lorentzian( name='TestLorentzian1Name', @@ -56,7 +56,7 @@ def sample_model(self): area=2.0, center=1.0, width=0.5, - unit='meV', + x_unit='meV', ) component_collection = ComponentCollection() component_collection.append_component(component1) @@ -65,20 +65,20 @@ def sample_model(self): return SampleModel( display_name='InitModel', components=component_collection, - unit='meV', + x_unit='meV', Q=np.array([1.0, 2.0, 3.0]), temperature=10.0, ) def test_init(self, resolution_model): # WHEN THEN - model = resolution_model # EXPECT - assert model.display_name == 'InitModel' - assert model.unit == 'meV' - assert len(model.components) == 2 - np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0])) + assert resolution_model.display_name == 'InitModel' + assert resolution_model.x_unit == 'meV' + assert resolution_model.y_unit == 'dimensionless' + assert len(resolution_model.components) == 2 + np.testing.assert_array_equal(resolution_model.Q.values, np.array([1.0, 2.0, 3.0])) @pytest.mark.parametrize( 'invalid_component, expected_error_msg', @@ -215,10 +215,10 @@ def test_from_sample_model( # EXPECT assert resolution_model.display_name == 'InitModel' - assert resolution_model.unit == 'meV' + assert resolution_model.x_unit == 'meV' assert len(resolution_model.components) == 2 np.testing.assert_array_equal( - resolution_model.Q, + resolution_model.Q.values, np.array([1.0, 2.0, 3.0]), ) @@ -242,6 +242,21 @@ def test_from_sample_model( variables = resolution_model.get_all_variables() assert all(var.fixed for var in variables) is all_fixed + def test_from_sample_model_installed_collections_are_stable(self, sample_model): + # WHEN + resolution_model = ResolutionModel.from_sample_model(sample_model) + + # EXPECT: the installed collections are not scheduled for a rebuild (regression: init + # callbacks used to set the dirty flag, so the next access silently regenerated the + # collections from templates, discarding the normalization) + assert resolution_model._component_collections_is_dirty is False + + # and repeated access returns the same installed, normalized collections + first_access = resolution_model.get_component_collection(0) + second_access = resolution_model.get_component_collection(0) + assert first_access is second_access + assert (first_access[0].area.value + first_access[1].area.value) == pytest.approx(1.0) + def test_from_sample_model_with_no_Q(self, sample_model): # WHEN sample_model_no_Q = SampleModel( @@ -301,3 +316,19 @@ def test_from_sample_model_invalid_components(self, sample_model): match='cannot be a DeltaFunction', ): ResolutionModel.from_sample_model(sample_model) + + def test_y_unit_setter_raises(self, resolution_model): + # WHEN / THEN / EXPECT + with pytest.raises(AttributeError): + resolution_model.y_unit = '1/meV' + + def test_convert_y_unit(self): + # WHEN + g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV') + model = ResolutionModel(components=g, x_unit='meV') + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + model.convert_y_unit('1/eV') + # EXPECT + assert model.y_unit == '1/eV' + assert model.components[0].y_unit == '1/eV' + assert g.area.value == pytest.approx(1e3) diff --git a/tests/unit/easydynamics/sample_model/test_sample_model.py b/tests/unit/easydynamics/sample_model/test_sample_model.py index 8086c92e6..424a41f0b 100644 --- a/tests/unit/easydynamics/sample_model/test_sample_model.py +++ b/tests/unit/easydynamics/sample_model/test_sample_model.py @@ -6,6 +6,7 @@ import numpy as np import pytest +import scipp as sc from scipp import UnitError from easydynamics.sample_model import ComponentCollection @@ -28,7 +29,7 @@ def sample_model(self): area=1.0, center=0.0, width=1.0, - unit='meV', + x_unit='meV', ) component2 = Lorentzian( name='TestLorentzian1Name', @@ -36,7 +37,7 @@ def sample_model(self): area=2.0, center=1.0, width=0.5, - unit='meV', + x_unit='meV', ) component_collection = ComponentCollection() component_collection.append_component(component1) @@ -50,7 +51,7 @@ def sample_model(self): display_name='InitModel', components=component_collection, diffusion_models=diffusion_model, - unit='meV', + x_unit='meV', Q=np.array([1.0, 2.0, 3.0]), temperature=10.0, ) @@ -62,7 +63,8 @@ def test_init(self, sample_model): # EXPECT assert model.display_name == 'InitModel' - assert model.unit == 'meV' + assert model.x_unit == 'meV' + assert model.y_unit == 'dimensionless' assert len(model.components) == 2 assert isinstance(model.diffusion_models, list) assert len(model.diffusion_models) == 1 @@ -71,7 +73,7 @@ def test_init(self, sample_model): assert model.normalize_detailed_balance is True assert model.use_detailed_balance is True assert isinstance(model.detailed_balance_settings, DetailedBalanceSettings) - np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0])) + np.testing.assert_array_equal(model.Q.values, np.array([1.0, 2.0, 3.0])) def test_init_custom_input(self): # WHEN THEN @@ -308,6 +310,62 @@ def test_convert_temperature_unit_raises_with_invalid_unit(self, sample_model): ): model.convert_temperature_unit('invalid_unit') + def test_convert_x_unit_propagates_to_diffusion_models(self): + # WHEN: a SampleModel with an attached diffusion model + Q = np.array([1.0, 2.0]) + brownian = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, Q=Q) + model = SampleModel(Q=Q, diffusion_models=brownian) + width_mev = model.get_component_collection(0)[0].width.value + + # THEN + model.convert_x_unit('ueV') + + # EXPECT: the diffusion model follows the SampleModel's new unit, and the merged + # collections carry rescaled widths in the new unit + assert sc.Unit(str(brownian.x_unit)) == sc.Unit('ueV') + assert sc.Unit(str(brownian.scale.unit)) == sc.Unit('ueV') + collection = model.get_component_collection(0) + assert sc.Unit(str(collection[0].width.unit)) == sc.Unit('ueV') + assert collection[0].width.value == pytest.approx(width_mev * 1000) + + def test_convert_x_unit_does_not_mark_collections_dirty(self): + # WHEN: a SampleModel with an attached diffusion model and built collections + Q = np.array([1.0, 2.0]) + brownian = BrownianTranslationalDiffusion(diffusion_coefficient=2.4e-9, Q=Q) + model = SampleModel(Q=Q, diffusion_models=brownian) + collection_before = model.get_component_collection(0) + width = collection_before[0].width + + # THEN + model.convert_x_unit('ueV') + + # EXPECT: conversion works in place — no dirty flag, no rebuild, and the converted + # dependent width stays in the new unit through later dependency-graph updates + # (regression: conversion used to mark the collections dirty, and the rebuild + # discarded per-Q state and object references) + assert model._component_collections_is_dirty is False + assert model.get_component_collection(0) is collection_before + width_uev = width.value + brownian.diffusion_coefficient = 4.8e-9 + assert sc.Unit(str(width.unit)) == sc.Unit('ueV') + assert width.value == pytest.approx(2 * width_uev) + + def test_convert_y_unit_propagates_to_diffusion_models(self): + # WHEN: SampleModel and diffusion model sharing y_unit='1/meV' + Q = np.array([1.0]) + brownian = BrownianTranslationalDiffusion( + diffusion_coefficient=2.4e-9, Q=Q, y_unit='1/meV' + ) + model = SampleModel(Q=Q, y_unit='1/meV', diffusion_models=brownian) + + # THEN + model.convert_y_unit('1/eV') + + # EXPECT + assert sc.Unit(str(brownian.y_unit)) == sc.Unit('1/eV') + assert sc.Unit(str(brownian.scale.unit)) == sc.Unit('meV/eV') + assert sc.Unit(str(model.y_unit)) == sc.Unit('1/eV') + def test_normalize_detailed_balance_setter(self, sample_model): # WHEN model = sample_model @@ -402,12 +460,12 @@ def test_evaluate_calls_dbf(self, sample_model): energy=x, temperature=sample_model.temperature, divide_by_temperature=sample_model.normalize_detailed_balance, - energy_unit=sample_model.unit, + energy_unit=sample_model.x_unit, ) # Check that evaluate was called on each component - collection1.evaluate.assert_called_once_with(x) - collection2.evaluate.assert_called_once_with(x) + collection1.evaluate.assert_called_once_with(x, output='numpy') + collection2.evaluate.assert_called_once_with(x, output='numpy') # Check that DBF was applied elementwise np.testing.assert_allclose(result[0], np.array([1.0, 2.0, 3.0]) * 10.0) @@ -452,8 +510,8 @@ def test_evaluate_doesnt_call_dbf_when_disabled( mock_dbf.assert_not_called() # Check that evaluate was called on each component - collection1.evaluate.assert_called_once_with(x) - collection2.evaluate.assert_called_once_with(x) + collection1.evaluate.assert_called_once_with(x, output='numpy') + collection2.evaluate.assert_called_once_with(x, output='numpy') # Check that results were not modified by DBF np.testing.assert_allclose(result[0], np.array([1.0, 2.0, 3.0])) @@ -531,9 +589,28 @@ def test_repr(self, sample_model): # THEN / EXPECT assert 'SampleModel' in repr_str - assert 'unit=' in repr_str - assert 'Q=' in repr_str + assert 'Q = ' in repr_str assert 'components' in repr_str assert 'diffusion_models' in repr_str assert 'temperature' in repr_str + # Regression: a stray ')' and a missing separator used to mangle this into + # 'x_unit=meV), y_unit=dimensionlessQ = ...', and the closing ')' was missing + assert 'x_unit=meV, y_unit=dimensionless' in repr_str + assert repr_str.rstrip().endswith(')') assert 'normalize_detailed_balance' in repr_str + + def test_y_unit_setter_raises(self, sample_model): + # WHEN / THEN / EXPECT + with pytest.raises(AttributeError): + sample_model.y_unit = '1/meV' + + def test_convert_y_unit(self): + # WHEN + g = Gaussian(area=1.0, x_unit='meV', y_unit='1/meV') + model = SampleModel(components=g, x_unit='meV') + # THEN: convert y_unit to '1/eV' (same dimension, different scale) + model.convert_y_unit('1/eV') + # EXPECT + assert model.y_unit == '1/eV' + assert model.components[0].y_unit == '1/eV' + assert g.area.value == pytest.approx(1e3) diff --git a/tests/unit/easydynamics/settings/test_convolution_settings.py b/tests/unit/easydynamics/settings/test_convolution_settings.py index 202f5baea..ccb900774 100644 --- a/tests/unit/easydynamics/settings/test_convolution_settings.py +++ b/tests/unit/easydynamics/settings/test_convolution_settings.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +from copy import copy + import pytest from easydynamics.settings.convolution_settings import ConvolutionSettings @@ -22,6 +24,21 @@ def test_init(self, default_convolution_settings): assert default_convolution_settings.extension_factor == pytest.approx(0.2) assert default_convolution_settings.convolution_plan_is_valid is False + def test_copy(self): + # WHEN + settings = ConvolutionSettings( + upsample_factor=10, extension_factor=0.5, suppress_warnings=True + ) + + # THEN + settings_copy = copy(settings) + + # EXPECT: a distinct instance with the same knob values + assert settings_copy is not settings + assert settings_copy.upsample_factor == 10 + assert settings_copy.extension_factor == pytest.approx(0.5) + assert settings_copy.suppress_warnings is True + def test_init_with_custom_parameters(self): """ Test initialization of ConvolutionSettings with custom @@ -135,12 +152,13 @@ def test_upsample_factor_setter_invalid( @pytest.mark.parametrize( 'value', - [0.0, 0.2, 1, 5.5], + [0.0, 0.2, 1, 5.5, None], ids=[ 'zero', 'typical_fraction', 'integer', 'float', + 'none_valid', ], ) def test_extension_factor_setter_valid(self, default_convolution_settings, value): @@ -152,7 +170,8 @@ def test_extension_factor_setter_valid(self, default_convolution_settings, value default_convolution_settings.extension_factor = value # EXPECT - assert default_convolution_settings.extension_factor == pytest.approx(float(value)) + expected = pytest.approx(float(value)) if value is not None else None + assert default_convolution_settings.extension_factor == expected assert default_convolution_settings.convolution_plan_is_valid is False @pytest.mark.parametrize( diff --git a/tests/unit/easydynamics/test_exceptions.py b/tests/unit/easydynamics/test_exceptions.py index 494f6f7c9..bc7731a5e 100644 --- a/tests/unit/easydynamics/test_exceptions.py +++ b/tests/unit/easydynamics/test_exceptions.py @@ -7,11 +7,14 @@ class TestAmbiguousNameError: def test_initialization(self): + # WHEN name = 'test' matches = ['test1', 'test2', 'test3'] + # THEN error = AmbiguousNameError(name, matches) + # EXPECT assert error.name == name assert error.matches == matches assert str(error) == ( @@ -19,11 +22,14 @@ def test_initialization(self): ) def test_empty_matches(self): + # WHEN name = 'unknown' matches = [] + # THEN error = AmbiguousNameError(name, matches) + # EXPECT assert error.name == name assert error.matches == matches assert str(error) == ("Ambiguous name 'unknown' matches 0 elements: []") diff --git a/tests/unit/easydynamics/test_import.py b/tests/unit/easydynamics/test_import.py index e062efcbe..d2ffe77ba 100644 --- a/tests/unit/easydynamics/test_import.py +++ b/tests/unit/easydynamics/test_import.py @@ -3,6 +3,5 @@ def test_import_easydynamics(): - import easydynamics - - assert easydynamics is not None + # WHEN THEN EXPECT: importing raises no error + import easydynamics # noqa: F401 diff --git a/tests/unit/easydynamics/utils/test_detailed_balance.py b/tests/unit/easydynamics/utils/test_detailed_balance.py index bebf5407a..2d2284d36 100644 --- a/tests/unit/easydynamics/utils/test_detailed_balance.py +++ b/tests/unit/easydynamics/utils/test_detailed_balance.py @@ -9,6 +9,7 @@ from scipp.constants import Boltzmann as kB from easydynamics.utils import detailed_balance_factor +from easydynamics.utils.detailed_balance import _convert_to_scipp_variable kB_meV_per_K = sc.to_unit(kB, 'meV/K').value @@ -305,3 +306,30 @@ def test_incompatible_temperature_unit_raises(self): energy_unit=energy_unit, temperature_unit=temperature_unit, ) + + +class TestConvertToScippVariable: + """Tests for _convert_to_scipp_variable internal helper.""" + + @pytest.mark.parametrize( + 'name, expected_match', + [ + ('energy', 'energy must be a number'), + ('temperature', 'temperature must be a number'), + ], + ids=['energy_name', 'other_name'], + ) + def test_invalid_type_raises_type_error(self, name, expected_match): + # WHEN THEN EXPECT + with pytest.raises(TypeError, match=expected_match): + _convert_to_scipp_variable({'invalid': 'type'}, name=name, unit='meV') + + def test_invalid_unit_scalar_raises_unit_error(self): + # WHEN THEN EXPECT + with pytest.raises(UnitError, match='Invalid unit string'): + _convert_to_scipp_variable(1.0, name='energy', unit='not_a_real_unit_xyz') + + def test_invalid_unit_array_raises_unit_error(self): + # WHEN THEN EXPECT + with pytest.raises(UnitError, match='Invalid unit string'): + _convert_to_scipp_variable([1.0, 2.0], name='energy', unit='not_a_real_unit_xyz') diff --git a/tests/unit/easydynamics/utils/test_utils.py b/tests/unit/easydynamics/utils/test_utils.py index 680d2a28a..6cd13b998 100644 --- a/tests/unit/easydynamics/utils/test_utils.py +++ b/tests/unit/easydynamics/utils/test_utils.py @@ -21,12 +21,14 @@ class TestValidateAndConvertQ: ], ) def test_validate_and_convert_Q_numeric_and_array(self, Q_input, expected): - # WHEN THEN + # WHEN THEN: numbers, lists, and numpy arrays are assumed to be in 1/angstrom result = _validate_and_convert_Q(Q_input) # EXPECT - assert isinstance(result, np.ndarray) - np.testing.assert_allclose(result, expected) + assert isinstance(result, sc.Variable) + assert result.dims == ('Q',) + assert result.unit == sc.Unit('1/angstrom') + np.testing.assert_allclose(result.values, expected) def test_validate_and_convert_Q_scipp_variable(self): # WHEN @@ -36,8 +38,20 @@ def test_validate_and_convert_Q_scipp_variable(self): result = _validate_and_convert_Q(Q) # EXPECT - assert isinstance(result, np.ndarray) - np.testing.assert_allclose(result, [1.0, 2.0]) + assert isinstance(result, sc.Variable) + assert result.unit == sc.Unit('1/angstrom') + np.testing.assert_allclose(result.values, [1.0, 2.0]) + + def test_validate_and_convert_Q_scipp_variable_other_unit(self): + # WHEN: a scipp Q in 1/nm + Q = sc.array(dims=['Q'], values=[10.0, 20.0], unit='1/nm') + + # THEN + result = _validate_and_convert_Q(Q) + + # EXPECT: converted to 1/angstrom + assert result.unit == sc.Unit('1/angstrom') + np.testing.assert_allclose(result.values, [1.0, 2.0]) def test_validate_and_convert_Q_none(self): # WHEN THEN EXPECT @@ -87,16 +101,24 @@ class TestValidateUnit: ], ) def test_validate_unit_valid(self, unit_input): + # WHEN + + # THEN unit = _validate_unit(unit_input) + # EXPECT if unit_input is None: assert unit is None else: assert isinstance(unit, str) def test_validate_unit_string_conversion(self): + # WHEN + + # THEN unit = _validate_unit(sc.Unit('meV')) + # EXPECT assert isinstance(unit, str) assert unit == 'meV' @@ -111,6 +133,7 @@ def test_validate_unit_string_conversion(self): ], ) def test_validate_unit_invalid_type(self, unit_input): + # WHEN THEN EXPECT with pytest.raises(TypeError, match='unit must be None, a string, or a scipp Unit'): _validate_unit(unit_input)