|
|
@@ -0,0 +1,151 @@
|
|
|
+from celery_app import celery_app
|
|
|
+import json
|
|
|
+import chromadb
|
|
|
+client = chromadb.HttpClient(host='47.101.198.30',port=8000)
|
|
|
+collection = client.get_or_create_collection(name="tj_de_bge")
|
|
|
+from FlagEmbedding import FlagModel
|
|
|
+model = FlagModel('/Users/zxp/Downloads/test2_encoder_only_base_bge-large-zh-v1.5')
|
|
|
+from sentence_transformers import CrossEncoder
|
|
|
+ce = CrossEncoder('/Users/zxp/Downloads/reranker')
|
|
|
+with open("hunningtu_rule", "r") as f:
|
|
|
+ content = f.read()
|
|
|
+obj = json.loads(content)
|
|
|
+with open("nantong_rule", "r") as f:
|
|
|
+ content = f.read()
|
|
|
+obj2 = json.loads(content)
|
|
|
+with open("basic_rule", "r") as f:
|
|
|
+ content = f.read()
|
|
|
+basic = json.loads(content)
|
|
|
+with open("incremental_rule", "r") as f:
|
|
|
+ content = f.read()
|
|
|
+incremental = json.loads(content)
|
|
|
+with open("label_name", "r") as f:
|
|
|
+ content = f.read()
|
|
|
+label_name = json.loads(content)
|
|
|
+THRESHOLD=0.9####adjust it
|
|
|
+
|
|
|
+
|
|
|
+@celery_app.task
|
|
|
+def process_data(data:dict)-> dict:
|
|
|
+ label = data['mc'] + ' ' + data['tz']
|
|
|
+ sentences = [label]
|
|
|
+ embeddings = model.encode(sentences)
|
|
|
+ result = collection.query(query_embeddings=embeddings,n_results=25)
|
|
|
+ d = result['documents'][0]
|
|
|
+ print(d)
|
|
|
+ ranks = ce.rank(label, d)
|
|
|
+ ranks = ranks[:10]
|
|
|
+ match = [("6.2.1.1","6.3.1.1"),
|
|
|
+ ("6.2.1.2", "6.3.1.2"),
|
|
|
+ ("6.2.1.3", "6.3.1.3"),
|
|
|
+ ("6.2.1.4", "6.3.1.4"),
|
|
|
+ ("6.2.1.5", "6.3.1.5"),
|
|
|
+ ("6.2.1.6", "6.3.1.6"),
|
|
|
+ ("6.2.2.1", "6.3.2.1"),
|
|
|
+ ("6.2.2.2", "6.3.2.2"),
|
|
|
+ ("6.2.3.1", "6.3.3.1"),
|
|
|
+ ("6.2.3.2", "6.3.3.2"),
|
|
|
+ ("6.2.3.3", "6.3.3.3"),
|
|
|
+ ("6.2.3.4", "6.3.3.4"),
|
|
|
+ ("6.2.3.5", "6.3.3.5"),
|
|
|
+ ("6.2.3.6", "6.3.3.6")]
|
|
|
+ match2=[
|
|
|
+ ("1.1.7", "nantong1.1.7"),
|
|
|
+ ("nantong2.1.2", "2.1.2"),
|
|
|
+ ("nantong3.1.2", "3.1.2"),
|
|
|
+ ("nantong3.1.4", "3.1.4"),
|
|
|
+ ("nantong3.1.5", "3.1.5"),
|
|
|
+ ("3.2.10", "nantong3.2.10"),
|
|
|
+ ("nantong4.1.1", "4.1.1"),
|
|
|
+ ("nantong4.1.2", "4.1.2"),
|
|
|
+ ("nantong4.1.3", "4.1.3"),
|
|
|
+ ("4.1.4", "nantong4.1.4"),
|
|
|
+ ("4.1.5", "nantong4.1.5"),
|
|
|
+ ("4.1.7", "nantong4.1.7"),
|
|
|
+ ("4.4", "nantong4.4"),
|
|
|
+ ("nantong6", "6"),
|
|
|
+ ("7.5", "nantong7.5"),
|
|
|
+ ("nantong7.8", "7.8"),
|
|
|
+ ("10.1.5", "nantong10.1.5"),
|
|
|
+ ("10.1.2", "nantong10.1.2"),
|
|
|
+ ("10.1.1", "nantong10.1.1"),
|
|
|
+ ("nantong10.1.1.2", "10.1.1.2"),
|
|
|
+ ("10.1.1.3", "nantong10.1.1.3"),
|
|
|
+ ("nantong11.1.2.1", "11.1.2.1"),
|
|
|
+ ("nantong11.1.2.2", "11.1.2.2"),
|
|
|
+ ("nantong11.1.1", "11.1.1"),
|
|
|
+ ("12.7", "nantong12.7"),
|
|
|
+ ("12.6", "nantong12.6"),
|
|
|
+ ("nantong12.5", "12.5"),
|
|
|
+ ("nantong13.1.1", "13.1.1"),
|
|
|
+ ("nantong13.1.2" , "13.1.2"),
|
|
|
+ ("nantong13.1.3", "13.1.3"),
|
|
|
+ ("nantong13.2.2", "13.2.2"),
|
|
|
+ ("nantong13.3.1", "13.3.1"),
|
|
|
+ ("nantong13.3.2", "13.3.2"),
|
|
|
+ ("13.3.3" ,"nantong13.3.3"),
|
|
|
+ ("13.4.4", "nantong13.4.4"),
|
|
|
+ ("nantong13.5.1", "13.5.1"),
|
|
|
+ ("13.5.4", "nantong13.5.4"),
|
|
|
+ ("nantong14.3.8", "14.3.8"),
|
|
|
+ ("14.4.4", "nantong14.4.4"),
|
|
|
+ ("14.4.6", "nantong14.4.6"),
|
|
|
+ ("nantong15.3.1", "15.3.1"),
|
|
|
+ ("16.2", "nantong16.2"),
|
|
|
+ ("17.1.3.2", "nantong17.1.3.2"),
|
|
|
+ ("17.1.3.3","nantong17.1.3.3"),
|
|
|
+ ("17.1.3.4","nantong17.1.3.4"),
|
|
|
+ ("18.3.3","nantong18.3.3"),
|
|
|
+ ("18.3.2","nantong18.3.2"),
|
|
|
+ ("18.5","nantong18.5"),
|
|
|
+ ("18.6","nantong18.6"),
|
|
|
+ ("18.15","nantong18.15"),
|
|
|
+ ("20.1.1","nantong20.1.1"),
|
|
|
+ ("20.1.2.1","nantong20.1.2.1"),
|
|
|
+ ("20.1.2.3","nantong20.1.2.3"),
|
|
|
+ ("20.1.2.5","nantong20.1.2.5"),
|
|
|
+ ("21.1.1.1","nantong21.1.1.1"),
|
|
|
+ ("21.1.1.2","nantong21.1.1.2"),
|
|
|
+ ("21.1.3.1","nantong21.1.3.1"),
|
|
|
+ ("21.1.3.2","nantong21.1.3.2"),
|
|
|
+ ("21.1.3.3","nantong21.1.3.3"),
|
|
|
+ ("21.1.5","nantong21.1.5"),
|
|
|
+ ("21.1.6","nantong21.1.6"),
|
|
|
+ ("21.1.7","nantong21.1.7"),
|
|
|
+ ("23.1.2","nantong23.1.2")
|
|
|
+ ]
|
|
|
+ selected=[]
|
|
|
+ notselected=[]
|
|
|
+ for entry in incremental:
|
|
|
+ notselected = notselected + incremental[entry]
|
|
|
+ for rank in ranks:
|
|
|
+ if rank['score']<THRESHOLD:
|
|
|
+ continue
|
|
|
+ if d[rank['corpus_id']] in notselected:
|
|
|
+ continue
|
|
|
+ print(f"{rank['score']} {d[rank['corpus_id']]}")
|
|
|
+ selected.append(d[rank['corpus_id']])
|
|
|
+ hunningtu_group = []
|
|
|
+ for entry in obj:
|
|
|
+ if d[rank['corpus_id']] in obj[entry]:
|
|
|
+ hunningtu_group=[entry]
|
|
|
+ if len(hunningtu_group) > 0:
|
|
|
+ for entry in match:
|
|
|
+ if entry[0]==hunningtu_group[0] or entry[1] == hunningtu_group[0]:
|
|
|
+ notselected = notselected + obj[entry[0]]
|
|
|
+ notselected = notselected + obj[entry[1]]
|
|
|
+ nantong_group = []
|
|
|
+ for entry in obj2:
|
|
|
+ if d[rank['corpus_id']] in obj2[entry]:
|
|
|
+ nantong_group=[entry]
|
|
|
+ if len(nantong_group) > 0:
|
|
|
+ for entry in match2:
|
|
|
+ if entry[0]==nantong_group[0] or entry[1] == nantong_group[0]:
|
|
|
+ notselected = notselected + obj2[entry[0]]
|
|
|
+ notselected = notselected + obj2[entry[1]]
|
|
|
+ for entry in basic:
|
|
|
+ if d[rank['corpus_id']] in basic[entry]:
|
|
|
+ notselected = notselected + basic[entry]
|
|
|
+ notselected = [x for x in notselected if x not in selected]
|
|
|
+ result = [label_name[x] for x in selected]
|
|
|
+ return {"result": result}
|