from datasets import load_dataset ds = load_dataset("json", data_files="corrected_with_wang.json", split="train") query = [ds[i]['query'] for i in range(len(ds))] mapping={} mapping05={} mapping0={} for i in query: mapping[i]=[] mapping05[i]=[] mapping0[i]=[] for i in range(len(ds)): q = ds[i]['query'] r = ds[i]['response'] l = ds[i]['label'] if l > 0.9: mapping[q].append(r) elif l > 0.4: mapping05[q].append(r) else: mapping0[q].append(r) with open("../hunningtu_rule", "r") as f: content = f.read() import json obj = json.loads(content) with open("../nantong_rule", "r") as f: content = f.read() obj2 = json.loads(content) for entry in obj2: print(entry) match = [("6.2.1.1","6.3.1.1"), ("6.2.1.2", "6.3.1.2"), ("6.2.1.3", "6.3.1.3"), ("6.2.1.4", "6.3.1.4"), ("6.2.1.5", "6.3.1.5"), ("6.2.1.6", "6.3.1.6"), ("6.2.2.1", "6.3.2.1"), ("6.2.2.2", "6.3.2.2"), ("6.2.3.1", "6.3.3.1"), ("6.2.3.2", "6.3.3.2"), ("6.2.3.3", "6.3.3.3"), ("6.2.3.4", "6.3.3.4"), ("6.2.3.5", "6.3.3.5"), ("6.2.3.6", "6.3.3.6")] match2=[ ("1.1.7", "nantong1.1.7"), ("nantong2.1.2", "2.1.2"), ("nantong3.1.2", "3.1.2"), ("nantong3.1.4", "3.1.4"), ("nantong3.1.5", "3.1.5"), ("3.2.10", "nantong3.2.10"), ("nantong4.1.1", "4.1.1"), ("nantong4.1.2", "4.1.2"), ("nantong4.1.3", "4.1.3"), ("4.1.4", "nantong4.1.4"), ("4.1.5", "nantong4.1.5"), ("4.1.7", "nantong4.1.7"), ("4.4", "nantong4.4"), ("nantong6", "6"), ("7.5", "nantong7.5"), ("nantong7.8", "7.8"), ("10.1.5", "nantong10.1.5"), ("10.1.2", "nantong10.1.2"), ("10.1.1", "nantong10.1.1"), ("nantong10.1.1.2", "10.1.1.2"), ("10.1.1.3", "nantong10.1.1.3"), ("nantong11.1.2.1", "11.1.2.1"), ("nantong11.1.2.2", "11.1.2.2"), ("nantong11.1.1", "11.1.1"), ("12.7", "nantong12.7"), ("12.6", "nantong12.6"), ("nantong12.5", "12.5"), ("nantong13.1.1", "13.1.1"), ("nantong13.1.2" , "13.1.2"), ("nantong13.1.3", "13.1.3"), ("nantong13.2.2", "13.2.2"), ("nantong13.3.1", "13.3.1"), ("nantong13.3.2", "13.3.2"), ("13.3.3" ,"nantong13.3.3"), ("13.4.4", "nantong13.4.4"), ("nantong13.5.1", "13.5.1"), ("13.5.4", "nantong13.5.4"), ("nantong14.3.8", "14.3.8"), ("14.4.4", "nantong14.4.4"), ("14.4.6", "nantong14.4.6"), ("nantong15.3.1", "15.3.1"), ("16.2", "nantong16.2"), ("17.1.3.2", "nantong17.1.3.2"), ("17.1.3.3","nantong17.1.3.3"), ("17.1.3.4","nantong17.1.3.4"), ("18.3.3","nantong18.3.3"), ("18.3.2","nantong18.3.2"), ("18.5","nantong18.5"), ("18.6","nantong18.6"), ("18.15","nantong18.15"), ("20.1.1","nantong20.1.1"), ("20.1.2.1","nantong20.1.2.1"), ("20.1.2.3","nantong20.1.2.3"), ("20.1.2.5","nantong20.1.2.5"), ("21.1.1.1","nantong21.1.1.1"), ("21.1.1.2","nantong21.1.1.2"), ("21.1.3.1","nantong21.1.3.1"), ("21.1.3.2","nantong21.1.3.2"), ("21.1.3.3","nantong21.1.3.3"), ("21.1.5","nantong21.1.5"), ("21.1.6","nantong21.1.6"), ("21.1.7","nantong21.1.7"), ("23.1.2","nantong23.1.2") ] for q in mapping: correct = mapping[q] new05=[] for c in correct: for tuple_ in match: left = tuple_[0] right = tuple_[1] t = obj[left]+obj[right] if c in obj[left] or c in obj[right]: new05 = new05 + t for tuple_ in match2: left = tuple_[0] right = tuple_[1] t = obj2[left]+obj2[right] if c in obj2[left] or c in obj2[right]: new05 = new05 + t new05 = [x for x in new05 if x not in correct] zero = mapping0[q] zero2 = [x for x in zero if x not in new05] old05 = mapping05[q] old05_u = [x for x in old05 if x not in new05] old05_u = old05_u + new05 mapping05[q]=old05_u mapping0[q]=zero2 query=[] response=[] label=[] for i in mapping: for j in mapping[i]: query.append(i) response.append(j) label.append(1) for i in mapping05: for j in mapping05[i]: query.append(i) response.append(j) label.append(0.5) for i in mapping0: for j in mapping0[i]: query.append(i) response.append(j) label.append(0) from datasets import Dataset Dataset.from_dict({"query": query, "response": response, "label": label}).to_json("corrected_with_wang_modified.json") query=[] docs=[] labels=[] import random for i in mapping: query.append(i) one=list(set(mapping[i])) zero = list(set(mapping0[i])) half = list(set(mapping05[i])) if len(one) + len(zero) + len(half) < 16: pass else: if len(zero) >= len(half): zero = random.sample(zero, len(half)) else: half = random.sample(half, len(zero)) target = 16 - len(one) target = target // 2 zero = random.sample(zero, target) half = random.sample(half, target) doc=[] label=[] for entry in one: doc.append(entry) label.append(1) for entry in half: if random.random()>0: doc.append(entry) label.append(0.5) for entry in zero: if random.random() > 0: doc.append(entry) label.append(0) docs.append(doc) labels.append(label) print(i) print(len(label)) Dataset.from_dict({"query": query, "docs": docs, "labels": labels}).to_json("corrected_with_wang_lambda.json")