Commit 452781e0 authored by Jerome Wuerf's avatar Jerome Wuerf
Browse files

Merge branch 'modify-arg-rank' into 'tira'

Modify arg rank

See merge request !22
parents b7427dba 40c4c2d2
......@@ -5,24 +5,53 @@ import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import torch
graphs = pickle.load(open('../data/graphs.pkl', 'rb'))
graphs_wdm = pickle.load(open('../data/wdm_arg_rank.pkl', 'rb'))
graphs_cosine = pickle.load(open('../data/cosine_arg_rank.pkl', 'rb'))
#%%
def all_sim(graphs):
return [attrs['weight'] for g in graphs.values() for a, b, attrs in g.edges(data=True)]
all_sim_wdm = pd.Series(all_sim(graphs_wdm))
all_sim_cosine = pd.Series(torch.stack(all_sim(graphs_cosine)))
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.set_xlabel("Word Mover's Similarity")
ax2.set_xlabel('Cosine Similarity')
ax1.boxplot(all_sim_wdm)
ax2.boxplot(all_sim_cosine)
plt.tight_layout()
plt.savefig('./plots/cosine_wdm_sim_scores.pdf')
#%%
graphs = graphs_wdm
graph_type = 'wdm'
#%%
for g in graphs.values():
to_remove = [(a,b) for a, b, attrs in g.edges(data=True) if attrs["weight"] < 0.35]
g.remove_edges_from(to_remove)
#%%
argument_nodes = pd.Series(
[len(g.nodes) for g in graphs.values()],
name='Argument Nodes per Topic', index=graphs.keys())
argument_edges = pd.Series(
[len(g.edges) for g in graphs.values()],
name='Argument Edges per Topic', index=graphs.keys())
# %%
fig, (ax1, ax2) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [4, 1]}, figsize=(10, 5))
fig.suptitle('Argument Edges')
# fig.suptitle('Argument Edges')
ax1.bar(argument_edges.index, argument_edges)
ax1.set_xlabel('Topic number')
ax1.set_ylabel('# Edges')
ax2.boxplot(argument_edges)
ax2.yaxis.set_ticklabels([])
ax2.xaxis.set_ticklabels([])
plt.savefig('argument_edges.png')
plt.savefig(f'./plots/{graph_type}_argument_edges.pdf')
# %%
fig, (ax1, ax2) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [4, 1]}, figsize=(10, 5))
......@@ -33,13 +62,13 @@ ax1.set_xlabel('Topic number')
ax1.set_ylabel('# Nodes')
ax2.yaxis.set_ticklabels([])
ax2.xaxis.set_ticklabels([])
plt.savefig('argument_nodes.png')
plt.savefig(f'./plots/{graph_type}_argument_nodes.pdf')
# %%
fig, axis = plt.subplots(1, 5, figsize=(13, 5))
fig.suptitle('Argument Graphs per Topic')
# fig.suptitle('Argument Graphs per Topic')
half_graphs = list(graphs.items())[:5]
for ax, elem in zip(axis.flatten(), half_graphs):
topic_nrb = elem[0]
......@@ -49,7 +78,7 @@ for ax, elem in zip(axis.flatten(), half_graphs):
pos = nx.spring_layout(Gcc, seed=10396953)
nx.draw_networkx_nodes(Gcc, pos, ax=ax, node_size=20)
nx.draw_networkx_edges(Gcc, pos, ax=ax, alpha=0.4)
plt.savefig('argument_graphs.png')
plt.savefig(f'./plots/{graph_type}_argument_graphs.pdf')
# %%
......@@ -68,13 +97,13 @@ for ax, elem in zip(axis.flatten(), graphs.items()):
degree_list = [g.degree() for g in graphs.values()]
degree_list = [item[1] for sublist in degree_list for item in sublist]
fig, ax = plt.subplots()
ax.set_title('Degrees Histogramm of Argument Graphs (all Topics)')
# ax.set_title('Degrees Histogramm of Argument Graphs (all Topics)')
ax.set_yscale('log')
ax.set_xlabel("Degree")
ax.set_ylabel("# of Nodes")
ax.hist(degree_list)
plt.savefig('degree_histogramm.png')
plt.savefig(f'./plots/{graph_type}_degree_histogramm.pdf')
# %%
# %%
import pandas as pd
import numpy as np
import pyarrow
import time
import openpyxl
import matplotlib.pyplot as plt
import math
# %%
......
from pathlib import Path
from indexing import Indexing
from retrieval import (Retrieval, MaximalMarginalRelevanceReranking, StructuralDistanceReranking,
ArgumentGraphReranking, WordMoverDistanceReranking, NoReranking)
from utils import (Configuration, SubCommands, RerankingOptions, parse_cli_args, read_data_to_index,
read_results, read_unranked, read_sentences, read_topics, write_terc_file)
ArgumentRankReranking, WordMoverDistanceReranking, NoReranking)
from utils import (Configuration, SubCommands, RerankingOptions, parse_cli_args,
read_data_to_index, read_unranked, read_topics, write_terc_file)
import logging
import time
......@@ -76,8 +76,8 @@ class App:
elif self.config['RERANKING'] == RerankingOptions.STRUCTURAL_DISTANCE.value:
reranker = StructuralDistanceReranking(retrieved_results, self.config['RUN_NAME'],
topics)
elif self.config['RERANKING'] == RerankingOptions.ARGUMENT_GRAPH.value:
reranker = ArgumentGraphReranking(retrieved_results, self.config['RUN_NAME'], topics)
elif self.config['RERANKING'] == RerankingOptions.ARGUMENT_RANK.value:
reranker = ArgumentRankReranking(retrieved_results, self.config['RUN_NAME'], topics)
elif self.config['RERANKING'] == RerankingOptions.WORD_MOVER_DISTANCE.value:
reranker = WordMoverDistanceReranking(retrieved_results, self.config['RUN_NAME'],
topics)
......
from .retrieval import Retrieval
from .reranking import (MaximalMarginalRelevanceReranking,
StructuralDistanceReranking,
ArgumentGraphReranking,
ArgumentRankReranking,
WordMoverDistanceReranking,
NoReranking)
from charset_normalizer import utils
import networkx as nx
import numpy as np
import logging
from sentence_transformers import util
from torch import embedding
from utils import word_mover_distance
class ArgumentGraph:
def __init__(self, premises_per_conclusions_per_topics: dict, threshold: float = 0.99):
def __init__(self, premises_per_conclusions_per_topics: dict, threshold: float = 0.20):
self.logger = logging.getLogger(__name__)
self.premises_per_conclusions_per_topics = premises_per_conclusions_per_topics
self.threshold = threshold
......@@ -21,10 +18,13 @@ class ArgumentGraph:
sim_matrices = {}
for conc_nrb, premises_per_conclusion in premises_per_conclusions.items():
conclusion = premises_per_conclusion['conclusion']
sim_matrices[conc_nrb] = util.cos_sim(
np.array(conclusion['_source']['sentence_vector']),
np.array([p['_source']['sentence_vector'] for p in
premises_per_conclusion['premises']]))[0]
sim_matrices[conc_nrb] = [
1 /
(1 +
word_mover_distance(
conclusion['_source']['sentence_text'],
p['_source']['sentence_text']))
for p in premises_per_conclusion['premises']]
for premises_per_conclusion in premises_per_conclusions.values():
conclusions_argument_id = conclusion["_id"].split('_')[0]
......@@ -35,8 +35,8 @@ class ArgumentGraph:
premise_argument_id = premise["_id"].split('_')[0]
if premise_argument_id not in graph:
graph.add_node(premise_argument_id)
else:
graph.nodes[premise_argument_id]['premises']
# else: TODO add a list of sentence_text to the nodes
# graph.nodes[premise_argument_id]['premises']
for conc_idx in range(len(premises_per_conclusions)):
sim = sim_matrices[conc_idx][prem_idx]
if sim >= self.threshold:
......
......@@ -211,7 +211,7 @@ class WordMoverDistanceReranking(Reranking):
reranking_strategy_premises, True)
class ArgumentGraphReranking(Reranking):
class ArgumentRankReranking(Reranking):
def __init__(self, premises_per_conclusions_per_topics: dict,
run_name: str,
......
......@@ -9,5 +9,5 @@ class SubCommands(Enum):
class RerankingOptions(Enum):
MAXIMAL_MARGINAL_RELEVANCE = 'maximal-marginal-relevance'
STRUCTURAL_DISTANCE = 'structural-distance'
ARGUMENT_GRAPH = 'argument-graph'
ARGUMENT_RANK = 'argument-rank'
WORD_MOVER_DISTANCE = 'word-mover-distance'
......@@ -70,7 +70,7 @@ def parse_cli_args() -> argparse.Namespace:
choices=[option.value for option in
[RerankingOptions.MAXIMAL_MARGINAL_RELEVANCE,
RerankingOptions.STRUCTURAL_DISTANCE,
RerankingOptions.ARGUMENT_GRAPH,
RerankingOptions.ARGUMENT_RANK,
RerankingOptions.WORD_MOVER_DISTANCE]
],
default=None)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment