Skip to content

Commit

Permalink
evaluation added
Browse files Browse the repository at this point in the history
  • Loading branch information
AswiN-7 committed May 26, 2022
1 parent 5608de8 commit 5ca3c1e
Show file tree
Hide file tree
Showing 26 changed files with 1,844 additions and 4,093 deletions.
1,199 changes: 1,017 additions & 182 deletions ODQA_utils/driver.ipynb

Large diffs are not rendered by default.

247 changes: 247 additions & 0 deletions evaluation results/eval.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('eval_data.csv')\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility\n",
"connections.connect()\n",
"from tqdm.autonotebook import tqdm\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"TABLE_NAME = 'eval_question_answering'\n",
"collection = None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Deleting previouslny stored table for clean run\n",
"def create_mqa():\n",
" if utility.has_collection(TABLE_NAME):\n",
" collection = Collection(name=TABLE_NAME)\n",
" collection.drop()\n",
"\n",
" field1 = FieldSchema(name=\"id\", dtype=DataType.INT64, descrition=\"int64\", is_primary=True)\n",
" field3 = FieldSchema(name=\"embedding\", dtype=DataType.FLOAT_VECTOR, descrition=\"float vector\",dim=1024, is_primary=False)\n",
" schema = CollectionSchema(fields=[field1, field3], description=\"collection description\")\n",
" collection = Collection(name=TABLE_NAME, schema=schema)\n",
" \n",
" default_index = {\"index_type\": \"IVF_FLAT\", \"metric_type\": 'IP', \"params\": {\"nlist\": 200}}\n",
" collection.create_index(field_name=\"embedding\", index_params=default_index)\n",
"\n",
"if utility.has_collection(TABLE_NAME):\n",
" global collection\n",
" collection = Collection(name=TABLE_NAME)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"create_mqa()\n",
"print(collection)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"\n",
"model = SentenceTransformer(\"AswiN037/sentence-t-roberta-large-wechsel-tamil\")\n",
"print(\"Retriever model loaded\")\n",
"def encode(text):\n",
" embeddings = model.encode(text)\n",
" return [embeddings.tolist()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# new \n",
"def push_context_to_milvus():\n",
" i = collection.num_entities\n",
" size = collection.num_entities \n",
" batch = 50\n",
" while i < len(df) and i < size + batch:\n",
" emb = encode(df['context'][i])\n",
" ids = [int(df['id'][i])]\n",
" collection.insert([ids, emb])\n",
" i+=1\n",
" return collection.num_entities"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"push_context_to_milvus()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def find_similar(emb):\n",
" collection.load()\n",
" return collection.search(\n",
"\tdata=emb, \n",
"\tanns_field=\"embedding\", \n",
"\tparam={\"metric_type\": \"IP\", \"params\": {\"nprobe\": 10}}, \n",
"\tlimit=10, \n",
"\texpr=None,\n",
"\toutput_fields = [\"id\"],\n",
"\tconsistency_level=\"Strong\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def rertieve_id_for_question():\n",
" # 0 - question id 1- retrieved context id\n",
" result=[]\n",
" for i in range(len(df)):\n",
" question_emb = encode(df['question'][i])\n",
" similar_ids = find_similar(question_emb)\n",
" sim_id = similar_ids[0].ids[0]\n",
" result.append((i, sim_id))\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"retriever_result = rertieve_id_for_question()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"retriever_result\n",
"df_retrieved = pd.DataFrame(retriever_result, columns=['question_id', 'context_id'])\n",
"df_retrieved.to_json(\"weschel_encoder_result.json\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import pipeline\n",
"model_name = \"AswiN037/xlm-roberta-squad-tamil\"\n",
"answer_extract = pipeline('question-answering', model=model_name, tokenizer=model_name)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"answer_extract_result = []\n",
"for i in range(len(df_retrieved)):\n",
" r_q_id = df_retrieved['question_id'][i]\n",
" question = df['question'][r_q_id]\n",
" r_c_id = df_retrieved['context_id'][i]\n",
" context = df['context'][r_c_id]\n",
" original_answer = df['answer_text'][r_q_id]\n",
" qc = {\n",
" \"context\" : context, \n",
" \"question\" : question \n",
" }\n",
" predicted_answer = answer_extract(qc)['answer']\n",
" # original answer, predicted answer\n",
" answer_extract_result.append((original_answer, predicted_answer))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"answer_extract_result\n",
"df_extracted_answer = pd.DataFrame(answer_extract_result, columns=['Actual', 'Predicted'])\n",
"df_extracted_answer.to_json('weschel_encoder_xlm_robert.json')"
]
}
],
"metadata": {
"interpreter": {
"hash": "2043299c89c8cd0b4d1a6f5cf4529bd58e6a4e0fe3181a25e0d328c821cdc5c5"
},
"kernelspec": {
"display_name": "Python 3.9.7 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 5ca3c1e

Please sign in to comment.