Skip to content

Commit

Permalink
240906
Browse files Browse the repository at this point in the history
  • Loading branch information
ssocean committed Sep 5, 2024
1 parent 6f1668b commit 24af212
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 32 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ official_runs*
*.pt
*.pth
/.idea
prune.py
prune.py
/.idea
3 changes: 3 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 19 additions & 31 deletions single_pred.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,34 @@
from transformers import pipeline
import torch
from datetime import datetime
import torch
from torch.nn.functional import mse_loss
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, AdamW, RobertaForSequenceClassification, AutoTokenizer, \
AutoModelForSequenceClassification, FlaxLlamaForCausalLM, LlamaForSequenceClassification
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.dataset import random_split
from tqdm import tqdm
import os
import argparse
import json
import torch
from peft import PeftModel,PeftModelForTokenClassification
import os

import torch.nn as nn
from tools.order_metrics import *
import transformers.models.qwen2
from peft import AutoPeftModelForCausalLM,AutoPeftModelForSequenceClassification,AutoPeftModelForTokenClassification

from peft import AutoPeftModelForSequenceClassification
from transformers import AutoTokenizer
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType



title = '''xxxx'''
abstract = ''' xxx '''




model_pth = r"xxx"


model_pth = r"xxx" # Warning, you have to modify the "base_model_name_or_path" in adapter_config.json. We will fix this error after the rebuttal.
model = AutoPeftModelForSequenceClassification.from_pretrained(model_pth,num_labels=1, load_in_8bit=True,)
tokenizer = AutoTokenizer.from_pretrained(model_pth)

model = model.to("cuda")
model.eval()
# Default Prompt Template
text = f'''Given a certain paper, Title: {title}\n Abstract: {abstract}. \n Predict its normalized academic impact (between 0 and 1):'''
inputs = tokenizer(text, return_tensors="pt")

outputs = model(input_ids=inputs["input_ids"].to("cuda"))
print(outputs['logits'])

while True:
title = input("Enter a title: ")
abstract = input("Enter a abstract: ")
title = title.replace("\n", "").strip()
abstract = abstract.replace("\n", "").strip()
# Default Prompt Template
text = f'''Given a certain paper, Title: {title}\n Abstract: {abstract}. \n Predict its normalized academic impact (between 0 and 1):'''
inputs = tokenizer(text, return_tensors="pt")
outputs = model(input_ids=inputs["input_ids"].to("cuda"))
# If you haven't modify the LLaMA code.
print(nn.Sigmoid()(outputs['logits']))
# Else print(outputs['logits'])

0 comments on commit 24af212

Please sign in to comment.