Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions compyle/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,27 @@ def has_return(code):
"""Returns True of the node has a return statement.
"""
return has_node(code, ast.Return)


def is_str_node(node):
"""Return True if the AST node represents a string literal.

Works on both older ASTs (ast.Str) and newer ones (ast.Constant).
"""
if isinstance(node, ast.Constant):
return isinstance(node.value, str)
Str = getattr(ast, 'Str', None)
return Str is not None and isinstance(node, Str)


def get_str_value(node):
"""Return the string value for a string AST node.

Returns None if the node is not a string node.
"""
if isinstance(node, ast.Constant):
return node.value if isinstance(node.value, str) else None
Str = getattr(ast, 'Str', None)
if Str is not None and isinstance(node, Str):
return node.s
return None
9 changes: 5 additions & 4 deletions compyle/cython_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .types import KnownType, Undefined, get_declare_info
from .config import get_config
from .ast_utils import get_assigned, has_return
from .ast_utils import get_assigned, has_return, get_str_value, is_str_node
from .utils import getsourcelines

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -247,11 +247,12 @@ def parse_declare(code):
if call.func.id != 'declare':
raise CodeGenerationError('Unknown declare statement: %s' % code)
arg0 = call.args[0]
if not isinstance(arg0, ast.Str):
err = 'Type should be a string, given :%r' % arg0.s
typestr = get_str_value(arg0)
if typestr is None:
err = 'Type should be a string, given :%r' % getattr(arg0, 's', getattr(arg0, 'value', arg0))
raise CodeGenerationError(err)

return get_declare_info(arg0.s)
return get_declare_info(typestr)


class CythonGenerator(object):
Expand Down
19 changes: 15 additions & 4 deletions compyle/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .extern import Extern
from .utils import getsourcelines
from .profile import profile
from .ast_utils import get_str_value

from . import array
from . import parallel
Expand Down Expand Up @@ -198,15 +199,18 @@ def warn(self, message, node):
warnings.warn(msg)

def visit_declare(self, node):
if not isinstance(node.args[0], ast.Str):
arg0 = node.args[0]
type_str = get_str_value(arg0)
if type_str is None:
self.error("Argument to declare should be a string.", node)
type_str = node.args[0].s
return self.get_declare_type(type_str)

def visit_cast(self, node):
if not isinstance(node.args[1], ast.Str):
arg1 = node.args[1]
typestr = get_str_value(arg1)
if typestr is None:
self.error("Cast type should be a string.", node)
return node.args[1].s
return typestr

def visit_address(self, node):
base_type = self.visit(node.args[0])
Expand Down Expand Up @@ -294,6 +298,13 @@ def visit_BinOp(self, node):
def visit_Num(self, node):
return get_ctype_from_arg(node.n)

def visit_Constant(self, node):
val = node.value
if isinstance(val, (int, float)):
return get_ctype_from_arg(val)
# For other constants (e.g., strings/None/bool), we return None
return None

def visit_UnaryOp(self, node):
return self.visit(node.operand)

Expand Down
11 changes: 9 additions & 2 deletions compyle/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .types import kwtype_to_annotation
import mako.template
from .ast_utils import get_str_value


getfullargspec = inspect.getfullargspec
Expand Down Expand Up @@ -45,8 +46,14 @@ def _get_code(self):
args += extra_args
arg_string = ', '.join(args)
body = m.body[0].body
template = body[-1].value.s
docstring = body[0].value.s if len(body) == 2 else ''
# Extract template and docstring in an AST-version-agnostic way
last_val = body[-1].value
template = get_str_value(last_val) or ''

docstring = ''
if len(body) == 2:
first_val = body[0].value
docstring = get_str_value(first_val) or ''
name = self.name
sig = 'def {name}({args}):\n """{docs}\n """'.format(
name=name, args=arg_string, docs=docstring
Expand Down
39 changes: 31 additions & 8 deletions compyle/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CodeGenerationError, KnownType, Undefined, all_numeric
)
from .utils import getsource
from .ast_utils import is_str_node, get_str_value

PY_VER = sys.version_info.major

Expand Down Expand Up @@ -234,11 +235,11 @@ def _indent_block(self, code):
return '\n'.join(pad + x for x in lines)

def _remove_docstring(self, body):
if body and isinstance(body[0], ast.Expr) and \
isinstance(body[0].value, ast.Str):
return body[1:]
else:
return body
if body and isinstance(body[0], ast.Expr):
val = body[0].value
if is_str_node(val):
return body[1:]
return body

def _get_local_info(self, obj):
return None
Expand Down Expand Up @@ -351,9 +352,11 @@ def visit_Assign(self, node):
left, right = node.targets[0], node.value
if isinstance(right, ast.Call) and \
isinstance(right.func, ast.Name) and right.func.id == 'declare':
if not isinstance(right.args[0], ast.Str):
arg0 = right.args[0]
s = get_str_value(arg0)
if s is None:
self.error("Argument to declare should be a string.", node)
type = right.args[0].s
type = s
if isinstance(left, ast.Name):
self._known.add(left.id)
return self._get_variable_declaration(type, [self.visit(left)])
Expand Down Expand Up @@ -395,7 +398,11 @@ def visit_Call(self, node):
elif 'atomic' in node.func.id:
return self.render_atomic(node.func.id, node.args[0])
elif node.func.id == 'cast':
return '(%s) (%s)' % (node.args[1].s, self.visit(node.args[0]))
arg1 = node.args[1]
typestr = get_str_value(arg1)
if typestr is None:
self.error("Argument to cast should be a string.", node)
return '(%s) (%s)' % (typestr, self.visit(node.args[0]))
else:
return '{func}({args})'.format(
func=node.func.id,
Expand Down Expand Up @@ -691,6 +698,22 @@ def visit_NotEq(self, node):
def visit_Num(self, node):
return literal_to_float(node.n, self._use_double)

def visit_Constant(self, node):
val = node.value
# Handle booleans explicitly first
if isinstance(val, bool):
return self._replacements[val]
# Numbers: int/float
if isinstance(val, (int, float)):
return literal_to_float(val, self._use_double)
# Strings
if isinstance(val, str):
return r'"%s"' % val
# None and other constants
if val in self._replacements:
return self._replacements[val]
return repr(val)

def visit_Or(self, node):
return '||'

Expand Down