import argparse import logging from policy import PhysicistPolicy from rasa_core import utils from rasa_core.agent import Agent from rasa_core.policies.memoization import MemoizationPolicy logger = logging.getLogger(__name__) def train_dialogue(domain_file="domain.yml", model_path="models/dialogue", training_data_file="stories.md"): agent = Agent(domain_file, policies=[MemoizationPolicy(max_history=3), PhysicistPolicy(batch_size=100, epochs=400, validation_split=0.2)]) training_data = agent.load_data(training_data_file) agent.train(training_data) agent.persist(model_path) return agent def train_nlu(): from rasa_nlu.training_data import load_data from rasa_nlu import config from rasa_nlu.model import Trainer training_data = load_data('nlu.md') trainer = Trainer(config.load('nlu_config.yml')) trainer.train(training_data) model_directory = trainer.persist('models/nlu/', fixed_model_name="current") return model_directory if __name__ == '__main__': utils.configure_colored_logging(loglevel="INFO") parser = argparse.ArgumentParser(description='starts the bot') parser.add_argument('task', choices=["train-nlu", "train-dialogue", "run"], help="what the bot should do - e.g. run or train?") task = parser.parse_args().task # decide what to do based on first parameter of the script if task == "train-nlu": train_nlu() elif task == "train-dialogue": train_dialogue()