extraction sqlquery (#4027)
clone https://github.com/infiniflow/ragflow/pull/4023 improve the information extraction, most llm return results in markdown format ````sql ___ query `____ ```
This commit is contained in:
parent
4a7bc4df92
commit
7ddccbb952
@ -20,7 +20,7 @@ import pymysql
|
|||||||
import psycopg2
|
import psycopg2
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
import pyodbc
|
import pyodbc
|
||||||
|
import logging
|
||||||
|
|
||||||
class ExeSQLParam(ComponentParamBase):
|
class ExeSQLParam(ComponentParamBase):
|
||||||
"""
|
"""
|
||||||
@ -65,13 +65,26 @@ class ExeSQL(ComponentBase, ABC):
|
|||||||
self._loop += 1
|
self._loop += 1
|
||||||
|
|
||||||
ans = self.get_input()
|
ans = self.get_input()
|
||||||
|
|
||||||
|
|
||||||
ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else ""
|
ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else ""
|
||||||
ans = re.sub(r'^.*?SELECT ', 'SELECT ', repr(ans), flags=re.IGNORECASE)
|
if self._param.db_type == 'mssql':
|
||||||
|
# improve the information extraction, most llm return results in markdown format ```sql query ```
|
||||||
|
match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
ans = match.group(1) # Query content
|
||||||
|
print(ans)
|
||||||
|
else:
|
||||||
|
print("no markdown")
|
||||||
|
ans = re.sub(r'^.*?SELECT ', 'SELECT ', (ans), flags=re.IGNORECASE)
|
||||||
|
else:
|
||||||
|
ans = re.sub(r'^.*?SELECT ', 'SELECT ', repr(ans), flags=re.IGNORECASE)
|
||||||
ans = re.sub(r';.*?SELECT ', '; SELECT ', ans, flags=re.IGNORECASE)
|
ans = re.sub(r';.*?SELECT ', '; SELECT ', ans, flags=re.IGNORECASE)
|
||||||
ans = re.sub(r';[^;]*$', r';', ans)
|
ans = re.sub(r';[^;]*$', r';', ans)
|
||||||
if not ans:
|
if not ans:
|
||||||
raise Exception("SQL statement not found!")
|
raise Exception("SQL statement not found!")
|
||||||
|
|
||||||
|
logging.info("db_type: ",self._param.db_type)
|
||||||
if self._param.db_type in ["mysql", "mariadb"]:
|
if self._param.db_type in ["mysql", "mariadb"]:
|
||||||
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
||||||
port=self._param.port, password=self._param.password)
|
port=self._param.port, password=self._param.password)
|
||||||
@ -96,11 +109,12 @@ class ExeSQL(ComponentBase, ABC):
|
|||||||
if not single_sql:
|
if not single_sql:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
|
logging.info("single_sql: ",single_sql)
|
||||||
cursor.execute(single_sql)
|
cursor.execute(single_sql)
|
||||||
if cursor.rowcount == 0:
|
if cursor.rowcount == 0:
|
||||||
sql_res.append({"content": "\nTotal: 0\n No record in the database!"})
|
sql_res.append({"content": "\nTotal: 0\n No record in the database!"})
|
||||||
continue
|
continue
|
||||||
single_res = pd.DataFrame([i for i in cursor.fetchmany(size=self._param.top_n)])
|
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.top_n)])
|
||||||
single_res.columns = [i[0] for i in cursor.description]
|
single_res.columns = [i[0] for i in cursor.description]
|
||||||
sql_res.append({"content": "\nTotal: " + str(cursor.rowcount) + "\n" + single_res.to_markdown()})
|
sql_res.append({"content": "\nTotal: " + str(cursor.rowcount) + "\n" + single_res.to_markdown()})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user