o
    GicI                     @   s   d dl Z d dlZd dlZd dlmZ d dlmZmZmZmZm	Z	 G dd dej
jZ		ddeded	ee d
edeeeeeef f ef f
ddZdededefddZ	ddee dee d	ee dee fddZdS )    N)Tensor)TupleOptionalSequenceUnionListc                   @   sz   e Zd ZdZe		ddejdejdeeej  deej de	d	ejfd
dZ
eded	eejejdddf fddZdS )"MutualInformationRecursionFunctionzA recursion that is useful in computing mutual information between two
    sequences of real vectors, but may be useful more generally in
    sequence-to-sequence tasks where monotonic alignment between pairs of
    sequences is desired.
    NFpxpy	pxy_gradsboundaryreturn_gradreturnc                 C   s$  |j \}}}|j d }	||	|	d fv sJ ||	f|j ||d |	fks+J |j ||	f|dur=|j |dfks=J |j |ftj||d |	d |j|jd}
t||||
}d\}}|sa|jsa|jr|tj||j|jd}t	||||
|\}}| 
|| t|dksJ t|||d< ||d< |S )	a  
        Computing mutual information between two sequences of real vectors.
        Args:
          px:
            A torch.Tensor of some floating point type, with shape ``[B][S][T]``
            if modified, ``[B][S][T+1]`` if not modified.
            where ``B`` is the batch size, ``S`` is the
            length of the ``x`` sequence (including representations of
            ``EOS`` symbols but not ``BOS`` symbols), and ``T`` is the
            length of the ``y`` sequence (including representations of
            ``EOS`` symbols but not  ``BOS`` symbols).  In the mutual
            information application, ``px[b][s][t]`` would represent the
            following log odds ratio; ignoring the b index on the right
            to make the notation more
            compact::

              px[b][s][t] =  log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]

            This expression also implicitly includes the log-probability of
            choosing to generate an ``x`` value as opposed to a ``y`` value.  In
            practice it might be computed as ``a + b``, where ``a`` is the log
            probability of choosing to extend the sequence of length ``(s,t)``
            with an ``x`` as opposed to a ``y`` value; and ``b`` might in
            practice be of the form::

                log(N exp f(x_s, y_{t-1}) / sum_t'  exp f(x_s, y_t'))

            where ``N`` is the number of terms that the sum over ``t'``
            included, which might include some or all of the other sequences as
            well as this one.

            Note:
              we don't require ``px`` and ``py`` to be contiguous, but the
              code assumes for optimization purposes that the ``T`` axis has
              stride 1.

          py:
            A torch.Tensor of the same dtype as ``px``, with shape
            ``[B][S+1][T]``, representing::

              py[b][s][t] =  log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]

            This function does not treat ``x`` and ``y`` differently; the only
            difference is that for optimization purposes we assume the last axis
            (the ``t`` axis) has stride of 1; this is true if ``px`` and ``py``
            are contiguous.

          pxy_grads:
            A List to store the return grads of ``px`` and ``py``
            if return_grad == True.
            Remain unchanged if return_grad == False.

            See `this PR <https://github.com/k2-fsa/k2/pull/924>` for more
            information about why we add this parameter.

            Note:
              the length of the list must be 2, where the first element
              represents the grads of ``px`` and the second one represents
              the grads of ``py``.

          boundary:
            If supplied, a torch.LongTensor of shape ``[B][4]``, where each
            row contains ``[s_begin, t_begin, s_end, t_end]``,
            with ``0 <= s_begin <= s_end <= S`` and
            ``0 <= t_begin <= t_end < T``
            (this implies that empty sequences are allowed).
            If not supplied, the values ``[0, 0, S, T]`` will be assumed.
            These are the beginning and one-past-the-last positions in the
            ``x`` and ``y`` sequences respectively, and can be used if not
            all sequences are
            of the same length.

          return_grad:
            Whether to return grads of ``px`` and ``py``, this grad standing
            for the occupation probability is the output of the backward with a
            ``fake gradient`` the ``fake gradient`` is the same as the gradient
            you'd get if you did
            ``torch.autograd.grad((scores.sum()), [px, py])``.
            This is useful to implement the pruned version of rnnt loss.

        Returns:
          Returns a torch.Tensor of shape ``[B]``, containing the log of
          the mutual information between the b'th pair of sequences.  This is
          defined by the following recursion on ``p[b,s,t]`` (where ``p``
          is of shape ``[B,S+1,T+1]``), representing a mutual information
          between sub-sequences of lengths ``s`` and ``t``::

                 p[b,0,0] = 0.0
                 p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
                                    p[b,s,t-1] + py[b,s,t-1])
                           (if s > 0 or t > 0)

          where we handle edge cases by treating quantities with negative
          indexes as **-infinity**.  The extension to cases where the
          boundaries are specified should be obvious; it just works on
          shorter sequences with offsets into ``px`` and ``py``.
           N   devicedtype)NN   r   )shapetorchemptyr   r   _k2mutual_information_forwardrequires_gradonesmutual_information_backwardsave_for_backwardlen)ctxr	   r
   r   r   r   BST1Tpanspx_gradpy_gradans_grad r*   I/home/ubuntu/.local/lib/python3.10/site-packages/k2/mutual_information.pyforward   s&   j
$ 
z*MutualInformationRecursionFunction.forwardr)   c                 C   s>   | j \}}|j\}||dd}|| }|| }||d d d fS )Nr   )saved_tensorsr   reshape)r    r)   r'   r(   r!   r*   r*   r+   backward   s   
z+MutualInformationRecursionFunction.backwardNF)__name__
__module____qualname____doc__staticmethodr   r   r   r   boolr,   r   r/   r*   r*   r*   r+   r      s2     r   Fr	   r
   r   r   r   c                 C   s  | j dks
J | j| j\}}}|jd }| jd ||d fv s'J | j|f|j||d |fks:J |j|||f| j|jksHJ | j|jf|dur|jtjksWJ |j|j|dfkseJ |j|f| D ]4\}}	}
}d|  kr~|
  kr~|ksn J ||
|fd|	  kr|  kr|ksn J |	||fqi|  | } }ddg}t| ||||}|\}}|r|||ffS |S )a  A recursion that is useful in computing mutual information between two
    sequences of real vectors, but may be useful more generally in
    sequence-to-sequence tasks where monotonic alignment between pairs of
    sequences is desired.  The definitions of the arguments are definitions that
    would be used when computing this type of mutual information, but you can
    also view them as arbitrary quantities and just make use of the formula
    computed by this function.

    Args:
      px:
        A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``,
        where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence
        (including representations of ``EOS`` symbols but not ``BOS`` symbols),
        and ``T`` is the length of the ``y`` sequence (including representations
        of ``EOS`` symbols but not ``BOS`` symbols).  In the mutual information
        application, ``px[b][s][t]`` would represent the following log odds
        ratio; ignoring the b index on the right to make the notation more
        compact::

          px[b][s][t] =  log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]

        This expression also implicitly includes the log-probability of
        choosing to generate an ``x`` value as opposed to a ``y`` value.  In
        practice it might be computed as ``a + b``, where ``a`` is the log
        probability of choosing to extend the sequence of length ``(s,t)``
        with an ``x`` as opposed to a ``y`` value; and ``b`` might in practice
        be of the form::

            log(N exp f(x_s, y_{t-1}) / sum_t'  exp f(x_s, y_t'))

        where ``N`` is the number of terms that the sum over ``t'`` included,
        which might include some or all of the other sequences as well as this
        one.

        Note:
          we don't require ``px`` and ``py`` to be contiguous, but the
          code assumes for optimization purposes that the ``T`` axis has
          stride 1.

      py:
        A torch.Tensor of the same dtype as ``px``, with shape ``[B][S+1][T]``,
        representing::

          py[b][s][t] =  log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]

        This function does not treat ``x`` and ``y`` differently; the only
        difference is that for optimization purposes we assume the last axis
        (the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` are
        contiguous.

      boundary:
        If supplied, a torch.LongTensor of shape ``[B][4]``, where each
        row contains ``[s_begin, t_begin, s_end, t_end]``,
        with ``0 <= s_begin <= s_end <= S`` and ``0 <= t_begin <= t_end < T``
        (this implies that empty sequences are allowed).
        If not supplied, the values ``[0, 0, S, T]`` will be assumed.
        These are the beginning and one-past-the-last positions in the ``x`` and
        ``y`` sequences respectively, and can be used if not all sequences are
        of the same length.

      return_grad:
        Whether to return grads of ``px`` and ``py``, this grad standing for the
        occupation probability is the output of the backward with a
        ``fake gradient`` the ``fake gradient`` is the same as the gradient
        you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``.
        This is useful to implement the pruned version of rnnt loss.

    Returns:
      Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual
      information between the b'th pair of sequences.  This is defined by
      the following recursion on ``p[b,s,t]`` (where ``p`` is of shape
      ``[B,S+1,T+1]``), representing a mutual information between sub-sequences
      of lengths ``s`` and ``t``::

             p[b,0,0] = 0.0
        if !modified:
             p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
                                p[b,s,t-1] + py[b,s,t-1])
        if modified:
             p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
                                p[b,s,t-1] + py[b,s,t-1])

      where we handle edge cases by treating quantities with negative indexes
      as **-infinity**.  The extension to cases where the boundaries are
      specified should be obvious; it just works on shorter sequences with
      offsets into ``px`` and ``py``.
       r   r   Nr   r   )	ndimr   r   r   int64tolist
contiguousr   apply)r	   r
   r   r   r!   r"   r#   r$   s_begint_begins_endt_endr   scoresr'   r(   r*   r*   r+   mutual_information_recursion   s&   ]
$&.0
rB   abc                 C   sT   | j d |j d ksJ | j |j f| d} |d}t| |}|ddS )z
    Does inner product on the last dimension, with expected broadcasting,
    i.e. equivalent to (a * b).sum(dim=-1)
    without creating a large temporary.
    r   )r   	unsqueezer   matmulsqueeze)rC   rD   cr*   r*   r+   _inner_product-  s
   $

rJ   c                 C   s  t | }t ||kr|dksJ t ||f| d j\}}}|d jd }|||d fv s3J ||f|d j||d |fksJJ |d j|||f| d j|d jks`J | d j|d jftj| dd}tj|dd}	|jdd}
|	jdd}|dur|jtjksJ |j|j|dfksJ |j|f| D ]4\}}}}d|  kr|  kr|ksn J |||fd|  kr|  kr|ksn J |||fq|
 | }
}|
j	dksJ |
j|j	dksJ |jtj
||d |d |
j|
jd}t|
|||}tj||
j|
jd}t|
||||\}}|d|d	}|d|d	}|||d	}|	||d	}	|jt|jjd
}|	jt|	jjd
}	t||}t||	}|| }t  ||jdd }W d   n	1 srw   Y  |d  |7  < |S )a  A recursion that is useful for modifications of RNN-T and similar loss
    functions, where the recursion probabilities have a number of terms and you
    want them reported separately.  See mutual_information_recursion() for more
    documentation of the basic aspects of this.

    Args:
      px:
        a sequence of Tensors, each of the same shape [B][S][T+1]
      py:
        a sequence of Tensor, each of the same shape [B][S+1][T],
        the sequence must be the same length as px.
      boundary:
        optionally, a LongTensor of shape [B][4] containing rows
        [s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end <= S
        and 0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T].
        These are the beginning and one-past-the-last positions in the x
        and y sequences respectively, and can be used if not all
        sequences are of the same length.
    Returns:
      a Tensor of shape (len(px), B),
      whose sum over dim 0 is the total log-prob of the recursion mentioned
      below, per sequence. The first element of the sequence of length len(px)
      is "special", in that it has an offset term reflecting the difference
      between sum-of-log and log-of-sum; for more interpretable loss values,
      the "main" part of your loss function should be first.

      The recursion below applies if boundary == None, when it defaults
      to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px
      and py::

          p = tensor of shape (B, S+1, T+1), containing -infinity
          p[b,0,0] = 0.0
          # do the following in loop over s and t:
          p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t],
                              p[b,s,t-1] + py_sum[b,s,t-1])
                      (if s > 0 or t > 0)
          return b[:][S][T]

    This function lets you implement the above recursion efficiently, except
    that it gives you a breakdown of the contribution from all the elements of
    px and py separately.  As noted above, the first element of the
    sequence is "special".
    r   r   r   )dimNr   r7   r   r   )min)r   r   r   r   stacksumr9   r:   r;   r8   r   r   r   r   r   r   r.   clampfinforL   rJ   no_grad)r	   r
   r   Nr!   r"   r#   r$   px_catpy_catpx_totpy_totr=   r>   r?   r@   r%   	tot_probsr)   r'   r(   x_prodsy_prodsprodsoffsetr*   r*   r+   "joint_mutual_information_recursion:  sX   0$.,.0 


r\   r0   )N)osr   r   r   typingr   r   r   r   r   autogradFunctionr   r6   rB   rJ   r\   r*   r*   r*   r+   <module>   s>    %
t