Skip to content

Commit c1fb802

Browse files
committed
fix: #473 normalize retrieve codes before lookup
1 parent 1d03337 commit c1fb802

2 files changed

Lines changed: 45 additions & 3 deletions

File tree

apps/base/views.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ async def create_file_code(code, **kwargs):
133133
return await FileCodes.create(code=code, **kwargs)
134134

135135

136+
def normalize_share_code(code: str) -> str:
137+
return str(code or "").strip()
138+
139+
136140
@share_api.post("/text/", dependencies=[Depends(share_required_login)])
137141
async def share_text(
138142
text: str = Form(...),
@@ -196,7 +200,10 @@ async def share_file(
196200
async def get_code_file_by_code(
197201
code: str, check: bool = True
198202
) -> Tuple[bool, Union[FileCodes, str]]:
199-
file_code = await FileCodes.filter(code=code).first()
203+
normalized_code = normalize_share_code(code)
204+
if not normalized_code:
205+
return False, "文件不存在"
206+
file_code = await FileCodes.filter(code=normalized_code).first()
200207
if not file_code:
201208
return False, "文件不存在"
202209
if await file_code.is_expired() and check:
@@ -298,10 +305,11 @@ async def select_file(data: SelectFileModel, ip: str = Depends(ip_limit["error"]
298305
@share_api.get("/download")
299306
async def download_file(key: str, code: str, ip: str = Depends(ip_limit["error"])):
300307
file_storage: FileStorageInterface = storages[settings.file_storage]()
301-
if await get_select_token(code) != key:
308+
normalized_code = normalize_share_code(code)
309+
if await get_select_token(normalized_code) != key:
302310
ip_limit["error"].add_ip(ip)
303311
raise HTTPException(status_code=403, detail="下载鉴权失败")
304-
has, file_code = await get_code_file_by_code(code, False)
312+
has, file_code = await get_code_file_by_code(normalized_code, False)
305313
if not has:
306314
return APIResponse(code=404, detail="文件不存在")
307315
assert isinstance(file_code, FileCodes)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import asyncio
2+
import unittest
3+
4+
import apps.base.views as views
5+
from apps.base.views import normalize_share_code
6+
7+
8+
class FakeFileCodes:
9+
seen_code = None
10+
11+
@classmethod
12+
def filter(cls, **kwargs):
13+
cls.seen_code = kwargs.get("code")
14+
return cls()
15+
16+
async def first(self):
17+
return None
18+
19+
20+
class RetrieveCodeTests(unittest.TestCase):
21+
def test_normalize_share_code_strips_surrounding_whitespace(self):
22+
self.assertEqual(normalize_share_code(" 12345\n"), "12345")
23+
24+
def test_get_code_file_by_code_queries_normalized_code(self):
25+
original_file_codes = views.FileCodes
26+
views.FileCodes = FakeFileCodes
27+
try:
28+
has_file, message = asyncio.run(views.get_code_file_by_code(" 12345\n"))
29+
finally:
30+
views.FileCodes = original_file_codes
31+
32+
self.assertFalse(has_file)
33+
self.assertEqual(message, "文件不存在")
34+
self.assertEqual(FakeFileCodes.seen_code, "12345")

0 commit comments

Comments
 (0)