Metadata-Version: 2.1
Name: numpyro-oop
Version: 0.0.1
Summary: A convenient object-oriented wrapper for working with numpyro models.
Author-email: Thomas Wallis <thomas.wallis@tu-darmstadt.de>
License: Copyright (c) 2024 Thomas Wallis
        
        Permission is hereby granted, free of charge, to any person obtaining a copy
        of this software and associated documentation files (the "Software"), to deal
        in the Software without restriction, including without limitation the rights
        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
        copies of the Software, and to permit persons to whom the Software is
        furnished to do so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
Keywords: numpyro,probabilistic programming,mcmc,bayesian inference
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: dm-tree ==0.1.8
Requires-Dist: multipledispatch ==1.0.0
Requires-Dist: pytz ==2024.1
Requires-Dist: pyparsing ==3.1.2 ; python_full_version >= "3.6.8"
Requires-Dist: tzdata ==2024.1 ; python_version >= "2"
Requires-Dist: python-dateutil ==2.9.0.post0 ; python_version >= "2.7" and python_version not in "3.0, 3.1, 3.2, 3.3"
Requires-Dist: six ==1.16.0 ; python_version >= "2.7" and python_version not in "3.0, 3.1, 3.2, 3.3"
Requires-Dist: arviz ==0.19.0 ; python_version >= "3.10"
Requires-Dist: jax ==0.4.31 ; python_version >= "3.10"
Requires-Dist: jaxlib ==0.4.31 ; python_version >= "3.10"
Requires-Dist: scipy ==1.14.0 ; python_version >= "3.10"
Requires-Dist: opt-einsum ==3.3.0 ; python_version >= "3.5"
Requires-Dist: kiwisolver ==1.4.5 ; python_version >= "3.7"
Requires-Dist: tqdm ==4.66.4 ; python_version >= "3.7"
Requires-Dist: cycler ==0.12.1 ; python_version >= "3.8"
Requires-Dist: fonttools ==4.53.1 ; python_version >= "3.8"
Requires-Dist: h5py ==3.11.0 ; python_version >= "3.8"
Requires-Dist: packaging ==24.1 ; python_version >= "3.8"
Requires-Dist: pillow ==10.4.0 ; python_version >= "3.8"
Requires-Dist: setuptools ==72.1.0 ; python_version >= "3.8"
Requires-Dist: typing-extensions ==4.12.2 ; python_version >= "3.8"
Requires-Dist: contourpy ==1.2.1 ; python_version >= "3.9"
Requires-Dist: h5netcdf ==1.3.0 ; python_version >= "3.9"
Requires-Dist: matplotlib ==3.9.1 ; python_version >= "3.9"
Requires-Dist: ml-dtypes ==0.4.0 ; python_version >= "3.9"
Requires-Dist: numpy ==2.0.1 ; python_version >= "3.9"
Requires-Dist: numpyro ==0.15.2 ; python_version >= "3.9"
Requires-Dist: pandas ==2.2.2 ; python_version >= "3.9"
Requires-Dist: xarray ==2024.7.0 ; python_version >= "3.9"
Requires-Dist: xarray-einstats ==0.7.0 ; python_version >= "3.9"
Provides-Extra: dev
Requires-Dist: asttokens ==2.4.1 ; extra == 'dev'
Requires-Dist: fastjsonschema ==2.20.0 ; extra == 'dev'
Requires-Dist: fqdn ==1.5.1 ; extra == 'dev'
Requires-Dist: isoduration ==20.11.0 ; extra == 'dev'
Requires-Dist: jsonpointer ==3.0.0 ; extra == 'dev'
Requires-Dist: ptyprocess ==0.7.0 ; extra == 'dev'
Requires-Dist: pure-eval ==0.2.3 ; extra == 'dev'
Requires-Dist: stack-data ==0.6.3 ; extra == 'dev'
Requires-Dist: uri-template ==1.3.0 ; extra == 'dev'
Requires-Dist: wcwidth ==0.2.13 ; extra == 'dev'
Requires-Dist: webcolors ==24.6.0 ; extra == 'dev'
Requires-Dist: webencodings ==0.5.1 ; extra == 'dev'
Requires-Dist: appnope ==0.1.4 ; (platform_system == "Darwin") and extra == 'dev'
Requires-Dist: beautifulsoup4 ==4.12.3 ; (python_full_version >= "3.6.0") and extra == 'dev'
Requires-Dist: pyparsing ==3.1.2 ; (python_full_version >= "3.6.8") and extra == 'dev'
Requires-Dist: charset-normalizer ==3.3.2 ; (python_full_version >= "3.7.0") and extra == 'dev'
Requires-Dist: prompt-toolkit ==3.0.47 ; (python_full_version >= "3.7.0") and extra == 'dev'
Requires-Dist: nbclient ==0.10.0 ; (python_full_version >= "3.8.0") and extra == 'dev'
Requires-Dist: pandocfilters ==1.5.1 ; (python_version >= "2.7" and python_version not in "3.0, 3.1, 3.2, 3.3") and extra == 'dev'
Requires-Dist: python-dateutil ==2.9.0.post0 ; (python_version >= "2.7" and python_version not in "3.0, 3.1, 3.2, 3.3") and extra == 'dev'
Requires-Dist: six ==1.16.0 ; (python_version >= "2.7" and python_version not in "3.0, 3.1, 3.2, 3.3") and extra == 'dev'
Requires-Dist: defusedxml ==0.7.1 ; (python_version >= "2.7" and python_version not in "3.0, 3.1, 3.2, 3.3, 3.4") and extra == 'dev'
Requires-Dist: rfc3339-validator ==0.1.4 ; (python_version >= "2.7" and python_version not in "3.0, 3.1, 3.2, 3.3, 3.4") and extra == 'dev'
Requires-Dist: rfc3986-validator ==0.1.1 ; (python_version >= "2.7" and python_version not in "3.0, 3.1, 3.2, 3.3, 3.4") and extra == 'dev'
Requires-Dist: psutil ==6.0.0 ; (python_version >= "2.7" and python_version not in "3.0, 3.1, 3.2, 3.3, 3.4, 3.5") and extra == 'dev'
Requires-Dist: send2trash ==1.8.3 ; (python_version >= "2.7" and python_version not in "3.0, 3.1, 3.2, 3.3, 3.4, 3.5") and extra == 'dev'
Requires-Dist: ipython ==8.26.0 ; (python_version >= "3.10") and extra == 'dev'
Requires-Dist: decorator ==5.1.1 ; (python_version >= "3.5") and extra == 'dev'
Requires-Dist: executing ==2.0.1 ; (python_version >= "3.5") and extra == 'dev'
Requires-Dist: idna ==3.7 ; (python_version >= "3.5") and extra == 'dev'
Requires-Dist: mypy-extensions ==1.0.0 ; (python_version >= "3.5") and extra == 'dev'
Requires-Dist: nest-asyncio ==1.6.0 ; (python_version >= "3.5") and extra == 'dev'
Requires-Dist: argon2-cffi-bindings ==21.2.0 ; (python_version >= "3.6") and extra == 'dev'
Requires-Dist: certifi ==2024.7.4 ; (python_version >= "3.6") and extra == 'dev'
Requires-Dist: jedi ==0.19.1 ; (python_version >= "3.6") and extra == 'dev'
Requires-Dist: overrides ==7.7.0 ; (python_version >= "3.6") and extra == 'dev'
Requires-Dist: parso ==0.8.4 ; (python_version >= "3.6") and extra == 'dev'
Requires-Dist: python-json-logger ==2.0.7 ; (python_version >= "3.6") and extra == 'dev'
Requires-Dist: pyyaml ==6.0.1 ; (python_version >= "3.6") and extra == 'dev'
Requires-Dist: argon2-cffi ==23.1.0 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: attrs ==23.2.0 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: click ==8.1.7 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: h11 ==0.14.0 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: iniconfig ==2.0.0 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: jinja2 ==3.1.4 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: kiwisolver ==1.4.5 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: markupsafe ==2.1.5 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: mistune ==3.0.2 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: notebook-shim ==0.2.4 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: pyproject-hooks ==1.1.0 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: pyzmq ==26.0.3 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: sniffio ==1.3.1 ; (python_version >= "3.7") and extra == 'dev'
Requires-Dist: anyio ==4.4.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: arrow ==1.3.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: async-lru ==2.0.4 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: babel ==2.15.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: black ==24.4.2 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: bleach ==6.1.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: build ==1.2.1 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: cffi ==1.16.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: comm ==0.2.2 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: cycler ==0.12.1 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: debugpy ==1.8.2 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: fonttools ==4.53.1 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: graphviz ==0.20.3 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: httpcore ==1.0.5 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: httpx ==0.27.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: ipykernel ==6.29.5 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: json5 ==0.9.25 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: jsonschema[format-nongpl] ==4.23.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: jsonschema-specifications ==2023.12.1 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: jupyter-client ==8.6.2 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: jupyter-core ==5.7.2 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: jupyter-events ==0.10.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: jupyter-lsp ==2.2.5 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: jupyter-server ==2.14.2 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: jupyter-server-terminals ==0.5.3 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: jupyterlab ==4.2.4 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: jupyterlab-pygments ==0.3.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: jupyterlab-server ==2.27.3 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: matplotlib-inline ==0.1.7 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: nbconvert ==7.16.4 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: nbformat ==5.10.4 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: packaging ==24.1 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: pathspec ==0.12.1 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: pillow ==10.4.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: platformdirs ==4.2.2 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: pluggy ==1.5.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: prometheus-client ==0.20.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: pycparser ==2.22 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: pygments ==2.18.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: pytest ==8.3.2 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: referencing ==0.35.1 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: requests ==2.32.3 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: rpds-py ==0.19.1 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: setuptools ==72.1.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: soupsieve ==2.5 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: terminado ==0.18.1 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: tinycss2 ==1.3.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: tornado ==6.4.1 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: traitlets ==5.14.3 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: types-python-dateutil ==2.9.0.20240316 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: urllib3 ==2.2.2 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: websocket-client ==1.8.0 ; (python_version >= "3.8") and extra == 'dev'
Requires-Dist: contourpy ==1.2.1 ; (python_version >= "3.9") and extra == 'dev'
Requires-Dist: matplotlib ==3.9.1 ; (python_version >= "3.9") and extra == 'dev'
Requires-Dist: numpy ==2.0.1 ; (python_version >= "3.9") and extra == 'dev'
Requires-Dist: pexpect ==4.9.0 ; (sys_platform != "win32" and sys_platform != "emscripten") and extra == 'dev'

# An object-oriented interface to numpyro

This package provides a wrapper for working with [numpyro](https://num.pyro.ai/) models.
It aims to remain model-agnostic, but package up a lot of the model fitting code to reduce repetition.

It is intended to make life a bit easier for people who are already familiar with Numpyro and Bayesian modelling.
It is not intended to fulfil the same high-level wrapper role as packages such as [brms](https://paul-buerkner.github.io/brms/).
The user is still required to write the model.

## Getting started

```
pip install numpyro-oop
```

The basic idea is that the user defines a new class that inherits from `BaseNumpyroModel`, 
and defines (minimally) the model to be fit by overwriting the `model` method:

```python
from numpyro-oop import BaseNumpyroModel

class DemoModel(BaseNumpyroModel):
    def model(self, data=None):
        ...

m1 = DemoModel(data=df, seed=42)
```

Then all other sampling and prediction steps are handled by `numpyro-oop`, or related libraries (e.g. `arviz`):

```python
m1.sample()  # sample from the model
preds = m1.predict()  # generate model predictions for the dataset given at initialization, or pass a new dataset
m1.generate_arviz_data()  # generate an Arviz InferenceData object stored in self.arviz_data
```

A more complete demo can be found in `/scripts/demo_1.ipynb`.

### Roadmap after initial release

- [ ] include doctest, improved examples
- [ ] demo and tests for multiple group variables
- [ ] export docs to some static page (readthedocs or similar); detail info on class methods and attributes
- [ ] CI test setup
- [ ] Contributor guidelines
- [ ] Fix type hints via linter checks


### Development notes

- Update dependencies with `make update-deps`
- Update and (re)install the environment with `make update-and-install`



