Commit 7141dbdc authored by Jerome Wuerf's avatar Jerome Wuerf
Browse files

Merge branch 'word-mover-distance' into 'main'

Add reranking based on word mover distance

See merge request !19
parents 95328b2b 150c1bd4
......@@ -19,4 +19,5 @@ repos:
rev: '3.8.3'
hooks:
- id: flake8
args: ['--max-complexity=10', '--max-line-length=100', '--ignore=F401,W504', '--exclude=tests/*']
args: ['--max-complexity=10', '--max-line-length=100', '--ignore=F401,W504,E722',
'--exclude=tests/*', '--exclude=notebooks/*']
......@@ -30,9 +30,9 @@ services:
- "/bin/sh"
- -ecx
- |
python -m debugpy --wait-for-client --listen 0.0.0.0:5678 \
python -m debugpy --wait-for-client --listen 0.0.0.0:5678 \
/app/src/prototype/app.py retrieval \
--reuse-unranked /data/unranked_dump_final.pkl \
--reranking word-mover-distance --reuse-unranked /data/unranked_dump_final.pkl \
word_mover_distance_token_length_factor_1_75 /data/topics.xml /data/results/
# entrypoint:
......@@ -41,7 +41,7 @@ services:
# - |
# python -m debugpy --wait-for-client --listen 0.0.0.0:5678 \
# /app/src/prototype/app.py assessment /data/sentences.parquet.snappy \
# /data/results/TESTRUN_WITH_LENGTHFACTOR_1_17.trec /data/topics.xml
# /data/results/topic51_word_mover_distance_token_length_factor_1_75.trec /data/topics.xml
# entrypoint:
# - "/bin/sh"
# -ecx
......
#%%
import pickle
import pandas as pd
import faiss
#%%
sentences = pd.read_parquet('../data/sentences.parquet.snappy').drop_duplicates('sent_text')
embeddings = pickle.load(open('../data/embeddings.pkl', 'rb'))[sentences.index]
sentences.reset_index(drop=True, inplace=True)
# %%
k = len(sentences[sentences['sent_type'] == 'CONC'])
kmeans = faiss.Kmeans(d=embeddings.shape[1], k=k, niter=20)
kmeans.train(embeddings[sentences[sentences['sent_type'] == 'PREMISE'].index])
......@@ -22,13 +22,10 @@ duplicates_indices_premises = premises['sent_text'].duplicated(keep=False)
a = conclusions[duplicates_indices_conclusions]['sent_text'].drop_duplicates()
b = premises[duplicates_indices_premises]['sent_text'].drop_duplicates()
# %%
a = [IndicesClient(es).analyze(body=dict(text=chunk)) for chunk in np.array_split(a.to_numpy(),
5_000)]
a = [len(res['tokens']) for res in a]
# %%
b = [IndicesClient(es).analyze(body=dict(text=chunk)) for chunk in np.array_split(b.to_numpy(),
5_000)]
b = [len(res['tokens']) for res in b]
a = [len(IndicesClient(es).analyze(body=dict(text=chunk))) for chunk in np.array_split(a.to_numpy(),
5_000)]
b = [len(IndicesClient(es).analyze(body=dict(text=chunk))) for chunk in np.array_split(b.to_numpy(),
5_000)]
# %%
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
......@@ -39,6 +36,6 @@ ax1.xaxis.set_ticklabels([])
ax2.boxplot(b)
ax2.set_xlabel("Premises")
ax2.xaxis.set_ticklabels([])
plt.savefig('conclusions_premise_duplicate.svg')
plt.savefig('plots/boxplots_amount_of_tokens_of_premise_and_conclusion_duplicates.svg')
# %%
# %%
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from elasticsearch import Elasticsearch
from elasticsearch.client import IndicesClient
# %%
es = Elasticsearch('localhost:9200')
df = pd.read_parquet('../data/sentences.parquet.snappy').drop_duplicates('sent_text')
df.reset_index(drop=True, inplace=True)
df = df[df['sent_text'].str.len().between(42, 420)]
# %%
def plot_token_length(sentences, title_text_arg_type):
token_lengths = [len(IndicesClient(es).analyze(body=dict(text=chunk))['tokens']) for chunk in
np.array_split(sentences.to_numpy(), 50_000)]
fig, ax = plt.subplots()
# Plot.
plt.rc('figure', figsize=(8, 6))
plt.rc('font', size=14)
plt.rc('lines', linewidth=2)
# plt.rc('axes', color_cycle=('#377eb8','#e41a1c','#4daf4a',
# '#984ea3','#ff7f00','#ffff33'))
# Histogram.
ax.hist(token_lengths, bins=20)
# Average length.
avg_len = sum(token_lengths) / float(len(token_lengths))
plt.axvline(avg_len, color='#e41a1c')
trans = ax.get_xaxis_transform()
plt.title(f'Histogram of amount of tokens in {title_text_arg_type}.')
plt.xlabel('# Tokens')
plt.text(avg_len+avg_len*0.2, 0.5, 'mean = %.2f' % avg_len, transform=trans)
plt.show()
plt.savefig(f'plots/histogramm_amount_of_tokens_{title_text_arg_type}.svg')
# %%
plot_token_length(df[df['sent_type'] == 'CONC']['sent_text'], 'Conclusions')
# %%
plot_token_length(df[df['sent_type'] == 'PREMISE']['sent_text'], 'Premises')
from collections import OrderedDict
from itertools import chain
import logging
from statistics import harmonic_mean
import struct
from typing import Callable
from utils import structural_distance, mmr_generic, word_mover_distance
from .argument_graph import ArgumentGraph
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from abc import abstractmethod
class Reranking:
......@@ -29,12 +28,16 @@ class Reranking:
def _rerank(self,
calculate_scores: Callable,
reranking_strategy_conclusions: Callable,
reranking_strategy_premises: Callable
reranking_strategy_premises: Callable,
min_max_normalizaion=False
):
self.logger.info('Starting to calculate reranking scores...')
calculate_scores()
self.logger.info('Finished to calculate reranking socres!')
if min_max_normalizaion:
self._min_max_normalizaion()
self.logger.info('Starting reranking...')
self._sort(reranking_strategy_conclusions, reranking_strategy_premises)
self.logger.info('Finished reranking...')
......@@ -47,18 +50,25 @@ class Reranking:
reranked_premises_per_conclusions_per_topics = OrderedDict()
# Sort directory entries according to the reranking strategies
for topic_nrb, premises_pre_conclusions in self.premises_per_conclusions_per_topics.items():
for premises_per_conclusion in premises_pre_conclusions.values():
premises_per_conclusion['conclusion']['final_reranking_score'] = \
reranking_strategy_conclusions(
premises_per_conclusion)
for premise in premises_per_conclusion['premises']:
premise['final_reranking_score'] = reranking_strategy_premises(premise)
premises_per_conclusion['premises'] = sorted(
premises_per_conclusion['premises'],
key=lambda x: x['final_reranking_score'],
reverse=True)
reranked_premises_per_conclusions_per_topics[topic_nrb] = OrderedDict(
enumerate(sorted(premises_pre_conclusions.values(),
key=reranking_strategy_conclusions,
key=lambda x: x['conclusion']['final_reranking_score'],
reverse=True)
)
)
for premises_per_conclusion in \
reranked_premises_per_conclusions_per_topics[topic_nrb].values():
premises_per_conclusion['premises'] = sorted(
premises_per_conclusion['premises'],
key=reranking_strategy_premises,
reverse=True)
self.premises_per_conclusions_per_topics = reranked_premises_per_conclusions_per_topics
......@@ -76,6 +86,32 @@ class Reranking:
premises_per_conclusion['conclusion']['_source']['sentence_text'])
return specific_score_calcualation
def _min_max_normalizaion(self):
def _min_max_norm(max, min, value):
return (value - min) / (max - min)
for topic_nrb, premises_per_conclusions in \
self.premises_per_conclusions_per_topics.items():
conclusion_scores = [c['conclusion'][self.reranking_key] for c in
premises_per_conclusions.values()]
min_conclusion_score = min(conclusion_scores)
max_conclusion_score = max(conclusion_scores)
for premises_per_conclusion in premises_per_conclusions.values():
premises_per_conclusion['conclusion'][
self.reranking_key] = _min_max_norm(
max_conclusion_score, min_conclusion_score,
premises_per_conclusion['conclusion'][self.reranking_key])
premise_scores = [p[self.reranking_key] for p in
premises_per_conclusion['premises']]
min_premise_score = min(premise_scores)
max_premise_score = max(premise_scores)
for p in premises_per_conclusion['premises']:
p[self.reranking_key] = _min_max_norm(
max_premise_score, min_premise_score, p[self.reranking_key])
def generate_trec_rows(self):
rows_per_topic = OrderedDict()
for topic_nrb, premises_pre_conclusions in self.premises_per_conclusions_per_topics.items():
......@@ -93,7 +129,7 @@ class Reranking:
conclusion['_source']['sentence_stance'],
f'{conclusion["_id"]},{premise["_id"]}',
len(trec_style_rows),
conclusion[self.reranking_key],
f'{conclusion["final_reranking_score"]:.3f}',
self.run_name))
encountered_premise_ids.add(premise['_id'])
break
......@@ -145,10 +181,10 @@ class StructuralDistanceReranking(Reranking):
def rerank(self) -> None:
def reranking_strategy_conclusions(x):
return x['conclusion'][self.reranking_key]
return x['conclusion'][self.reranking_key] * 0.1 + x['conclusion']['_score'] / 2 * 0.9
def reranking_strategy_premises(x):
return x[self.reranking_key]
return x[self.reranking_key] * 0.25 + x['_score'] / 2 * 0.75
super()._rerank(
super()._generic_score_calculation(structural_distance),
......@@ -165,15 +201,15 @@ class WordMoverDistanceReranking(Reranking):
def rerank(self) -> None:
def reranking_strategy_conclusions(x):
return x['conclusion'][self.reranking_key]
return x['conclusion'][self.reranking_key] * 0.1 + x['conclusion']['_score'] / 2 * 0.9
def reranking_strategy_premises(x):
return x[self.reranking_key]
return x[self.reranking_key] * 0.25 + x['_score'] / 2 * 0.75
super()._rerank(
self._generic_score_calculation(word_mover_distance),
reranking_strategy_conclusions,
reranking_strategy_premises)
reranking_strategy_premises, True)
class ArgumentGraphReranking(Reranking):
......
......@@ -10,6 +10,8 @@ import libwmdrelax
# ~~~~~~~~~~~~~~~~~ ATTENTION ~~~~~~~~~~~~~~~~~
# The source code of this class is mostly a copy pasta form
# https://github.com/src-d/wmd-relax/blob/master/wmd/__init__.py
# Modified the WDM to a normalized implementation as described on SO
# https://stackoverflow.com/questions/56822056/unnormalized-result-of-word-movers-distance-with-spacy
class SpacySimilarityHook:
"""
......@@ -67,7 +69,15 @@ class SpacySimilarityHook:
dists = evec_sqr - 2 * evec.dot(evec.T) + evec_sqr[:, numpy.newaxis]
dists[dists < 0] = 0
dists = numpy.sqrt(dists)
return libwmdrelax.emd(w1, w2, dists)
result = None
# wmd throws error when documents that are empty after preprocessing pipeline
# This is a dirty fix to rank documents containing junk lower
try:
result = libwmdrelax.emd(w1, w2, dists)
except Exception:
result = 0
return result
def _convert_document(self, doc):
words = defaultdict(int)
......
flake8
pre-commit
autopep8
\ No newline at end of file
autopep8
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment