Skip to content

Latest commit

 

History

History
105 lines (93 loc) · 4.47 KB

File metadata and controls

105 lines (93 loc) · 4.47 KB

.\SentEval\senteval\utils.py

# 引入未来版本的特性允许 Python 2/3 兼容
from __future__ import absolute_import, division, unicode_literals

# 引入需要的库
import numpy as np  # 导入数值计算库 numpy
import re  # 导入正则表达式库 re
import inspect  # 导入检查模块信息的 inspect
from torch import optim  # 从 torch 库中导入优化器模块 optim

# 定义函数 create_dictionary,接受一个句子列表作为参数
def create_dictionary(sentences):
    words = {}  # 创建空字典 words
    # 遍历句子列表
    for s in sentences:
        # 遍历句子中的单词
        for word in s:
            # 如果单词已经在字典中,增加其出现次数;否则将其添加到字典并置计数为1
            if word in words:
                words[word] += 1
            else:
                words[word] = 1
    # 将特殊符号添加到字典中,赋予它们较高的计数以确保排序时排在前面
    words['<s>'] = 1e9 + 4
    words['</s>'] = 1e9 + 3
    words['<p>'] = 1e9 + 2
    # words['<UNK>'] = 1e9 + 1  # 可选的未知单词标记
    sorted_words = sorted(words.items(), key=lambda x: -x[1])  # 对字典按值(计数)进行逆排序
    id2word = []  # 创建空列表 id2word,用于存储单词索引到单词的映射关系
    word2id = {}  # 创建空字典 word2id,用于存储单词到索引的映射关系
    # 遍历排序后的单词列表
    for i, (w, _) in enumerate(sorted_words):
        id2word.append(w)  # 将单词添加到 id2word 列表中
        word2id[w] = i  # 将单词映射到其索引的字典中
    # 返回单词索引到单词的列表和单词到索引的字典
    return id2word, word2id


# 定义函数 cosine,计算两个向量的余弦相似度
def cosine(u, v):
    return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))


# 定义 dotdict 类,继承自 dict,实现点号访问字典属性
class dotdict(dict):
    """ dot.notation access to dictionary attributes """
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


# 定义函数 get_optimizer,解析优化器参数并返回相应的优化器函数和参数字典
def get_optimizer(s):
    """
    Parse optimizer parameters.
    Input should be of the form:
        - "sgd,lr=0.01"
        - "adagrad,lr=0.1,lr_decay=0.05"
    """
    # 如果参数字符串包含逗号
    if "," in s:
        method = s[:s.find(',')]  # 提取优化方法名
        optim_params = {}  # 创建空字典,用于存储优化器参数
        # 遍历以逗号分隔的参数部分
        for x in s[s.find(',') + 1:].split(','):
            split = x.split('=')  # 根据等号分隔参数名和参数值
            assert len(split) == 2  # 确保分隔结果为参数名和参数值两部分
            assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None  # 使用正则表达式检查参数值格式
            optim_params[split[0]] = float(split[1])  # 将参数名和参数值添加到参数字典中
    else:
        method = s  # 如果参数字符串不包含逗号,则整个字符串为优化方法名
        optim_params = {}  # 空字典

    # 根据优化方法名选择相应的优化器函数
    if method == 'adadelta':
        optim_fn = optim.Adadelta
    elif method == 'adagrad':
        optim_fn = optim.Adagrad
    elif method == 'adam':
        optim_fn = optim.Adam
    elif method == 'adamax':
        optim_fn = optim.Adamax
    elif method == 'asgd':
        optim_fn = optim.ASGD
    elif method == 'rmsprop':
        optim_fn = optim.RMSprop
    elif method == 'rprop':
        optim_fn = optim.Rprop
    elif method == 'sgd':
        optim_fn = optim.SGD
        assert 'lr' in optim_params  # 确保 SGD 方法中包含学习率参数
    else:
        raise Exception('Unknown optimization method: "%s"' % method)  # 抛出异常,未知的优化方法名

    # 检查优化器函数初始化时的参数
    expected_args = inspect.getargspec(optim_fn.__init__)[0]  # 获取优化器函数初始化方法的参数列表
    assert expected_args[:2] == ['self', 'params']  # 确保参数列表的前两个参数为 self 和 params
    if not all(k in expected_args[2:] for k in optim_params.keys()):
        raise Exception('Unexpected parameters: expected "%s", got "%s"' % (
            str(expected_args[2:]), str(optim_params.keys())))  # 检查传入的参数是否符合预期

    return optim_fn, optim_params  # 返回选择的优化器函数和参数字典