科大讯飞-群聊对话角色要素提取:不微调范式模拟官网评分

news/2024/7/18 0:29:05 标签: python

不微调范式模拟官网评分

    • step1: 模型api配置及加载测试
    • step2: 数据加载与数据分析:
    • 测试集分析:
    • step3: prompt设计:
    • step4 :大模型推理:
    • step 5: 结果评分测试:
      • 评分细则:
      • 评估指标
  • 参考:

比赛说明:
#AI夏令营 #Datawhale #夏令营
主要参考datawhale夏令营活动:零基础入门大模型技术竞赛。
连接:https://datawhaler.feishu.cn/wiki/VIy8ws47ii2N79kOt9zcXnbXnuS
比赛网址:
https://challenge.xfyun.cn/topic/info?type=role-element-extraction&ch=dw24_y0SCtd
在这里插入图片描述

说明:
- 1,主要适用于不微调的范式。
- 2,针对每次在修改prompt,或者COT之后,想要查看性能如何时,都要提交到官网等待。但是受限于官网每个人每天只能提交3次,无法得到更多的反馈。
- 3,在这里主要从训练集train.json中,随机挑选数据,作为验证集,模仿官网的评分细则,用于验证性能指标。当验证性能满意后,再放到test.json数据进行推理,并提交官网。

步骤:
- 1,模型api配置及加载测试:
- 2,数据加载:加载训练集,数据预处理,数据分析,可设置验证集比例
- 3,prompt设计:提示工程或者COT的方式,根据数据分析设计提示;
- 4,模型推理:输出符合格式预测;
- 5,结果测试:采用与讯飞比赛官网相同的得分计算策略。

为了方便展示,参考群里某大佬画的不微调范式的概要图:
在这里插入图片描述

step1: 模型api配置及加载测试

python"># api配置

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import json

#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = ''
SPARKAI_API_SECRET = ''
SPARKAI_API_KEY = ''
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'
python"># 模型对话测试

def get_completions(text):
    messages = [ChatMessage(
        role="user",
        content=text
    )]
    spark = ChatSparkLLM(
        spark_api_url=SPARKAI_URL,
        spark_app_id=SPARKAI_APP_ID,
        spark_api_key=SPARKAI_API_KEY,
        spark_api_secret=SPARKAI_API_SECRET,
        spark_llm_domain=SPARKAI_DOMAIN,
        streaming=False,
    )
    handler = ChunkPrintHandler()
    a = spark.generate([messages], callbacks=[handler])
    return a.generations[0][0].text

# 测试模型配置是否正确
text = "你好,请问你是谁?"
get_completions(text)

step2: 数据加载与数据分析:

加载训练集,部分化为验证集,可设置验证集比例

python">def read_json(json_file_path):
    """读取json文件"""
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    return data

def write_json(json_file_path, data):
    """写入json文件"""
    with open(json_file_path, 'w') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

# 读取数据
train_data = read_json("dataset/train.json")
print('done!')
python"># 查看数据格式
print(train_data[1]['chat_text'])
python"># 简单的数据清洗:将对话的无关信息删除。不要让“[图片]”这种信息干扰;
# 如:[链接],[图片],[玫瑰],以及:????上线功能,H5红包,【收集表】2023年度满意度评价等与内容无关的字段去除掉。
import re

def clean_chat_text(chat_text):
    # 定义正则表达式用于匹配链接、图片、特殊表情和无关字段
    patterns = [
        r"【收集表】 2023年度服务满意度评价",
        r"https?://\S+",
        r"\{[\w\W]*?\}",
        r'\[.*?\]'
    ]

    # 移除匹配到的内容
    for pattern in patterns:
        chat_text = re.sub(pattern, '', chat_text)
    # 移除多余的空格和换行符
    chat_text = re.sub(r'\n+', '\n', chat_text).strip()
    return chat_text

# 遍历每个样本,清洗chat_text字段
for sample in train_data:
    if "chat_text" in sample:
        sample["chat_text"] = clean_chat_text(sample["chat_text"])
python"># 验证集划分
import json
import random

def split_data(data, validation_size):
    #  # 随机打乱数据
    # random.shuffle(data)
    # 划分数据
    validation_data = data[:validation_size]
    train_data = data[validation_size:]
    
    return train_data, validation_data

validation_size = 50
# 划分数据
train_data, validation_data = split_data(train_data, validation_size)

测试集分析:

训练集数据分析,假设训练集与测试集同分布。
参考:https://qixiangxingqiu.feishu.cn/wiki/V4duwxVzkipnHjk7djzcFykbnbf
这里面进行了详细的数据分析:
要点:

  1. 姓名都是做选择题的,可以先人工提取所有的姓名再让模型做选择;
  2. 很多类别都是多分类任务,如:咨询类型、意向产品、购买异议点、客户购买阶段、客户是否有意向、客户是否有卡点,能出现的都是固定的那几个。
  3. 在训练集中,年龄、生日、竞品信息都是100%为空的,测试集就不确定了。(我猜测也是空)
    在这里插入图片描述
  4. 一些栏目的数据分布:
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    咨询类型:
    在这里插入图片描述
    意向产品
    在这里插入图片描述
    购买异议点:
    在这里插入图片描述

step3: prompt设计:

提示工程或者COT的方式;

python"># prompt 设计
PROMPT_EXTRACT = """
你将获得一段群聊对话记录。你的任务是根据给定的表单格式从对话记录中提取结构化信息。在提取信息时,请确保它与类型信息完全匹配,不要添加任何没有出现在下面模式中的属性。

表单格式如下:
info: Array<Dict(
    "基本信息-姓名": string | "",  // 客户的姓名。
    "基本信息-手机号码": string | "",  // 客户的手机号码。
    "基本信息-邮箱": string | "",  // 客户的电子邮箱地址。
    "基本信息-地区": string | "",  // 客户所在的地区或城市。
    "基本信息-详细地址": string | "",  // 客户的详细地址。
    "基本信息-性别": string | "",  // 客户的性别。
    "基本信息-年龄": string | "",  // 客户的年龄。
    "基本信息-生日": string | "",  // 客户的生日。
    "咨询类型": string[] | [],  // 客户的咨询类型,如询价、答疑等。
    "意向产品": string[] | [],  // 客户感兴趣的产品。
    "购买异议点": string[] | [],  // 客户在购买过程中提出的异议或问题。
    "客户预算-预算是否充足": string | "",  // 客户的预算是否充足。示例:充足, 不充足
    "客户预算-总体预算金额": string | "",  // 客户的总体预算金额。
    "客户预算-预算明细": string | "",  // 客户预算的具体明细。
    "竞品信息": string | "",  // 竞争对手的信息。
    "客户是否有意向": string | "",  // 客户是否有购买意向。示例:有意向, 无意向
    "客户是否有卡点": string | "",  // 客户在购买过程中是否遇到阻碍或卡点。示例:有卡点, 无卡点
    "客户购买阶段": string | "",  // 客户当前的购买阶段,如合同中、方案交流等。
    "下一步跟进计划-参与人": string[] | [],  // 下一步跟进计划中涉及的人员(客服人员)。
    "下一步跟进计划-时间点": string | "",  // 下一步跟进的时间点。
    "下一步跟进计划-具体事项": string | ""  // 下一步需要进行的具体事项。
)>

请分析以下群聊对话记录,并根据上述格式提取信息:

对话记录:
{content}
请将提取的信息以JSON格式输出。
不要添加任何澄清信息。
输出必须遵循上面的模式。
不要添加任何没有出现在模式中的附加字段。
不要随意删除字段。

**输出:**
[{{
    "基本信息-姓名": "姓名",
    "基本信息-手机号码": "手机号码",
    "基本信息-邮箱": "邮箱",
    "基本信息-地区": "地区",
    "基本信息-详细地址": "详细地址",
    "基本信息-性别": "性别",
    "基本信息-年龄": "年龄",
    "基本信息-生日": "生日",
    "咨询类型": ["咨询类型"],
    "意向产品": ["意向产品"],
    "购买异议点": ["购买异议点"],
    "客户预算-预算是否充足": "充足或不充足",
    "客户预算-总体预算金额": "总体预算金额",
    "客户预算-预算明细": "预算明细",
    "竞品信息": "竞品信息",
    "客户是否有意向": "有意向或无意向",
    "客户是否有卡点": "有卡点或无卡点",
    "客户购买阶段": "购买阶段",
    "下一步跟进计划-参与人": ["跟进计划参与人"],
    "下一步跟进计划-时间点": "跟进计划时间点",
    "下一步跟进计划-具体事项": "跟进计划具体事项"
}}]

改进一下:在原来的基础上加上下面这段,根据数据分析得来的。

请分析以下群聊对话记录,并根据上述格式提取信息。根据表单的形式,他们具有不同的提取方式,如下:
1,"基本信息-姓名":直接从对话记录中提取询问问题的人物姓名;
2,"基本信息-手机号码":直接从对话记录中提取询问问题的人的手机号码;
3,"基本信息-邮箱":直接从对话记录中提取询问问题的人的邮箱地址,如果不存在则返回空字符串;
4,"基本信息-地区":直接从对话记录中提取询问问题的人所在的城市或地区;
5,"基本信息-详细地址":直接从对话记录中提取询问问题的人的详细地址,如果不存在则返回空字符串;
5,基本信息-性别":客服的性别,一般是空字符;
6,"基本信息-年龄":尝试提取年龄相关的描述,一般为空字符串;
7,"基本信息-生日":尝试提取生日相关的描述,一般为空字符串;
8,"咨询类型":从对话记录中提取关键信息并判断是下面哪一种:答疑,询价,吐槽,答疑和吐槽;
9,"意向产品":从对话记录中提取客户感兴趣的产品,并从下面选项中选择:会话存档,高级版,CRM,开放接口,商城,标准版,定制版,AI,运营服务,会话存档、标准版。
10,"购买异议点":从对话记录中提取客户在购买过程中提出的问题,并从下面选项中选择:产品功能,客户内部问题,价格,工时,竞品。
11,"客户预算-预算是否充足":从对话记录中提取客户预算是否充足,并从下面选项中选择:充足, 不充足。
12,"客户预算-总体预算金额":从对话记录中提取客户预算的总额,如果存在则返回数字,如果不存在则返回空字符串。
13,"客户预算-预算明细":从对话记录中提取客户预算的具体明细,如果不存在则返回空字符串。
14,"竞品信息":从对话记录中提取竞争对手的信息,如果不存在则返回空字符串,一般是没有竞品信息。
15,"客户是否有意向":从对话记录中提取客户是否有购买意向,并从下面选项中选择:有意向, 无意愿。
16,"客户是否有卡点":从对话记录中提取客户在购买过程中是否遇到阻碍或卡点,并从下面选项中选择:有卡点, 无卡点。
17,"客户购买阶段":从对话记录中提取客户当前的购买阶段,如果客户在购买阶段,则从下面选项中选择:赢单,方案交流,续费,项目搁置,合同中,报价, 需求调研,等待结果, 输单。
18,"下一步跟进计划-参与人":从对话记录中提取下一步跟进计划中涉及的人员(客服人员),如果不存在则返回空字符串。
19,"下一步跟进计划-时间点":从对话记录中提取下一步跟进的时间点,如果不存在则返回空字符串。
20,"下一步跟进计划-具体事项":从对话记录中提取下一步进行的具体事项,如果不存在则返回空字符串。

step4 :大模型推理:

输出符合格式预测;以方便与标签对比;

python"># 大模型输出格式转换
import json

class JsonFormatError(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)

def convert_all_json_in_text_to_dict(text):
    """提取LLM输出文本中的json字符串"""
    dicts, stack = [], []
    for i in range(len(text)):
        if text[i] == '{':
            stack.append(i)
        elif text[i] == '}':
            begin = stack.pop()
            if not stack:
                dicts.append(json.loads(text[begin:i+1]))
    return dicts

# 查看对话标签
def print_json_format(data):
    """格式化输出json格式"""
    print(json.dumps(data, indent=4, ensure_ascii=False))

# 对大模型抽取的结果进行字段格式的检查以及缺少的字段进行补全
def check_and_complete_json_format(data):
    required_keys = {
        "基本信息-姓名": str,
        "基本信息-手机号码": str,
        "基本信息-邮箱": str,
        "基本信息-地区": str,
        "基本信息-详细地址": str,
        "基本信息-性别": str,
        "基本信息-年龄": str,
        "基本信息-生日": str,
        "咨询类型": list,
        "意向产品": list,
        "购买异议点": list,
        "客户预算-预算是否充足": str,
        "客户预算-总体预算金额": str,
        "客户预算-预算明细": str,
        "竞品信息": str,
        "客户是否有意向": str,
        "客户是否有卡点": str,
        "客户购买阶段": str,
        "下一步跟进计划-参与人": list,
        "下一步跟进计划-时间点": str,
        "下一步跟进计划-具体事项": str
    }

    if not isinstance(data, list):
        raise JsonFormatError("Data is not a list")

    for item in data:
        if not isinstance(item, dict):
            raise JsonFormatError("Item is not a dictionary")
        for key, value_type in required_keys.items():
            if key not in item:
                item[key] = [] if value_type == list else ""
            if not isinstance(item[key], value_type):
                raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}")
            if value_type == list and not all(isinstance(i, str) for i in item[key]):
                raise JsonFormatError(f"Key '{key}' does not contain all strings in the list")

    return data
python"># 大模型推理

from tqdm import tqdm

retry_count = 5 # 重试次数
result = []
error_data = []
labels = []

for index, data in tqdm(enumerate(validation_data)):
    index += 1
    is_success = False
    for i in range(retry_count):
        try:
            res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
            label = data['infos']

            infos = convert_all_json_in_text_to_dict(res)
            infos = check_and_complete_json_format(infos)
            result.append({
                "infos": infos,
                "index": index
            })
            labels.append({
                "infos":label,
                "index": index
            })
            is_success = True
            break
        except Exception as e:
            print("index:", index, ", error:", e)
            continue
    if not is_success:
        data["index"] = index
        error_data.append(data)

step 5: 结果评分测试:

结果测试:采用与讯飞比赛官网相同的得分计算策略。

评分细则:

满分36分。
按照各类字段提取的难易程度,共设置了1、2、3三种难度分数。
具体待提取的字段以及提取正确时的得分规则,如下链接:https://challenge.xfyun.cn/topic/info?type=role-element-extraction&ch=j4XWs7V

评估指标

测试集的每条数据同样包含共21个字段, 按照各字段难易程度划分总计满分36分。每个提取正确性的判定标准如下:
1)对于答案唯一字段,将使用完全匹配的方式计算提取是否正确,提取正确得到相应分数,否则为0分
2)对于答案不唯一字段,将综合考虑提取完整性、语义相似度等维度判定提取的匹配分数,最终该字段得分为 “匹配分数 * 该字段难度分数”
每条测试数据的最终得分为各字段累计得分。最终测试集上的分数为所有测试数据的平均得分。

python"># 按照给定的评分规则计算每个样本的得分,然后计算所有样本平均得分
import json

# 定义字段及其难度分数,是否单值;
fields = [
    ('基本信息-姓名', 1, True),
    ('基本信息-手机号码', 1, True),
    ('基本信息-邮箱', 1, True),
    ('基本信息-地区', 1, True),
    ('基本信息-详细地址', 1, True),
    ('基本信息-性别', 1, True),
    ('基本信息-年龄', 1, True),
    ('基本信息-生日', 1, True),
    ('咨询类型', 2, False),
    ('意向产品', 3, False),
    ('购买异议点', 3, False),
    ('客户预算-预算是否充足', 2, True),
    ('客户预算-总体预算金额', 2, True),
    ('客户预算-预算明细', 3, True),
    ('竞品信息', 2, True),
    ('客户是否有意向', 1, True),
    ('客户是否有卡点', 1, True),
    ('客户购买阶段', 2, True),
    ('下一步跟进计划-参与人', 2, False),
    ('下一步跟进计划-时间点', 2, True),
    ('下一步跟进计划-具体事项', 3, True)
]
python">from difflib import SequenceMatcher

# 计算单值字段得分
def calculate_single_value_score(pred, true, score):
    return score if pred == true else 0

# 计算文本相似度得分
def calculate_text_similarity_score(pred, true):
    return SequenceMatcher(None, pred, true).ratio()

# 计算多值字段的得分
def calculate_multi_value_score(pred, true, score):
    if isinstance(pred, list) and isinstance(true, list):
        pred_str = ' '.join(pred)
        true_str = ' '.join(true)
    else:
        pred_str = str(pred)
        true_str = str(true)
    
    similarity_score = calculate_text_similarity_score(pred_str, true_str)
    return similarity_score * score

# 计算样本得分
def calculate_sample_score(pred, true):
    total_score = 0
    for field, score, is_single_value in fields:
        pred_value = pred.get(field, "")
        true_value = true.get(field, "")
        
        if is_single_value:
            total_score += calculate_single_value_score(pred_value, true_value, score)
        else:
            total_score += calculate_multi_value_score(pred_value, true_value, score)
    return total_score

# 
# 计算平均得分
def calculate_average_score(predictions, truths):
    total_score = 0
    num_samples = len(predictions)
    

    for i in range(num_samples):
        pred = predictions[i]
        true = truths[i]
        if not pred['infos'] or not true['infos']:
            continue
        # print(pred)
        total_score += calculate_sample_score(pred['infos'][0], true['infos'][0])

    # for pred, true in zip(predictions, truths):
    #     total_score += calculate_sample_score(pred['infos'][0], true['infos'][0])
    
    return total_score / num_samples
python"># 计算
average_score = calculate_average_score(result, labels)

print("average score:", average_score)

计算得分是21分:
在这里插入图片描述

官网评分:20分,感觉相差3分以内吧。
模拟官网的评分策略还是需要优化的。
在这里插入图片描述

参考:

datawhale夏令营活动:
https://datawhaler.feishu.cn/wiki/VIy8ws47ii2N79kOt9zcXnbXnuS


http://www.niftyadmin.cn/n/5543952.html

相关文章

WEB开发-页面跳转

1 需求 HTML a标签的href属性form标签的action属性 JavaScript window.location.hrefwindow.location.replace()window.open() PHP header() 2 接口 3 示例 在HTML中&#xff0c;实现页面跳转的方法有多种。以下是一些常用的方法&#xff1a; 使用 <a> 标签的 hre…

python项目转化为so文件-最终版

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、python项目转化为so文件1.引入库 总结 前言 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; 提示&#xff1a;以下是本篇文章正文内容&am…

洛杉矶裸机云大宽带服务器的特性和优势

洛杉矶裸机云大宽带服务器是结合了物理服务器性能和云服务灵活性的高性能计算服务&#xff0c;为用户提供高效、安全的计算和存储能力。在了解如何使用洛杉矶裸机云大宽带服务器之前&#xff0c;需要了解其基本特性和优势。以下是对洛杉矶裸机云大宽带服务器的具体分析&#xf…

动态数据炼金术:在Mojo模型中自定义数据预处理的艺术

动态数据炼金术&#xff1a;在Mojo模型中自定义数据预处理的艺术 在机器学习工作流程中&#xff0c;数据预处理是至关重要的一步&#xff0c;它直接影响到模型训练的效果和最终性能。Mojo模型&#xff0c;作为H2O.ai提供的一种模型部署格式&#xff0c;主要用于模型的序列化和…

巴图自动化Profinet协议转Modbus协议模块接称重模块与PLC通讯

巴图自动化Profinet协议转Modbus协议模块&#xff08;BT-MDPN10&#xff09;是一种能够实现Modbus协议和Profinet协议之间转换的设备。Profinet协议转Modbus协议模块可提供单个或多个RS485接口&#xff0c;使得不同设备之间可以顺利进行通信&#xff0c;进一步提升了工业自动化…

PDM系统中物料分类与编码规则生成方案

在企业管理软件中&#xff0c;PDM系统是企业管理的前端软件&#xff0c;用于管理研发图纸、BOM等数据&#xff0c;然后生成相关物料表或BOM&#xff0c;递交给后端ERP系统进行生产管理。在PDM系统中&#xff0c;有两种方式可以生成物料编码。 1第一种是用户可以通过软件接口将…

如何在Linux上删除Systemd服务

Systemd是Linux 操作系统的系统和服务管理器&#xff0c;提供控制系统启动时启动哪些服务的标准流程。 有时&#xff0c;您可能出于各种原因需要删除systemd服务&#xff0c;例如不再需要、与其他服务冲突&#xff0c;或者您只是想清理系统。 Systemd使用单元文件来管理服务&…

机械键盘如何挑选

机械键盘的选择是一个关键的决策&#xff0c;因为它直接影响到我们每天的打字体验。在选择机械键盘时&#xff0c;有几个关键因素需要考虑。首先是键盘的键轴类型。常见的键轴类型包括蓝轴、红轴、茶轴和黑轴等。不同的键轴类型具有不同的触发力、触发点和声音。蓝轴通常具有明…