diff --git a/doc/source/conf.py b/doc/source/conf.py index d0d00db..f7d0af2 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -34,7 +34,8 @@ extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.viewcode', "sphinx_rtd_theme", - 'sphinx.ext.napoleon', + 'sphinx.ext.mathjax', + 'numpydoc', ] # Add any paths that contain templates here, relative to this directory. @@ -59,4 +60,5 @@ html_theme = "sphinx_rtd_theme" html_static_path = ['_static'] source_suffix = ['.rst', '.md'] -# -- Extension configuration ------------------------------------------------- \ No newline at end of file +# -- Extension configuration ------------------------------------------------- +numpydoc_show_class_members = False diff --git a/parakeet/modules/attention.py b/parakeet/modules/attention.py index cb0ae63..923ff36 100644 --- a/parakeet/modules/attention.py +++ b/parakeet/modules/attention.py @@ -36,10 +36,13 @@ def scaled_dot_product_attention(q, q: Tensor [shape=(*, T_q, d)] the query tensor. + k: Tensor [shape=(*, T_k, d)] the key tensor. + v: Tensor [shape=(*, T_k, d_v)] the value tensor. + mask: Tensor, [shape=(*, T_q, T_k) or broadcastable shape], optional the mask tensor, zeros correspond to paddings. Defaults to None. @@ -47,6 +50,7 @@ def scaled_dot_product_attention(q, ---------- out: Tensor [shape(*, T_q, d_v)] the context vector. + attn_weights [Tensor shape(*, T_q, T_k)] the attention weights. """