cosine similarity
This commit is contained in:
parent
1510e04c6c
commit
9fd246cb01
39
main.py
39
main.py
|
@ -1,19 +1,38 @@
|
||||||
|
import warnings
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# Suppress FutureWarnings and other warnings
|
||||||
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||||
|
|
||||||
# Load the tokenizer and the model
|
# Load the tokenizer and the model
|
||||||
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
|
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
|
||||||
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
|
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
|
||||||
|
|
||||||
# Prepare a test input sentence (e.g., "Hello, world!")
|
# Function to compute sentence embeddings by pooling token embeddings (CLS token)
|
||||||
input_text = "Hello, world!"
|
def get_sentence_embedding(text):
|
||||||
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
||||||
# Tokenize the input text and convert it to input IDs
|
with torch.no_grad():
|
||||||
inputs = tokenizer(input_text, return_tensors="pt") # Return tensors in PyTorch format
|
|
||||||
|
|
||||||
# Forward pass through the model
|
|
||||||
with torch.no_grad(): # Disable gradient calculation since we are only doing inference
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
# Output model's hidden states (for the last layer)
|
# Pooling strategy: Use the hidden state of the [CLS] token as the sentence embedding
|
||||||
print(outputs.last_hidden_state)
|
cls_embedding = outputs.last_hidden_state[:, 0, :] # Shape: (batch_size, hidden_size)
|
||||||
|
return cls_embedding
|
||||||
|
|
||||||
|
# Example subject and abstract
|
||||||
|
subject = "Artificial Intelligence in Healthcare"
|
||||||
|
abstract = """
|
||||||
|
Artificial intelligence (AI) is transforming healthcare with its ability to analyze complex medical data and assist in diagnosis.
|
||||||
|
AI models, especially in medical imaging, have shown promise in detecting diseases like cancer and predicting patient outcomes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get embeddings
|
||||||
|
subject_embedding = get_sentence_embedding(subject)
|
||||||
|
abstract_embedding = get_sentence_embedding(abstract)
|
||||||
|
|
||||||
|
# 2. **Measure Semantic Similarity Using Cosine Similarity**
|
||||||
|
|
||||||
|
# Compute cosine similarity between subject and abstract embeddings
|
||||||
|
similarity = F.cosine_similarity(subject_embedding, abstract_embedding)
|
||||||
|
print(f"Cosine Similarity: {similarity.item():.4f}")
|
||||||
|
|
Loading…
Reference in a new issue