Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import os 

2import shutil 

3import logging 

4 

5 

6class ModelGenerator: 

7 

8 uses_stories: bool = False 

9 dry_run: bool = False 

10 

11 def __init__(self, dry_run: bool = False): 

12 self.dry_run = dry_run 

13 

14 def get_intent(self, intent: str, plugin: str): 

15 if plugin == "main": 

16 return intent 

17 else: 

18 return f'{intent}_{plugin}' 

19 

20 def get_action(self, action: str, plugin: str): 

21 if action.startswith('utter_') or action.startswith('validate_'): 

22 return f'{action}_{plugin}' 

23 else: 

24 return f'action_{action}_{plugin}' 

25 

26 def generate_config_yml(self, location: str, rasa_language: str, spacy_model: str): 

27 with open(location + '/config.yml', 'w') as f: 

28 f.write(f"""language: {rasa_language} 

29pipeline: 

30 - name: SpacyNLP 

31 model: "{spacy_model}" 

32 case_sensitive: False 

33 - name: SpacyTokenizer 

34 - name: SpacyFeaturizer 

35 - name: RegexFeaturizer 

36 case_sensitive: False 

37 - name: LexicalSyntacticFeaturizer 

38 - name: CountVectorsFeaturizer 

39 - name: CountVectorsFeaturizer 

40 analyzer: char_wb 

41 min_ngram: 1 

42 max_ngram: 4 

43 - name: DIETClassifier 

44 epochs: 100 

45 - name: EntitySynonymMapper 

46 - name: SpacyEntityExtractor 

47 - name: RegexEntityExtractor 

48 case_sensitive: False 

49 use_lookup_tables: True 

50 use_regexes: True 

51 "use_word_boundaries": True 

52 - name: ResponseSelector 

53 epochs: 100 

54 - name: FallbackClassifier 

55 threshold: 0.75 

56 ambiguity_threshold: 0.1 

57 

58policies: 

59 - name: AugmentedMemoizationPolicy 

60# - name: TEDPolicy 

61# max_history: 5 

62# epochs: 100 

63 - name: RulePolicy 

64 

65""") 

66 

67 def generate_credentials_yml(self, location): 

68 with open(location + '/credentials.yml', 'w') as f: 

69 f.write("""rasa: 

70 url: "http://localhost:5002/api\"""") 

71 

72 def generate_endpoints_yml(self, location): 

73 with open(location + '/endpoints.yml', 'w') as f: 

74 f.write("""action_endpoint: 

75 url: "http://localhost:5055/webhook\"""") 

76 

77 def generate_domain(self, config, location, plugin): 

78 with open(os.path.join(location, plugin + '.yml'), 'w') as f: 

79 f.write('version: "2.0"\n') 

80 

81 # Intents 

82 if 'intents' not in config: 

83 raise RuntimeError("No intents have been registered.") 

84 

85 f.write('\nintents:\n') 

86 f.writelines([' - {}\n'.format(self.get_intent(intent['intent_id'], plugin)) 

87 for intent in config['intents']]) 

88 f.write('\n') 

89 

90 if 'entities' in config: 

91 f.write('\nentities:\n') 

92 f.writelines([f' - {entity}\n' 

93 for entity in config['entities']]) 

94 f.write('\n') 

95 

96 if 'slots' in config: 

97 f.write('\nslots:\n') 

98 for slot in config['slots']: 

99 if 'slot_id' not in slot: 

100 raise RuntimeError( 

101 "Each slot definition must contain a 'slot_id' field.") 

102 

103 if 'type' not in slot: 

104 raise RuntimeError( 

105 "Each slot definition must contain a 'type' field.") 

106 

107 if slot['type'] not in ['text', 'bool', 'categorical', 'float', 'list', 'any']: 

108 raise RuntimeError( 

109 f"Unknown slot type '{slot['type']}'.") 

110 

111 f.write(f" {slot['slot_id']}:\n") 

112 f.write(f" type: {slot['type']}\n") 

113 

114 if 'influence_conversation' in slot: 

115 f.write( 

116 f" influence_conversation: {slot['influence_conversation']}\n") 

117 

118 if slot['type'] == 'float': 

119 if 'min_value' in slot: 

120 f.write(f" min_value: {slot['min_value']}") 

121 

122 if 'max_value' in slot: 

123 f.write(f" max_value: {slot['max_value']}") 

124 

125 if slot['type'] == 'categorical': 

126 if 'values' not in slot or not isinstance(slot['values'], list): 

127 raise RuntimeError( 

128 "Slot of type 'categorical' must have a 'values' list of acceptable slot values.") 

129 

130 f.writelines( 

131 [f" - {value}" for value in slot['values']]) 

132 

133 f.write('\n') 

134 

135 if 'actions' in config: 

136 f.write('\nactions:\n') 

137 f.writelines([' - {}\n'.format(self.get_action(action, plugin)) 

138 for action in config['actions']]) 

139 f.write('\n\n') 

140 

141 # Responses 

142 if 'responses' in config: 

143 f.write('\nresponses:\n') 

144 for response in config['responses']: 

145 if 'text' not in response: 

146 raise RuntimeError( 

147 "There is no 'text' field defined for a response.") 

148 

149 f.write(' {}: \n'.format(self.get_action( 

150 response['response_id'], plugin))) 

151 f.writelines([' - text: "' + text + 

152 '"\n' for text in response['text']]) 

153 f.write('\n') 

154 

155 # Session config 

156 f.write(""" 

157session_config: 

158 session_expiration_time: 20 

159 carry_over_slots_to_new_session: false 

160""") 

161 

162 def generate_training_data(self, config, location, plugin): 

163 with open(os.path.join(location, plugin + '.yml'), 'w') as f: 

164 f.write('version: "2.0"\n') 

165 

166 if 'intents' not in config: 

167 raise RuntimeError("Configurations must have intents.") 

168 

169 # NLU 

170 f.write('\nnlu:\n') 

171 for intent in config['intents']: 

172 if 'intent_id' not in intent or 'examples' not in intent: 

173 raise RuntimeError("Intents must have an id and examples.") 

174 

175 f.write(' - intent: {}\n examples: |\n'.format( 

176 self.get_intent(intent['intent_id'], plugin))) 

177 f.writelines( 

178 [' - ' + example + '\n' for example in intent['examples']]) 

179 f.write('\n') 

180 

181 if 'synonyms' in config: 

182 for synonym in config['synonyms']: 

183 if 'synonym_id' not in synonym or 'examples' not in synonym: 

184 raise RuntimeError( 

185 "A synonym must have both a 'synonym_id' field and a list of examples under 'examples'.") 

186 

187 f.write( 

188 ' - synonym: {}\n examples: |\n'.format(synonym['synonym_id'])) 

189 f.writelines( 

190 [' - ' + example + '\n' for example in synonym['examples']]) 

191 f.write('\n') 

192 

193 if 'regexes' in config: 

194 for regex in config['regexes']: 

195 if 'regex_id' not in regex or 'examples' not in regex: 

196 raise RuntimeError( 

197 "A regex must have both a 'regex_id' field and a list of examples under 'examples'.") 

198 

199 f.write(' - regex: {}\n examples: |\n'.format( 

200 regex['regex_id'])) 

201 f.writelines( 

202 [' - ' + example + '\n' for example in regex['examples']]) 

203 f.write('\n') 

204 

205 if 'lookups' in config: 

206 for lookup in config['lookups']: 

207 if 'lookup_id' not in lookup or 'examples' not in lookup: 

208 raise RuntimeError( 

209 "A lookup must have both a 'lookup_id' field and a list of examples under 'examples'.") 

210 

211 f.write(' - lookup: {}\n examples: |\n'.format( 

212 lookup['lookup_id'])) 

213 f.writelines( 

214 [' - ' + example + '\n' for example in lookup['examples']]) 

215 f.write('\n') 

216 

217 # Rules 

218 if 'rules' in config or 'skills' in config: 

219 f.write('\nrules:\n') 

220 

221 if 'skills' in config: 

222 for skill in config['skills']: 

223 f.write(" - rule: {}\n steps:\n - intent: {}\n".format( 

224 skill.get('description', ''), self.get_intent(skill['intent'], plugin))) 

225 f.writelines([' - action: {}\n'.format(self.get_action(action, plugin)) 

226 for action in skill['actions']]) 

227 f.write('\n') 

228 

229 # Rules 

230 if 'rules' in config: 

231 for rule in config['rules']: 

232 f.write( 

233 f" - rule: {rule.get('description', '')}\n steps:\n") 

234 f.writelines([f" - {step['type']}: {self.get_intent(step['value'], plugin) if step['type'] == 'intent' else self.get_action(step['value'], plugin)}\n" 

235 for step in rule['steps']]) 

236 f.write('\n') 

237 

238 # Stories 

239 if 'stories' in config: 

240 self.uses_stories = True 

241 f.write('\nstories:\n') 

242 for story in config['stories']: 

243 f.write( 

244 ' - story: {}\n steps:\n'.format(story.get('description', ''))) 

245 f.writelines([' - ' + step['type'] + ': ' + self.get_intent(step['step_id'], plugin) if step['type'] == 'intent' else self.get_action(step['step_id'], plugin) + '\n' 

246 for step in story['steps']]) 

247 f.write('\n') 

248 

249 def generate(self, configs, config_location, output_location, rasa_language, spacy_model): 

250 if os.path.exists(config_location): 

251 shutil.rmtree(config_location) 

252 

253 os.makedirs(config_location) 

254 

255 if not os.path.exists(output_location): 

256 os.makedirs(output_location) 

257 

258 self.generate_config_yml(config_location, rasa_language, spacy_model) 

259 self.generate_credentials_yml(config_location) 

260 self.generate_endpoints_yml(config_location) 

261 

262 domain_path = os.path.join(config_location, 'domain') 

263 if os.path.exists(domain_path): 

264 shutil.rmtree(domain_path) 

265 os.makedirs(domain_path) 

266 

267 training_path = os.path.join(config_location, 'training') 

268 if os.path.exists(training_path): 

269 shutil.rmtree(training_path) 

270 os.makedirs(training_path) 

271 

272 try: 

273 for config in configs: 

274 plugin = config['plugin'] if 'plugin' in config else 'main' 

275 self.generate_domain(config, domain_path, plugin) 

276 self.generate_training_data(config, training_path, plugin) 

277 except Exception as e: 

278 logging.error(str(e)) 

279 #raise e 

280 return None 

281 

282 if self.dry_run: 

283 return 0 

284 

285 from rasa import train 

286 

287 result = train( 

288 domain=config_location + '/domain', 

289 config=config_location + '/config.yml', 

290 training_files=config_location + '/training', 

291 output=output_location 

292 ) 

293 

294 return result