← 返回博客列表

transformer模型进行语义搜索和语言生成

演示使用 transformer 模型进行语义搜索和语言生成,其中使用 huggingface 的 transformers python 包需要自动下载语言模型,文件有几个G

第一部分

主要目标是在信息检索上下文中试验不同的嵌入技术

https://huggingface.co/sentence-transformers 如果安装了 python3.9+,则可以使用 pip3 install -U sentence-transformers 如果没有 python,则可以从 https://anaconda.org/ 安装 conda,然后使用 conda install -c conda-forge sentence-transformers 安装

代码如下:

from sentence_transformers import SentenceTransformer
from numpy import dot
from math import sqrt
import json

#使用 json python 包从 tweets-utf-8.json 读取推文,并生成包含每条推文文本的字符串列表。
def get_tweets():
    tweets = []
    with open('tweets-utf-8.json', 'r', encoding='utf-8') as file:
        for line in file:
            tweet = json.loads(line.strip())
            tweets.append(tweet['text'])
    return tweets

#该函数采用查询文档的嵌入、文档嵌入列表和相应文档的列表,并返回表单对 (similarity,document) 的列表,
# 根据每个文档与查询之间的余弦相似度按降序排序。您可以使用任何您喜欢的包;请注意,NumPy 有一个 DOT 函数
def sort_by_sim(query_embedding,document_embeddings,documents):
    # Calculate the cosine similarity between the query and each document
    similarities = []
    for i in range(len(document_embeddings)):
        similarity = dot(query_embedding, document_embeddings[i]) / (sqrt(dot(query_embedding, query_embedding)) * sqrt(dot(document_embeddings[i], document_embeddings[i])))
        similarities.append((similarity, documents[i]))
    # Sort the similarities in descending order
    similarities.sort(key=lambda x: x[0], reverse=True)
    # Return the sorted list of tuples
    return similarities


#,返回与查询 “I am looking for a job.” 最相似的 25 条推文(as (similarity,document) 对)。
# 使用此处定义的基于手套的句子嵌入:https://huggingface.co/sentence-transformers/average_word_embeddings_glove.840B.300d
def glove_top25(query,documents):
    # Load the GloVe model
    model = SentenceTransformer('sentence-transformers/average_word_embeddings_glove.840B.300d')
    # Encode the query
    query_embedding = model.encode(query)
    # Encode the documents
    document_embeddings = model.encode(documents)
    # Sort the documents by similarity to the query
    sorted_similarities = sort_by_sim(query_embedding, document_embeddings, documents)
    # Return the top 25 most similar documents
    top_25 = sorted_similarities[:25]
    # Return the list of tuples (similarity, document)
    return top_25

#返回与查询 “I am looking for a job.” 最相似的前 25 条推文(as (similarity,document) 对)。
# 使用此处定义的基于 MiniLM(派生自 BERT)的句子嵌入:https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
def minilm_top25(query,documents):
    # Load the MiniLM model
    model = SentenceTransformer('all-MiniLM-L6-v2')
    # Encode the query
    query_embedding = model.encode(query)
    # Encode the documents
    document_embeddings = model.encode(documents)
    # Sort the documents by similarity to the query
    sorted_similarities = sort_by_sim(query_embedding, document_embeddings, documents)
    # Return the top 25 most similar documents
    top_25 = sorted_similarities[:25]
    # Return the list of tuples (similarity, document)
    return top_25
        
## Test Code

tweets = get_tweets()

print("**************GLOVE*****************")
for p in glove_top25("I am looking for a job.",tweets): print(p)

print("**************MINILM*****************")
for p in minilm_top25("I am looking for a job.",tweets): print(p)

第二部分

使用生成语言模型,使用语义搜索更加精准

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer


def generate_story():
    # Set random seed for reproducibility (确保每次生成相同结果)
    torch.manual_seed(42)

    # Load pre-trained GPT-2 model and tokenizer (加载本地模型)
    model_name = "gpt2"  # 也可用 "gpt2"(小模型)或 "gpt2-large"
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name)

    # Define prompt and generation parameters (设置输入和生成参数)
    prompt = "Once upon a time, in a land far away,"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    # Generate story with controlled randomness (控制生成参数)
    output = model.generate(
        input_ids,
        max_length=100,
        num_return_sequences=1,
        no_repeat_ngram_size=3,  # 避免3-gram重复
        do_sample=True,
        top_k=50,
        top_p=0.95,
        temperature=0.7,
        pad_token_id=tokenizer.eos_token_id
    )

    # Decode and post-process the generated text (解码并后处理)
    story = tokenizer.decode(output[0], skip_special_tokens=True)

    # Ensure word count > 100 (检查字数)
    words = story.split()
    if len(words) < 100:
        print("Warning: Story is shorter than 100 words. Adjust max_length.")

    return story


if __name__ == "__main__":
    story = generate_story()
    print("Generated Story:\n", story)