mirror of
https://github.com/langgenius/dify.git
synced 2025-12-28 18:42:18 +00:00
one example of Session (#24135)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
96a0b9991e
commit
25c69ac540
152
api/commands.py
152
api/commands.py
@ -10,6 +10,7 @@ from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
@ -61,31 +62,30 @@ def reset_password(email, new_password, password_confirm):
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@ -100,22 +100,21 @@ def reset_email(email, new_email, email_confirm):
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
account.email = new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command(
|
||||
@ -139,25 +138,24 @@ def reset_encrypt_key_pair():
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
tenants = session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
|
||||
tenants = db.session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
|
||||
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command("vdb-migrate", help="Migrate vector db.")
|
||||
@ -182,14 +180,15 @@ def migrate_annotation_vector_database():
|
||||
try:
|
||||
# get apps info
|
||||
per_page = 50
|
||||
apps = (
|
||||
db.session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
apps = (
|
||||
session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
if not apps:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
@ -203,26 +202,27 @@ def migrate_annotation_vector_database():
|
||||
)
|
||||
try:
|
||||
click.echo(f"Creating app annotation index: {app.id}")
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = db.session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz # pip install pytz
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
@ -70,7 +71,7 @@ class CompletionConversationApi(Resource):
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
query = db.select(Conversation).where(
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
@ -236,7 +237,7 @@ class ChatConversationApi(Resource):
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
|
||||
@ -4,6 +4,7 @@ from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
@ -211,13 +212,13 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
elif sort == "created_at":
|
||||
|
||||
@ -910,7 +910,7 @@ class AppDatasetJoin(Base):
|
||||
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
@ -931,7 +931,7 @@ class DatasetQuery(Base):
|
||||
source_app_id = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role = mapped_column(String, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetKeywordTable(Base):
|
||||
|
||||
@ -1731,7 +1731,7 @@ class MessageChain(Base):
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
input = mapped_column(sa.Text, nullable=True)
|
||||
output = mapped_column(sa.Text, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
|
||||
|
||||
class MessageAgentThought(Base):
|
||||
@ -1769,7 +1769,7 @@ class MessageAgentThought(Base):
|
||||
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
|
||||
created_by_role = mapped_column(String, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
|
||||
@property
|
||||
def files(self) -> list[Any]:
|
||||
@ -1872,7 +1872,7 @@ class DatasetRetrieverResource(Base):
|
||||
index_node_hash = mapped_column(sa.Text, nullable=True)
|
||||
retriever_from = mapped_column(sa.Text, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
from typing import TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
||||
from configs import dify_config
|
||||
@ -65,7 +66,7 @@ class AppService:
|
||||
return None
|
||||
|
||||
app_models = db.paginate(
|
||||
db.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
sa.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
page=args["page"],
|
||||
per_page=args["limit"],
|
||||
error_out=False,
|
||||
|
||||
@ -115,12 +115,12 @@ class DatasetService:
|
||||
# Check if permitted_dataset_ids is not empty to avoid WHERE false condition
|
||||
if permitted_dataset_ids and len(permitted_dataset_ids) > 0:
|
||||
query = query.where(
|
||||
db.or_(
|
||||
sa.or_(
|
||||
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
|
||||
db.and_(
|
||||
sa.and_(
|
||||
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
|
||||
),
|
||||
db.and_(
|
||||
sa.and_(
|
||||
Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
Dataset.id.in_(permitted_dataset_ids),
|
||||
),
|
||||
@ -128,9 +128,9 @@ class DatasetService:
|
||||
)
|
||||
else:
|
||||
query = query.where(
|
||||
db.or_(
|
||||
sa.or_(
|
||||
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
|
||||
db.and_(
|
||||
sa.and_(
|
||||
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
|
||||
),
|
||||
)
|
||||
@ -1879,7 +1879,7 @@ class DocumentService:
|
||||
# for notion_info in notion_info_list:
|
||||
# workspace_id = notion_info.workspace_id
|
||||
# data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
# db.and_(
|
||||
# sa.and_(
|
||||
# DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
# DataSourceOauthBinding.provider == "notion",
|
||||
# DataSourceOauthBinding.disabled == False,
|
||||
|
||||
@ -471,7 +471,7 @@ class PluginMigration:
|
||||
total_failed_tenant = 0
|
||||
while True:
|
||||
# paginate
|
||||
tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
|
||||
tenants = db.paginate(sa.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
|
||||
if tenants.items is None or len(tenants.items) == 0:
|
||||
break
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -18,7 +19,7 @@ class TagService:
|
||||
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
|
||||
)
|
||||
if keyword:
|
||||
query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||
query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -51,7 +52,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
db.and_(
|
||||
sa.and_(
|
||||
DataSourceOauthBinding.tenant_id == document.tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user