cosine similarity
This commit is contained in:
parent
1510e04c6c
commit
9fd246cb01
39
main.py
39
main.py
|
@ -1,19 +1,38 @@
|
|||
import warnings
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Suppress FutureWarnings and other warnings
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
|
||||
# Load the tokenizer and the model
|
||||
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
|
||||
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
|
||||
|
||||
# Prepare a test input sentence (e.g., "Hello, world!")
|
||||
input_text = "Hello, world!"
|
||||
|
||||
# Tokenize the input text and convert it to input IDs
|
||||
inputs = tokenizer(input_text, return_tensors="pt") # Return tensors in PyTorch format
|
||||
|
||||
# Forward pass through the model
|
||||
with torch.no_grad(): # Disable gradient calculation since we are only doing inference
|
||||
# Function to compute sentence embeddings by pooling token embeddings (CLS token)
|
||||
def get_sentence_embedding(text):
|
||||
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Output model's hidden states (for the last layer)
|
||||
print(outputs.last_hidden_state)
|
||||
# Pooling strategy: Use the hidden state of the [CLS] token as the sentence embedding
|
||||
cls_embedding = outputs.last_hidden_state[:, 0, :] # Shape: (batch_size, hidden_size)
|
||||
return cls_embedding
|
||||
|
||||
# Example subject and abstract
|
||||
subject = "Artificial Intelligence in Healthcare"
|
||||
abstract = """
|
||||
Artificial intelligence (AI) is transforming healthcare with its ability to analyze complex medical data and assist in diagnosis.
|
||||
AI models, especially in medical imaging, have shown promise in detecting diseases like cancer and predicting patient outcomes.
|
||||
"""
|
||||
|
||||
# Get embeddings
|
||||
subject_embedding = get_sentence_embedding(subject)
|
||||
abstract_embedding = get_sentence_embedding(abstract)
|
||||
|
||||
# 2. **Measure Semantic Similarity Using Cosine Similarity**
|
||||
|
||||
# Compute cosine similarity between subject and abstract embeddings
|
||||
similarity = F.cosine_similarity(subject_embedding, abstract_embedding)
|
||||
print(f"Cosine Similarity: {similarity.item():.4f}")
|
||||
|
|
Loading…
Reference in a new issue