o
    i;U                     @   s>  d dl Z d dlmZ d dlmZmZ d dlmZ d dlm	Z	 d dl
mZmZmZ d dlmZmZ d dlmZ d d	lmZ d d
lmZ d dlmZ ddlmZmZmZ dd Ze jdg dg dg dgdd Ze jj e jddge jddddgddgddfdddgd d!gddfd"dd#ddd$fd"dd%ddd$fge jd&d'd(d)d*dd+fd,d-d)d*dd+fd'd.d)d/d0fd,d1d)d/d0fgd2d3 Z!d4d5 Z"d6Z#d7g d8d9d:d;d<fd=g d>d:d9d;d<fgZ$e jd?d@dAdB Z%dCdD Z&dEdF Z'dGdH Z(dIdJ Z)dKdL Z*dMZ+dNdO Z,dPZ-dQdR Z.dSdT Z/dUdV Z0dS )W    N)assert_array_equal)Configget_current_ops)util)English)MaxoutWindowEncoderMultiHashEmbedbuild_Tok2Vec_model)Tok2VecTok2VecListener)Doc)Example)registry)Vocab   )add_vecs_to_vocab	get_batchmake_tempdirc                  C   s   d} d}t  }t|g d}tt| ||||gdg ddt| ddd	d
}|  ||g\}}t|dks8J |d jd| fksCJ d S )N     )wordsFNORMPREFIXSUFFIXSHAPEwidthrowsinclude_static_vectorsattrs         r   depthwindow_sizemaxout_piecesr   )	r   r   r	   r   r   
initializebegin_updatelenshape)r   
embed_sizevocabdoctok2vecvectorsbackprop r2   U/home/ubuntu/.local/lib/python3.10/site-packages/spacy/tests/pipeline/test_tok2vec.pytest_empty_doc   s"   
	r4   zbatch_size,width,embed_size)r"   r   r   )r   r   r   )r#      ?   c           	      C   s   t | }tt||gd dg ddt|dddd}|  ||\}}t|t|ks/J t||D ]\}}|jt||fksCJ q4d S )Nr!   Fr   r   r"   r#   r$   )	r   r	   r   r   r(   r)   r*   zipr+   )	
batch_sizer   r,   batchr/   r0   r1   doc_vecr.   r2   r2   r3   test_tok2vec_batch_sizes)   s    	r;   r   r5   zembed_arch,embed_configzspacy.MultiHashEmbed.v1d   r   LOWERF)r   r    r      ORTHr   zspacy.CharacterEmbed.v1@   )r   nMnCr      z&tok2vec_arch,encode_arch,encode_configzspacy.Tok2Vec.v1zspacy.MaxoutWindowEncoder.v1r"   r#   )r&   r'   r%   spacy.Tok2Vec.v2zspacy.MaxoutWindowEncoder.v2zspacy.MishWindowEncoder.v1   )r&   r%   zspacy.MishWindowEncoder.v2c                 C   s   t d|}t d|}t d|}| |d< | |d< td}	||di ||di |}
|
|	 |
|	\}}t|t|	ksCJ |d jt|	d | fksRJ || d S )Narchitecturesr   r#   r   r2   )r   getr   r(   r)   r*   r+   )r   tok2vec_arch
embed_archembed_configencode_archencode_configembedencodetok2vec_modeldocsr/   r0   r1   r2   r2   r3   test_tok2vec_configs>   s   
rQ   c                  C   s:   t  } | d}|jg ksJ |   |jdsJ d S )Nr/   nO)r   add_pipe	listenersr(   modelget_dim)nlpr/   r2   r2   r3   test_init_tok2vech   s
   
rX   a  
    [nlp]
    lang = "en"
    pipeline = ["tok2vec","tagger"]

    [components]

    [components.tagger]
    factory = "tagger"

    [components.tagger.model]
    @architectures = "spacy.Tagger.v2"
    nO = null

    [components.tagger.model.tok2vec]
    @architectures = "spacy.Tok2VecListener.v1"
    width = ${components.tok2vec.model.encode.width}

    [components.tok2vec]
    factory = "tok2vec"

    [components.tok2vec.model]
    @architectures = "spacy.Tok2Vec.v2"

    [components.tok2vec.model.embed]
    @architectures = "spacy.MultiHashEmbed.v1"
    width = ${components.tok2vec.model.encode.width}
    rows = [2000, 1000, 1000, 1000]
    attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
    include_static_vectors = false

    [components.tok2vec.model.encode]
    @architectures = "spacy.MaxoutWindowEncoder.v2"
    width = 96
    depth = 4
    window_size = 1
    maxout_pieces = 3
    I like green eggsNVJr[         ?        )
preference
imperative)tagscatsEat blue hamr\   r]   r[   with_vectors)FTc              	      s  t  t}| |d d d d d< tj|ddd}| rOt }d|g d	fd
|g dfd|g dfd|g dfd|g dfg}t|j| |j	ddgksXJ |
d}|
d}|jd}t|tsoJ t|tsvJ g  tD ] } t||d |d  |d d D ]}	||	 qqz| fdd}
|j|gksJ tdD ]}i }|j |
|d q|d}||gd }t }t||j|| |d}|jdd |j	dgksJ |d d S )N
componentsr/   rU   rM   r   T	auto_fillvalidateapple)r"   r   r#   orange)and)rm   rm   rm   juice)   rr   
   pie)   g333333@g!@taggerr   r"   rb   c                          S Nr2   r2   train_examplesr2   r3   <lambda>       z'test_tok2vec_listener.<locals>.<lambda>rr   sgdlossesz Running the pipeline as a whole. )disablez9Running the pipeline with the Tok2Vec component disabled.)r   from_str
cfg_stringr   load_model_from_configr   asarrayr   r-   
pipe_namesget_piperU   get_ref
isinstancer
   r   
TRAIN_DATAappendr   	from_dictmake_doc	add_labelr(   rT   rangeupdatepredictr   to_numpytensorselect_pipes)rf   orig_configrW   opsr0   rv   r/   tagger_tok2vecttag	optimizerir   r.   
doc_tensorr2   ry   r3   test_tok2vec_listener   sP   

 r   c            	      C   s   t  t} tj| ddd}|jddgksJ |d}|d}|dg}|jj	|d dd |D }|jj
j|d	d
g}|jj	||d |dg}|dd |D  |j|\}}||d usgJ d S )NTrh   r/   rv   zA random sentence)Xc                 S   s   g | ]	}d d dD qS )c                 S   s   g | ]}d qS )r^   r2   ).0r   r2   r2   r3   
<listcomp>   s    z=test_tok2vec_listener_callback.<locals>.<listcomp>.<listcomp>)r\   Zr2   )r   wordr2   r2   r3   r      s    z2test_tok2vec_listener_callback.<locals>.<listcomp>float32)dtype)r   Yz Another entirely random sentencec                 S   s   g | ]}t |i qS r2   )r   r   )r   xr2   r2   r3   r      s    )r   r   r   r   r   r   r   r   rU   r(   r   r   r   r)   )	r   rW   rv   r/   rP   
gold_arraylabel_sampler   get_dXr2   r2   r3   test_tok2vec_listener_callback   s   

r   c               	      sx  t  t} tj| ddd}g  tD ]} t|	|d |d  q|j
 fddd}tdD ]}i }|j ||d	gd
 q3|d dk sJJ d}||}|d jdksYJ |d jdksbJ |d jdkskJ |d jdkstJ t ;}|| t|}	|	|}
|
d jdksJ |
d jdksJ |
d jdksJ |
d jdksJ W d   dS 1 sw   Y  dS )ziTest that a pipeline with a listener properly overfits, even if 'tok2vec' is in the annotating componentsTrh   r   r"   c                      rw   rx   r2   r2   ry   r2   r3   r{      r|   z3test_tok2vec_listener_overfitting.<locals>.<lambda>get_examples2   r/   )r~   r   	annotatesrv   gh㈵>I like blue eggsr[   r\   r   r]   r#   Nr   r   r   r   r   r   r   r   r   r   r(   r   r   tag_r   to_diskload_model_from_pathr   rW   r   r   r   r   	test_textr.   tmp_dirnlp2doc2r2   ry   r3   !test_tok2vec_listener_overfitting   s2   "

"r   c               	      s   t  t} tj| ddd}g  tD ]} t|	|d |d  q|j
 fddd}tdD ]%}i }tjtd	d
 |j ||dgd W d   n1 sSw   Y  q3dS )z]Test that a pipeline with a frozen tok2vec raises an error when the tok2vec is not annotatingTrh   r   r"   c                      rw   rx   r2   r2   ry   r2   r3   r{     r|   z4test_tok2vec_frozen_not_annotating.<locals>.<lambda>r   r   z*the tok2vec embedding layer is not updated)matchr/   )r~   r   excludeN)r   r   r   r   r   r   r   r   r   r   r(   r   pytestraises
ValueErrorr   )r   rW   r   r   r   r   r2   ry   r3   "test_tok2vec_frozen_not_annotating  s"   "
r   c               	      s|  t  t} tj| ddd}g  tD ]} t|	|d |d  q|j
 fddd}tdD ]}i }|j ||d	gd	gd
 q3|d dk sLJ d}||}|d jdks[J |d jdksdJ |d jdksmJ |d jdksvJ t ;}|| t|}	|	|}
|
d jdksJ |
d jdksJ |
d jdksJ |
d jdksJ W d   dS 1 sw   Y  dS )zITest that a pipeline with a frozen & annotating tok2vec can still overfitTrh   r   r"   c                      rw   rx   r2   r2   ry   r2   r3   r{   &  r|   z1test_tok2vec_frozen_overfitting.<locals>.<lambda>r   r<   r/   )r~   r   r   r   rv   g-C6?r   r[   r\   r   r]   r#   Nr   r   r2   ry   r3   test_tok2vec_frozen_overfitting  s>   "

"r   c                     s  t  t} tj| ddd}t|ddddgig | fdd |	d	}|	d
}t
|jjd ts:J |jd
 d |jjd ksIJ |jd d	 d d dksXJ |jd d
 d d	 d dksiJ |d	d
dg t
|jjd tr|J |jd d	 d }|d dksJ |jd d
 d d	 |ksJ tt |dd
dg W d    n1 sw   Y  tt |d	ddg W d    n1 sw   Y  tt |d	d
dg W d    n1 sw   Y  tt |d	d
ddg W d    n	1 sw   Y  | fdd}tdD ]}i }|j ||d |d	 dks4J |d
 dks=J qd S )NTrh   x yrb   r\   r   c                      rw   rx   r2   r2   examplesr2   r3   r{   J  r|   z(test_replace_listeners.<locals>.<lambda>r/   rv   r   rg   rU   @architecturesrD   spacy.Tok2VecListener.v1model.tok2vecinvalidparserz
model.yoloc                      rw   rx   r2   r2   r   r2   r3   r{   e  r|   r   r}   r_   )r   r   r   r   r   r   r   r   r(   r   r   rU   layersr   listener_mapconfigreplace_listenersr   r   r   r   r   )r   rW   r/   rv   t2v_cfgr   r   r   r2   r   r3   test_replace_listenersF  sL   

r   a  
    [nlp]
    lang = "en"
    pipeline = ["tok2vec","tagger", "ner"]

    [components]

    [components.tagger]
    factory = "tagger"

    [components.tagger.model]
    @architectures = "spacy.Tagger.v2"
    nO = null

    [components.tagger.model.tok2vec]
    @architectures = "spacy.Tok2VecListener.v1"
    width = ${components.tok2vec.model.encode.width}

    [components.ner]
    factory = "ner"

    [components.ner.model]
    @architectures = "spacy.TransitionBasedParser.v2"

    [components.ner.model.tok2vec]
    @architectures = "spacy.Tok2VecListener.v1"
    width = ${components.tok2vec.model.encode.width}

    [components.tok2vec]
    factory = "tok2vec"

    [components.tok2vec.model]
    @architectures = "spacy.Tok2Vec.v2"

    [components.tok2vec.model.embed]
    @architectures = "spacy.MultiHashEmbed.v1"
    width = ${components.tok2vec.model.encode.width}
    rows = [2000, 1000, 1000, 1000]
    attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
    include_static_vectors = false

    [components.tok2vec.model.encode]
    @architectures = "spacy.MaxoutWindowEncoder.v2"
    width = 96
    depth = 4
    window_size = 1
    maxout_pieces = 3
    c                     s\  t  t} tj| dd}ddgddgd}t|d|g | fd	d
 |	d}|	d}|	d}|j
ddgksAJ tdd |j D sOJ tdd |j D s]J t 4}|| t|}dg ddd|i|ddgd|dd|dddd}tj|dd}	W d    n1 sw   Y  |	 fdd
 |		d}|		d}|		d}d|	jvsJ d|	jvsJ |j
ddgksJ tdd |j D sJ tdd |j D rJ |	jd d d  }
|
d! d"ksJ |	jd d d  d |
ksJ |	jd d d  d d! d#ksJ |	jd d d  d d! d#ks,J d S )$NT)ri   r\   r   )r   r"   A)r"   r   B)rb   entitiesr   c                      rw   rx   r2   r2   r   r2   r3   r{     r|   z4test_replace_listeners_from_config.<locals>.<lambda>r/   rv   nerc                 s       | ]}t |tV  qd S rx   r   r   r   noder2   r2   r3   	<genexpr>      z5test_replace_listeners_from_config.<locals>.<genexpr>c                 s   r   rx   r   r   r2   r2   r3   r     r   en)r/   tagger2ner3tagger4)langpipelinesourcer   )r   	componentr   )r   r   )rW   rg   c                      rw   rx   r2   r2   r   r2   r3   r{     r|   r   r   r   c                 s   r   rx   r   r   r2   r2   r3   r     r   c                 s   r   rx   r   r   r2   r2   r3   r     r   rg   rU   r   rD   r   )r   r   cfg_string_multir   r   r   r   r   r(   r   listening_componentsanyrU   walkr   r   strr   r   )r   rW   annotsr/   rv   r   dir_path
base_model
new_confignew_nlpr   r2   r   r3   "test_replace_listeners_from_config  sd   






 
r   a  
    [nlp]
    lang = "en"
    pipeline = ["tok2vec","textcat_multilabel","tagger"]

    [components]

    [components.textcat_multilabel]
    factory = "textcat_multilabel"

    [components.textcat_multilabel.model]
    @architectures = "spacy.TextCatEnsemble.v2"
    nO = null

    [components.textcat_multilabel.model.tok2vec]
    @architectures = "spacy.Tok2VecListener.v1"
    width = ${components.tok2vec.model.encode.width}

    [components.textcat_multilabel.model.linear_model]
    @architectures = "spacy.TextCatBOW.v1"
    exclusive_classes = false
    ngram_size = 1
    no_output_layer = false

    [components.tagger]
    factory = "tagger"

    [components.tagger.model]
    @architectures = "spacy.Tagger.v2"
    nO = null

    [components.tagger.model.tok2vec]
    @architectures = "spacy.Tok2VecListener.v1"
    width = ${components.tok2vec.model.encode.width}

    [components.tok2vec]
    factory = "tok2vec"

    [components.tok2vec.model]
    @architectures = "spacy.Tok2Vec.v2"

    [components.tok2vec.model.embed]
    @architectures = "spacy.MultiHashEmbed.v1"
    width = ${components.tok2vec.model.encode.width}
    rows = [2000, 1000, 1000, 1000]
    attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
    include_static_vectors = false

    [components.tok2vec.model.encode]
    @architectures = "spacy.MaxoutWindowEncoder.v2"
    width = 96
    depth = 4
    window_size = 1
    maxout_pieces = 3
    c               	      s  t  t} tj| ddd}|jg dksJ |d}|d}|d}|jd}|jd}t	|t
s9J t	|ts@J t	|tsGJ g  tD ]} t||d |d  qK| fd	d
}tdD ]}	i }
|j ||
d qkt|ddg}|d j}|d dk sJ |d dksJ |d j}|d dksJ |d dk sJ dd |d D g dksJ dd |d D g dksJ d S )NTrh   )r/   textcat_multilabelrv   rv   r   r/   r   r"   c                      rw   rx   r2   r2   ry   r2   r3   r{   $  r|   z0test_tok2vec_listeners_textcat.<locals>.<lambda>r   r}   rd   rY   r`   g?ra   g?c                 S      g | ]}|j qS r2   r   r   r   r2   r2   r3   r   0      z2test_tok2vec_listeners_textcat.<locals>.<listcomp>re   c                 S   r   r2   r   r   r2   r2   r3   r   1  r   rZ   )r   r   cfg_string_multi_textcatr   r   r   r   rU   r   r   r
   r   r   r   r   r   r   r(   r   r   listpiperc   )r   rW   rv   textcatr/   r   textcat_tok2vecr   r   r   r   rP   cats0cats1r2   ry   r3   test_tok2vec_listeners_textcat  s6   


"

"r   c                  C   sX  t  t} tj| ddd}|djddgksJ t }|jd|d |jdd|d |dj	|dj	  kr?dksBJ  J |djdgksMJ |jdd	|d |djdd	gksaJ |
d	 |djdgksqJ |
d |djg ksJ |djg ksJ |d
 |djddgksJ |d
 |djg ksJ dS )zvThe component's internal name and the tok2vec listener map correspond
    to the most recently modified pipeline.
    Trh   r/   rv   r   r   r   namer   r   sentencizerN)r   r   r   r   r   r   r   r   rS   r   remove_piper   nlp1r   r2   r2   r3   &test_tok2vec_listener_source_link_name4  s&   ,



r   c                  C   s   t  t} tj| ddd}|djddgksJ |dddg |djdgks-J t }|j	d|d |djg ksAJ |j	d|d |djg ksRJ |j	dd|d	 |djdgkseJ d S )
NTrh   r/   rv   r   r   r   ner2r   )
r   r   r   r   r   r   r   r   r   rS   r   r2   r2   r3   .test_tok2vec_listener_source_replace_listenersZ  s   r  )1r   numpy.testingr   	thinc.apir   r   spacyr   spacy.lang.enr   spacy.ml.models.tok2vecr   r   r	   spacy.pipeline.tok2vecr
   r   spacy.tokensr   spacy.trainingr   
spacy.utilr   spacy.vocabr   r   r   r   r4   markparametrizer;   slowrQ   rX   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r2   r2   r2   r3   <module>   st    
	)
4!''2<9 &