Commit 45f47d8f authored by Jerome Wuerf's avatar Jerome Wuerf
Browse files

Refactor reranking and fix configuration error

parent 0c3b31b5
......@@ -32,7 +32,7 @@ services:
- |
python -m debugpy --wait-for-client --listen 0.0.0.0:5678 \
/app/src/prototype/app.py retrieval \
--reranking word-mover-distance --reuse-unranked /data/unranked_dump_final.pkl \
--reuse-unranked /data/unranked_dump_final.pkl \
word_mover_distance_token_length_factor_1_75 /data/topics.xml /data/results/
# entrypoint:
......
......@@ -4,4 +4,4 @@ RUN mkdir /app
WORKDIR /app
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt
RUN python -m spacy download en_core_web_sm
RUN python -m spacy download en_core_web_md
......@@ -5,7 +5,7 @@ from indexing import Indexing
from retrieval import (Retrieval, MaximalMarginalRelevanceReranking, StructuralDistanceReranking,
ArgumentGraphReranking, WordMoverDistanceReranking, NoReranking)
from utils import (Configuration, SubCommands, RerankingOptions, parse_cli_args, read_data_to_index,
read_results, read_unranked, read_sentences, read_topics, write_output)
read_results, read_unranked, read_sentences, read_topics, write_terc_file)
import logging
import pickle
......@@ -46,39 +46,47 @@ class App:
def _retrieval(self) -> None:
topics = read_topics(self.config['TOPICS_PATH'], self.config['TOPIC_NRB'])
self.logger.info('Read topics!')
results = None
retrieved_results = None
if self.config['REUSE_UNRANKED']:
self.logger.info('Reading unranked...')
results = read_unranked(self.config['REUSE_UNRANKED'], self.config['TOPIC_NRB'])
retrieved_results = read_unranked(self.config['REUSE_UNRANKED'],
self.config['TOPIC_NRB'])
self.logger.info('Read unranked!')
else:
results = Retrieval(topics,
self.config['RUN_NAME'],
self.config['MIN_LENGTH_FACTOR'],
self.config['ELASTIC_HOST'],
self.config['NRB_CONCLUSIONS_PER_TOPIC'],
self.config['NRB_PREMISES_PER_CONCLUSION']
).retrieve()
pickle.dump(results, open('/data/baseline_dump.pkl', 'wb'))
retrieved_results = Retrieval(topics,
self.config['RUN_NAME'],
self.config['MIN_LENGTH_FACTOR'],
self.config['ELASTIC_HOST'],
self.config['NRB_CONCLUSIONS_PER_TOPIC'],
self.config['NRB_PREMISES_PER_CONCLUSION']
).retrieve()
pickle.dump(retrieved_results, open('/data/baseline_dump.pkl', 'wb'))
# Pattern matching is only available pre Python 3.10 :(
reranker = None
if self.config['RERANKING'] == RerankingOptions.MAXIMAL_MARGINAL_RELEVANCE.value:
results = MaximalMarginalRelevanceReranking(
results,
reranker = MaximalMarginalRelevanceReranking(
retrieved_results,
self.config['RUN_NAME'],
self.config['LAMBDA_CONCLUSIONS'],
self.config['LAMBDA_PREMISES']
)
elif self.config['RERANKING'] == RerankingOptions.STRUCTURAL_DISTANCE.value:
results = StructuralDistanceReranking(results, self.config['RUN_NAME'], topics)
reranker = StructuralDistanceReranking(retrieved_results, self.config['RUN_NAME'],
topics)
elif self.config['RERANKING'] == RerankingOptions.ARGUMENT_GRAPH.value:
results = ArgumentGraphReranking(results, self.config['RUN_NAME'], topics)
elif self.config['RERANKING'] == RerankingOptions.WORD_MOVER_DISTACNE:
results = WordMoverDistanceReranking(results, self.config['RUN_NAME'], topics)
reranker = ArgumentGraphReranking(retrieved_results, self.config['RUN_NAME'], topics)
elif self.config['RERANKING'] == RerankingOptions.WORD_MOVER_DISTANCE.value:
reranker = WordMoverDistanceReranking(retrieved_results, self.config['RUN_NAME'],
topics)
else:
results = NoReranking(results, self.config['RUN_NAME'])
write_output(results.get_trec_rows(),
Path(self.config['TREC_OUTPUT_PATH'],
f'{ self.config["RUN_NAME"]}.trec')
)
reranker = NoReranking(retrieved_results, self.config['RUN_NAME'])
reranker.rerank()
write_terc_file(reranker.generate_trec_rows(),
Path(self.config['TREC_OUTPUT_PATH'],
f'{ self.config["RUN_NAME"]}.trec')
)
def _assessment(self) -> None:
sentences = read_sentences(self.config['SENTENCES_PATH'])
......
......@@ -46,5 +46,5 @@ class ArgumentGraph:
graph.add_edge(conc_arg_id, premise_argument_id, weight=sim)
graphs[topic_nrb] = graph
logging.info(f'Created argument graph for topic nrb {topic_nrb}')
logging.info('Finished creating the arguemnt graphs!')
logging.info('Finished creating the argument graphs!')
return graphs
......@@ -12,4 +12,4 @@ class RerankingOptions(Enum):
MAXIMAL_MARGINAL_RELEVANCE = 'maximal-marginal-relevance'
STRUCTURAL_DISTANCE = 'structural-distance'
ARGUMENT_GRAPH = 'argument-graph'
WORD_MOVER_DISTACNE = 'word-mover-distance'
WORD_MOVER_DISTANCE = 'word-mover-distance'
import argparse
from dataclasses import dataclass
from numpy import require
from python.src.prototype.retrieval.reranking import Reranking
from .commands import SubCommands, RerankingOptions
......@@ -79,10 +76,12 @@ def parse_cli_args() -> argparse.Namespace:
parser_retrieval.add_argument('--reranking',
type=str,
required=False,
choices=[RerankingOptions.MAXIMAL_MARGINAL_RELEVANCE.value,
RerankingOptions.STRUCTURAL_DISTANCE.value,
RerankingOptions.ARGUMENT_GRAPH.value,
RerankingOptions.WORD_MOVER_DISTACNE],
choices=[option.value for option in
[RerankingOptions.MAXIMAL_MARGINAL_RELEVANCE,
RerankingOptions.STRUCTURAL_DISTANCE,
RerankingOptions.ARGUMENT_GRAPH,
RerankingOptions.WORD_MOVER_DISTANCE]
],
default=None)
parser_retrieval.add_argument('--reuse-unranked',
type=str,
......
from cgitb import text
import spacy
import textdistance
nlp = spacy.load("en_core_web_sm")
nlp = spacy.load("en_core_web_md")
def structural_distance(first_sent: str, second_sent: str):
......
......@@ -3,7 +3,7 @@ from pathlib import Path
from typing import OrderedDict
def write_output(result: list, output_path: Path) -> None:
def write_terc_file(result: list, output_path: Path) -> None:
"""
TODO
"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment