Commit 95328b2b authored by Jerome Wuerf's avatar Jerome Wuerf
Browse files

Merge branch 'word-mover-distance' into 'main'

Word mover distance

See merge request !18
parents f482d928 45f47d8f
......@@ -30,10 +30,11 @@ services:
- "/bin/sh"
- -ecx
- |
python \
python -m debugpy --wait-for-client --listen 0.0.0.0:5678 \
/app/src/prototype/app.py retrieval \
--nrb-conclusions-per-topic 100 --nrb-premises-per-conclusion 50 \
baseline /data/topics.xml /data/results/
--reuse-unranked /data/unranked_dump_final.pkl \
word_mover_distance_token_length_factor_1_75 /data/topics.xml /data/results/
# entrypoint:
# - "/bin/sh"
# - -ecx
......
......@@ -4,4 +4,4 @@ RUN mkdir /app
WORKDIR /app
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt
RUN python -m spacy download en_core_web_sm
RUN python -m spacy download en_core_web_md
......@@ -10,3 +10,4 @@ textdistance
spacy
networkx
matplotlib
wmd
......@@ -3,9 +3,9 @@ from assessment import Assessment
from deduplication import Deduplication
from indexing import Indexing
from retrieval import (Retrieval, MaximalMarginalRelevanceReranking, StructuralDistanceReranking,
ArgumentGraphReranking, NoReranking)
ArgumentGraphReranking, WordMoverDistanceReranking, NoReranking)
from utils import (Configuration, SubCommands, RerankingOptions, parse_cli_args, read_data_to_index,
read_results, read_unranked, read_sentences, read_topics, write_output)
read_results, read_unranked, read_sentences, read_topics, write_terc_file)
import logging
import pickle
......@@ -46,37 +46,47 @@ class App:
def _retrieval(self) -> None:
topics = read_topics(self.config['TOPICS_PATH'], self.config['TOPIC_NRB'])
self.logger.info('Read topics!')
results = None
retrieved_results = None
if self.config['REUSE_UNRANKED']:
self.logger.info('Reading unranked...')
results = read_unranked(self.config['REUSE_UNRANKED'], self.config['TOPIC_NRB'])
retrieved_results = read_unranked(self.config['REUSE_UNRANKED'],
self.config['TOPIC_NRB'])
self.logger.info('Read unranked!')
else:
results = Retrieval(topics,
self.config['RUN_NAME'],
self.config['MIN_LENGTH_FACTOR'],
self.config['ELASTIC_HOST'],
self.config['NRB_CONCLUSIONS_PER_TOPIC'],
self.config['NRB_PREMISES_PER_CONCLUSION']
).retrieve()
pickle.dump(results, open('/data/baseline_dump.pkl', 'wb'))
retrieved_results = Retrieval(topics,
self.config['RUN_NAME'],
self.config['MIN_LENGTH_FACTOR'],
self.config['ELASTIC_HOST'],
self.config['NRB_CONCLUSIONS_PER_TOPIC'],
self.config['NRB_PREMISES_PER_CONCLUSION']
).retrieve()
pickle.dump(retrieved_results, open('/data/baseline_dump.pkl', 'wb'))
# Pattern matching is only available pre Python 3.10 :(
reranker = None
if self.config['RERANKING'] == RerankingOptions.MAXIMAL_MARGINAL_RELEVANCE.value:
results = MaximalMarginalRelevanceReranking(
results,
reranker = MaximalMarginalRelevanceReranking(
retrieved_results,
self.config['RUN_NAME'],
self.config['LAMBDA_CONCLUSIONS'],
self.config['LAMBDA_PREMISES']
)
elif self.config['RERANKING'] == RerankingOptions.STRUCTURAL_DISTANCE.value:
results = StructuralDistanceReranking(results, self.config['RUN_NAME'], topics)
reranker = StructuralDistanceReranking(retrieved_results, self.config['RUN_NAME'],
topics)
elif self.config['RERANKING'] == RerankingOptions.ARGUMENT_GRAPH.value:
results = ArgumentGraphReranking(results, self.config['RUN_NAME'], topics)
reranker = ArgumentGraphReranking(retrieved_results, self.config['RUN_NAME'], topics)
elif self.config['RERANKING'] == RerankingOptions.WORD_MOVER_DISTANCE.value:
reranker = WordMoverDistanceReranking(retrieved_results, self.config['RUN_NAME'],
topics)
else:
results = NoReranking(results, self.config['RUN_NAME'])
write_output(results.get_trec_rows(),
Path(self.config['TREC_OUTPUT_PATH'],
f'{ self.config["RUN_NAME"]}.trec')
)
reranker = NoReranking(retrieved_results, self.config['RUN_NAME'])
reranker.rerank()
write_terc_file(reranker.generate_trec_rows(),
Path(self.config['TREC_OUTPUT_PATH'],
f'{ self.config["RUN_NAME"]}.trec')
)
def _assessment(self) -> None:
sentences = read_sentences(self.config['SENTENCES_PATH'])
......
......@@ -2,4 +2,5 @@ from .retrieval import Retrieval
from .reranking import (MaximalMarginalRelevanceReranking,
StructuralDistanceReranking,
ArgumentGraphReranking,
WordMoverDistanceReranking,
NoReranking)
......@@ -46,5 +46,5 @@ class ArgumentGraph:
graph.add_edge(conc_arg_id, premise_argument_id, weight=sim)
graphs[topic_nrb] = graph
logging.info(f'Created argument graph for topic nrb {topic_nrb}')
logging.info('Finished creating the arguemnt graphs!')
logging.info('Finished creating the argument graphs!')
return graphs
......@@ -2,11 +2,14 @@ from collections import OrderedDict
from itertools import chain
import logging
from statistics import harmonic_mean
import struct
from typing import Callable
from utils import structural_distance, mmr_generic
from utils import structural_distance, mmr_generic, word_mover_distance
from .argument_graph import ArgumentGraph
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
class Reranking:
......@@ -19,18 +22,30 @@ class Reranking:
self.reranking_key = reranking_key
self.run_name = run_name
@abstractmethod
def rerank(self):
pass
def _rerank(self,
calc_reranking_scores: Callable,
calculate_scores: Callable,
reranking_strategy_conclusions: Callable,
reranking_strategy_premises: Callable
):
self.logger.info('Starting to calculate reranking scores...')
calc_reranking_scores()
calculate_scores()
self.logger.info('Finished to calculate reranking socres!')
self.logger.info('Starting reranking...')
reranked_premises_per_conclusions_per_topics = OrderedDict()
self._sort(reranking_strategy_conclusions, reranking_strategy_premises)
self.logger.info('Finished reranking...')
def _sort(
self,
reranking_strategy_conclusions: Callable,
reranking_strategy_premises: Callable
) -> None:
reranked_premises_per_conclusions_per_topics = OrderedDict()
# Sort directory entries according to the reranking strategies
for topic_nrb, premises_pre_conclusions in self.premises_per_conclusions_per_topics.items():
reranked_premises_per_conclusions_per_topics[topic_nrb] = OrderedDict(
enumerate(sorted(premises_pre_conclusions.values(),
......@@ -46,12 +61,22 @@ class Reranking:
reverse=True)
self.premises_per_conclusions_per_topics = reranked_premises_per_conclusions_per_topics
self.logger.info('Finished reranking...')
def _get_track_rows(self):
"""
TODO
"""
def _generic_score_calculation(self, scoring_function: Callable) -> None:
def specific_score_calcualation() -> None:
for topic_nrb, premises_per_conclusions in tqdm(
self.premises_per_conclusions_per_topics.items()):
for premises_per_conclusion in tqdm(premises_per_conclusions.values()):
premises_per_conclusion['conclusion'][self.reranking_key] = scoring_function(
premises_per_conclusion['conclusion']['_source']['sentence_text'],
self.topics.loc[topic_nrb]['topic'])
for premise in premises_per_conclusion['premises']:
premise[self.reranking_key] = scoring_function(
premise['_source']['sentence_text'],
premises_per_conclusion['conclusion']['_source']['sentence_text'])
return specific_score_calcualation
def generate_trec_rows(self):
rows_per_topic = OrderedDict()
for topic_nrb, premises_pre_conclusions in self.premises_per_conclusions_per_topics.items():
trec_style_rows = []
......@@ -87,8 +112,8 @@ class MaximalMarginalRelevanceReranking(Reranking):
self.conclusion_lambda = conclusions_lambda
self.premises_lambda = premises_lambda
def get_trec_rows(self):
# Just takes the assigned mmr scores.
def rerank(self) -> None:
def reranking_strategy_conclusions(x):
return x['conclusion'][self.reranking_key]
......@@ -100,8 +125,6 @@ class MaximalMarginalRelevanceReranking(Reranking):
reranking_strategy_conclusions,
reranking_strategy_premises)
return super()._get_track_rows()
def _mmr_conclusions_and_premises(self):
"""
TODO
......@@ -117,10 +140,10 @@ class StructuralDistanceReranking(Reranking):
def __init__(self, premises_per_conclusions_per_topics: dict,
run_name: str,
topics):
super().__init__(premises_per_conclusions_per_topics, 'struc_dissim', run_name)
super().__init__(premises_per_conclusions_per_topics, 'sd', run_name)
self.topics = topics
def get_trec_rows(self):
def rerank(self) -> None:
def reranking_strategy_conclusions(x):
return x['conclusion'][self.reranking_key]
......@@ -128,26 +151,29 @@ class StructuralDistanceReranking(Reranking):
return x[self.reranking_key]
super()._rerank(
self._structural_distance_conclusions_and_premises,
super()._generic_score_calculation(structural_distance),
reranking_strategy_conclusions,
reranking_strategy_premises)
return super()._get_track_rows()
def _structural_distance_conclusions_and_premises(self) -> None:
class WordMoverDistanceReranking(Reranking):
def __init__(self, premises_per_conclusions_per_topics: dict,
run_name: str,
topics):
super().__init__(premises_per_conclusions_per_topics, 'wmd', run_name)
self.topics = topics
def rerank(self) -> None:
def reranking_strategy_conclusions(x):
return x['conclusion'][self.reranking_key]
for topic_nrb, premises_pre_conclusions in tqdm(
self.premises_per_conclusions_per_topics.items()):
for premises_per_conclusion in tqdm(premises_pre_conclusions.values()):
conclusion = premises_per_conclusion['conclusion']
conclusion_text = conclusion['_source']['sentence_text']
conclusion[self.reranking_key] = structural_distance(
conclusion_text,
self.topics.loc[topic_nrb]['topic'])
def reranking_strategy_premises(x):
return x[self.reranking_key]
for premise in premises_per_conclusion['premises']:
premise[self.reranking_key] = structural_distance(
premise['_source']['sentence_text'], conclusion_text)
super()._rerank(
self._generic_score_calculation(word_mover_distance),
reranking_strategy_conclusions,
reranking_strategy_premises)
class ArgumentGraphReranking(Reranking):
......@@ -155,25 +181,24 @@ class ArgumentGraphReranking(Reranking):
def __init__(self, premises_per_conclusions_per_topics: dict,
run_name: str,
topics):
super().__init__(premises_per_conclusions_per_topics, 'arg_graph_score', run_name)
super().__init__(premises_per_conclusions_per_topics, 'arg_rank', run_name)
self.graphs = ArgumentGraph(premises_per_conclusions_per_topics).create()
def get_trec_rows(self):
def rerank(self) -> None:
def reranking_strategy_conclusions(x):
return harmonic_mean([x['conclusion'][self.reranking_key], x['conclusion']['arg_rank']])
return x['conclusion'][self.reranking_key]
def reranking_strategy_premises(x):
return harmonic_mean([x[self.reranking_key], x['_arg_rank']])
return x[self.reranking_key]
super()._rerank(
self._arg_rank,
reranking_strategy_conclusions,
reranking_strategy_premises)
return super()._get_track_rows()
def _arg_rank(self):
for topic_nrb, g in self.graphs.items():
# TODO
pass
......@@ -182,5 +207,5 @@ class NoReranking(Reranking):
run_name: str):
super().__init__(premises_per_conclusions_per_topics, '_score', run_name)
def get_trec_rows(self):
return super()._get_track_rows()
def rerank(self):
pass
......@@ -3,6 +3,7 @@ from utils.commands import SubCommands, RerankingOptions
from utils.configuration import Configuration
from utils.structural_distance import structural_distance
from utils.maximal_marginal_relevance import mmr_score, mmr_generic
from utils.word_mover_distance import word_mover_distance
from utils.reader import (read_data_to_index,
read_results, read_unranked, read_sentences, read_topics)
from utils.write_output import write_output
from utils.write_trec_file import write_terc_file
......@@ -12,3 +12,4 @@ class RerankingOptions(Enum):
MAXIMAL_MARGINAL_RELEVANCE = 'maximal-marginal-relevance'
STRUCTURAL_DISTANCE = 'structural-distance'
ARGUMENT_GRAPH = 'argument-graph'
WORD_MOVER_DISTANCE = 'word-mover-distance'
import argparse
from dataclasses import dataclass
from numpy import require
from .commands import SubCommands, RerankingOptions
......@@ -77,9 +76,12 @@ def parse_cli_args() -> argparse.Namespace:
parser_retrieval.add_argument('--reranking',
type=str,
required=False,
choices=[RerankingOptions.MAXIMAL_MARGINAL_RELEVANCE.value,
RerankingOptions.STRUCTURAL_DISTANCE.value,
RerankingOptions.ARGUMENT_GRAPH.value],
choices=[option.value for option in
[RerankingOptions.MAXIMAL_MARGINAL_RELEVANCE,
RerankingOptions.STRUCTURAL_DISTANCE,
RerankingOptions.ARGUMENT_GRAPH,
RerankingOptions.WORD_MOVER_DISTANCE]
],
default=None)
parser_retrieval.add_argument('--reuse-unranked',
type=str,
......
from cgitb import text
import spacy
import textdistance
nlp = spacy.load("en_core_web_sm")
nlp = spacy.load("en_core_web_md")
def structural_distance(first_sent: str, second_sent: str):
......
from collections import defaultdict
import numpy
import wmd
import spacy
from spacy.language import Language
import libwmdrelax
# ~~~~~~~~~~~~~~~~~ ATTENTION ~~~~~~~~~~~~~~~~~
# The source code of this class is mostly a copy pasta form
# https://github.com/src-d/wmd-relax/blob/master/wmd/__init__.py
class SpacySimilarityHook:
"""
This guy is needed for the integration with `spaCy <https://spacy.io>`_.
Use it like this:
::
nlp = spacy.load('en_core_web_md')
nlp.add_pipe(wmd.WMD.SpacySimilarityHook(nlp), last=True)
It defines :func:`~wmd.WMD.SpacySimilarityHook.compute_similarity()` \
method which is called by spaCy over pairs of
`documents <https://spacy.io/docs/api/doc>`_.
.. automethod:: wmd::WMD.SpacySimilarityHook.__init__
"""
def __init__(self, nlp, ignore_stops, only_alpha, frequency_processor):
"""
Initializes a new instance of SpacySimilarityHook class.
:param nlp: `spaCy language object <https://spacy.io/docs/api/language>`_.
:param ignore_stops: Indicates whether to ignore the stop words.
:param only_alpha: Indicates whether only alpha tokens must be used.
:param frequency_processor: The function which is applied to raw \
token frequencies.
:type ignore_stops: bool
:type only_alpha: bool
:type frequency_processor: callable
"""
self.nlp = nlp
self.ignore_stops = ignore_stops
self.only_alpha = only_alpha
self.frequency_processor = frequency_processor
def __call__(self, doc):
doc.user_hooks["similarity"] = self.compute_similarity
doc.user_span_hooks["similarity"] = self.compute_similarity
return doc
def compute_similarity(self, doc1, doc2):
"""
Calculates the similarity between two spaCy documents. Extracts the
nBOW from them and evaluates the WMD.
:return: The calculated similarity.
:rtype: float.
"""
doc1 = self._convert_document(doc1)
doc2 = self._convert_document(doc2)
vocabulary = {
w: i for i, w in enumerate(sorted(set(doc1).union(doc2)))}
w1 = self._generate_weights(doc1, vocabulary)
w2 = self._generate_weights(doc2, vocabulary)
evec = numpy.zeros((len(vocabulary), self.nlp.vocab.vectors_length),
dtype=numpy.float32)
for w, i in vocabulary.items():
evec[i] = self.nlp.vocab[w].vector
evec_sqr = (evec * evec).sum(axis=1)
dists = evec_sqr - 2 * evec.dot(evec.T) + evec_sqr[:, numpy.newaxis]
dists[dists < 0] = 0
dists = numpy.sqrt(dists)
return libwmdrelax.emd(w1, w2, dists)
def _convert_document(self, doc):
words = defaultdict(int)
for t in doc:
if self.only_alpha and not t.is_alpha:
continue
if self.ignore_stops and t.is_stop:
continue
words[t.orth] += 1
return {t: self.frequency_processor(t, v) for t, v in words.items()}
def _generate_weights(self, doc, vocabulary):
w = numpy.zeros(len(vocabulary), dtype=numpy.float32)
for t, v in doc.items():
w[vocabulary[t]] = v
w /= w.sum()
return w
@Language.factory('wmd_component',
default_config={
"ignore_stops": True,
"only_alpha": True,
"frequency_processor": lambda t, f: numpy.log(1 + f)}
)
def wmd_component(nlp, name, ignore_stops, only_alpha, frequency_processor):
return SpacySimilarityHook(nlp, ignore_stops, only_alpha, frequency_processor)
nlp = spacy.load('en_core_web_md')
nlp.add_pipe('wmd_component', last=True)
def word_mover_distance(first_sent: str, second_sent: str):
return nlp(first_sent).similarity(nlp(second_sent))
......@@ -3,7 +3,7 @@ from pathlib import Path
from typing import OrderedDict
def write_output(result: list, output_path: Path) -> None:
def write_terc_file(result: list, output_path: Path) -> None:
"""
TODO
"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment