Commit 892a227c authored by Jerome Wuerf's avatar Jerome Wuerf
Browse files

Adopt program to tira requirements

parent 5f0c2cbc
from pathlib import Path
from assessment import Assessment
from deduplication import Deduplication
from indexing import Indexing
from retrieval import (Retrieval, MaximalMarginalRelevanceReranking, StructuralDistanceReranking,
ArgumentGraphReranking, WordMoverDistanceReranking, NoReranking)
......@@ -29,10 +27,6 @@ class App:
self._indexing()
elif self.command == SubCommands.RETRIEVAL:
self._retrieval()
elif self.command == SubCommands.ASSESSMENT:
self._assessment()
elif self.command == SubCommands.DEDUPLICATION:
self._deduplication()
def _indexing(self) -> None:
self.logger.info('Reading sentences and embeddings form disk.')
......@@ -44,7 +38,10 @@ class App:
).index_to_es()
def _retrieval(self) -> None:
topics = read_topics(self.config['TOPICS_PATH'], self.config['TOPIC_NRB'])
topics = read_topics(
Path(self.config['INPUT_PATH'],
'topics.xml'),
self.config['TOPIC_NRB'])
self.logger.info('Read topics!')
retrieved_results = None
if self.config['REUSE_UNRANKED']:
......@@ -60,7 +57,6 @@ class App:
self.config['NRB_CONCLUSIONS_PER_TOPIC'],
self.config['NRB_PREMISES_PER_CONCLUSION']
).retrieve()
pickle.dump(retrieved_results, open('/data/baseline_dump.pkl', 'wb'))
# Pattern matching is only available pre Python 3.10 :(
reranker = None
......@@ -84,19 +80,9 @@ class App:
reranker.rerank()
write_terc_file(reranker.generate_trec_rows(),
Path(self.config['TREC_OUTPUT_PATH'],
f'{ self.config["RUN_NAME"]}.trec')
Path(self.config['OUTPUT_PATH'], 'run.txt')
)
def _assessment(self) -> None:
sentences = read_sentences(self.config['SENTENCES_PATH'])
results = read_results(self.config['RESULT_PATH'])
topics = read_topics(self.config['TOPICS_PATH'])
Assessment(results, sentences, topics).cmdloop()
def _deduplication(self):
Deduplication(self.config['ELASTIC_HOST'], self.config['INDICES']).deduplicate()
if __name__ == '__main__':
logger = logging.getLogger(__name__)
......
from .assessment import Assessment
from cmd import Cmd
from os import listdir
from pathlib import Path
from typing import Tuple
import pandas as pd
class Assessment(Cmd):
"""
TODO
"""
def __init__(self, results: pd.DataFrame, sentences: pd.DataFrame, topics: pd.DataFrame):
super().__init__()
sentences['id'] = sentences['arg_id'] + '_' \
+ sentences['sent_type'] + '_' + sentences['sent_pos']
sentences = sentences.drop(['arg_id', 'sent_type', 'sent_pos'], axis='columns')
self.sentences = sentences
self.results = results
self.topics = topics
self.idx = 0
self._set_current_topic()
def do_1(self, arg):
pass
def do_2(self, arg):
pass
def do_3(self, arg):
pass
def precmd(self, line):
self._print_sentence_pair()
self.idx += 1
return line
def postcmd(self, stop: bool, line: str) -> bool:
self._set_current_topic()
return stop
def preloop(self):
self._print_sentence_pair()
self.idx += 1
def _print_sentence_pair(self) -> None:
result_row = self.results.iloc[self.idx]
conclusion_sent_id, premise_sent_id = result_row['pair'].split(',')
conclusion_sent = self._get_sentence(conclusion_sent_id)
premise_sent = self._get_sentence(premise_sent_id)
print(f'\n\tConclusion: {conclusion_sent}\n\tPremise: {premise_sent}\n')
def _get_sentence(self, sent_id: str) -> str:
return self.sentences[self.sentences['id'] == sent_id].iloc[0]['sent_text']
def _set_current_topic(self) -> None:
self.prompt = f'(Topic: {self.topics.loc[self.results.iloc[self.idx]["qid"]]["topic"]})>'
from .deduplication import Deduplication
import hashlib
import logging
from elasticsearch import Elasticsearch
from elasticsearch.helpers import streaming_bulk
from tqdm import tqdm
class Deduplication:
"""
TODO
https://www.elastic.co/de/blog/how-to-find-and-remove-duplicate-documents-in-elasticsearch
"""
TOTAL_DUPLICATES = 786383
def __init__(self, elastic_host: str = 'elastic', indices: list = ['conc', 'premise']):
self.logger = logging.getLogger(__name__)
self.es = Elasticsearch(elastic_host)
self.indices = indices
self.keys_to_hash = ['sentence_text']
def deduplicate(self) -> None:
"""
TODO
"""
self.logger.info('Starting deduplication...')
with tqdm(total=Deduplication.TOTAL_DUPLICATES) as progress_bar:
for ok, action in streaming_bulk(client=self.es,
actions=self.gen_delete_actions()):
progress_bar.update()
self.logger.info('Finished deduplication!')
def gen_delete_actions(self):
for index in self.indices:
for id_list in self.get_duplicates(index).values():
# Skip the fist element, delete all following with same hash
for id in id_list[1:]:
yield {"_op_type": "delete", "_index": index, "_id": id}
def get_duplicates(self, index: str) -> dict:
"""
TODO
"""
duplicates = {}
data = self.es.search(index=index, scroll='1m', body={"size": 10000, "query": {
"match_all": {}}, "fields": self.keys_to_hash, "_source": False})
sid = data['_scroll_id']
scroll_size = len(data['hits']['hits'])
while scroll_size > 0:
self._add_duplicates_of_batch(data['hits']['hits'], duplicates)
data = self.es.scroll(scroll_id=sid, scroll='2m')
sid = data['_scroll_id']
scroll_size = len(data['hits']['hits'])
return duplicates
def _add_duplicates_of_batch(self, hits, duplicates: dict) -> dict:
"""
TODO
"""
for item in hits:
key_values = ''.join([item['fields'][k][0] for k in self.keys_to_hash])
_id = item["_id"]
hashval = hashlib.md5(key_values.encode('utf-8')).digest()
duplicates.setdefault(hashval, []).append(_id)
return duplicates
......@@ -4,8 +4,6 @@ from enum import Enum
class SubCommands(Enum):
INDEXING = 'indexing'
RETRIEVAL = 'retrieval'
ASSESSMENT = 'assessment'
DEDUPLICATION = 'deduplication'
class RerankingOptions(Enum):
......
......@@ -16,8 +16,8 @@ class Configuration():
'EMBEDDINGS_PATH',
'CREATE_INDEX',
'ELASTIC_HOST'],
SubCommands.RETRIEVAL: ['TOPICS_PATH',
'TREC_OUTPUT_PATH',
SubCommands.RETRIEVAL: ['INPUT_PATH',
'OUTPUT_PATH',
'RUN_NAME',
'TOPIC_NRB',
'NRB_CONCLUSIONS_PER_TOPIC',
......@@ -27,11 +27,6 @@ class Configuration():
'REUSE_UNRANKED',
'LAMBDA_CONCLUSIONS',
'LAMBDA_PREMISES'],
SubCommands.ASSESSMENT: ['TOPICS_PATH',
'SENTENCES_PATH',
'RESULT_PATH'],
SubCommands.DEDUPLICATION: ['ELASTIC_HOST',
'INDICES']
}
def __init__(self, args: Namespace):
......@@ -54,8 +49,8 @@ class Configuration():
args.elastic_host]
elif self.command == SubCommands.RETRIEVAL:
args_list = [Path(args.topics_path),
Path(args.trec_output_path),
args_list = [Path(args.input_path),
Path(args.output_path),
args.run_name,
args.topic_nrb,
args.nrb_conclusions_per_topic,
......@@ -67,14 +62,7 @@ class Configuration():
args.lambda_premises
]
elif self.command == SubCommands.ASSESSMENT:
args_list = [Path(args.topics_path),
Path(args.sentences_path),
args.result_path]
elif self.command == SubCommands.DEDUPLICATION:
args_list = [args.elastic_host, args.indices]
config = dict(zip(self.keys[self.command], args_list))
self.logger.info(config)
return config
......@@ -11,27 +11,18 @@ class Text:
TODO
"""
description = '$$$$ Graph based Argument Mining on Sentence Embeddings $$$$'
sub = 'Lorem Ipsum dolor' # TODO
indexing = 'This sub command creates a semantic index with elastic search for an initial set' \
'retrieval.'
elastic_host = 'The hostname of the server that runs elastic search.'
create = 'If flag is present two new indices are created, one for conclusions, one for ' \
'premises. If there is already an existing index it will be overridden.'
indexing = 'Sub command to create a semantic index with elastic search.'
elastic_host = 'The hostname of the server/docker container that runs elastic search.'
create = 'If flag is present two new indices are created, overriding existing ones.'
sentences_path = 'The file path to the csv file containing the sentences.'
embeddings_path = 'The file path to the embeddings of the argument units.' \
'Overrides the EMBEDDINGS_PKL_PATH environment variable.'
retrieval = 'This sub command is intended for a run on the TIRA evaluation system.'
topic_nrb = 'Lorem ipsum' # TODO
embeddings_path = 'The file path to the embeddings of the argument units.'
retrieval = 'Retrive sentence pairs from the index.'
topic_nrb = 'Restrict the current indexing and/or reranking to a given topic number.'
nrb_conclusions = 'The number of conclusions that should be retrieved from the index per topic.'
nrb_premises = 'The number of premises that should be retrieved from the index per conclusion.'
run_name = 'The run name that will be included in the last column of the trec output file'
topics_path = 'The file path to the xml file containing the topics.' \
'Overrides the TOPICS_PATH environment variable.'
trec_output_path = 'The file path to the output file.' \
'Overrides the TREC_OUTPUT_PATH environment variable.'
assessment = 'Starts the interactive assesment of a result file.'
result_path = 'The path to a result file for an interactive assessment.'
deduplication = 'Deduplicates the indices given.'
run_name = 'The run name that will be included in the last column of the trec file.'
output_path = 'The file path to the directory containing the output files.'
input_path = 'The file path to the directory containing the input files.'
def parse_cli_args() -> argparse.Namespace:
......@@ -39,7 +30,7 @@ def parse_cli_args() -> argparse.Namespace:
TODO
"""
parser = argparse.ArgumentParser(description=Text.description)
subparsers = parser.add_subparsers(required=True, dest='subparser_name', help=Text.sub)
subparsers = parser.add_subparsers(required=True, dest='subparser_name')
# Indexing
......@@ -95,23 +86,7 @@ def parse_cli_args() -> argparse.Namespace:
required=False,
default=0.5)
parser_retrieval.add_argument('run_name', type=str, help=Text.run_name)
parser_retrieval.add_argument('topics_path', type=str, help=Text.topics_path)
parser_retrieval.add_argument('trec_output_path', type=str, help=Text.trec_output_path)
# Assessment
parser_assessment = subparsers.add_parser(SubCommands.ASSESSMENT.value, help=Text.assessment)
parser_assessment.add_argument('sentences_path',
type=str,
help=Text.sentences_path)
parser_assessment.add_argument('result_path', type=str, help=Text.result_path)
parser_assessment.add_argument('topics_path', type=str, help=Text.topics_path)
# Deduplication
parser_deduplicate = subparsers.add_parser(SubCommands.DEDUPLICATION.value,
help=Text.deduplication)
parser_deduplicate.add_argument('--elastic-host', help=Text.elastic_host)
parser_deduplicate.add_argument("--indices", nargs="+")
parser_retrieval.add_argument('input_path', type=str, help=Text.input_path)
parser_retrieval.add_argument('output_path', type=str, help=Text.output_path)
return parser.parse_args()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment