splisosm.model
==============

.. py:module:: splisosm.model

.. autoapi-nested-parse::

   Multinomial GLM/GLMM model implementations for SPLISOSM.



Classes
-------

.. autoapisummary::

   splisosm.model.MultinomGLM
   splisosm.model.MultinomGLMM


Module Contents
---------------

.. py:class:: MultinomGLM(fitting_method = 'iwls', fitting_configs = {})

   Bases: :py:obj:`BaseModel`, :py:obj:`torch.nn.Module`


   The Multinomial Generalized Linear Model for spatial isoform expression.

   Compared to MultinomGLMM, this model does not have a random effect term::

       Y ~ Multinomial(alpha, Y.sum(1))
       eta = multinomial-logit(alpha) = X @ beta + bias_eta

   Given isoform counts of a gene ``Y`` (n_spots, n_isos) and design matrix ``X`` (n_spots, n_factors),
   MultinomGLM.fit will find the MAP estimates of the following learnable parameters:

   - ``beta``: (n_factors, n_isos - 1) covariate coefficients of the fixed effect term.
   - ``bias_eta``: (n_isos - 1) intercepts of the fixed effect term.

   Inference is performed by maximizing the log likelihood using `fitting_method` (default: ``'iwls'``)

   .. rubric:: Example

   >>> from splisosm.model import MultinomGLM
   >>> import torch
   >>> # Generate synthetic data
   >>> counts = torch.randint(0, 10, (5, 100, 3))  # 5 genes, 100 spots, each 3 isoforms
   >>> # Fit the GLM model
   >>> model = MultinomGLM(fitting_method='iwls')
   >>> model.setup_data(counts, design_mtx=None)
   >>> model.fit()
   >>> print(model)
   >>> # Extract the fitted isoform ratios
   >>> isoform_ratios = model.get_isoform_ratio()  # shape (5, 100, 3)
   >>> # Fitted parameters
   >>> print(model.beta.shape)  # shape (5, 0, 2)
   >>> print(model.bias_eta.shape)  # shape (5, 2)

   :param fitting_method: Method for fitting the model.
                          ``'iwls'``: Iteratively reweighted least squares.
                          ``'newton'``: Newton's method.
                          ``'gd'``: Gradient descent.
   :param fitting_configs: Dictionary of fitting configurations. Keys include

                           - ``'lr'``: float, Learning rate for gradient descent or Newton's method.
                           - ``'optim'``: str, Optimizer type, one of ``'adam'``, ``'sgd'``, or ``'lbfgs'``.
                           - ``'tol'``: float, Tolerance for convergence.
                           - ``'max_epochs'``: int, Maximum number of epochs for fitting.
                           - ``'patience'``: int, Number of epochs to wait for improvement before stopping.


   .. py:method:: clone()

      Clone a model with the same set of parameters.



   .. py:method:: fit(diagnose = False, verbose = False, quiet = False, random_seed = None)

      Fit the model using all data

      :param diagnose: Whether to store parameter changes during training (passed to PatienceLogger).
      :param verbose: Whether to print verbose information during fitting.
      :param quiet: Whether to suppress output during fitting.
      :param random_seed: Random seed for reproducibility.

      :returns: **params_iter** -- If `diagnose` is True, returns a dictionary of parameter changes during training. Otherwise returns None.
      :rtype: dict or None



   .. py:method:: forward()

      Calculate log probability given data.

      :returns: **log_prob** -- Shape (n_genes,), the log probability for each gene.
      :rtype: torch.Tensor



   .. py:method:: get_isoform_ratio()

      Extract the fitted isoform ratio across space.

      :returns: **ratio** -- Shape (n_genes, n_spots, n_isos), the fitted isoform ratio across space.
      :rtype: torch.Tensor



   .. py:method:: setup_data(counts, design_mtx = None, device = 'cpu')

      Set up the data for the model.

      :param counts: Shape (n_genes, n_spots, n_isos) or (n_spots, n_isos).
                     For batched calculations, all genes in the batch must have the same number of isoforms.
      :param design_mtx: Shape (n_spots, n_factors). Design matrix of spatial covariates.
                         If None, an intercept-only design matrix will be used.
      :param device: 'cpu' or 'cuda'. 'mps' currently not supported (torch.lgamma not supported on mps).



   .. py:method:: update_params_from_dict(params)

      Update a subset of model parameters with a dictionary of parameters.

      :param params: A dictionary of parameters to be updated. The keys must be
                     existing parameter names in the model.



   .. py:attribute:: fitting_configs
      :type:  dict

      Dictionary of fitting configurations.


   .. py:attribute:: fitting_method
      :type:  Literal['iwls', 'newton', 'gd']

      Method for fitting the model.


   .. py:attribute:: fitting_time
      :type:  float

      Time taken for fitting the model.


   .. py:attribute:: n_factors
      :type:  int | None

      Number of covariates in the design matrix


   .. py:attribute:: n_genes
      :type:  int

      Number of genes in the batch.


   .. py:attribute:: n_isos
      :type:  int

      Number of isoforms per gene in the batch.


   .. py:attribute:: n_spots
      :type:  int

      Number of samples/spots


.. py:class:: MultinomGLMM(share_variance = True, var_parameterization_sigma_theta = True, var_fix_sigma = True, var_prior_model = 'none', var_prior_model_params = {}, init_ratio = 'observed', fitting_method = 'joint_gd', fitting_configs = {})

   Bases: :py:obj:`MultinomGLM`, :py:obj:`BaseModel`, :py:obj:`torch.nn.Module`


   The Multinomial Generalized Linear Mixed Model for spatial isoform expression.

   The model is defined as follows::

       Y ~ Multinomial(alpha, Y.sum(1))
       eta = multinomial-logit(alpha) = X @ beta + bias_eta + nu
       nu ~ MVN(0, sigma^2 * (theta * V_sp + (1-theta) * I) =
            MVN(0, sigma_sp^2 * V_sp + sigma_nsp^2 * I)

   Given isoform counts of a gene ``Y`` (n_spots, n_isos), design matrix ``X`` (n_spots, n_factors),
   and spatial covariance matrix ``V_sp`` (n_spots, n_spots), the model estimates the isoform usage
   ratio ``alpha`` (n_spots, n_isos) across space.
   Specifically, `MultinomGLMM.fit` will find the MAP estimates of the following learnable parameters:

   - ``beta``: (n_factors, n_isos - 1) covariate coefficients of the fixed effect term.
   - ``bias_eta``: (n_isos - 1) intercepts of the fixed effect term.
   - ``nu``: (n_spots, n_isos - 1) the random effect term.
   - variance components: each of length n_isos - 1 (or 1 if `share_variance` is True).
     If `var_parameterization_sigma_theta` is True, they are (``sigma``, ``theta_logit``),
     representing total variance and logit of spatial variance proportion ``theta``.
     Otherwise they are (``sigma_sp``, ``sigma_nsp``), representing spatial and non-spatial
     variance components.

   Inference algorithms can be categorized into two types based on the optimization objective:

   - Joint: Maximize the joint likelihood (with the random effect ``nu``).
     This is equivalent to the first-order Laplace approximation of the marginal likelihood.
   - Marginal: Maximize the marginal likelihood (with the random effect ``nu`` integrated out).
     The integral is approximated by a second-order Laplace approximation.

   Methods implemented:

   - ``'joint_gd'``: Maximize the joint likelihood using gradient descent.
   - ``'joint_newton'``: Maximize the joint likelihood using Newton's method.
   - ``'marginal_gd'``: Maximize the marginal likelihood using gradient descent.
   - ``'marginal_newton'``: Maximize the marginal likelihood using Newton's method.
     In this method, ``nu`` is first updated using Newton's method every
     ``'update_nu_every_k'`` iterations, and ``beta``, ``bias_eta``, and variance components
     are updated using gradient descent.

   .. rubric:: Notes

   It is also possible to implement held-out likelihood for model selection.

   .. rubric:: Example

   >>> from splisosm.model import MultinomGLMM
   >>> from splisosm.utils import get_cov_sp
   >>> import torch
   >>> # Generate synthetic data
   >>> counts = torch.randint(0, 10, (5, 100, 3))  # 5 genes, 100 spots, each 3 isoforms
   >>> coords = torch.rand(100, 2)  # 100 spots with 2D coordinates
   >>> K_sp = get_cov_sp(coords, k=4, rho=0.9) # spatial covariance matrix of shape (100, 100)
   >>> # Fit the GLMM model
   >>> model = MultinomGLMM(fitting_method='joint_gd')
   >>> model.setup_data(counts, corr_sp=K_sp, design_mtx=None)
   >>> model.fit()
   >>> print(model)
   >>> # Extract the fitted isoform ratios
   >>> isoform_ratios = model.get_isoform_ratio()  # shape (5, 100, 3)
   >>> # Fitted parameters
   >>> print(model.beta.shape)  # shape (5, 0, 2)
   >>> print(model.bias_eta.shape)  # shape (5, 2)
   >>> print(model.nu.shape)  # shape (5, 100, 2)
   >>> print(model.sigma.shape)  # shape (5, 1)
   >>> print(model.theta_logit.shape)  # shape (5, 1)

   :param share_variance: Whether to use the same variance across isoforms. If True, the variance components
                          will be of length 1. If False, the variance components will be of length n_isos - 1.
   :param var_parameterization_sigma_theta: Whether to parameterize the variance components as (``sigma``, ``theta_logit``) or (``sigma_sp``, ``sigma_nsp``).
                                            If True, the variance components will be (``sigma``, ``theta_logit``), where ``sigma`` is the total variance and
                                            ``theta_logit`` is the logit of the spatial variance proportion.
                                            If False, the variance components will be (``sigma_sp``, ``sigma_nsp``), where ``sigma_sp`` is the spatial
                                            variance and ``sigma_nsp`` is the non-spatial variance.
   :param var_fix_sigma: Whether to fix the total variance (``sigma``) or not. If True, the total variance will be fixed to the initial value,
                         which is the average per-spot variance of isoform counts normalized by its mean expression.
                         See `MultinomGLMM._initialize_params` for details.
   :param var_prior_model: The prior model on the total variance ``sigma``. Default is ``'none'`` with no prior.
                           Other options are ``'gamma'`` (Gamma prior) and ``'inv_gamma'`` (Inverse Gamma prior).
   :param var_prior_model_params: The parameters for the prior model on the total variance ``sigma``.
                                  For ``'gamma'``, the default parameters are ``{'alpha': 2.0, 'beta': 0.3}``.
                                  For ``'inv_gamma'``, the default parameters are ``{'alpha': 3, 'beta': 0.5}``.
   :param init_ratio: The initialization method for the logit isoform usage ratio. Options are ``'observed'`` (initialize using observed counts)
                      and ``'uniform'`` (equal isoform usage across space).
   :param fitting_method: The fitting method to use. Options are ``'joint_gd'`` (joint likelihood with gradient descent),
                          ``'joint_newton'`` (joint likelihood with Newton's method),
                          ``'marginal_gd'`` (marginal likelihood with gradient descent),
                          and ``'marginal_newton'`` (marginal likelihood with Newton's method).
   :param fitting_configs: A dictionary of fitting configurations with the following keys:

                           - ``'lr'``: float, learning rate
                           - ``'optim'``: str, optimization method (one of ``'adam'``, ``'sgd'``, or ``'lbfgs'``)
                           - ``'tol'``: float, tolerance for convergence
                           - ``'max_epochs'``: int, maximum number of epochs
                           - ``'patience'``: int, number of epochs to wait for improvement before stopping
                           - ``'update_nu_every_k'``: int, number of iterations to update ``nu`` when using ``fitting_method='marginal_newton'``


   .. py:method:: clone()

      Clone a Multinomial GLMM model with the same set of parameters.



   .. py:method:: fit(diagnose = False, verbose = False, quiet = False, random_seed = None)

      Fit the model using all data

      :param diagnose: Whether to store parameter changes during training (passed to PatienceLogger).
      :param verbose: Whether to print verbose information during fitting.
      :param quiet: Whether to suppress output during fitting.
      :param random_seed: Random seed for reproducibility.

      :returns: **params_iter** -- If `diagnose` is True, returns a dictionary of parameter changes during training. Otherwise returns None.
      :rtype: dict or None



   .. py:method:: forward()

      Calculate the log-likelihood or log-marginal-likelihood of the model.

      :returns: **log_prob** -- Shape (n_genes,), the log probability for each gene.
      :rtype: torch.Tensor



   .. py:method:: setup_data(counts, corr_sp = None, design_mtx = None, device = 'cpu', corr_sp_eigvals = None, corr_sp_eigvecs = None)

      Set up the data for the model.

      :param counts: Shape (n_genes, n_spots, n_isoforms) or (n_spots, n_isoforms).
                     For batched calculations, all genes in the batch must have the same number of isoforms.
      :param corr_sp: Shape (n_spots, n_spots), spatial covariance matrix.
                      If None, the eigendecomposition of the spatial covariance matrix must be provided.
      :param design_mtx: Shape (n_spots, n_factors). Design matrix of spatial covariates.
                         If None, an intercept-only design matrix will be used.
      :param device: 'cpu' or 'cuda'. 'mps' currently not supported (torch.lgamma not supported on mps).
      :param corr_sp_eigvals: Shape (n_spots,), eigenvalues of spatial covariance.
                              If None, the spatial covariance matrix `corr_sp` must be provided.
      :param corr_sp_eigvecs: Shape (n_spots, n_spots), eigenvectors of spatial covariance.
                              If None, the spatial covariance matrix `corr_sp` must be provided.



   .. py:method:: var_sp_prop()

      Output the proporptions of the spatial variance.

      :returns: **var_sp_prop** -- The proportion ``theta`` of spatial variance of shape (n_genes, n_var_components).
      :rtype: torch.Tensor



   .. py:method:: var_total()

      Output the total variance.

      :returns: **var_total** -- The total variance ``sigma`` of shape (n_genes, n_var_components).
      :rtype: torch.Tensor



   .. py:attribute:: fitting_configs
      :type:  dict

      A dictionary of fitting configurations.


   .. py:attribute:: fitting_method
      :type:  Literal['joint_gd', 'joint_newton', 'marginal_gd', 'marginal_newton']

      The fitting method to use.


   .. py:attribute:: fitting_time
      :type:  float

      The time taken to fit the model.


   .. py:attribute:: init_ratio
      :type:  Literal['observed', 'uniform']

      The initialization method for the logit isoform usage ratio ``gamma``.


   .. py:attribute:: share_variance
      :type:  bool

      Whether to use the same variance across isoforms.


   .. py:attribute:: var_fix_sigma
      :type:  bool

      Whether to fix the total variance (``sigma``) or not.


   .. py:attribute:: var_parameterization_sigma_theta
      :type:  bool

      Whether variance components are parameterized as (``sigma``, ``theta_logit``) or (``sigma_sp``, ``sigma_nsp``).


   .. py:attribute:: var_prior_model
      :type:  Literal['none', 'gamma', 'inv_gamma']

      The prior model on the total variance ``sigma``.


   .. py:attribute:: var_prior_model_params
      :type:  dict

      The parameters for the prior model on the total variance ``sigma``.


