Commit 22205d8e authored by Jerome Wuerf's avatar Jerome Wuerf
Browse files

Fix experiments

parent 9991dd1a
......@@ -25,28 +25,28 @@ services:
# - |
# python -m debugpy --wait-for-client --listen 0.0.0.0:5678 \
# /app/src/prototype/app.py indexing \
# --create /data/sentences.parquet.snappy /data/embeddings.pkl
# --create /data/sentences.csv /data/embeddings.csv
entrypoint:
- "/bin/sh"
- -ecx
- |
python -m debugpy --wait-for-client --listen 0.0.0.0:5678 \
python \
/app/src/prototype/app.py retrieval \
--nrb-conclusions-per-topic 100 --nrb-premises-per-conclusion 1 \
baseline-double-index-semantic-search /data/topics.xml /data/results/
# entrypoint:
# - "/bin/sh"
# - -ecx
# - |
# python -m debugpy --wait-for-client --listen 0.0.0.0:5678 \
# /app/src/prototype/app.py assessment /data/sentences.parquet.snappy \
# /data/results/double-index-semantic-search.trec /data/topics.xml
# entrypoint:
# - "/bin/sh"
# -ecx
# - |
# python -m debugpy --wait-for-client --listen 0.0.0.0:5678 /app/src/prototype/app.py \
# deduplication --elastic-host elastic --indices conc premise
--nrb-conclusions-per-topic 100 --nrb-premises-per-conclusion 50 \
baseline /data/topics.xml /data/results/
# entrypoint:
# - "/bin/sh"
# - -ecx
# - |
# python -m debugpy --wait-for-client --listen 0.0.0.0:5678 \
# /app/src/prototype/app.py assessment /data/sentences.parquet.snappy \
# /data/results/TESTRUN_WITH_LENGTHFACTOR_1_17.trec /data/topics.xml
# entrypoint:
# - "/bin/sh"
# -ecx
# - |
# python -m debugpy --wait-for-client --listen 0.0.0.0:5678 /app/src/prototype/app.py \
# deduplication --elastic-host elastic --indices conc premise
elastic:
image: "docker.elastic.co/elasticsearch/elasticsearch:7.15.2"
restart: always
......
#!/bin/bash
test_points=(1,0 0.75,0.25 0.5,0.5 0.25,0.75 0,1)
test_points=(0,0 0.25,0.25 0.5,0.5 0.75,0.75 1,1)
for i in "${test_points[@]}"
do IFS=","; set -- $i;
python \
/app/src/prototype/app.py retrieval \
--nrb-conclusions-per-topic 100 --nrb-premises-per-conclusion 50 \
--topic-nrb 51 \
--reranking maximal-marginal-relevance --lambda-conclusions $1 --lambda-premises $2 \
test-run /data/topics.xml /data/results/
--reranking maximal-marginal-relevance --reuse-unranked /data/unranked_dump_final.pkl \
--lambda-conclusions $1 --lambda-premises $2 \
mrr_$1_$2_token_factor_1_75 /data/topics.xml /data/results/mmr_symmetric
done
......@@ -9,20 +9,20 @@ class Indexing:
"""
TODO
"""
EMBEDDING_LENGTH = 384
TOTAL_DUPLICATED_SENTENCES = 5337410
def __init__(self, sentences: ndarray,
embeddings: ndarray,
def __init__(self, input_generator,
elastic_host: str = 'elastic',
create_indices: bool = True):
"""
TODO
"""
self.logger = logging.getLogger(__name__)
self.sentences = sentences
self.embeddings = embeddings
self.input_generator = input_generator
self.es = Elasticsearch(elastic_host)
self.create_indices = create_indices
self.index_schema = self.get_index_schema(len(self.embeddings[0]))
self.index_schema = self.get_index_schema(self.EMBEDDING_LENGTH)
def get_index_schema(self, dims: int) -> None:
"""
......@@ -67,8 +67,8 @@ class Indexing:
if self.create_indices:
self.recreate_index('premise')
self.recreate_index('conc')
self.logger('Start indexing...')
with tqdm(total=len(self.sentences)) as progress_bar:
self.logger.info('Start indexing...')
with tqdm(total=self.TOTAL_DUPLICATED_SENTENCES) as progress_bar:
for ok, action in streaming_bulk(client=self.es, actions=self._generate_actions()):
progress_bar.update()
self.logger('Finished indexing!')
......@@ -77,8 +77,7 @@ class Indexing:
"""
TODO
"""
for sentence, embedding in zip(self.sentences, self.embeddings):
# if sentence[1].lower() == 'conc':
for sentence, embedding in self.input_generator:
yield {"_index": sentence[1].lower(),
"_id": '_'.join(sentence[:3]),
"_source": {
......
......@@ -5,6 +5,7 @@ from statistics import harmonic_mean
from typing import Callable
from utils import structural_distance, mmr_generic
from .argument_graph import ArgumentGraph
from tqdm import tqdm
class Reranking:
......@@ -102,23 +103,13 @@ class MaximalMarginalRelevanceReranking(Reranking):
return super()._get_track_rows()
def _mmr_conclusions_and_premises(self):
self._mmr_conclusions()
self._mmr_premises()
def _mmr_conclusions(self) -> None:
"""
TODO
"""
for premises_per_conclusions in self.premises_per_conclusions_per_topics.values():
for _, premises_per_conclusions in tqdm(self.premises_per_conclusions_per_topics.items()):
conclusions = [element['conclusion'] for element in premises_per_conclusions.values()]
mmr_generic(conclusions, self.conclusion_lambda)
def _mmr_premises(self) -> None:
"""
TODO
"""
for _, premises_pre_conclusions in self.premises_per_conclusions_per_topics.items():
for premises_per_conclusion in premises_pre_conclusions.values():
for premises_per_conclusion in tqdm(premises_per_conclusions.values()):
mmr_generic(premises_per_conclusion['premises'], self.premises_lambda)
......@@ -131,10 +122,10 @@ class StructuralDistanceReranking(Reranking):
def get_trec_rows(self):
def reranking_strategy_conclusions(x):
return harmonic_mean([x['conclusion'][self.reranking_key], x['conclusion']['_score']])
return x['conclusion'][self.reranking_key]
def reranking_strategy_premises(x):
return harmonic_mean([x[self.reranking_key], x['_score']])
return x[self.reranking_key]
super()._rerank(
self._structural_distance_conclusions_and_premises,
......@@ -144,31 +135,20 @@ class StructuralDistanceReranking(Reranking):
return super()._get_track_rows()
def _structural_distance_conclusions_and_premises(self) -> None:
self._structural_distance_conclusions()
self._structural_distance_premises()
def _structural_distance_premises(self) -> None:
"""
TODO Refactor to an immutable transformation
"""
for _, premises_pre_conclusions in self.premises_per_conclusions_per_topics.items():
for premises_per_conclusion in premises_pre_conclusions.values():
for topic_nrb, premises_pre_conclusions in tqdm(
self.premises_per_conclusions_per_topics.items()):
for premises_per_conclusion in tqdm(premises_pre_conclusions.values()):
conclusion_text = premises_per_conclusion['conclusion']['_source']['sentence_text']
for premise in premises_per_conclusion['premises']:
premise[self.reranking_key] = structural_distance(
premise['_source']['sentence_text'], conclusion_text)
def _structural_distance_conclusions(self) -> None:
"""
TODO Refactor to an immutable transformation
"""
for topic_nrb, premises_per_conclusions in self.premises_per_conclusions_per_topics.items():
for premises_per_conclusion in premises_per_conclusions.values():
conclusion = premises_per_conclusion['conclusion']
conclusion[self.reranking_key] = structural_distance(
conclusion['_source']['sentence_text'],
self.topics.loc[topic_nrb]['topic'])
for premise in premises_per_conclusion['premises']:
premise[self.reranking_key] = structural_distance(
premise['_source']['sentence_text'], conclusion_text)
class ArgumentGraphReranking(Reranking):
......
......@@ -77,7 +77,7 @@ class Retrieval:
for topic_nrb, conclusions_response in tqdm(conclusion_per_topic.items()):
premise_per_conclusion_per_topic[topic_nrb] = {}
conclusions = conclusions_response['hits']['hits']
with ThreadPoolExecutor(8) as executor:
with ThreadPoolExecutor(10) as executor:
premises_per_conclusion = executor.map(lambda x: (x, self._get_premises(x)),
conclusions)
......@@ -95,10 +95,10 @@ class Retrieval:
body = self._get_query_body(self.nrb_premises_per_conclusion,
conclusion['_source']['sentence_text'],
conclusion['_source']['sentence_vector'])
body['query']['script_score']['query'] = {'bool': {'must': {
body['query']['script_score']['query']['bool']['must'] = {
'match': {
'sentence_stance': conclusion['_source']['sentence_stance']
}}}
}
}
return self.es.search(index='premise', body=body)
......@@ -112,21 +112,18 @@ class Retrieval:
"query": {
"script_score": {
"query": {
"match_all": {}
# "bool": {
# "filter": {
# "script": {
# "script": {
# "source": "doc['sentence_text'].size() >= params.
# min_amount_of_tokens",
# "params": {
# "min_amount_of_tokens": int(len(self._analyze(query)
# ['tokens']) * self.length_factor),
# }
# }
# }
# }
# }
"bool": {
"filter": {
"range": {
"sentence_text.length": {
"gte": int(
len(self._analyze(query)['tokens']) *
self.length_factor
)
}
}
}
}
},
"script": {
"source": "cosineSimilarity(params.queryVector,'sentence_vector') + 1.0",
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment