cosine similarity

This commit is contained in:
WanderingPenwing 2024-09-26 10:50:50 +02:00
parent 1510e04c6c
commit 9fd246cb01

39
main.py
View file

@ -1,19 +1,38 @@
import warnings
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import torch import torch
import torch.nn.functional as F
# Suppress FutureWarnings and other warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
# Load the tokenizer and the model # Load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased') tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased') model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
# Prepare a test input sentence (e.g., "Hello, world!") # Function to compute sentence embeddings by pooling token embeddings (CLS token)
input_text = "Hello, world!" def get_sentence_embedding(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
# Tokenize the input text and convert it to input IDs with torch.no_grad():
inputs = tokenizer(input_text, return_tensors="pt") # Return tensors in PyTorch format
# Forward pass through the model
with torch.no_grad(): # Disable gradient calculation since we are only doing inference
outputs = model(**inputs) outputs = model(**inputs)
# Output model's hidden states (for the last layer) # Pooling strategy: Use the hidden state of the [CLS] token as the sentence embedding
print(outputs.last_hidden_state) cls_embedding = outputs.last_hidden_state[:, 0, :] # Shape: (batch_size, hidden_size)
return cls_embedding
# Example subject and abstract
subject = "Artificial Intelligence in Healthcare"
abstract = """
Artificial intelligence (AI) is transforming healthcare with its ability to analyze complex medical data and assist in diagnosis.
AI models, especially in medical imaging, have shown promise in detecting diseases like cancer and predicting patient outcomes.
"""
# Get embeddings
subject_embedding = get_sentence_embedding(subject)
abstract_embedding = get_sentence_embedding(abstract)
# 2. **Measure Semantic Similarity Using Cosine Similarity**
# Compute cosine similarity between subject and abstract embeddings
similarity = F.cosine_similarity(subject_embedding, abstract_embedding)
print(f"Cosine Similarity: {similarity.item():.4f}")