消失了好久好久,这次换了一家公司,然后又在忙于秋招,因此很久没有更新,最近事情也告一段落,因此终于有空回来水博客,今天给大家带来最近的工作,NL2SQL数据集,我们的工作是利用代码生成大模型(类似CodeFuse系列,CodeLlama系列)进行fine-tune,通过用户query和query涉及的数据库表的Schema作为输入,使用fine-tune后的LLM进行推理来得到最后的生成SQL,当然为了工作的方便,所以我们试图将所有的开源数据集进行整合,因此在此处的NL2SQL数据集中,提供了经过模型翻译的Wiki_SQL数据集,Cspider数据集,Du_SQL数据集,如果有大佬有追一科技的数据集请告诉我,需要一些帮助,接下来首先给出NL2SQL数据集的处理脚本:
1、数据集生成(耗时13h,把8w条WikiSQL翻译了)
Data_deal_Script.py
"""codeer:Jinzhanglifunction:数据集处理和构建relation:2035877994@qq.comtime:2023/11/21 15:23"""import json,reclass Cspider_Data_make: def Cspider_Schema_load_deal(self): Schema={} All_DB=self.Cspider_Data_load("Data/Cspider/tables.json") for i in range(len(All_DB)): DB={} column_names=All_DB[i]["column_names"] table_names=All_DB[i]['table_names'] for j in range(len(table_names)): DB["_".join(re.split(" ",table_names[j]))]=[column_names[k][1] for k in range(len(column_names)) if column_names[k][0]==j] Schema[All_DB[i]["db_id"]]=DB return Schema def Cspider_Data_load(self,file_path:str): dict_data=json.loads(open(file_path,"r",encoding="utf-8").read()) return dict_data def Cspider_Schema_pipe(self,db_name:str,Table_list:list): All_Schema=self.Cspider_Schema_load_deal() result=[] Table_list=[i for i in Table_list if i not in ["("]] for i in range(len(Table_list)): result.append(All_Schema[db_name][Table_list[i]]) return result def Table_get(self,SQL_token:list)->list: Table_list=[SQL_token[i] for i in range(len(SQL_token)) if SQL_token[i-1] in ["from","join"]] return Table_list def Dict_deal(self,one_dict:dict)->dict: query=one_dict["question"] SQL=one_dict["query"] db_name=one_dict["db_id"] return {"query":query,"SQL":SQL,"table_name":"","column_name":"","db_name":db_name} def Cspider_Datas_Get(self,Cspider_data): Result=[] for i in range(len(Cspider_data)): if i not in [3097,3153]: print("=========正在处理第"+str(i)+",总共有"+str(len(Cspider_data))+"个=========") one_dict = self.Dict_deal(Cspider_data[i]) Table_list = list(set(self.Table_get(Cspider_data[i]["query_toks"]))) result = self.Cspider_Schema_pipe(one_dict["db_name"], Table_list) one_dict["table_name"] = Table_list one_dict["column_name"] = result Result.append(one_dict) return Result def Csipder_main(self): Cspider_train_data = self.Cspider_Data_load("Data/Cspider/train.json") Cspider_dev_data=self.Cspider_Data_load("Data/Cspider/dev.json") Cspider_Result=self.Cspider_Datas_Get(Cspider_train_data)+self.Cspider_Datas_Get(Cspider_dev_data) return Cspider_Resultclass wikiSQL_Data_make: def wiki_load(self,file_path): file_str=open(file_path,"r",encoding="utf-8").readlines() Dict_Data=[eval(file_str[i]) for i in range(len(file_str))] return Dict_Data def wiki_deal(self,data_path,table_path): Dict_data=self.wiki_load(data_path) Table_data=self.wiki_load(table_path) Wiki_Result,Index=[],0 Table_dict={Table_data[i]["id"]:[Table_data[i]["header"],Table_data[i]['caption']] for i in range(len(Table_data)) if "caption" in Table_data[i].keys()} for i in range(len(Dict_data)): table_id=Dict_data[i]["table_id"] all_table=Table_dict.keys() if table_id in all_table: #print("正在处理第" + str(Index) + ",总共有" + str(len(Dict_data)) + "个") Index+=1 query=Dict_data[i]["question"] table_name="_".join(re.split(" ",Table_dict[Dict_data[i]["table_id"]][1])) SQL=Dict_data[i]["sql"] column_name=Table_dict[Dict_data[i]["table_id"]][0] for j in range(len(column_name)): column=[] if "/" in column_name[j] and "(" not in column_name[j]: column_name[j]=re.split("/",column_name[j])[0] elif "(" in column_name[j]: for k in column_name[j]: if k!="(": column.append(k) else: column_name[j]=re.split(" ","".join(column)) if column_name[j][-1]=="": column_name[j]="_".join(column_name[j][0:-1]) else: column_name[j] = "_".join(column_name[j]) break elif " " in column_name[j]: column_name[j]="_".join(re.split(" ",column_name[j])) elif type(column_name[j])==list: column_name[j]=column_name[j][0] SQL=self.SQL_make(SQL,column_name,table_name) one_dict={"query": query, "SQL": SQL, "table_name": table_name, "column_name":column_name, "db_name": ""} Wiki_Result.append(one_dict) return Wiki_Result def SQL_make(self,SQL_token,column_name,table_name): agg_Action, conds_Acction= ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'],['=', '>', '<', 'OP'] SQL="SELECT "+agg_Action[SQL_token["agg"]]+" ( "+column_name[SQL_token["sel"]]+" ) "+"FROM "+table_name if len(SQL_token["conds"])==1: if type(SQL_token["conds"][0][2])!=str: SQL_token["conds"][0][2]=str(SQL_token["conds"][0][2]) SQL_token["conds"][0][1]=conds_Acction[SQL_token["conds"][0][1]] SQL_token["conds"][0][0]=column_name[SQL_token["conds"][0][0]] SQL+=" WHERE "+" ".join(SQL_token["conds"][0]) else: conds_list=SQL_token["conds"] for i in range(len(conds_list)): if type(conds_list[i][2])!=str: conds_list[i][2]=str(conds_list[i][2]) conds_list[i][0]=column_name[conds_list[i][0]] conds_list[i][1]=conds_Acction[conds_list[i][1]] for i in range(len(conds_list)): if i==len(conds_list)-1: SQL+="and "+" ".join(conds_list[i]) elif i==0: SQL+="WHERE "+" ".join(conds_list[i])+" " else: SQL+="and "+" ".join(conds_list[i])+" " return SQL def wiki_main(self): Wiki_Result=self.wiki_deal("Data/WikiSQL/train.json","Data/WikiSQL/train_tables.json") return Wiki_Resultclass DuSQL_Data_make: def DuSQL_load(self,file_path): DuSQL_data=json.loads(open(file_path,"r",encoding="utf-8").read()) return DuSQL_data def Schema_deal(self,DuSQL_schema:list[dict]): Schema_dict={} for i in range(len(DuSQL_schema)): table_names=DuSQL_schema[i]["table_names"] column_names=DuSQL_schema[i]["column_names"] Schema_dict[DuSQL_schema[i]["db_id"]]={table_names[j]:[column_names[k][1] for k in range(len(column_names)) if column_names[k][0]==j] for j in range(len(table_names))} return Schema_dict def TableGetFromSQL(self,SQL): SQL_List=re.split(" ",SQL) Table=list(set([SQL_List[i] for i in range(len(SQL_List)) if i!=0 and SQL_List[i-1] in ["from","join"]])) return Table def Query_SQL_Schema(self,DUSQL_data:list[dict],DuSQL_Schema): Result=[] for i in range(len(DUSQL_data)): print("=========正在处理第" + str(i) + ",总共有" + str(len(DUSQL_data)) + "个=========") SQL=DUSQL_data[i]["sql_query"] query=DUSQL_data[i]["question"] db_name=DUSQL_data[i]["db_id"] table=self.TableGetFromSQL(SQL)[0] column=DuSQL_Schema[db_name][table] Result.append({"query":query,"SQL":SQL,"table_name":table,"column_name":column,"db_name":db_name}) return Result def DuSQL_main(self): DuSQL_data=self.DuSQL_load("Data/DuSQL/sample-data.json") DUSQL_Schema=self.DuSQL_load("Data/DuSQL/db-schema.json") DUSQL_Schema=self.Schema_deal(DUSQL_Schema) DuSQL_Result=self.Query_SQL_Schema(DuSQL_data,DUSQL_Schema) return DuSQL_Result
用于翻译的数据接口,这里用了通义千问14B
OutAPI.py
"""codeer:Jinzhanglifunction:接入外部API服务relation:2035877994@qq.comtime:2023/11/30 15:49"""import requests,jsondef Qwen14BChat(text,history): url="http://172.16.158.247:9899/Qwen14B" data=json.dumps({"prompt":text,"history":history}) response=requests.post(url=url,data=data) response=eval(response.text) return response
接下来是主控脚本,Tune_main.py
"""codeer:Jinzhanglifunction:主控文件relation:2035877994@qq.comtime:2023/11/30 18:05"""import jsonfrom Data_Deal_Script import *from OutAPI import *def LearningDataJson_build(): wikiSQL_Data = wikiSQL_Data_make() print("开始处理WIKI_SQL") WIKI_SQL = wikiSQL_Data.wiki_main() #英文数据集翻译 for i in range(len(WIKI_SQL)): print("====翻译第"+str(i)+"个句子====") WIKI_SQL[i]["query"] = Qwen14BChat("请帮我将以下文本翻译为中文,只输出结果,不要任何解释\n"+WIKI_SQL[i]["query"],[])["response"] print(WIKI_SQL[i]["query"]) Cspider_Data = Cspider_Data_make() Dusql_Data = DuSQL_Data_make() print("开始处理DU_SQL") DU_SQL = Dusql_Data.DuSQL_main() print("开始处理Cspider") Cspider = Cspider_Data.Csipder_main() Result=DU_SQL+Cspider+WIKI_SQL with open("result.json", "w", encoding="utf-8") as json_file: json.dump(Result,json_file,ensure_ascii=False)
2、基于Swift框架的加载LoRA微调
接下来是LLM微调脚本(基于Swift框架)
首先安装阿里巴巴Swift框架
git clone https://github.com/modelscope/swift.gitcd swiftpip install -e .
然后进入Clone下来的Swift文件夹
cd ../swift/examples/pytorch/llm
使用llm下自带的脚本,也可以自己写,我比较懒直接os.system()来修改
import oscommand="""CUDA_VISIBLE_DEVICES=0 \python llm_sft.py \ --model_type qwen-14b \ --model_cache_dir /home/gpu-user1/JinzhangLi/Qwen-14B \ --sft_type lora \ --template_type default-generation \ --dtype bf16 \ --output_dir output \ --dataset dureader-robust-zh \ --train_dataset_sample -1 \ --num_train_epochs 1 \ --max_length 2048 \ --quantization_bit 4 \ --bnb_4bit_comp_dtype bf16 \ --lora_rank 8 \ --lora_alpha 32 \ --lora_dropout_p 0. \ --lora_target_modules ALL \ --gradient_checkpointing true \ --batch_size 1 \ --weight_decay 0. \ --learning_rate 1e-4 \ --gradient_accumulation_steps 16 \ --max_grad_norm 0.5 \ --warmup_ratio 0.03 \ --eval_steps 100 \ --save_steps 100 \ --save_total_limit 2 \ --logging_steps 10 \ --use_flash_attn false \ --push_to_hub false \ --hub_model_id qwen-14b-qlora \ --hub_private_repo true \ --hub_token 'your-sdk-token' """os.system(command)
3、数据集样式和链接(根据自己使用的框架微调,不出意外,后面数据集还会变大)
最后给出搞定后的NL2SQL数据集(当然数据集还得调整,只是将数据格式整理如下)
{ "query": "创刊时间不早于1989年10月10日的期刊,按出版刊数降序排列给出期刊的名称以及语言", "SQL": "select 名称 , 语言 from 期刊 where 创刊时间 >= '1989-10-10' order by 出版刊数 desc", "table_name": "期刊", "column_name": ["词条id", "名称", "语言", "类别", "主办单位", "创刊时间", "国家", "出版刊数"], "db_name": "期刊"}
如想获取数据,请访问我们在modelscope的开源地址
Text2SQL-英文-150K · 数据集 (modelscope.cn)
Text2SQL-中文-180K · 数据集 (modelscope.cn)