Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import os
2import shutil
3import logging
6class ModelGenerator:
8 uses_stories: bool = False
9 dry_run: bool = False
11 def __init__(self, dry_run: bool = False):
12 self.dry_run = dry_run
14 def get_intent(self, intent: str, plugin: str):
15 if plugin == "main":
16 return intent
17 else:
18 return f'{intent}_{plugin}'
20 def get_action(self, action: str, plugin: str):
21 if action.startswith('utter_') or action.startswith('validate_'):
22 return f'{action}_{plugin}'
23 else:
24 return f'action_{action}_{plugin}'
26 def generate_config_yml(self, location: str, rasa_language: str, spacy_model: str):
27 with open(location + '/config.yml', 'w') as f:
28 f.write(f"""language: {rasa_language}
29pipeline:
30 - name: SpacyNLP
31 model: "{spacy_model}"
32 case_sensitive: False
33 - name: SpacyTokenizer
34 - name: SpacyFeaturizer
35 - name: RegexFeaturizer
36 case_sensitive: False
37 - name: LexicalSyntacticFeaturizer
38 - name: CountVectorsFeaturizer
39 - name: CountVectorsFeaturizer
40 analyzer: char_wb
41 min_ngram: 1
42 max_ngram: 4
43 - name: DIETClassifier
44 epochs: 100
45 - name: EntitySynonymMapper
46 - name: SpacyEntityExtractor
47 - name: RegexEntityExtractor
48 case_sensitive: False
49 use_lookup_tables: True
50 use_regexes: True
51 "use_word_boundaries": True
52 - name: ResponseSelector
53 epochs: 100
54 - name: FallbackClassifier
55 threshold: 0.75
56 ambiguity_threshold: 0.1
58policies:
59 - name: AugmentedMemoizationPolicy
60# - name: TEDPolicy
61# max_history: 5
62# epochs: 100
63 - name: RulePolicy
65""")
67 def generate_credentials_yml(self, location):
68 with open(location + '/credentials.yml', 'w') as f:
69 f.write("""rasa:
70 url: "http://localhost:5002/api\"""")
72 def generate_endpoints_yml(self, location):
73 with open(location + '/endpoints.yml', 'w') as f:
74 f.write("""action_endpoint:
75 url: "http://localhost:5055/webhook\"""")
77 def generate_domain(self, config, location, plugin):
78 with open(os.path.join(location, plugin + '.yml'), 'w') as f:
79 f.write('version: "2.0"\n')
81 # Intents
82 if 'intents' not in config:
83 raise RuntimeError("No intents have been registered.")
85 f.write('\nintents:\n')
86 f.writelines([' - {}\n'.format(self.get_intent(intent['intent_id'], plugin))
87 for intent in config['intents']])
88 f.write('\n')
90 if 'entities' in config:
91 f.write('\nentities:\n')
92 f.writelines([f' - {entity}\n'
93 for entity in config['entities']])
94 f.write('\n')
96 if 'slots' in config:
97 f.write('\nslots:\n')
98 for slot in config['slots']:
99 if 'slot_id' not in slot:
100 raise RuntimeError(
101 "Each slot definition must contain a 'slot_id' field.")
103 if 'type' not in slot:
104 raise RuntimeError(
105 "Each slot definition must contain a 'type' field.")
107 if slot['type'] not in ['text', 'bool', 'categorical', 'float', 'list', 'any']:
108 raise RuntimeError(
109 f"Unknown slot type '{slot['type']}'.")
111 f.write(f" {slot['slot_id']}:\n")
112 f.write(f" type: {slot['type']}\n")
114 if 'influence_conversation' in slot:
115 f.write(
116 f" influence_conversation: {slot['influence_conversation']}\n")
118 if slot['type'] == 'float':
119 if 'min_value' in slot:
120 f.write(f" min_value: {slot['min_value']}")
122 if 'max_value' in slot:
123 f.write(f" max_value: {slot['max_value']}")
125 if slot['type'] == 'categorical':
126 if 'values' not in slot or not isinstance(slot['values'], list):
127 raise RuntimeError(
128 "Slot of type 'categorical' must have a 'values' list of acceptable slot values.")
130 f.writelines(
131 [f" - {value}" for value in slot['values']])
133 f.write('\n')
135 if 'actions' in config:
136 f.write('\nactions:\n')
137 f.writelines([' - {}\n'.format(self.get_action(action, plugin))
138 for action in config['actions']])
139 f.write('\n\n')
141 # Responses
142 if 'responses' in config:
143 f.write('\nresponses:\n')
144 for response in config['responses']:
145 if 'text' not in response:
146 raise RuntimeError(
147 "There is no 'text' field defined for a response.")
149 f.write(' {}: \n'.format(self.get_action(
150 response['response_id'], plugin)))
151 f.writelines([' - text: "' + text +
152 '"\n' for text in response['text']])
153 f.write('\n')
155 # Session config
156 f.write("""
157session_config:
158 session_expiration_time: 20
159 carry_over_slots_to_new_session: false
160""")
162 def generate_training_data(self, config, location, plugin):
163 with open(os.path.join(location, plugin + '.yml'), 'w') as f:
164 f.write('version: "2.0"\n')
166 if 'intents' not in config:
167 raise RuntimeError("Configurations must have intents.")
169 # NLU
170 f.write('\nnlu:\n')
171 for intent in config['intents']:
172 if 'intent_id' not in intent or 'examples' not in intent:
173 raise RuntimeError("Intents must have an id and examples.")
175 f.write(' - intent: {}\n examples: |\n'.format(
176 self.get_intent(intent['intent_id'], plugin)))
177 f.writelines(
178 [' - ' + example + '\n' for example in intent['examples']])
179 f.write('\n')
181 if 'synonyms' in config:
182 for synonym in config['synonyms']:
183 if 'synonym_id' not in synonym or 'examples' not in synonym:
184 raise RuntimeError(
185 "A synonym must have both a 'synonym_id' field and a list of examples under 'examples'.")
187 f.write(
188 ' - synonym: {}\n examples: |\n'.format(synonym['synonym_id']))
189 f.writelines(
190 [' - ' + example + '\n' for example in synonym['examples']])
191 f.write('\n')
193 if 'regexes' in config:
194 for regex in config['regexes']:
195 if 'regex_id' not in regex or 'examples' not in regex:
196 raise RuntimeError(
197 "A regex must have both a 'regex_id' field and a list of examples under 'examples'.")
199 f.write(' - regex: {}\n examples: |\n'.format(
200 regex['regex_id']))
201 f.writelines(
202 [' - ' + example + '\n' for example in regex['examples']])
203 f.write('\n')
205 if 'lookups' in config:
206 for lookup in config['lookups']:
207 if 'lookup_id' not in lookup or 'examples' not in lookup:
208 raise RuntimeError(
209 "A lookup must have both a 'lookup_id' field and a list of examples under 'examples'.")
211 f.write(' - lookup: {}\n examples: |\n'.format(
212 lookup['lookup_id']))
213 f.writelines(
214 [' - ' + example + '\n' for example in lookup['examples']])
215 f.write('\n')
217 # Rules
218 if 'rules' in config or 'skills' in config:
219 f.write('\nrules:\n')
221 if 'skills' in config:
222 for skill in config['skills']:
223 f.write(" - rule: {}\n steps:\n - intent: {}\n".format(
224 skill.get('description', ''), self.get_intent(skill['intent'], plugin)))
225 f.writelines([' - action: {}\n'.format(self.get_action(action, plugin))
226 for action in skill['actions']])
227 f.write('\n')
229 # Rules
230 if 'rules' in config:
231 for rule in config['rules']:
232 f.write(
233 f" - rule: {rule.get('description', '')}\n steps:\n")
234 f.writelines([f" - {step['type']}: {self.get_intent(step['value'], plugin) if step['type'] == 'intent' else self.get_action(step['value'], plugin)}\n"
235 for step in rule['steps']])
236 f.write('\n')
238 # Stories
239 if 'stories' in config:
240 self.uses_stories = True
241 f.write('\nstories:\n')
242 for story in config['stories']:
243 f.write(
244 ' - story: {}\n steps:\n'.format(story.get('description', '')))
245 f.writelines([' - ' + step['type'] + ': ' + self.get_intent(step['step_id'], plugin) if step['type'] == 'intent' else self.get_action(step['step_id'], plugin) + '\n'
246 for step in story['steps']])
247 f.write('\n')
249 def generate(self, configs, config_location, output_location, rasa_language, spacy_model):
250 if os.path.exists(config_location):
251 shutil.rmtree(config_location)
253 os.makedirs(config_location)
255 if not os.path.exists(output_location):
256 os.makedirs(output_location)
258 self.generate_config_yml(config_location, rasa_language, spacy_model)
259 self.generate_credentials_yml(config_location)
260 self.generate_endpoints_yml(config_location)
262 domain_path = os.path.join(config_location, 'domain')
263 if os.path.exists(domain_path):
264 shutil.rmtree(domain_path)
265 os.makedirs(domain_path)
267 training_path = os.path.join(config_location, 'training')
268 if os.path.exists(training_path):
269 shutil.rmtree(training_path)
270 os.makedirs(training_path)
272 try:
273 for config in configs:
274 plugin = config['plugin'] if 'plugin' in config else 'main'
275 self.generate_domain(config, domain_path, plugin)
276 self.generate_training_data(config, training_path, plugin)
277 except Exception as e:
278 logging.error(str(e))
279 #raise e
280 return None
282 if self.dry_run:
283 return 0
285 from rasa import train
287 result = train(
288 domain=config_location + '/domain',
289 config=config_location + '/config.yml',
290 training_files=config_location + '/training',
291 output=output_location
292 )
294 return result