| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- from datasets import load_dataset
- ds = load_dataset("json", data_files="corrected_with_wang.json", split="train")
- query = [ds[i]['query'] for i in range(len(ds))]
- mapping={}
- mapping05={}
- mapping0={}
- for i in query:
- mapping[i]=[]
- mapping05[i]=[]
- mapping0[i]=[]
- for i in range(len(ds)):
- q = ds[i]['query']
- r = ds[i]['response']
- l = ds[i]['label']
- if l > 0.9:
- mapping[q].append(r)
- elif l > 0.4:
- mapping05[q].append(r)
- else:
- mapping0[q].append(r)
- with open("../hunningtu_rule", "r") as f:
- content = f.read()
- import json
- obj = json.loads(content)
- with open("../nantong_rule", "r") as f:
- content = f.read()
- obj2 = json.loads(content)
- for entry in obj2:
- print(entry)
- match = [("6.2.1.1","6.3.1.1"),
- ("6.2.1.2", "6.3.1.2"),
- ("6.2.1.3", "6.3.1.3"),
- ("6.2.1.4", "6.3.1.4"),
- ("6.2.1.5", "6.3.1.5"),
- ("6.2.1.6", "6.3.1.6"),
- ("6.2.2.1", "6.3.2.1"),
- ("6.2.2.2", "6.3.2.2"),
- ("6.2.3.1", "6.3.3.1"),
- ("6.2.3.2", "6.3.3.2"),
- ("6.2.3.3", "6.3.3.3"),
- ("6.2.3.4", "6.3.3.4"),
- ("6.2.3.5", "6.3.3.5"),
- ("6.2.3.6", "6.3.3.6")]
- match2=[
- ("1.1.7", "nantong1.1.7"),
- ("nantong2.1.2", "2.1.2"),
- ("nantong3.1.2", "3.1.2"),
- ("nantong3.1.4", "3.1.4"),
- ("nantong3.1.5", "3.1.5"),
- ("3.2.10", "nantong3.2.10"),
- ("nantong4.1.1", "4.1.1"),
- ("nantong4.1.2", "4.1.2"),
- ("nantong4.1.3", "4.1.3"),
- ("4.1.4", "nantong4.1.4"),
- ("4.1.5", "nantong4.1.5"),
- ("4.1.7", "nantong4.1.7"),
- ("4.4", "nantong4.4"),
- ("nantong6", "6"),
- ("7.5", "nantong7.5"),
- ("nantong7.8", "7.8"),
- ("10.1.5", "nantong10.1.5"),
- ("10.1.2", "nantong10.1.2"),
- ("10.1.1", "nantong10.1.1"),
- ("nantong10.1.1.2", "10.1.1.2"),
- ("10.1.1.3", "nantong10.1.1.3"),
- ("nantong11.1.2.1", "11.1.2.1"),
- ("nantong11.1.2.2", "11.1.2.2"),
- ("nantong11.1.1", "11.1.1"),
- ("12.7", "nantong12.7"),
- ("12.6", "nantong12.6"),
- ("nantong12.5", "12.5"),
- ("nantong13.1.1", "13.1.1"),
- ("nantong13.1.2" , "13.1.2"),
- ("nantong13.1.3", "13.1.3"),
- ("nantong13.2.2", "13.2.2"),
- ("nantong13.3.1", "13.3.1"),
- ("nantong13.3.2", "13.3.2"),
- ("13.3.3" ,"nantong13.3.3"),
- ("13.4.4", "nantong13.4.4"),
- ("nantong13.5.1", "13.5.1"),
- ("13.5.4", "nantong13.5.4"),
- ("nantong14.3.8", "14.3.8"),
- ("14.4.4", "nantong14.4.4"),
- ("14.4.6", "nantong14.4.6"),
- ("nantong15.3.1", "15.3.1"),
- ("16.2", "nantong16.2"),
- ("17.1.3.2", "nantong17.1.3.2"),
- ("17.1.3.3","nantong17.1.3.3"),
- ("17.1.3.4","nantong17.1.3.4"),
- ("18.3.3","nantong18.3.3"),
- ("18.3.2","nantong18.3.2"),
- ("18.5","nantong18.5"),
- ("18.6","nantong18.6"),
- ("18.15","nantong18.15"),
- ("20.1.1","nantong20.1.1"),
- ("20.1.2.1","nantong20.1.2.1"),
- ("20.1.2.3","nantong20.1.2.3"),
- ("20.1.2.5","nantong20.1.2.5"),
- ("21.1.1.1","nantong21.1.1.1"),
- ("21.1.1.2","nantong21.1.1.2"),
- ("21.1.3.1","nantong21.1.3.1"),
- ("21.1.3.2","nantong21.1.3.2"),
- ("21.1.3.3","nantong21.1.3.3"),
- ("21.1.5","nantong21.1.5"),
- ("21.1.6","nantong21.1.6"),
- ("21.1.7","nantong21.1.7"),
- ("23.1.2","nantong23.1.2")
- ]
- for q in mapping:
- correct = mapping[q]
- new05=[]
- for c in correct:
- for tuple_ in match:
- left = tuple_[0]
- right = tuple_[1]
- t = obj[left]+obj[right]
- if c in obj[left] or c in obj[right]:
- new05 = new05 + t
- for tuple_ in match2:
- left = tuple_[0]
- right = tuple_[1]
- t = obj2[left]+obj2[right]
- if c in obj2[left] or c in obj2[right]:
- new05 = new05 + t
- new05 = [x for x in new05 if x not in correct]
- zero = mapping0[q]
- zero2 = [x for x in zero if x not in new05]
- old05 = mapping05[q]
- old05_u = [x for x in old05 if x not in new05]
- old05_u = old05_u + new05
- mapping05[q]=old05_u
- mapping0[q]=zero2
- query=[]
- response=[]
- label=[]
- for i in mapping:
- for j in mapping[i]:
- query.append(i)
- response.append(j)
- label.append(1)
- for i in mapping05:
- for j in mapping05[i]:
- query.append(i)
- response.append(j)
- label.append(0.5)
-
- for i in mapping0:
- for j in mapping0[i]:
- query.append(i)
- response.append(j)
- label.append(0)
- from datasets import Dataset
- Dataset.from_dict({"query": query, "response": response, "label": label}).to_json("corrected_with_wang_modified.json")
- query=[]
- docs=[]
- labels=[]
- import random
- for i in mapping:
- query.append(i)
- one=list(set(mapping[i]))
- zero = list(set(mapping0[i]))
- half = list(set(mapping05[i]))
- if len(one) + len(zero) + len(half) < 16:
- pass
- else:
- if len(zero) >= len(half):
- zero = random.sample(zero, len(half))
- else:
- half = random.sample(half, len(zero))
- target = 16 - len(one)
- target = target // 2
- zero = random.sample(zero, target)
- half = random.sample(half, target)
- doc=[]
- label=[]
- for entry in one:
- doc.append(entry)
- label.append(1)
- for entry in half:
- if random.random()>0:
- doc.append(entry)
- label.append(0.5)
-
- for entry in zero:
- if random.random() > 0:
- doc.append(entry)
- label.append(0)
- docs.append(doc)
- labels.append(label)
- print(i)
- print(len(label))
- Dataset.from_dict({"query": query, "docs": docs, "labels": labels}).to_json("corrected_with_wang_lambda.json")
|