Source code for doatools.estimation.esprit

import numpy as np
from ..model.sources import FarField1DSourcePlacement
from .core import ensure_n_resolvable_sources

def get_default_row_weights(m):
    """Gets the default row weights for the ESPRIT estimator.
    
    Args:
        m (int): Number of rows.

    Returns:
        A ndarray vector of weights.
    """
    w = np.zeros((m,))
    for i in range(m // 2):
        w[i] = i + 1
        w[m - i - 1] = i + 1
    if m % 2 == 1:
        w[m // 2] = (m + 1) / 2
    return np.sqrt(w)

[docs]class Esprit1D: """Creates an ESPRIT estimator for 1D uniform linear arrays. Args: wavelength (float): Wavelength of the carrier wave. References: [1] R. Roy and T. Kailath, "ESPRIT-estimation of signal parameters via rotational invariance techniques," IEEE Transactions on Acoustics, Speech and Signal Processing, vol. 37, no. 7, pp. 984–995, Jul. 1989. [2] H. L. Van Trees, Optimum array processing. New York: Wiley, 2002. """ def __init__(self, wavelength): self._wavelength = wavelength
[docs] def estimate(self, R, k, d0=None, displacement=1, formulation='ls', row_weights='default', unit='rad'): r"""Estimate the direction-of-arrivals (DOAs) using ESPRIT. Args: R (~numpy.ndarray): Covariance matrix input. This covariance matrix must be obtained using a uniform linear array. k (int): Expected number of sources. d0 (float): Inter-element spacing of the uniform linear array used to obtain ``R``. If not specified, it will be set to one half of the ``wavelength`` used when creating this estimator. Default value is ``None``. displacement (int): The displacement between the two overlapping subarrays measured in number of minimal inter-element spacings. Default value is 1. Increasing this value will lead to **smaller** unambiguous range and number of resolvable sources. Make sure your DOAs falls within the unambiguous range. formulation (str): Method used to estimate the rotation matrix. Either ``'tls'`` (Total Lease Squares) or ``'ls'`` (Least Squares). Default value is ``'tls'``. row_weights (str or ~numpy.ndarray): Specifies the row weights with a vector or a string. Default value is ``'default'``, which generates the following weight vector: .. math:: \lbrack 1\ \sqrt{2}\ \sqrt{3}\ \cdots\ \sqrt{3}\ \sqrt{2}\ 1 \rbrack You can disable row weighting by passing in ``'none'``, or specify your own row weights with a 1D :class:`~numpy.ndarray`. unit (str): Unit of the estimates. Default value is ``'rad'``. See :class:`~doatools.model.sources.FarField1DSourcePlacement` for more details on valid units. Returns: A tuple with the following elements. * resolved (:class:`bool`): ``True`` only if the rooting algorithm successfully finds ``k`` roots inside the unit circle. This flag does **not** guarantee that the estimated source locations are correct. The estimated source locations may be completely wrong! If resolved is False, ``estimates`` will be ``None``. * estimates (:class:`~doatools.model.sources.FarField1DSourcePlacement`): A :class:`~doatools.model.sources.FarField1DSourcePlacement` recording the estimated source locations. Will be ``None`` if resolved is ``False``. """ m = R.shape[0] if displacement < 1: raise ValueError('Displacement must be a non-negative integer.') m_reduced = m - displacement ensure_n_resolvable_sources(k, m_reduced) if d0 is None: d0 = self._wavelength / 2.0 if isinstance(row_weights, str): if row_weights == 'none': row_weights = None elif row_weights == 'default': row_weights = get_default_row_weights(m_reduced) else: raise ValueError("When specified using a string, row weights must be either 'none' or 'default'.") elif isinstance(row_weights, np.ndarray): if row_weights.ndim != 1 or row_weights.size != m_reduced: raise ValueError('Row weights must be a vector of length {0}.'.format(m_reduced)) else: raise ValueError("Row weights must be 'default', 'none', or a compatible numpy vector.") # Extract the signal subspace. _, E = np.linalg.eigh(R) Es = E[:, -k:] # Separation Es1 = Es[:-displacement, :] Es2 = Es[displacement:, :] # Apply row weights. if row_weights is not None: Es1 *= row_weights[:, np.newaxis] Es2 *= row_weights[:, np.newaxis] # Estimate the rotation matrix. if formulation == 'tls': # Total least-squares C = np.hstack((Es1, Es2)) C = C.conj().T @ C _, V = np.linalg.eigh(C) V = np.fliplr(V) # Now in descending order V12 = V[:k, k:] V22 = V[k:, k:] # Phi = -V12 V22^{-1} Phi = -np.linalg.solve(V22.T, V12.T).T elif formulation == 'ls': # Least-squares Es1_H = Es1.conj().T Phi = np.linalg.solve(Es1_H @ Es1, Es1_H @ Es2) else: raise ValueError("Formulation must be either 'ls' or 'tls'.") # Recover the DOAs. z = np.linalg.eigvals(Phi) return True, FarField1DSourcePlacement.from_z(z, self._wavelength, d0 * displacement, unit)