---
name: Fix TRT encoder bypass bug
overview: The TRT encoder is NEVER actually called in production. A method-binding bug causes every inference to fall back to the original PyTorch encoder. Fixing this one method should unlock the real 2.84x speedup.
todos:
  - id: fix-cache-aware
    content: Add cache_aware_stream_step override to TRTEncoderWrapper in nemotron_asr_trt.py
    status: in_progress
  - id: redeploy
    content: Redeploy hindi-nemotron-asr-1 and verify encoder=tensorrt in /health
    status: pending
  - id: benchmark
    content: Run bench_modal.py at c=1,10,50,100,200,300,400,500 and compare FT p50 with previous results
    status: pending
---

# Fix: TRT Encoder Is Never Called (Method-Binding Bypass Bug)

## The Bug

**The TRT engine runs zero times during inference. Every single encoder forward pass uses the original unoptimized PyTorch encoder.**

Here is the exact call chain that proves it:

```mermaid
sequenceDiagram
    participant Pipeline as NeMo Pipeline
    participant Wrapper as TRTEncoderWrapper
    participant Original as OriginalEncoder
    participant TRT as TRT Engine

    Pipeline->>Wrapper: encoder.cache_aware_stream_step(...)
    Note over Wrapper: __getattr__ delegates to Original<br/>because TRTWrapper has no<br/>cache_aware_stream_step method
    Wrapper->>Original: original.cache_aware_stream_step(...)
    Note over Original: Inside method, self = Original<br/>self() calls Original.__call__()
    Original->>Original: self.forward(audio, length, caches...)
    Note over TRT: NEVER CALLED
    Original-->>Pipeline: PyTorch output (not TRT)
```

The root cause in [nemotron_asr/nemotron_asr_trt.py](nemotron_asr/nemotron_asr_trt.py):

- `TRTEncoderWrapper` overrides `forward()` to use TRT (line 183)
- But NeMo's pipeline calls `encoder.cache_aware_stream_step()`, NOT `encoder.forward()`
- `cache_aware_stream_step` is not defined on `TRTEncoderWrapper`, so `__getattr__` (line 176) resolves it from the **original** encoder
- Inside the original's `cache_aware_stream_step`, `self()` calls `OriginalEncoder.__call__()` -> `OriginalEncoder.forward()` -- completely bypassing TRT

**This is why the benchmarks show identical latency/throughput between TRT and torch.compile -- both are running vanilla PyTorch.**

(Note: `torch.compile` has the same delegation issue via `OptimizedModule.__getattr__`, but torch.compile may partially compensate through JIT tracing of the `__call__` path. TRT gets zero benefit.)

## The Fix

Add an explicit `cache_aware_stream_step` method to `TRTEncoderWrapper` that calls `self.forward()` (TRT path) instead of delegating to the original:

```mermaid
sequenceDiagram
    participant Pipeline as NeMo Pipeline
    participant Wrapper as TRTEncoderWrapper
    participant Original as OriginalEncoder
    participant TRT as TRT Engine

    Pipeline->>Wrapper: encoder.cache_aware_stream_step(...)
    Note over Wrapper: Method exists on Wrapper!<br/>No __getattr__ delegation
    Wrapper->>Wrapper: self.forward(audio, length, caches...)
    Wrapper->>TRT: _trt_execute(audio, length, caches...)
    TRT-->>Wrapper: encoded, caches (FP16 optimized)
    Wrapper->>Original: streaming_post_process(output)
    Note over Original: Slice caches and encoded<br/>(tensor ops, not compute)
    Original-->>Pipeline: TRT output with post-processing
```

### What the new method does

In `TRTEncoderWrapper` (around line 183 in `nemotron_asr_trt.py`), add:

```python
def cache_aware_stream_step(self, processed_signal, processed_signal_length=None,
                             cache_last_channel=None, cache_last_time=None,
                             cache_last_channel_len=None, keep_all_outputs=True,
                             drop_extra_pre_encoded=None, bypass_pre_encode=False):
    orig = self._original

    # Setup streaming params if needed (first call only)
    if orig.streaming_cfg is None:
        orig.setup_streaming_params()

    # Save/restore drop_extra_pre_encoded (matches original behavior)
    if drop_extra_pre_encoded is not None:
        prev = orig.streaming_cfg.drop_extra_pre_encoded
        orig.streaming_cfg.drop_extra_pre_encoded = drop_extra_pre_encoded
    else:
        prev = None

    if processed_signal_length is None:
        processed_signal_length = processed_signal.new_full(
            (processed_signal.size(0),), processed_signal.size(-1)
        )

    # KEY FIX: self.forward() routes to TRT via _trt_execute()
    # Original code: self() on original encoder -> PyTorch forward (BYPASSED TRT)
    encoder_output = self.forward(
        audio_signal=processed_signal,
        length=processed_signal_length,
        cache_last_channel=cache_last_channel,
        cache_last_time=cache_last_time,
        cache_last_channel_len=cache_last_channel_len,
    )

    # Post-process: cache truncation + output slicing (from original encoder)
    encoder_output = orig.streaming_post_process(
        encoder_output, keep_all_outputs=keep_all_outputs
    )

    if prev is not None:
        orig.streaming_cfg.drop_extra_pre_encoded = prev

    return encoder_output
```

### Why streaming_post_process is needed separately

The TRT engine contains only the encoder's `forward()` computation (24 conformer layers + pre-encode). The `streaming_post_process` does tensor slicing AFTER forward (truncating cache channels, clamping output lengths). These are cheap ops that must happen at runtime because `keep_all_outputs` varies per call.

## Expected Impact

- The encoder is ~60-70% of total inference time (24-layer conformer vs lightweight RNNT decoder)
- A real 2.84x encoder speedup should translate to roughly **1.8-2.2x overall pipeline speedup**
- At c=300 where torch.compile shows FT p50=1156ms, TRT should show ~550-650ms
- Capacity ceiling should move well beyond 500 streams per L40S

## Verification

After fixing, a quick sanity check: hit `/health` and run a single-stream benchmark. The FT p50 at c=1 should drop from ~600ms to ~350-400ms. That's the definitive signal that TRT is actually executing.