2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								import  re  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  collections  import  namedtuple  
						 
					
						
							
								
									
										
										
										
											2022-10-05 22:11:30 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  typing  import  List  
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  lark  
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# will be represented with prompt_schedule like this (assuming steps=100):  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# [75, 'fantasy landscape with a lake and an oak in background masterful']  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								schedule_parser  =  lark . Lark ( r """ 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								! start :  ( prompt  |  / [ ] [ ( ) : ] / + ) *  
						 
					
						
							
								
									
										
										
										
											2022-10-05 19:10:39 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								prompt :  ( emphasized  |  scheduled  |  alternate  |  plain  |  WHITESPACE ) *  
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								! emphasized :  " ( "  prompt  " ) "  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        |  " ( "  prompt  " : "  prompt  " ) " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        |  " [ "  prompt  " ] " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								scheduled :  " [ "  [ prompt  " : " ]  prompt  " : "  [ WHITESPACE ]  NUMBER  " ] "  
						 
					
						
							
								
									
										
										
										
											2022-10-05 19:10:39 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								alternate :  " [ "  prompt  ( " | "  prompt ) +  " ] "  
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								WHITESPACE :  / \s + /  
						 
					
						
							
								
									
										
										
										
											2022-10-05 19:10:39 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								plain :  / ( [ ^ \\\[ \] ( ) : | ] | \\. ) + /  
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								% import  common . SIGNED_NUMBER  - >  NUMBER  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								""" ) 
  
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  get_learned_conditioning_prompt_schedules ( prompts ,  steps ) :  
						 
					
						
							
								
									
										
										
										
											2022-10-03 19:25:36 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    >> >  g  =  lambda  p :  get_learned_conditioning_prompt_schedules ( [ p ] ,  10 ) [ 0 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  g ( " test " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ 10 ,  ' test ' ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  g ( " a [b:3] " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ 3 ,  ' a  ' ] ,  [ 10 ,  ' a b ' ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  g ( " a [b: 3] " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ 3 ,  ' a  ' ] ,  [ 10 ,  ' a b ' ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  g ( " a [[[b]]:2] " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ 2 ,  ' a  ' ] ,  [ 10 ,  ' a [[b]] ' ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  g ( " [(a:2):3] " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ 3 ,  ' ' ] ,  [ 10 ,  ' (a:2) ' ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  g ( " a [b : c : 1] d " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ 1 ,  ' a b  d ' ] ,  [ 10 ,  ' a  c  d ' ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  g ( " a[b:[c:d:2]:1]e " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ 1 ,  ' abe ' ] ,  [ 2 ,  ' ace ' ] ,  [ 10 ,  ' ade ' ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  g ( " a [unbalanced " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ 10 ,  ' a [unbalanced ' ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  g ( " a [b:.5] c " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ 5 ,  ' a  c ' ] ,  [ 10 ,  ' a b c ' ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  g ( " a [ { b|d { :.5] c " )   # not handling this right now 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ 5 ,  ' a  c ' ] ,  [ 10 ,  ' a  { b|d {  c ' ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  g ( " ((a][:b:c [d:3] " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ 3 ,  ' ((a][:b:c  ' ] ,  [ 10 ,  ' ((a][:b:c d ' ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 18:02:01 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-03 19:25:36 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    def  collect_steps ( steps ,  tree ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        l  =  [ steps ] 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        class  CollectSteps ( lark . Visitor ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-03 19:25:36 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            def  scheduled ( self ,  tree ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                tree . children [ - 1 ]  =  float ( tree . children [ - 1 ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                if  tree . children [ - 1 ]  <  1 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    tree . children [ - 1 ]  * =  steps 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                tree . children [ - 1 ]  =  min ( steps ,  int ( tree . children [ - 1 ] ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                l . append ( tree . children [ - 1 ] ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-05 19:10:39 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            def  alternate ( self ,  tree ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                l . extend ( range ( 1 ,  steps + 1 ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-03 19:25:36 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        CollectSteps ( ) . visit ( tree ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  sorted ( set ( l ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 18:02:01 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-03 19:25:36 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    def  at_step ( step ,  tree ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        class  AtStep ( lark . Transformer ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-03 19:25:36 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            def  scheduled ( self ,  args ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                before ,  after ,  _ ,  when  =  args 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                yield  before  or  ( )  if  step  < =  when  else  after 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-05 19:10:39 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            def  alternate ( self ,  args ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                yield  next ( args [ ( step  -  1 ) % len ( args ) ] ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-03 19:25:36 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            def  start ( self ,  args ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                def  flatten ( x ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    if  type ( x )  ==  str : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                        yield  x 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                        for  gen  in  x : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                            yield from  flatten ( gen ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                return  ' ' . join ( flatten ( args ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-03 19:25:36 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            def  plain ( self ,  args ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                yield  args [ 0 ] . value 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            def  __default__ ( self ,  data ,  children ,  meta ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                for  child  in  children : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    yield from  child 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  AtStep ( ) . transform ( tree ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-03 19:25:36 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    def  get_schedule ( prompt ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        try : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            tree  =  schedule_parser . parse ( prompt ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        except  lark . exceptions . LarkError  as  e : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                import  traceback 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                traceback . print_exc ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  [ [ steps ,  prompt ] ] 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-03 19:25:36 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  [ [ t ,  at_step ( t ,  tree ) ]  for  t  in  collect_steps ( steps ,  tree ) ] 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 18:02:01 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    promptdict  =  { prompt :  get_schedule ( prompt )  for  prompt  in  set ( prompts ) } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  [ promptdict [ prompt ]  for  prompt  in  prompts ] 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								ScheduledPromptConditioning  =  namedtuple ( " ScheduledPromptConditioning " ,  [ " end_at_step " ,  " cond " ] )  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  get_learned_conditioning ( model ,  prompts ,  steps ) :  
						 
					
						
							
								
									
										
										
										
											2022-10-05 23:16:27 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    """ converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    and  the  sampling  step  at  which  this  condition  is  to  be  replaced  by  the  next  one . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    Input : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ( model ,  [ ' a red crown ' ,  ' a [blue:green:5] jeweled crown ' ] ,  20 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    Output : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        [ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ScheduledPromptConditioning ( end_at_step = 20 ,  cond = tensor ( [ [ - 0.3886 ,   0.0229 ,  - 0.0523 ,   . . . ,  - 0.4901 ,  - 0.3066 ,   0.0674 ] ,  . . . ,  [  0.3317 ,  - 0.5102 ,  - 0.4066 ,   . . . ,   0.4119 ,  - 0.7647 ,  - 1.0160 ] ] ,  device = ' cuda:0 ' ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        [ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ScheduledPromptConditioning ( end_at_step = 5 ,  cond = tensor ( [ [ - 0.3886 ,   0.0229 ,  - 0.0522 ,   . . . ,  - 0.4901 ,  - 0.3067 ,   0.0673 ] ,  . . . ,  [ - 0.0192 ,   0.3867 ,  - 0.4644 ,   . . . ,   0.1135 ,  - 0.3696 ,  - 0.4625 ] ] ,  device = ' cuda:0 ' ) ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ScheduledPromptConditioning ( end_at_step = 20 ,  cond = tensor ( [ [ - 0.3886 ,   0.0229 ,  - 0.0522 ,   . . . ,  - 0.4901 ,  - 0.3067 ,   0.0673 ] ,  . . . ,  [ - 0.7352 ,  - 0.4356 ,  - 0.7888 ,   . . . ,   0.6994 ,  - 0.4312 ,  - 1.2593 ] ] ,  device = ' cuda:0 ' ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    res  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    prompt_schedules  =  get_learned_conditioning_prompt_schedules ( prompts ,  steps ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cache  =  { } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  prompt ,  prompt_schedule  in  zip ( prompts ,  prompt_schedules ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        cached  =  cache . get ( prompt ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  cached  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            res . append ( cached ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-15 18:05:42 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            continue 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        texts  =  [ x [ 1 ]  for  x  in  prompt_schedule ] 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        conds  =  model . get_learned_conditioning ( texts ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        cond_schedule  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  i ,  ( end_at_step ,  text )  in  enumerate ( prompt_schedule ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            cond_schedule . append ( ScheduledPromptConditioning ( end_at_step ,  conds [ i ] ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        cache [ prompt ]  =  cond_schedule 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        res . append ( cond_schedule ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-05 23:16:27 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  res 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								re_AND  =  re . compile ( r " \ bAND \ b " )  
						 
					
						
							
								
									
										
										
										
											2022-10-05 23:52:05 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								re_weight  =  re . compile ( r " ^(.*?)(?: \ s*: \ s*([-+]?(?: \ d+ \ .?| \ d* \ . \ d+)))? \ s*$ " )  
						 
					
						
							
								
									
										
										
										
											2022-10-05 23:16:27 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  get_multicond_prompt_list ( prompts ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    res_indexes  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    prompt_flat_list  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    prompt_indexes  =  { } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  prompt  in  prompts : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        subprompts  =  re_AND . split ( prompt ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        indexes  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  subprompt  in  subprompts : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-06 13:21:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            match  =  re_weight . search ( subprompt ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            text ,  weight  =  match . groups ( )  if  match  is  not  None  else  ( subprompt ,  1.0 ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-05 23:16:27 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            weight  =  float ( weight )  if  weight  is  not  None  else  1.0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            index  =  prompt_indexes . get ( text ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  index  is  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                index  =  len ( prompt_flat_list ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                prompt_flat_list . append ( text ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                prompt_indexes [ text ]  =  index 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            indexes . append ( ( index ,  weight ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        res_indexes . append ( indexes ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  res_indexes ,  prompt_flat_list ,  prompt_indexes 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  ComposableScheduledPromptConditioning :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __init__ ( self ,  schedules ,  weight = 1.0 ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-05 22:11:30 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . schedules :  List [ ScheduledPromptConditioning ]  =  schedules 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-05 23:16:27 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . weight :  float  =  weight 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  MulticondLearnedConditioning :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __init__ ( self ,  shape ,  batch ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . shape :  tuple  =  shape   # the shape field is needed to send this object to DDIM/PLMS 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-05 22:11:30 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . batch :  List [ List [ ComposableScheduledPromptConditioning ] ]  =  batch 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-05 23:16:27 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  get_multicond_learned_conditioning ( model ,  prompts ,  steps )  - >  MulticondLearnedConditioning :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    """ same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    For  each  prompt ,  the  list  is  obtained  by  splitting  the  prompt  using  the  AND  separator . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    https : / / energy - based - model . github . io / Compositional - Visual - Generation - with - Composable - Diffusion - Models / 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    res_indexes ,  prompt_flat_list ,  prompt_indexes  =  get_multicond_prompt_list ( prompts ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    learned_conditioning  =  get_learned_conditioning ( model ,  prompt_flat_list ,  steps ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    res  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  indexes  in  res_indexes : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        res . append ( [ ComposableScheduledPromptConditioning ( learned_conditioning [ i ] ,  weight )  for  i ,  weight  in  indexes ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  MulticondLearnedConditioning ( shape = ( len ( prompts ) , ) ,  batch = res ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-05 22:11:30 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  reconstruct_cond_batch ( c :  List [ List [ ScheduledPromptConditioning ] ] ,  current_step ) :  
						 
					
						
							
								
									
										
										
										
											2022-10-05 23:16:27 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    param  =  c [ 0 ] [ 0 ] . cond 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    res  =  torch . zeros ( ( len ( c ) , )  +  param . shape ,  device = param . device ,  dtype = param . dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  i ,  cond_schedule  in  enumerate ( c ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        target_index  =  0 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        for  current ,  ( end_at ,  cond )  in  enumerate ( cond_schedule ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            if  current_step  < =  end_at : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                target_index  =  current 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								                break 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        res [ i ]  =  cond_schedule [ target_index ] . cond 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-19 18:18:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  res 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-05 23:16:27 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  reconstruct_multicond_batch ( c :  MulticondLearnedConditioning ,  current_step ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    param  =  c . batch [ 0 ] [ 0 ] . schedules [ 0 ] . cond 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    tensors  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    conds_list  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  batch_no ,  composable_prompts  in  enumerate ( c . batch ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        conds_for_batch  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  cond_index ,  composable_prompt  in  enumerate ( composable_prompts ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            target_index  =  0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            for  current ,  ( end_at ,  cond )  in  enumerate ( composable_prompt . schedules ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                if  current_step  < =  end_at : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    target_index  =  current 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    break 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            conds_for_batch . append ( ( len ( tensors ) ,  composable_prompt . weight ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            tensors . append ( composable_prompt . schedules [ target_index ] . cond ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        conds_list . append ( conds_for_batch ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-08 15:43:25 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # and won't be able to torch.stack them. So this fixes that. 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    token_count  =  max ( [ x . shape [ 0 ]  for  x  in  tensors ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  i  in  range ( len ( tensors ) ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  tensors [ i ] . shape [ 0 ]  !=  token_count : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            last_vector  =  tensors [ i ] [ - 1 : ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            last_vector_repeated  =  last_vector . repeat ( [ token_count  -  tensors [ i ] . shape [ 0 ] ,  1 ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            tensors [ i ]  =  torch . vstack ( [ tensors [ i ] ,  last_vector_repeated ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-05 23:16:27 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  conds_list ,  torch . stack ( tensors ) . to ( device = param . device ,  dtype = param . dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 11:31:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								re_attention  =  re . compile ( r """ 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								\\\( | 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								\\\) | 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								\\\[ | 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								\\] | 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								\\\\| 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								\\| 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								\( | 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								\[ | 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								: ( [ + - ] ? [ . \d ] + ) \) |  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								\) | 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								] |  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								[ ^ \\( ) \[ \] : ] + |  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								:  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								""" , re.X) 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  parse_prompt_attention ( text ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-19 02:18:56 +09:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    Parses  a  string  with  attention  tokens  and  returns  a  list  of  pairs :  text  and  its  associated  weight . 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 11:31:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    Accepted  tokens  are : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ( abc )  -  increases  attention  to  abc  by  a  multiplier  of  1.1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ( abc : 3.12 )  -  increases  attention  to  abc  by  a  multiplier  of  3.12 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      [ abc ]  -  decreases  attention  to  abc  by  a  multiplier  of  1.1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      \(  -  literal  character  ' ( ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      \[  -  literal  character  ' [ ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      \)  -  literal  character  ' ) ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      \]  -  literal  character  ' ] ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      \\ -  literal  character  ' \' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      anything  else  -  just  text 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    >> >  parse_prompt_attention ( ' normal text ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ ' normal text ' ,  1.0 ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  parse_prompt_attention ( ' an (important) word ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ ' an  ' ,  1.0 ] ,  [ ' important ' ,  1.1 ] ,  [ '  word ' ,  1.0 ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  parse_prompt_attention ( ' (unbalanced ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ ' unbalanced ' ,  1.1 ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  parse_prompt_attention ( ' \ (literal \ ] ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ ' (literal] ' ,  1.0 ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  parse_prompt_attention ( ' (unnecessary)(parens) ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ ' unnecessaryparens ' ,  1.1 ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    >> >  parse_prompt_attention ( ' a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))). ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    [ [ ' a  ' ,  1.0 ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     [ ' house ' ,  1.5730000000000004 ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     [ '   ' ,  1.1 ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     [ ' on ' ,  1.0 ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     [ '  a  ' ,  1.1 ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     [ ' hill ' ,  0.55 ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     [ ' , sun,  ' ,  1.1 ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     [ ' sky ' ,  1.4641000000000006 ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     [ ' . ' ,  1.1 ] ] 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 11:31:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-15 13:10:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 11:31:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    res  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    round_brackets  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    square_brackets  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    round_bracket_multiplier  =  1.1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    square_bracket_multiplier  =  1  /  1.1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  multiply_range ( start_position ,  multiplier ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  p  in  range ( start_position ,  len ( res ) ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            res [ p ] [ 1 ]  * =  multiplier 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  m  in  re_attention . finditer ( text ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        text  =  m . group ( 0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        weight  =  m . group ( 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  text . startswith ( ' \\ ' ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            res . append ( [ text [ 1 : ] ,  1.0 ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        elif  text  ==  ' ( ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            round_brackets . append ( len ( res ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        elif  text  ==  ' [ ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            square_brackets . append ( len ( res ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        elif  weight  is  not  None  and  len ( round_brackets )  >  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            multiply_range ( round_brackets . pop ( ) ,  float ( weight ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        elif  text  ==  ' ) '  and  len ( round_brackets )  >  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            multiply_range ( round_brackets . pop ( ) ,  round_bracket_multiplier ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        elif  text  ==  ' ] '  and  len ( square_brackets )  >  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            multiply_range ( square_brackets . pop ( ) ,  square_bracket_multiplier ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            res . append ( [ text ,  1.0 ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  pos  in  round_brackets : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        multiply_range ( pos ,  round_bracket_multiplier ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  pos  in  square_brackets : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        multiply_range ( pos ,  square_bracket_multiplier ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 11:39:55 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  len ( res )  ==  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        res  =  [ [ " " ,  1.0 ] ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    # merge runs of identical weights 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    i  =  0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    while  i  +  1  <  len ( res ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  res [ i ] [ 1 ]  ==  res [ i  +  1 ] [ 1 ] : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            res [ i ] [ 0 ]  + =  res [ i  +  1 ] [ 0 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            res . pop ( i  +  1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            i  + =  1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 11:31:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  res 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 09:49:51 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								if  __name__  ==  " __main__ " :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    import  doctest 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    doctest . testmod ( optionflags = doctest . NORMALIZE_WHITESPACE ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								else :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    import  torch   # doctest faster