Commit 0c3b31b5 authored by Jerome Wuerf's avatar Jerome Wuerf
Browse files

Add word mover distnace reranking

parent f482d928
......@@ -30,10 +30,11 @@ services:
- "/bin/sh"
- -ecx
- |
python \
python -m debugpy --wait-for-client --listen 0.0.0.0:5678 \
/app/src/prototype/app.py retrieval \
--nrb-conclusions-per-topic 100 --nrb-premises-per-conclusion 50 \
baseline /data/topics.xml /data/results/
--reranking word-mover-distance --reuse-unranked /data/unranked_dump_final.pkl \
word_mover_distance_token_length_factor_1_75 /data/topics.xml /data/results/
# entrypoint:
# - "/bin/sh"
# - -ecx
......
......@@ -10,3 +10,4 @@ textdistance
spacy
networkx
matplotlib
wmd
......@@ -3,7 +3,7 @@ from assessment import Assessment
from deduplication import Deduplication
from indexing import Indexing
from retrieval import (Retrieval, MaximalMarginalRelevanceReranking, StructuralDistanceReranking,
ArgumentGraphReranking, NoReranking)
ArgumentGraphReranking, WordMoverDistanceReranking, NoReranking)
from utils import (Configuration, SubCommands, RerankingOptions, parse_cli_args, read_data_to_index,
read_results, read_unranked, read_sentences, read_topics, write_output)
import logging
......@@ -71,6 +71,8 @@ class App:
results = StructuralDistanceReranking(results, self.config['RUN_NAME'], topics)
elif self.config['RERANKING'] == RerankingOptions.ARGUMENT_GRAPH.value:
results = ArgumentGraphReranking(results, self.config['RUN_NAME'], topics)
elif self.config['RERANKING'] == RerankingOptions.WORD_MOVER_DISTACNE:
results = WordMoverDistanceReranking(results, self.config['RUN_NAME'], topics)
else:
results = NoReranking(results, self.config['RUN_NAME'])
write_output(results.get_trec_rows(),
......
......@@ -2,4 +2,5 @@ from .retrieval import Retrieval
from .reranking import (MaximalMarginalRelevanceReranking,
StructuralDistanceReranking,
ArgumentGraphReranking,
WordMoverDistanceReranking,
NoReranking)
......@@ -2,11 +2,14 @@ from collections import OrderedDict
from itertools import chain
import logging
from statistics import harmonic_mean
import struct
from typing import Callable
from utils import structural_distance, mmr_generic
from utils import structural_distance, mmr_generic, word_mover_distance
from .argument_graph import ArgumentGraph
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
class Reranking:
......@@ -19,18 +22,30 @@ class Reranking:
self.reranking_key = reranking_key
self.run_name = run_name
@abstractmethod
def rerank(self):
pass
def _rerank(self,
calc_reranking_scores: Callable,
calculate_scores: Callable,
reranking_strategy_conclusions: Callable,
reranking_strategy_premises: Callable
):
self.logger.info('Starting to calculate reranking scores...')
calc_reranking_scores()
calculate_scores()
self.logger.info('Finished to calculate reranking socres!')
self.logger.info('Starting reranking...')
reranked_premises_per_conclusions_per_topics = OrderedDict()
self._sort(reranking_strategy_conclusions, reranking_strategy_premises)
self.logger.info('Finished reranking...')
def _sort(
self,
reranking_strategy_conclusions: Callable,
reranking_strategy_premises: Callable
) -> None:
reranked_premises_per_conclusions_per_topics = OrderedDict()
# Sort directory entries according to the reranking strategies
for topic_nrb, premises_pre_conclusions in self.premises_per_conclusions_per_topics.items():
reranked_premises_per_conclusions_per_topics[topic_nrb] = OrderedDict(
enumerate(sorted(premises_pre_conclusions.values(),
......@@ -46,12 +61,22 @@ class Reranking:
reverse=True)
self.premises_per_conclusions_per_topics = reranked_premises_per_conclusions_per_topics
self.logger.info('Finished reranking...')
def _get_track_rows(self):
"""
TODO
"""
def _generic_score_calculation(self, scoring_function: Callable) -> None:
def specific_score_calcualation() -> None:
for topic_nrb, premises_per_conclusions in tqdm(
self.premises_per_conclusions_per_topics.items()):
for premises_per_conclusion in tqdm(premises_per_conclusions.values()):
premises_per_conclusion['conclusion'][self.reranking_key] = scoring_function(
premises_per_conclusion['conclusion']['_source']['sentence_text'],
self.topics.loc[topic_nrb]['topic'])
for premise in premises_per_conclusion['premises']:
premise[self.reranking_key] = scoring_function(
premise['_source']['sentence_text'],
premises_per_conclusion['conclusion']['_source']['sentence_text'])
return specific_score_calcualation
def generate_trec_rows(self):
rows_per_topic = OrderedDict()
for topic_nrb, premises_pre_conclusions in self.premises_per_conclusions_per_topics.items():
trec_style_rows = []
......@@ -87,8 +112,8 @@ class MaximalMarginalRelevanceReranking(Reranking):
self.conclusion_lambda = conclusions_lambda
self.premises_lambda = premises_lambda
def get_trec_rows(self):
# Just takes the assigned mmr scores.
def rerank(self) -> None:
def reranking_strategy_conclusions(x):
return x['conclusion'][self.reranking_key]
......@@ -100,8 +125,6 @@ class MaximalMarginalRelevanceReranking(Reranking):
reranking_strategy_conclusions,
reranking_strategy_premises)
return super()._get_track_rows()
def _mmr_conclusions_and_premises(self):
"""
TODO
......@@ -117,10 +140,10 @@ class StructuralDistanceReranking(Reranking):
def __init__(self, premises_per_conclusions_per_topics: dict,
run_name: str,
topics):
super().__init__(premises_per_conclusions_per_topics, 'struc_dissim', run_name)
super().__init__(premises_per_conclusions_per_topics, 'sd', run_name)
self.topics = topics
def get_trec_rows(self):
def rerank(self) -> None:
def reranking_strategy_conclusions(x):
return x['conclusion'][self.reranking_key]
......@@ -128,26 +151,29 @@ class StructuralDistanceReranking(Reranking):
return x[self.reranking_key]
super()._rerank(
self._structural_distance_conclusions_and_premises,
super()._generic_score_calculation(structural_distance),
reranking_strategy_conclusions,
reranking_strategy_premises)
return super()._get_track_rows()
def _structural_distance_conclusions_and_premises(self) -> None:
class WordMoverDistanceReranking(Reranking):
def __init__(self, premises_per_conclusions_per_topics: dict,
run_name: str,
topics):
super().__init__(premises_per_conclusions_per_topics, 'wmd', run_name)
self.topics = topics
def rerank(self) -> None:
def reranking_strategy_conclusions(x):
return x['conclusion'][self.reranking_key]
for topic_nrb, premises_pre_conclusions in tqdm(
self.premises_per_conclusions_per_topics.items()):
for premises_per_conclusion in tqdm(premises_pre_conclusions.values()):
conclusion = premises_per_conclusion['conclusion']
conclusion_text = conclusion['_source']['sentence_text']
conclusion[self.reranking_key] = structural_distance(
conclusion_text,
self.topics.loc[topic_nrb]['topic'])
def reranking_strategy_premises(x):
return x[self.reranking_key]
for premise in premises_per_conclusion['premises']:
premise[self.reranking_key] = structural_distance(
premise['_source']['sentence_text'], conclusion_text)
super()._rerank(
self._generic_score_calculation(word_mover_distance),
reranking_strategy_conclusions,
reranking_strategy_premises)
class ArgumentGraphReranking(Reranking):
......@@ -155,25 +181,24 @@ class ArgumentGraphReranking(Reranking):
def __init__(self, premises_per_conclusions_per_topics: dict,
run_name: str,
topics):
super().__init__(premises_per_conclusions_per_topics, 'arg_graph_score', run_name)
super().__init__(premises_per_conclusions_per_topics, 'arg_rank', run_name)
self.graphs = ArgumentGraph(premises_per_conclusions_per_topics).create()
def get_trec_rows(self):
def rerank(self) -> None:
def reranking_strategy_conclusions(x):
return harmonic_mean([x['conclusion'][self.reranking_key], x['conclusion']['arg_rank']])
return x['conclusion'][self.reranking_key]
def reranking_strategy_premises(x):
return harmonic_mean([x[self.reranking_key], x['_arg_rank']])
return x[self.reranking_key]
super()._rerank(
self._arg_rank,
reranking_strategy_conclusions,
reranking_strategy_premises)
return super()._get_track_rows()
def _arg_rank(self):
for topic_nrb, g in self.graphs.items():
# TODO
pass
......@@ -182,5 +207,5 @@ class NoReranking(Reranking):
run_name: str):
super().__init__(premises_per_conclusions_per_topics, '_score', run_name)
def get_trec_rows(self):
return super()._get_track_rows()
def rerank(self):
pass
......@@ -3,6 +3,7 @@ from utils.commands import SubCommands, RerankingOptions
from utils.configuration import Configuration
from utils.structural_distance import structural_distance
from utils.maximal_marginal_relevance import mmr_score, mmr_generic
from utils.word_mover_distance import word_mover_distance
from utils.reader import (read_data_to_index,
read_results, read_unranked, read_sentences, read_topics)
from utils.write_output import write_output
from utils.write_trec_file import write_terc_file
......@@ -12,3 +12,4 @@ class RerankingOptions(Enum):
MAXIMAL_MARGINAL_RELEVANCE = 'maximal-marginal-relevance'
STRUCTURAL_DISTANCE = 'structural-distance'
ARGUMENT_GRAPH = 'argument-graph'
WORD_MOVER_DISTACNE = 'word-mover-distance'
......@@ -3,6 +3,8 @@ from dataclasses import dataclass
from numpy import require
from python.src.prototype.retrieval.reranking import Reranking
from .commands import SubCommands, RerankingOptions
......@@ -79,7 +81,8 @@ def parse_cli_args() -> argparse.Namespace:
required=False,
choices=[RerankingOptions.MAXIMAL_MARGINAL_RELEVANCE.value,
RerankingOptions.STRUCTURAL_DISTANCE.value,
RerankingOptions.ARGUMENT_GRAPH.value],
RerankingOptions.ARGUMENT_GRAPH.value,
RerankingOptions.WORD_MOVER_DISTACNE],
default=None)
parser_retrieval.add_argument('--reuse-unranked',
type=str,
......
from collections import defaultdict
import numpy
import wmd
import spacy
from spacy.language import Language
import libwmdrelax
# ~~~~~~~~~~~~~~~~~ ATTENTION ~~~~~~~~~~~~~~~~~
# The source code of this class is mostly a copy pasta form
# https://github.com/src-d/wmd-relax/blob/master/wmd/__init__.py
class SpacySimilarityHook:
"""
This guy is needed for the integration with `spaCy <https://spacy.io>`_.
Use it like this:
::
nlp = spacy.load('en_core_web_md')
nlp.add_pipe(wmd.WMD.SpacySimilarityHook(nlp), last=True)
It defines :func:`~wmd.WMD.SpacySimilarityHook.compute_similarity()` \
method which is called by spaCy over pairs of
`documents <https://spacy.io/docs/api/doc>`_.
.. automethod:: wmd::WMD.SpacySimilarityHook.__init__
"""
def __init__(self, nlp, ignore_stops, only_alpha, frequency_processor):
"""
Initializes a new instance of SpacySimilarityHook class.
:param nlp: `spaCy language object <https://spacy.io/docs/api/language>`_.
:param ignore_stops: Indicates whether to ignore the stop words.
:param only_alpha: Indicates whether only alpha tokens must be used.
:param frequency_processor: The function which is applied to raw \
token frequencies.
:type ignore_stops: bool
:type only_alpha: bool
:type frequency_processor: callable
"""
self.nlp = nlp
self.ignore_stops = ignore_stops
self.only_alpha = only_alpha
self.frequency_processor = frequency_processor
def __call__(self, doc):
doc.user_hooks["similarity"] = self.compute_similarity
doc.user_span_hooks["similarity"] = self.compute_similarity
return doc
def compute_similarity(self, doc1, doc2):
"""
Calculates the similarity between two spaCy documents. Extracts the
nBOW from them and evaluates the WMD.
:return: The calculated similarity.
:rtype: float.
"""
doc1 = self._convert_document(doc1)
doc2 = self._convert_document(doc2)
vocabulary = {
w: i for i, w in enumerate(sorted(set(doc1).union(doc2)))}
w1 = self._generate_weights(doc1, vocabulary)
w2 = self._generate_weights(doc2, vocabulary)
evec = numpy.zeros((len(vocabulary), self.nlp.vocab.vectors_length),
dtype=numpy.float32)
for w, i in vocabulary.items():
evec[i] = self.nlp.vocab[w].vector
evec_sqr = (evec * evec).sum(axis=1)
dists = evec_sqr - 2 * evec.dot(evec.T) + evec_sqr[:, numpy.newaxis]
dists[dists < 0] = 0
dists = numpy.sqrt(dists)
return libwmdrelax.emd(w1, w2, dists)
def _convert_document(self, doc):
words = defaultdict(int)
for t in doc:
if self.only_alpha and not t.is_alpha:
continue
if self.ignore_stops and t.is_stop:
continue
words[t.orth] += 1
return {t: self.frequency_processor(t, v) for t, v in words.items()}
def _generate_weights(self, doc, vocabulary):
w = numpy.zeros(len(vocabulary), dtype=numpy.float32)
for t, v in doc.items():
w[vocabulary[t]] = v
w /= w.sum()
return w
@Language.factory('wmd_component',
default_config={
"ignore_stops": True,
"only_alpha": True,
"frequency_processor": lambda t, f: numpy.log(1 + f)}
)
def wmd_component(nlp, name, ignore_stops, only_alpha, frequency_processor):
return SpacySimilarityHook(nlp, ignore_stops, only_alpha, frequency_processor)
nlp = spacy.load('en_core_web_md')
nlp.add_pipe('wmd_component', last=True)
def word_mover_distance(first_sent: str, second_sent: str):
return nlp(first_sent).similarity(nlp(second_sent))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment